In [1]:
import pandas as pd
import torch
from torch_frame import TensorFrame, stype
from torch_frame.nn import (
    StypeWiseFeatureEncoder,
    EmbeddingEncoder,
    LinearBucketEncoder,
)
from torch_frame.data import Dataset
from torch.nn import LayerNorm
import torch.nn.functional as F

Let's start by creating initial embeddings for our top 150 players using Torch Frame

In [2]:
# load our players data in from the CSV
players = pd.read_csv("../data/player_features.csv")

In [3]:
# channels controls the dimension size each column will have for our player rows after encoding
channels = 128

# set the stypes for each column in our data
col_to_stype = {
    "player_id": stype.numerical,
    "current_rank": stype.numerical,
    "hand": stype.categorical,
    "dob": stype.numerical,
    "height": stype.numerical,
    "country_num": stype.categorical
}

# 2) Build a Dataset and materialize it -> computes col_stats and a TensorFrame
ds = Dataset(df=players, col_to_stype=col_to_stype).materialize()
tf_players = ds.tensor_frame
col_stats  = ds.col_stats
col_names_dict = tf_players.col_names_dict

# 3) Create the stype-wise encoder with the computed stats
stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical:  LinearBucketEncoder(post_module=LayerNorm(channels)),
}

encoder = StypeWiseFeatureEncoder(
    out_channels=channels,
    col_stats=col_stats,
    col_names_dict=col_names_dict,
    stype_encoder_dict=stype_encoder_dict,
)

# 4) Encode
x, _meta = encoder(tf_players)  # x: [batch, num_cols, channels]

player_emb = x.mean(dim=1) # simple average pooling over columns for now. we can get fancier later on

Now that we have our initial player embeddings, we can grab our edges to make our graph

In [4]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

In [5]:
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.data import Data

edges = pd.read_csv("../data/edges.csv")

# --- node indices ---
src = torch.from_numpy(edges["winner_idx"].to_numpy()).long()
dst = torch.from_numpy(edges["loser_idx"].to_numpy()).long()

# --- categorical encodings ---

# surface is already 0/1/2 from SURFACE_MAP -> just one-hot
surface = torch.from_numpy(edges["surface"].to_numpy()).long()
surface_oh = F.one_hot(surface, num_classes=3).float()           # [E, 3]

# map tourney_level (e.g., 'A', 'M', 'G', ...) to ints then one-hot
lvl_values = sorted(edges["tourney_level"].unique())
lvl2id = {lvl: i for i, lvl in enumerate(lvl_values)}
tourney_level_idx = torch.from_numpy(edges["tourney_level"].map(lvl2id).to_numpy()).long()
tourney_level_oh = F.one_hot(tourney_level_idx, num_classes=len(lvl2id)).float()  # [E, L]

# map round ('RR', 'R32', 'R16', 'QF', 'SF', 'F', ...) to ints then one-hot
rnd_values = sorted(edges["round"].unique())
rnd2id = {rnd: i for i, rnd in enumerate(rnd_values)}
round_idx = torch.from_numpy(edges["round"].map(rnd2id).to_numpy()).long()
round_oh = F.one_hot(round_idx, num_classes=len(rnd2id)).float()  # [E, R]

# --- numerical features: best_of, days_ago ---

best_of = torch.from_numpy(edges["best_of"].to_numpy()).float().unsqueeze(1)  # [E, 1]
# (optional) normalize best_of; not strictly necessary, but harmless
best_of = (best_of - best_of.mean()) / (best_of.std() + 1e-6)

days = torch.from_numpy(edges["days_ago"].to_numpy()).float().unsqueeze(1)    # [E, 1]
days = (days - days.mean()) / (days.std() + 1e-6)

# --- full edge_attr: concat everything ---
# shape: [E, 3 + L + R + 1 + 1]
edge_attr = torch.cat(
    [surface_oh, tourney_level_oh, round_oh, best_of, days],
    dim=1
)

E = edge_attr.size(0)

# --- build bidirectional edges ---

# original direction: winner -> loser
edge_index_fwd = torch.stack([src, dst], dim=0)  # [2, E]
edge_attr_fwd = edge_attr

