In [2]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from pathlib import Path

def load_asset(asset):
    freq = '5min'
    DATA_DIR = Path('../dataset')
    path = DATA_DIR / f"{asset}_{freq}.csv"

    df = pd.read_csv(path, index_col=0)
    df['timestamp'] = pd.to_datetime(df["system_time"])
    df['timestamp'] = df['timestamp'].dt.round('min')
    df = df.sort_values('timestamp').set_index('timestamp')

    return df

df_ADA = load_asset("ADA")
df_BTC = load_asset("BTC")
df_ETH = load_asset("ETH")

# Align by timestamp
df = df_ADA[['midpoint']] \
    .join(df_BTC[['midpoint']], rsuffix='_BTC') \
    .join(df_ETH[['midpoint']], rsuffix='_ETH')

# You will add features next
class CryptoDataset(Dataset):
    def __init__(self, dataframe, transform=None, pre_transform=None):
        self.dataframe = dataframe
        super().__init__(None, transform, pre_transform)

    def len(self):
        return len(self.dataframe) - 1

    def get(self, idx):
        row = self.dataframe.iloc[idx]
        next_row = self.dataframe.iloc[idx + 1]

        x = torch.tensor(row.values, dtype=torch.float).unsqueeze(0)
        y = torch.tensor(next_row.values, dtype=torch.float).unsqueeze(0)

        edge_index = torch.tensor([[0], [0]], dtype=torch.long)

        data = Data(x=x, y=y, edge_index=edge_index)
        return data

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def add_features(df):
    df = df.copy()
    
    df["imbalance_trades"] = (df["buys"] - df["sells"]) / (df["buys"] + df["sells"] + 1e-6)
    
    df["bid_liq_near"] = df[[f"bids_notional_{i}" for i in range(3)]].sum(axis=1)
    df["ask_liq_near"] = df[[f"asks_notional_{i}" for i in range(3)]].sum(axis=1)
    
    df["lob_imbalance"] = (df["bid_liq_near"] - df["ask_liq_near"]) / \
                          (df["bid_liq_near"] + df["ask_liq_near"] + 1e-6)
    
    keep = [
        "midpoint", "spread", "buys", "sells",
        "imbalance_trades",
        "bid_liq_near", "ask_liq_near", "lob_imbalance",
        "bids_distance_0", "bids_distance_1", "bids_distance_2",
        "asks_distance_0", "asks_distance_1", "asks_distance_2",
    ]
    
    return df[keep]


In [4]:
class MultiAssetDataset(Dataset):
    def __init__(self, df_ADA, df_BTC, df_ETH, window=24):
        super().__init__()
        self.window = window
        
        A = add_features(df_ADA)
        B = add_features(df_BTC)
        C = add_features(df_ETH)
        
        merged = A.join(B, rsuffix="_BTC").join(C, rsuffix="_ETH")
        self.features = merged.dropna()
        
        midpoint = self.features["midpoint"]
        self.y = (midpoint.shift(-1) > midpoint).astype(int)
        self.y = self.y.iloc[window:]
        
        self.features = self.features.iloc[:-1]  # align

    def len(self):
        return len(self.y)

    def get(self, idx):
        x_window = self.features.iloc[idx:idx+self.window]
        
        # for each timestep: 3 nodes Ã— feature_dim
        node_feats = []
        for t in range(self.window):
            row = x_window.iloc[t]
            ada = row[[c for c in row.index if not c.endswith("_BTC") and not c.endswith("_ETH")]].values
            btc = row[[c for c in row.index if c.endswith("_BTC")]].values
            eth = row[[c for c in row.index if c.endswith("_ETH")]].values
            node_feats.append(np.vstack([ada, btc, eth]))
        
        x = torch.tensor(np.array(node_feats), dtype=torch.float)

        edge_index = torch.tensor([[0,1,0,2,1,2],[1,0,2,0,2,1]], dtype=torch.long)

        return Data(
            x=x, 
            edge_index=edge_index, 
            y=torch.tensor(self.y.iloc[idx], dtype=torch.long)
        )


In [5]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv

