# GNN + LSTM Tutorial on Synthetic Financial Graph Data

**Objective:** Modeling 3 assets (A, B, C) with dynamic graph structure and predicting the next 5-minute direction of Asset A (Up/Down).

---

## üìå –ß–∞—Å—Ç—å 1. –ò–º–ø–æ—Ä—Ç –±–∏–±–ª–∏–æ—Ç–µ–∫

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# –î–ª—è –≥—Ä–∞—Ñ–æ–≤ (–ø—Ä–æ—Å—Ç–∞—è —Ä–µ–∞–ª–∏–∑–∞—Ü–∏—è –±–µ–∑ PyTorch Geometric)
# –ú—ã —Å–∞–º–∏ –±—É–¥–µ–º –æ–±—Ä–∞–±–∞—Ç—ã–≤–∞—Ç—å edge_index + aggregation


## üìå –ß–∞—Å—Ç—å 2. –ì–µ–Ω–µ—Ä–∞—Ü–∏—è —Å–∏–Ω—Ç–µ—Ç–∏—á–µ—Å–∫–∏—Ö –¥–∞–Ω–Ω—ã—Ö

In [2]:
"""
–ú—ã —Å–≥–µ–Ω–µ—Ä–∏—Ä—É–µ–º —Å–∏–Ω—Ç–µ—Ç–∏—á–µ—Å–∫–∏–µ –≤—Ä–µ–º–µ–Ω–Ω—ã–µ —Ä—è–¥—ã –∞–∫—Ç–∏–≤–æ–≤ A, B, C:
price_A, price_B, price_C ‚Äî —Å–ª—É—á–∞–π–Ω—ã–µ random walk
volume ‚Äî —Å–ª—É—á–∞–π–Ω–æ, —Å –Ω–µ–±–æ–ª—å—à–∏–º —Ç—Ä–µ–Ω–¥–æ–º
correlations ‚Äî –∏–º–∏—Ç–∏—Ä—É–µ–º –∑–∞–≤–∏—Å–∏–º–æ—Å—Ç—å –º–µ–∂–¥—É –∞–∫—Ç–∏–≤–∞–º–∏
direction ‚Äî –∑–Ω–∞–∫ —Ä–∞–∑–Ω–∏—Ü—ã —Ü–µ–Ω
–î–∞–Ω–Ω—ã–µ –∫–∞–∂–¥—ã–µ 5 –º–∏–Ω—É—Ç.
"""
np.random.seed(42)
T = 2000  # –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ —Ç–∞–π–º—Å—Ç–µ–ø–æ–≤

# --- –°–∏–Ω—Ç–µ—Ç–∏—á–µ—Å–∫–∏–µ —Ü–µ–Ω—ã –∫–∞–∫ random walk ---
price_A = np.cumsum(np.random.normal(0, 0.5, T)) + 100
price_B = np.cumsum(np.random.normal(0, 0.4, T)) + 80
price_C = np.cumsum(np.random.normal(0, 0.3, T)) + 120

# --- Volume ---
vol_A = np.abs(np.random.normal(50, 10, T))
vol_B = np.abs(np.random.normal(60, 12, T))
vol_C = np.abs(np.random.normal(55, 15, T))

# --- –ö–æ—Ä—Ä–µ–ª—è—Ü–∏–∏ (–ø—Ä–æ—Å—Ç–∞—è –∏–º–∏—Ç–∞—Ü–∏—è) ---
corr_AB = 0.5 + 0.1 * np.sin(np.linspace(0, 20, T))
corr_AC = 0.3 + 0.05 * np.cos(np.linspace(0, 15, T))
corr_CB = 0.4 + 0.08 * np.sin(np.linspace(0, 10, T))

# --- Log-return correlation: –∏–º–∏—Ç–∏—Ä—É–µ–º –∫–∞–∫ –Ω–µ–±–æ–ª—å—à–æ–µ —à—É–º–Ω–æ–µ –æ—Ç–∫–ª–æ–Ω–µ–Ω–∏–µ ---
corr_lr_AB = corr_AB + np.random.normal(0, 0.02, T)
corr_lr_AC = corr_AC + np.random.normal(0, 0.02, T)
corr_lr_CB = corr_CB + np.random.normal(0, 0.02, T)

# --- Direction ---
# –ò—Å–ø–æ–ª—å–∑—É–µ–º price_A[0] –≤ prepend –¥–ª—è —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏—è —Ä–∞–∑–º–µ—Ä–∞ T
direction_A = np.sign(np.diff(price_A, prepend=price_A[0]))
direction_B = np.sign(np.diff(price_B, prepend=price_B[0]))
direction_C = np.sign(np.diff(price_C, prepend=price_C[0]))