# reverse: loser -> winner
edge_index_rev = torch.stack([dst, src], dim=0)  # [2, E]
edge_attr_rev = edge_attr.clone()                # same match-level features

# concat both directions
edge_index = torch.cat([edge_index_fwd, edge_index_rev], dim=1)   # [2, 2E]
edge_attr_bidir = torch.cat([edge_attr_fwd, edge_attr_rev], dim=0)  # [2E, feat_dim]

# direction/type: 0 = "won-against" (winner->loser), 1 = "lost-to" (loser->winner)
edge_type = torch.cat([
    torch.zeros(E, dtype=torch.long),
    torch.ones(E, dtype=torch.long)
], dim=0)  # [2E]

# --- split masks (per *original* edge), then expanded to bidirectional ---

split_map = {"train": 0, "val": 1, "test": 2}
edge_split = torch.from_numpy(edges["split"].map(split_map).to_numpy()).long()  # [E]

train_mask = edge_split == 0
val_mask   = edge_split == 1
test_mask  = edge_split == 2

# for the bidirectional edges, just duplicate the masks
train_mask_bidir = torch.cat([train_mask, train_mask], dim=0)  # [2E]
val_mask_bidir   = torch.cat([val_mask,   val_mask],   dim=0)
test_mask_bidir  = torch.cat([test_mask,  test_mask],  dim=0)

# --- final PyG graph (message-passing graph; you can subset edge_index per phase) ---

g = Data(
    x=player_emb.detach(),                  # [N, C] from TorchFrame encoder
    edge_index=edge_index,         # [2, 2E]
    edge_attr=edge_attr_bidir,     # [2E, F]
    edge_type=edge_type,           # [2E]
    train_mask=train_mask_bidir,   # [2E]
    val_mask=val_mask_bidir,       # [2E]
    test_mask=test_mask_bidir,     # [2E]
)

Before moving on, let's do a quick test to see how we preform before we use a GNN, let's get a baseline

In [6]:
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

def build_match_tensors(edges_df, z, split: str):
    df = edges_df[edges_df["split"] == split].reset_index(drop=True)

    w_idx = torch.tensor(df["winner_idx"].to_numpy(), dtype=torch.long)
    l_idx = torch.tensor(df["loser_idx"].to_numpy(), dtype=torch.long)

    # --- player embeddings ---
    z_w = z[w_idx]  # [M, D]
    z_l = z[l_idx]  # [M, D]

    # --- surface one-hot ---
    surface = torch.tensor(df["surface"].to_numpy(), dtype=torch.long)
    surface_oh = F.one_hot(surface, num_classes=3).float()  # [M, 3]

    # --- tourney_level one-hot (reuse same mapping as for edge_attr) ---
    lvl_values = sorted(edges_df["tourney_level"].unique())
    lvl2id = {lvl: i for i, lvl in enumerate(lvl_values)}
    lvl_idx = torch.tensor(df["tourney_level"].map(lvl2id).to_numpy(), dtype=torch.long)
    lvl_oh = F.one_hot(lvl_idx, num_classes=len(lvl2id)).float()  # [M, L]

    # --- round one-hot ---
    rnd_values = sorted(edges_df["round"].unique())
    rnd2id = {rnd: i for i, rnd in enumerate(rnd_values)}
    rnd_idx = torch.tensor(df["round"].map(rnd2id).to_numpy(), dtype=torch.long)
    rnd_oh = F.one_hot(rnd_idx, num_classes=len(rnd2id)).float()  # [M, R]

    # --- best_of (numeric, normalized) ---
    best_of = torch.tensor(df["best_of"].to_numpy(), dtype=torch.float32).unsqueeze(1)
    best_of = (best_of - best_of.mean()) / (best_of.std() + 1e-6)

    # --- days_ago (numeric, normalized) ---
    days = torch.tensor(df["days_ago"].to_numpy(), dtype=torch.float32).unsqueeze(1)
    days = (days - days.mean()) / (days.std() + 1e-6)

    # --- combine match-level features ---
    match_feat = torch.cat(
        [surface_oh, lvl_oh, rnd_oh, best_of, days], dim=1
    )  # [M, F_match]

    # --- construct positive & negative examples ---
    # positive: (winner, loser) → label 1
    X_pos = torch.cat([z_w, z_l, match_feat], dim=1)
    y_pos = torch.ones(X_pos.size(0), dtype=torch.long)

    # negative: (loser, winner) → label 0
    X_neg = torch.cat([z_l, z_w, match_feat], dim=1)
    y_neg = torch.zeros(X_neg.size(0), dtype=torch.long)

    X = torch.cat([X_pos, X_neg], dim=0)  # [2M, 2D + F_match]
    y = torch.cat([y_pos, y_neg], dim=0)  # [2M]

    return X, y