class TemporalGNN(nn.Module):
    def __init__(self, node_features, gnn_hidden=64, lstm_hidden=64):
        super().__init__()
        
        self.gnn = GATv2Conv(node_features, gnn_hidden, heads=2, concat=False)
        self.lstm = nn.LSTM(gnn_hidden*3, lstm_hidden, batch_first=True)
        self.out = nn.Linear(lstm_hidden, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # x: [T, 3, F]
        T = x.size(0)
        
        gnn_outputs = []
        for t in range(T):
            h = self.gnn(x[t], edge_index)  # [3, hidden]
            gnn_outputs.append(h.reshape(1, -1))
        
        h = torch.cat(gnn_outputs, dim=0).unsqueeze(0)  # [1, T, 3*hidden]
        out, _ = self.lstm(h)
        final = out[:, -1, :]
        
        return self.out(final)


In [7]:
dataset = MultiAssetDataset(df_ADA, df_BTC, df_ETH)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = TemporalGNN(node_features=14)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    total = 0
    for batch in loader:
        optim.zero_grad()
        pred = model(batch)
        loss = criterion(pred, batch.y)
        loss.backward()
        optim.step()
        total += loss.item()
    print("Epoch", epoch, "loss", total / len(loader))


Epoch 0 loss 0.7007795692650144
Epoch 1 loss 0.6970860327343027
Epoch 2 loss 0.6957824012119613
Epoch 3 loss 0.6954976568678896
Epoch 4 loss 0.6972960884431879
Epoch 5 loss 0.6973480282161764
Epoch 6 loss 0.697908770163616


KeyboardInterrupt: 

In [16]:
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal
import numpy as np

def build_temporal_graph(df_ADA, df_BTC, df_ETH, window=24):
    A = add_features(df_ADA)
    B = add_features(df_BTC)
    C = add_features(df_ETH)

    merged = A.join(B, rsuffix="_BTC").join(C, rsuffix="_ETH").dropna()

    # labels
    midpoint = merged["midpoint"]
    y = (midpoint.shift(-1) > midpoint).astype(int).dropna()
    merged = merged.iloc[:-1]

    X = []
    Y = []
    E = []
    EI = []

    # fixed fully connected graph
    edge_index = np.array([[0,1,0,2,1,2],[1,0,2,0,2,1]])

    for i in range(len(merged)-window):
        block = merged.iloc[i:i+window]
        # create node features per timestep
        X_t = []
        for t in range(window):
            row = block.iloc[t]

            ada = row[[c for c in block.columns if not (c.endswith("_BTC") or c.endswith("_ETH"))]].values
            btc = row[[c for c in block.columns if c.endswith("_BTC")]].values
            eth = row[[c for c in block.columns if c.endswith("_ETH")]].values

            X_t.append(np.vstack([ada, btc, eth]))  # shape: [3, F]

        X.append(np.array(X_t))          # [T, 3, F]
        EI.append(edge_index)            # constant edges
        Y.append(y.iloc[i+window-1])

    return DynamicGraphTemporalSignal(
        edge_indices=EI,
        edge_weights=[None]*len(EI),
        features=X,
        targets=np.array(Y)
    )


In [21]:
import torch
import torch.nn as nn
from torch_geometric_temporal.nn import A3TGCN

class PriceDirectionTGAT(nn.Module):
    def __init__(self, node_features, out_channels=32, periods=24):
        super().__init__()
        self.tgat = A3TGCN(in_channels=node_features, out_channels=out_channels, periods=periods)
        self.fc = nn.Linear(out_channels, 2)

    def forward(self, x, edge_index):
        # x shape: [T, num_nodes, F]
        h = self.tgat(x, edge_index)  # output per timestep: [num_nodes, out_channels]
        h_last = h[-1].reshape(-1)    # flatten nodes at last step
        return self.fc(h_last)


In [22]:
from torch_geometric_temporal.signal import temporal_signal_split

dataset = build_temporal_graph(df_ADA, df_BTC, df_ETH, window=24)
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

model = PriceDirectionTGAT(node_features=14)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    losses = []
    for snapshot in train_dataset:
        x = torch.tensor(snapshot.x, dtype=torch.float)
        edge_idx = torch.tensor(snapshot.edge_index, dtype=torch.long)
        y = torch.tensor([snapshot.y], dtype=torch.long)

        optimizer.zero_grad()
        pred = model(x, edge_idx)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
    print(f"Epoch {epoch} | loss={np.mean(losses):.4f}")


  x = torch.tensor(snapshot.x, dtype=torch.float)
  edge_idx = torch.tensor(snapshot.edge_index, dtype=torch.long)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x3 and 14x32)