In [46]:
import pandas as pd
import torch
from torch_frame import TensorFrame, stype
from torch_frame.nn import (
    StypeWiseFeatureEncoder,
    EmbeddingEncoder,
    LinearBucketEncoder,
)
from torch_frame.data import Dataset
from torch.nn import LayerNorm
import torch.nn.functional as F

Let's start by creating initial embeddings for our top 150 players using Torch Frame

In [47]:
# load our players data in from the CSV
players = pd.read_csv("../data/player_features.csv")

In [48]:
# channels controls the dimension size each column will have for our player rows after encoding
channels = 128

# set the stypes for each column in our data
col_to_stype = {
    "player_id": stype.numerical,
    "current_rank": stype.numerical,
    "dob": stype.numerical,
    "height": stype.numerical,
    "country_num": stype.categorical
}

# 2) Build a Dataset and materialize it -> computes col_stats and a TensorFrame
ds = Dataset(df=players, col_to_stype=col_to_stype).materialize()
tf_players = ds.tensor_frame
col_stats  = ds.col_stats
col_names_dict = tf_players.col_names_dict

# 3) Create the stype-wise encoder with the computed stats
stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical:  LinearBucketEncoder(post_module=LayerNorm(channels)),
}

encoder = StypeWiseFeatureEncoder(
    out_channels=channels,
    col_stats=col_stats,
    col_names_dict=col_names_dict,
    stype_encoder_dict=stype_encoder_dict,
)

# 4) Encode
x, _meta = encoder(tf_players)  # x: [batch, num_cols, channels]

player_emb = x.mean(dim=1) # simple average pooling over columns for now. we can get fancier later on

Now that we have our initial player embeddings, we can grab our edges to make our graph

In [49]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATConv

In [51]:
edges = pd.read_csv("../data/edges.csv")

src = torch.from_numpy(edges["winner_idx"].to_numpy()).long()
dst = torch.from_numpy(edges["loser_idx"].to_numpy()).long()

# original edge_attr: surface one-hot + z-scored days
surface = torch.from_numpy(edges["surface"].to_numpy()).long()
surface_oh = F.one_hot(surface, num_classes=3).float()           # [E, 3]
days = torch.from_numpy(edges["days_ago"].to_numpy()).float().unsqueeze(1)
days = (days - days.mean()) / (days.std() + 1e-6)                # [E, 1]
edge_attr = torch.cat([surface_oh, days], dim=1)                 # [E, 4]

# build reverse (loser -> winner)
src_rev, dst_rev = dst, src
edge_attr_rev = edge_attr.clone()  # same match features on reverse edge

# concatenate both directions
edge_index = torch.cat(
    [torch.stack([src, dst], dim=0),
     torch.stack([src_rev, dst_rev], dim=0)],
    dim=1
)  # [2, 2E]

# direction/type: 0 = "won-against" (winner->loser), 1 = "lost-to" (loser->winner)
edge_type = torch.cat([
    torch.zeros(src.size(0), dtype=torch.long),
    torch.ones(src_rev.size(0), dtype=torch.long)
], dim=0)  # [2E]

# duplicate edge attributes to match 2E edges
edge_attr_bidir = torch.cat([edge_attr, edge_attr_rev], dim=0)   # [2E, 4]

g = Data(
    x=player_emb,                # [N, C] from TorchFrame encoder
    edge_index=edge_index,       # [2, 2E]
    edge_attr=edge_attr_bidir,   # [2E, 4]
    edge_type=edge_type          # [2E]
)

In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv  

class TwoWayEdgeAwareGAT(nn.Module):
    def __init__(self, in_ch, hidden, out_ch, edge_dim):
        super().__init__()
        # separate params per direction
        self.win1 = GATv2Conv(in_ch, hidden, edge_dim=edge_dim, add_self_loops=False)
        self.los1 = GATv2Conv(in_ch, hidden, edge_dim=edge_dim, add_self_loops=False)
        self.win2 = GATv2Conv(hidden, out_ch, edge_dim=edge_dim, add_self_loops=False)
        self.los2 = GATv2Conv(hidden, out_ch, edge_dim=edge_dim, add_self_loops=False)
        self.combine = nn.Linear(out_ch * 2, out_ch)  # concat â†’ linear

    def forward(self, x, edge_index, edge_attr, edge_type):
        # split edges by direction
        idx_win = (edge_type == 0).nonzero(as_tuple=False).view(-1)
        idx_los = (edge_type == 1).nonzero(as_tuple=False).view(-1)

        ei_win = edge_index[:, idx_win]
        ei_los = edge_index[:, idx_los]
        ea_win = edge_attr[idx_win]
        ea_los = edge_attr[idx_los]

        z_win = F.relu(self.win1(x, ei_win, ea_win))
        z_los = F.relu(self.los1(x, ei_los, ea_los))
        z_win = self.win2(z_win, ei_win, ea_win)
        z_los = self.los2(z_los, ei_los, ea_los)

        # combine the two message streams
        z = torch.cat([z_win, z_los], dim=-1)   # or: z = z_win + z_los
        z = self.combine(z)
        return z

edge_dim = g.edge_attr.shape[1]
gnn = TwoWayEdgeAwareGAT(in_ch=player_emb.shape[1], hidden=128, out_ch=128, edge_dim=edge_dim)
z = gnn(g.x, g.edge_index, g.edge_attr, g.edge_type)