In [7]:
edges = pd.read_csv("../data/edges.csv")  # same DataFrame used to build g

player_emb_detached = player_emb.detach()
X_train, y_train = build_match_tensors(edges, player_emb_detached, split="train")
X_val,   y_val   = build_match_tensors(edges, player_emb_detached, split="val")
X_test,  y_test  = build_match_tensors(edges, player_emb_detached, split="test")

In [8]:
from torch.utils.data import Dataset, DataLoader

class MatchDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return self.X.size(0)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = MatchDataset(X_train, y_train)
val_ds   = MatchDataset(X_val,   y_val)
test_ds  = MatchDataset(X_test,  y_test)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)  
val_loader   = DataLoader(val_ds,   batch_size=512, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=512, shuffle=True)


In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# ---- tiny prediction head ----

class MatchMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),  # 2 classes: 0 = (loser,winner), 1 = (winner,loser)
        )

    def forward(self, X):
        return self.net(X)  # returns logits [B, 2]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

in_dim = X_train.size(1)
model = MatchMLP(in_dim=in_dim, hidden_dim=128).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def eval_loader(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for Xb, yb in loader:
            Xb = Xb.to(device)
            yb = yb.to(device)
            logits = model(Xb)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return correct / total if total > 0 else 0.0

num_epochs = 20  # bump this as needed

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_examples = 0

    for Xb, yb in train_loader:
        Xb = Xb.to(device)
        yb = yb.to(device)

        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * Xb.size(0)
        total_examples += Xb.size(0)

    avg_loss = total_loss / total_examples
    val_acc = eval_loader(val_loader, model)
    print(f"Epoch {epoch+1}: loss={avg_loss:.4f}, val_acc={val_acc:.4f}")

test_acc = eval_loader(test_loader, model)
print("Final test accuracy:", test_acc)


Epoch 1: loss=0.6878, val_acc=0.5880
Epoch 2: loss=0.6781, val_acc=0.6106
Epoch 3: loss=0.6689, val_acc=0.6242
Epoch 4: loss=0.6598, val_acc=0.6309
Epoch 5: loss=0.6518, val_acc=0.6388
Epoch 6: loss=0.6450, val_acc=0.6377
Epoch 7: loss=0.6400, val_acc=0.6371
Epoch 8: loss=0.6356, val_acc=0.6337
Epoch 9: loss=0.6321, val_acc=0.6332
Epoch 10: loss=0.6291, val_acc=0.6270
Epoch 11: loss=0.6272, val_acc=0.6292
Epoch 12: loss=0.6251, val_acc=0.6247
Epoch 13: loss=0.6235, val_acc=0.6264
Epoch 14: loss=0.6221, val_acc=0.6253
Epoch 15: loss=0.6206, val_acc=0.6298
Epoch 16: loss=0.6203, val_acc=0.6281
Epoch 17: loss=0.6192, val_acc=0.6230
Epoch 18: loss=0.6180, val_acc=0.6247
Epoch 19: loss=0.6171, val_acc=0.6281
Epoch 20: loss=0.6166, val_acc=0.6225
Final test accuracy: 0.628668171557562


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv  

class TwoWayEdgeAwareGAT(nn.Module):
    def __init__(self, in_ch, hidden, out_ch, edge_dim):
        super().__init__()
        # separate params per direction
        self.win1 = GATv2Conv(in_ch, hidden, edge_dim=edge_dim, add_self_loops=True)
        self.los1 = GATv2Conv(in_ch, hidden, edge_dim=edge_dim, add_self_loops=True)
        self.win2 = GATv2Conv(hidden, out_ch, edge_dim=edge_dim, add_self_loops=True)
        self.los2 = GATv2Conv(hidden, out_ch, edge_dim=edge_dim, add_self_loops=True)
        self.combine = nn.Linear(out_ch * 2, out_ch)  # concat → linear

    def forward(self, x, edge_index, edge_attr, edge_type):
        # split edges by direction
        idx_win = (edge_type == 0).nonzero(as_tuple=False).view(-1)
        idx_los = (edge_type == 1).nonzero(as_tuple=False).view(-1)

        ei_win = edge_index[:, idx_win]
        ei_los = edge_index[:, idx_los]
        ea_win = edge_attr[idx_win]
        ea_los = edge_attr[idx_los]

        z_win = F.relu(self.win1(x, ei_win, ea_win))
        z_los = F.relu(self.los1(x, ei_los, ea_los))
        z_win = self.win2(z_win, ei_win, ea_win)
        z_los = self.los2(z_los, ei_los, ea_los)

        # combine the two message streams
        z = torch.cat([z_win, z_los], dim=-1)   # or: z = z_win + z_los
        z = self.combine(z)
        return z

edge_dim = g.edge_attr.shape[1]
gnn = TwoWayEdgeAwareGAT(in_ch=player_emb.shape[1], hidden=128, out_ch=128, edge_dim=edge_dim)
z = gnn(g.x, g.edge_index, g.edge_attr, g.edge_type)


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GATv2Conv  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ========= 1. Precompute per-match tensors ONCE =========

# row indices per split (on CPU)
train_rows = torch.tensor(edges.index[edges["split"] == "train"].to_numpy(), dtype=torch.long)
val_rows   = torch.tensor(edges.index[edges["split"] == "val"].to_numpy(),   dtype=torch.long)
test_rows  = torch.tensor(edges.index[edges["split"] == "test"].to_numpy(),  dtype=torch.long)

# player indices per match (move to device)
winner_idx_all = torch.tensor(edges["winner_idx"].to_numpy(), dtype=torch.long, device=device)
loser_idx_all  = torch.tensor(edges["loser_idx"].to_numpy(),  dtype=torch.long, device=device)

# --- categorical encodings computed ONCE ---

# surface: already 0/1/2
surface_all = torch.tensor(edges["surface"].to_numpy(), dtype=torch.long, device=device)
surface_oh  = F.one_hot(surface_all, num_classes=3).float()  # [M, 3]

# tourney_level → ids → one-hot
lvl_values = sorted(edges["tourney_level"].unique())
lvl2id = {lvl: i for i, lvl in enumerate(lvl_values)}
tourney_level_idx = torch.tensor(
    edges["tourney_level"].map(lvl2id).to_numpy(), dtype=torch.long, device=device
)
tourney_level_oh = F.one_hot(tourney_level_idx, num_classes=len(lvl2id)).float()  # [M, L]

# round → ids → one-hot
rnd_values = sorted(edges["round"].unique())
rnd2id = {rnd: i for i, rnd in enumerate(rnd_values)}
round_idx = torch.tensor(
    edges["round"].map(rnd2id).to_numpy(), dtype=torch.long, device=device
)
round_oh = F.one_hot(round_idx, num_classes=len(rnd2id)).float()  # [M, R]

# --- numeric features (normed once) ---

best_of_all = torch.tensor(edges["best_of"].to_numpy(), dtype=torch.float32, device=device).unsqueeze(1)
days_all    = torch.tensor(edges["days_ago"].to_numpy(), dtype=torch.float32, device=device).unsqueeze(1)

best_of_all = (best_of_all - best_of_all.mean()) / (best_of_all.std() + 1e-6)
days_all    = (days_all    - days_all.mean())    / (days_all.std()    + 1e-6)

# --- final per-match features (constant) ---

match_feat_all = torch.cat([surface_oh, tourney_level_oh, round_oh, best_of_all, days_all], dim=1)
match_feat_dim = match_feat_all.size(1)  # 3 + |lvl| + |round| + 1 + 1

# message-passing edges: USE ONLY TRAIN EDGES to avoid leakage
train_edge_mask = g.train_mask.to(device)         # [2E] bool
edge_index_train = g.edge_index[:, train_edge_mask].to(device)
edge_attr_train  = g.edge_attr[train_edge_mask].to(device)
edge_type_train  = g.edge_type[train_edge_mask].to(device)

# ========= 2. Dataset that holds ONLY row indices =========

class MatchIndexDataset(Dataset):
    def __init__(self, row_indices: torch.Tensor):
        self.row_indices = row_indices
    def __len__(self):
        return self.row_indices.size(0)
    def __getitem__(self, idx):
        return self.row_indices[idx]

train_ds = MatchIndexDataset(train_rows)
val_ds   = MatchIndexDataset(val_rows)
test_ds  = MatchIndexDataset(test_rows)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=512, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=512, shuffle=False)