# --- DataFrame ---
df = pd.DataFrame({
    "price_A": price_A,
    "price_B": price_B,
    "price_C": price_C,
    "vol_A": vol_A,
    "vol_B": vol_B,
    "vol_C": vol_C,
    "corr_AB": corr_AB,
    "corr_AC": corr_AC,
    "corr_CB": corr_CB,
    "corr_lr_AB": corr_lr_AB,
    "corr_lr_AC": corr_lr_AC,
    "corr_lr_CB": corr_lr_CB,
    "dir_A": direction_A,
    "dir_B": direction_B,
    "dir_C": direction_C,
})

print("Shape:", df.shape)
df.head()

Shape: (2000, 15)


Unnamed: 0,price_A,price_B,price_C,vol_A,vol_B,vol_C,corr_AB,corr_AC,corr_CB,corr_lr_AB,corr_lr_AC,corr_lr_CB,dir_A,dir_B,dir_C
0,100.248357,79.729929,119.740952,38.859186,59.603697,44.822579,0.5,0.35,0.4,0.523496,0.345515,0.417586,0.0,0.0,0.0
1,100.179225,79.672121,119.731591,43.690692,53.956197,50.417508,0.501,0.349999,0.4004,0.463421,0.345782,0.401186,-1.0,-1.0,-1.0
2,100.503069,79.355153,119.736996,40.579398,57.931501,46.039284,0.502001,0.349994,0.4008,0.495445,0.394703,0.382554,1.0,-1.0,1.0
3,101.264584,79.231969,119.878785,44.520042,68.576781,56.656271,0.503001,0.349987,0.401201,0.502168,0.366174,0.391951,1.0,-1.0,1.0
4,101.147507,78.474523,119.468728,47.858497,75.33428,72.957678,0.504001,0.349977,0.401601,0.504319,0.39959,0.408458,-1.0,-1.0,-1.0


## üìå –ß–∞—Å—Ç—å 3. –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –¥–∞–Ω–Ω—ã—Ö –≤ —Ñ–æ—Ä–º–∞—Ç –≥—Ä–∞—Ñ–æ–≤

In [3]:
"""
–£ –Ω–∞—Å –≥—Ä–∞—Ñ –∏–∑ 3 —É–∑–ª–æ–≤: A(0), B(1), C(2).
Adjacency (edge_index) - –ù–µ–æ—Ä–∏–µ–Ω—Ç–∏—Ä–æ–≤–∞–Ω–Ω—ã–π –ø–æ–ª–Ω—ã–π –≥—Ä–∞—Ñ:
0 ‚Üî 1 (A-B)
0 ‚Üî 2 (A-C)
1 ‚Üî 2 (B-C)
"""
edge_index = torch.tensor([
    [0, 0, 1, 1, 2, 2],  # source
    [1, 2, 0, 2, 0, 1]   # target
], dtype=torch.long)

print("edge_index:\n", edge_index)

"""
Node feature vector size = 7 –¥–ª—è –∫–∞–∂–¥–æ–≥–æ –∞–∫—Ç–∏–≤–∞:
price, volume, dir, correlations to others
–ú—ã –±—É–¥–µ–º —Å–æ–±–∏—Ä–∞—Ç—å node feature vector –∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–∏ –≤ Dataset()
"""
IN_DIM = 7

edge_index:
 tensor([[0, 0, 1, 1, 2, 2],
        [1, 2, 0, 2, 0, 1]])


## üìå –ß–∞—Å—Ç—å 4. Dataset –¥–ª—è GNN + LSTM

In [4]:
class GraphWindowDataset(Dataset):
    """
    Creates sliding windows of graph snapshots:
    Input: window of size K (sequence of K graph features)
    Target: direction of A at time t+1 (next_dir = self.df.iloc[idx + self.window].dir_A)
    """
    def __init__(self, df, window=12):
        super().__init__()
        self.df = df
        self.window = window
        self.N = len(df)

    def __len__(self):
        # N - window - 1 to ensure target at t+window exists (needs index t+window)
        return self.N - self.window - 1

    def __getitem__(self, idx):
        """
        –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç—å –æ–∫–æ–Ω –≥—Ä–∞—Ñ–∞:
        x_seq: list of K [3 x feature_dim] tensors
        """
        x_seq = []
        for t in range(idx, idx + self.window):
            row = self.df.iloc[t]

            # Node features per asset (size 7)
            # A: price, vol, dir, corr_AB, corr_AC, corr_lr_AB, corr_lr_AC
            feat_A = [row.price_A, row.vol_A, row.dir_A,
                      row.corr_AB, row.corr_AC,
                      row.corr_lr_AB, row.corr_lr_AC]

            # B: price, vol, dir, corr_AB, corr_CB, corr_lr_AB, corr_lr_CB
            feat_B = [row.price_B, row.vol_B, row.dir_B,
                      row.corr_AB, row.corr_CB,
                      row.corr_lr_AB, row.corr_lr_CB]

            # C: price, vol, dir, corr_AC, corr_CB, corr_lr_AC, corr_lr_CB
            feat_C = [row.price_C, row.vol_C, row.dir_C,
                      row.corr_AC, row.corr_CB,
                      row.corr_lr_AC, row.corr_lr_CB]

            x_t = torch.tensor([feat_A, feat_B, feat_C], dtype=torch.float32)
            x_seq.append(x_t)

        # Target = direction of A at t+window (class: 0=down/zero, 1=up)
        next_dir = self.df.iloc[idx + self.window].dir_A
        target = 1 if next_dir > 0 else 0

        return x_seq, target

