In [34]:
# %pip install torch torch_geometric pytorch_lightning

In [35]:
import torch
from torch_geometric.data import Data
import pandas as pd

# Load the data

In [36]:
df = pd.read_csv('../get_lineup_target_score/nba_with_lineup_score.csv')

# Step 1: Map each unique lineup to an integer index (node ID)

In [37]:
lineups = pd.unique(df[['home_lineup', 'away_lineup']].values.ravel())
lineup_to_id = {lineup: i for i, lineup in enumerate(lineups)}

# Step 2: Build edge list with direction based on lineup_score

In [38]:
edge_index = []

for _, row in df.iterrows():
    home = row['home_lineup']
    away = row['away_lineup']
    home_score = row['lineup_score']

    # Estimate away score if not directly available
    # For now, let's assume both teams get the same formula
    # You may want to calculate a true 'away_lineup_score' in your preprocessing
    away_score = row.get('away_lineup_score', 0)

    if home_score > away_score:
        edge_index.append([lineup_to_id[away], lineup_to_id[home]])
    elif away_score > home_score:
        edge_index.append([lineup_to_id[home], lineup_to_id[away]])

# Convert to tensor (shape: [2, num_edges])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

In [39]:
# Create an edge feature tensor filled with zeros
edge_attr = torch.zeros((edge_index.size(1), 1))  # Assuming 1 feature per edge

# Step 3: Create node features

In [40]:
# Placeholder: use a simple zero vector for each lineup
# Replace with actual stats (from nba_api) per lineup
num_nodes = len(lineup_to_id)
x = torch.zeros((num_nodes, 10))  # e.g., 10 features per lineup

# Step 4: Create PyG Data object

In [41]:
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

print(data)

Data(x=[43136, 10], edge_index=[2, 46035], edge_attr=[46035, 1])


# GNN Model

In [42]:
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import GINEConv
from torch_geometric.loader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [52]:
class LineupGINE(pl.LightningModule):
    def __init__(self, input_dim, edge_feature_dim, hidden_dim=32, output_dim=16, lr=0.001):
        super(LineupGINE, self).__init__()
        self.save_hyperparameters()
        self.lr = lr

        self.gine1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim=edge_feature_dim,
        )
        self.gine2 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim),
            ),
            edge_dim=edge_feature_dim,
        )
        self.decoder = nn.Sequential(
            nn.Linear(2 * output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.gine1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.gine2(x, edge_index, edge_attr)
        return x
    
    def training_step(self, batch, batch_idx):
        embeddings = self(batch)

        src, tgt = batch.edge_index
        src_embeddings = embeddings[src]
        tgt_embeddings = embeddings[tgt]
        edge_feature_input = torch.cat([src_embeddings, tgt_embeddings], dim=1)
        pred_scores = self.decoder(edge_feature_input)
        loss = F.mse_loss(pred_scores, batch.edge_attr)
        self.log('train_loss', loss, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        embeddings = self(batch)

        src, tgt = batch.edge_index
        src_embeddings = embeddings[src]
        tgt_embeddings = embeddings[tgt]
        edge_feature_input = torch.cat([src_embeddings, tgt_embeddings], dim=1)
        pred_scores = self.decoder(edge_feature_input)
        loss = F.mse_loss(pred_scores, batch.edge_attr)
        self.log('val_loss', loss, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        embeddings = self(batch)

        src, tgt = batch.edge_index
        src_embeddings = embeddings[src]
        tgt_embeddings = embeddings[tgt]
        edge_feature_input = torch.cat([src_embeddings, tgt_embeddings], dim=1)
        pred_scores = self.decoder(edge_feature_input)
        loss = F.mse_loss(pred_scores, batch.edge_attr)
        self.log('test_loss', loss, prog_bar=True, logger=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [53]:
# Get the total number of edges
num_edges = edge_index.size(1)

# Define split ratios
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

# Calculate the number of edges for each split
num_train = int(num_edges * train_ratio)
num_val = int(num_edges * val_ratio)
num_test = num_edges - num_train - num_val

# Shuffle the edges
perm = torch.randperm(num_edges)

# Split the edge indices and edge attributes
train_edge_index = edge_index[:, perm[:num_train]]
val_edge_index = edge_index[:, perm[num_train:num_train + num_val]]
test_edge_index = edge_index[:, perm[num_train + num_val:]]

train_edge_attr = edge_attr[perm[:num_train]]
val_edge_attr = edge_attr[perm[num_train:num_train + num_val]]
test_edge_attr = edge_attr[perm[num_train + num_val:]]

# Create train, val, and test Data objects
train_data = Data(x=x, edge_index=train_edge_index, edge_attr=train_edge_attr)
val_data = Data(x=x, edge_index=val_edge_index, edge_attr=val_edge_attr)
test_data = Data(x=x, edge_index=test_edge_index, edge_attr=test_edge_attr)

In [54]:
batch_size = 32
train_loader = DataLoader([train_data], batch_size=batch_size, shuffle=True)
val_loader = DataLoader([val_data], batch_size=batch_size, shuffle=False)
test_loader = DataLoader([test_data], batch_size=batch_size, shuffle=False)

In [55]:
model = LineupGINE(
    input_dim=10,  # Number of features per node
    edge_feature_dim=1,  # Number of features per edge
    hidden_dim=32,
    output_dim=1,
    lr=0.001
)

trainer = pl.Trainer(
    max_epochs=10
)

trainer.fit(model, train_loader, val_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | gine1   | GINEConv   | 1.4 K  | train
1 | gine2   | GINEConv   | 1.2 K  | train
2 | decoder | Sequential | 129    | train
-----------------------------------------------
2.7 K     Trainable params
0         Non-trainable params
2.7 K     Total params
0.011     Total estimated model params size (MB)
18        Modules in train mode
0         Modules in eval mode


                                                                           

c:\Users\rokaa\egyetem\basketball_lineup_analysis\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
c:\Users\rokaa\egyetem\basketball_lineup_analysis\.venv\lib\site-packages\pytorch_lightning\utilities\data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 43136. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  7.75it/s, v_num=6, train_loss=0.0201, val_loss=0.00715]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  6.80it/s, v_num=6, train_loss=0.0201, val_loss=0.00715]


In [57]:
trainer.test(model, test_loader)

Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 26.31it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss          0.007198105100542307
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.007198105100542307}]