# ========= 3. GNN: two-way edge-aware GAT =========

from torch_geometric.nn import GINEConv

from torch_geometric.nn import GINEConv

class TwoWayEdgeAwareGIN(nn.Module):
    def __init__(self, in_ch, hidden, out_ch, edge_dim, dropout=0.1):
        super().__init__()

        def make_mlp(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, hidden),
                nn.ReLU(),
                nn.Linear(hidden, out_dim),
            )

        # GINE for each direction
        self.win_conv = GINEConv(make_mlp(in_ch, out_ch), edge_dim=edge_dim)
        self.los_conv = GINEConv(make_mlp(in_ch, out_ch), edge_dim=edge_dim)

        # combine directional embeddings
        self.combine = nn.Linear(out_ch * 2, out_ch)

        # project original node features to same dim for residual
        if in_ch == out_ch:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Linear(in_ch, out_ch)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_attr, edge_type):
        # split edges
        idx_win = (edge_type == 0).nonzero(as_tuple=False).view(-1)
        idx_los = (edge_type == 1).nonzero(as_tuple=False).view(-1)

        ei_win = edge_index[:, idx_win]
        ei_los = edge_index[:, idx_los]
        ea_win = edge_attr[idx_win]
        ea_los = edge_attr[idx_los]

        # neighbor messages
        z_win = self.win_conv(x, ei_win, ea_win)  # [N, out_ch]
        z_los = self.los_conv(x, ei_los, ea_los)  # [N, out_ch]

        # directional combine
        z_dir = torch.cat([z_win, z_los], dim=-1)   # [N, 2*out_ch]
        z_dir = self.combine(z_dir)                # [N, out_ch]

        # residual + nonlinearity + dropout
        z = z_dir + self.skip(x)                   # include original node vec
        z = F.relu(z)
        z = self.dropout(z)

        return z  # [N, out_ch]