## üìå –ß–∞—Å—Ç—å 5. Train/Val/Test split –ø–æ –≤—Ä–µ–º–µ–Ω–∏

In [5]:
train_len = int(0.7 * len(df))
val_len = int(0.15 * len(df))
test_len = len(df) - train_len - val_len

train_df = df.iloc[:train_len]
val_df = df.iloc[train_len : train_len + val_len]
test_df = df.iloc[train_len + val_len :]

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")
print(f"Test size: {len(test_df)}")

window = 12

train_dataset = GraphWindowDataset(train_df, window)
val_dataset   = GraphWindowDataset(val_df, window)
test_dataset  = GraphWindowDataset(test_df, window)

print(f"Train dataset samples: {len(train_dataset)}")
print(f"Validation dataset samples: {len(val_dataset)}")
print(f"Test dataset samples: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_dataset, batch_size=32)
test_loader  = DataLoader(test_dataset, batch_size=32)

Train size: 1400
Validation size: 300
Test size: 300
Train dataset samples: 1387
Validation dataset samples: 287
Test dataset samples: 287


## üìå –ß–∞—Å—Ç—å 6. –ü—Ä–æ—Å—Ç–∞—è GNN-–∞–≥—Ä–µ–≥–∞—Ü–∏—è (mean aggregation)

In [6]:
class SimpleGNNLayer(nn.Module):
    """
    Implements a basic message-passing layer:
    - mean aggregation of neighbor features
    - concatenation of node + aggregated
    - linear projection + ReLU
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # Input: [node_features, aggregated_neighbor_features] -> size in_dim * 2
        self.lin = nn.Linear(in_dim * 2, out_dim)

    def forward(self, x, edge_index):
        # x: [3, in_dim]
        row, col = edge_index  # row=source, col=target
        num_nodes = x.size(0)

        # Initialize aggregator
        agg = torch.zeros_like(x)

        # Compute messages (Mean Aggregation)
        for v in range(num_nodes):
            # Find indices where 'v' is the target node
            incoming_edges_indices = (col == v).nonzero(as_tuple=True)[0]

            if len(incoming_edges_indices) > 0:
                # Get features of source nodes (neighbors)
                neighbor_features = x[row[incoming_edges_indices]]
                # Compute mean
                agg[v] = neighbor_features.mean(dim=0)

        # Update step (Concatenate node features with aggregated features)
        new_x = torch.cat([x, agg], dim=1)
        
        # Linear projection + Activation
        return F.relu(self.lin(new_x))

## üìå –ß–∞—Å—Ç—å 7. –ú–æ–¥–µ–ª—å GNN + LSTM

In [7]:
class GNN_LSTM_model(nn.Module):
    """
    1) For each time step in window: run GNN, produce node embeddings z_t
    2) Run LSTM over time **only for node A (index 0)**.
    3) Classification head for final state of node A's LSTM.
    """
    def __init__(self, in_dim, gnn_dim, lstm_dim):
        super().__init__()
        self.gnn1 = SimpleGNNLayer(in_dim, gnn_dim)
        self.gnn2 = SimpleGNNLayer(gnn_dim, gnn_dim)

        # LSTM for node A separately
        self.lstm = nn.LSTMCell(gnn_dim, lstm_dim)

        self.classifier = nn.Linear(lstm_dim, 2) # 2 classes: 0 (down/zero), 1 (up)

    def forward(self, x_seq, edge_index):
        # x_seq: list length K (window size) of shape [3, in_dim]

        # Initialize hidden states
        h = None
        c = None
        
        # Process time sequence
        for x_t in x_seq:
            # GNN layers (same weights across time)
            z = self.gnn1(x_t, edge_index)
            z = self.gnn2(z, edge_index)

            # Get embedding for node A (index 0)
            node_A_embed = z[0] # shape: [gnn_dim]

            # Initialize h, c dynamically if first step
            if h is None:
                # Initialize with zeros, same shape as node_A_embed or lstm_dim
                h = torch.zeros(self.lstm.hidden_size, device=node_A_embed.device)
                c = torch.zeros(self.lstm.hidden_size, device=node_A_embed.device)

            # Update LSTM state for node A
            h, c = self.lstm(node_A_embed, (h, c))

        # Final classification based on the last hidden state h
        logits = self.classifier(h)
        return logits

## üìå –ß–∞—Å—Ç—å 8. –û–±—É—á–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏

In [8]:
IN_DIM = 7   # –∫–æ–ª-–≤–æ –ø—Ä–∏–∑–Ω–∞–∫–æ–≤ –¥–ª—è —É–∑–ª–∞
gnn_dim = 16
lstm_dim = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GNN_LSTM_model(IN_DIM, gnn_dim, lstm_dim).to(device)
edge_index = edge_index.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

print(f"Using device: {device}")
print(model)

# --- Training loop ---

def run_epoch(loader):
    model.train()
    total_loss = 0
    total_correct = 0
    total = 0
    
    for x_seq_batch, target_batch in loader:
        # x_seq_batch: list of len window, containing tensors of shape [batch_size, 3, in_dim]
        optimizer.zero_grad()
        
        batch_size = len(target_batch)
        
        # Because x_seq is a list, we must iterate over the batch to compute logits
        batch_logits = []
        for i in range(batch_size):
            # x_seq[i] is a list of K tensors, each [3, in_dim]
            # Need to re-format x_seq_batch to get the sequence for i-th sample
            x_seq_sample = [x_t[i].to(device) for x_t in x_seq_batch]
            logits = model(x_seq_sample, edge_index)
            batch_logits.append(logits.unsqueeze(0))

        batch_logits = torch.cat(batch_logits, dim=0)
        target_batch = target_batch.to(device)
        
        loss = criterion(batch_logits, target_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_size
        pred = batch_logits.argmax(dim=1)
        total_correct += (pred == target_batch).sum().item()
        total += batch_size

    return total_loss / total, total_correct / total

def evaluate(loader):
    model.eval() # Set model to evaluation mode
    total_correct = 0
    total = 0
    
    with torch.no_grad():
        for x_seq_batch, target_batch in loader:
            batch_size = len(target_batch)
            
            for i in range(batch_size):
                # Extract sample sequence and target
                x_seq_sample = [x_t[i].to(device) for x_t in x_seq_batch]
                target = target_batch[i].item()
                
                logits = model(x_seq_sample, edge_index)
                pred = logits.argmax().item()
                
                if pred == target:
                    total_correct += 1
                total += 1

    return total_correct / total

# --- Train ---

EPOCHS = 8

for epoch in range(EPOCHS):
    loss, acc = run_epoch(train_loader)
    val_acc = evaluate(val_loader)

    print(f"Epoch {epoch+1:02d} | Loss={loss:.4f} | Train Acc={acc:.3f} | Val Acc={val_acc:.3f}")

print("\n--- Testing ---")
test_acc = evaluate(test_loader)

print("Final Test Accuracy:", f"{test_acc:.4f}")

Using device: cpu
GNN_LSTM_model(
  (gnn1): SimpleGNNLayer(
    (lin): Linear(in_features=14, out_features=16, bias=True)
  )
  (gnn2): SimpleGNNLayer(
    (lin): Linear(in_features=32, out_features=16, bias=True)
  )
  (lstm): LSTMCell(16, 16)
  (classifier): Linear(in_features=16, out_features=2, bias=True)
)
Epoch 01 | Loss=0.6939 | Train Acc=0.519 | Val Acc=0.537
Epoch 02 | Loss=0.6928 | Train Acc=0.505 | Val Acc=0.523
Epoch 03 | Loss=0.6938 | Train Acc=0.517 | Val Acc=0.523
Epoch 04 | Loss=0.6933 | Train Acc=0.501 | Val Acc=0.523
Epoch 05 | Loss=0.6931 | Train Acc=0.505 | Val Acc=0.519
Epoch 06 | Loss=0.6925 | Train Acc=0.515 | Val Acc=0.516
Epoch 07 | Loss=0.6926 | Train Acc=0.517 | Val Acc=0.523
Epoch 08 | Loss=0.6926 | Train Acc=0.502 | Val Acc=0.523

--- Testing ---
Final Test Accuracy: 0.5017