# ========= 4. Prediction head MLP =========
class MatchMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),  # 0 = (loser,winner), 1 = (winner,loser)
        )

    def forward(self, X):
        return self.net(X)

# ========= 5. Build batch features from z + row indices (no duplicate encodings) =========

def make_batch_features(z, batch_rows):
    """
    z: [N, D] node embeddings from GNN
    batch_rows: [B] indices into edges/matches
    returns X_batch: [2B, feat_dim], y_batch: [2B]
    """
    batch_rows = batch_rows.to(device)

    w_idx = winner_idx_all[batch_rows]   # [B]
    l_idx = loser_idx_all[batch_rows]    # [B]

    z_w = z[w_idx]  # [B, D]
    z_l = z[l_idx]  # [B, D]

    mf = match_feat_all[batch_rows]  # [B, match_feat_dim]

    # positive: (winner, loser) → label 1
    X_pos = torch.cat([z_w, z_l, mf], dim=1)
    y_pos = torch.ones(X_pos.size(0), dtype=torch.long, device=device)

    # negative: (loser, winner) → label 0
    X_neg = torch.cat([z_l, z_w, mf], dim=1)
    y_neg = torch.zeros(X_neg.size(0), dtype=torch.long, device=device)

    X_batch = torch.cat([X_pos, X_neg], dim=0)
    y_batch = torch.cat([y_pos, y_neg], dim=0)
    return X_batch, y_batch

# ========= 6. Instantiate models, optimizer, etc. =========

edge_dim = g.edge_attr.shape[1]
gnn_hidden = 128
gnn_out = 128

gnn = TwoWayEdgeAwareGIN(
    in_ch=g.x.shape[1],
    hidden=gnn_hidden,
    out_ch=gnn_out,
    edge_dim=edge_dim,
).to(device)

head_in_dim = 2 * gnn_out + match_feat_dim
head = MatchMLP(in_dim=head_in_dim, hidden_dim=128).to(device)

params = list(gnn.parameters()) + list(head.parameters())
criterion = nn.CrossEntropyLoss()
gnn_lr  = 1e-4   # smaller LR for GNN
head_lr = 1e-5   # larger LR for prediction head

optimizer = torch.optim.Adam(
    [
        {"params": gnn.parameters(),  "lr": gnn_lr},
        {"params": head.parameters(), "lr": head_lr},
    ],
    weight_decay=1e-5,   # optional
)

# ========= 7. Eval helper (compute z once per split) =========

def eval_split(loader, gnn, head):
    gnn.eval()
    head.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        z = gnn(
            g.x.to(device),
            edge_index_train,
            edge_attr_train,
            edge_type_train,
        )
        for batch_rows in loader:
            batch_rows = batch_rows.squeeze(-1) if batch_rows.ndim > 1 else batch_rows
            Xb, yb = make_batch_features(z, batch_rows)
            logits = head(Xb)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return correct / total if total > 0 else 0.0

# ========= 8. Training loop (end-to-end) =========

num_epochs = 100
batch_size = 256

for epoch in range(num_epochs):
    gnn.train()
    head.train()

    # shuffle match indices each epoch
    perm = torch.randperm(train_rows.size(0))
    train_rows_shuffled = train_rows[perm].to(device)
    

    total_loss = 0.0
    total_examples = 0

    # build the total loss from mini-batches of matches
    for start in range(0, train_rows_shuffled.size(0), batch_size):
        end = start + batch_size
        batch_rows = train_rows_shuffled[start:end]
        optimizer.zero_grad()
        z = gnn(
            g.x.to(device),
            edge_index_train,
            edge_attr_train,
            edge_type_train,
        )  # [N, gnn_out]  # [B]

        Xb, yb = make_batch_features(z, batch_rows)   # [2B, feat_dim], [2B]
        logits = head(Xb)
        loss = criterion(logits, yb)                  # scalar

        # accumulate loss *without* breaking the graph
        total_loss = total_loss + loss * yb.size(0)
        total_examples += yb.size(0)
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / total_examples

    # single backward through the whole graph for this epoch

    # evaluation
    val_acc = eval_split(val_loader, gnn, head)
    print(f"Epoch {epoch+1}: loss={avg_loss.item():.4f}, val_acc={val_acc:.4f}")

test_acc = eval_split(test_loader, gnn, head)
print("Final test accuracy:", test_acc)


Using device: cpu
Epoch 1: loss=0.6679, val_acc=0.6005
Epoch 2: loss=0.6283, val_acc=0.6089
Epoch 3: loss=0.6196, val_acc=0.6185
Epoch 4: loss=0.6191, val_acc=0.6146
Epoch 5: loss=0.6163, val_acc=0.6140
Epoch 6: loss=0.6161, val_acc=0.6179
Epoch 7: loss=0.6158, val_acc=0.6185
Epoch 8: loss=0.6151, val_acc=0.6225
Epoch 9: loss=0.6143, val_acc=0.6202
Epoch 10: loss=0.6144, val_acc=0.6163
Epoch 11: loss=0.6144, val_acc=0.6163
Epoch 12: loss=0.6124, val_acc=0.6174
Epoch 13: loss=0.6136, val_acc=0.6185
Epoch 14: loss=0.6113, val_acc=0.6168
Epoch 15: loss=0.6121, val_acc=0.6168
Epoch 16: loss=0.6120, val_acc=0.6191
Epoch 17: loss=0.6105, val_acc=0.6140
Epoch 18: loss=0.6136, val_acc=0.6179
Epoch 19: loss=0.6105, val_acc=0.6163
Epoch 20: loss=0.6106, val_acc=0.6129
Epoch 21: loss=0.6095, val_acc=0.6168
Epoch 22: loss=0.6092, val_acc=0.6174
Epoch 23: loss=0.6102, val_acc=0.6106
Epoch 24: loss=0.6112, val_acc=0.6185
Epoch 25: loss=0.6092, val_acc=0.6196
Epoch 26: loss=0.6099, val_acc=0.6179
Epo