In [54]:
# %pip install torch torch_geometric pytorch_lightning wandb scikit-learn

In [55]:
import torch
from torch_geometric.data import Data
import pandas as pd
import wandb
from sklearn.preprocessing import MultiLabelBinarizer
import ast

# Load the data

In [111]:
df = pd.read_csv('../get_lineup_target_score/nba_with_lineup_score.csv')

In [112]:
df.size

112860

In [113]:
df.dropna(subset=['net_score'], inplace=True)

In [114]:
df.size

91808

In [86]:
df['home_lineup'] = df['home_lineup'].apply(ast.literal_eval)
df['away_lineup'] = df['away_lineup'].apply(ast.literal_eval)

# All lineups
all_lineups = df['home_lineup'].tolist() + df['away_lineup'].tolist()

# Get unique lineups and encode them as binary vectors (each node = lineup)
unique_lineups = list(set(tuple(lineup) for lineup in all_lineups))
lineup2id = {lineup: idx for idx, lineup in enumerate(unique_lineups)}

mlb = MultiLabelBinarizer()
x = torch.tensor(mlb.fit_transform(unique_lineups), dtype=torch.float)  # Node features

# Edges from lower to higher node id
edge_index = []
edge_attr = []

for _, row in df.iterrows():
    src = lineup2id[tuple(row['home_lineup'])]
    dst = lineup2id[tuple(row['away_lineup'])]
    
    # Ensure edge goes from lower to higher ID
    low, high = sorted((src, dst))
    edge_index.append([low, high])
    edge_attr.append([row['net_score']])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_attr, dtype=torch.float)

# Final graph
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(data)

Data(x=[1990, 656], edge_index=[2, 2970], edge_attr=[2970, 1])


In [115]:
def create_lineup_graph(df):
    # Convert lineups to sorted tuples for consistency
    df['home_lineup'] = df['home_lineup'].apply(lambda x: tuple(sorted(ast.literal_eval(x))))
    df['away_lineup'] = df['away_lineup'].apply(lambda x: tuple(sorted(ast.literal_eval(x))))
    
    # Create unique lineup nodes
    all_lineups = pd.concat([df['home_lineup'], df['away_lineup']]).unique()
    lineup2idx = {lineup: idx for idx, lineup in enumerate(all_lineups)}
    
    # Create empty node features (required by PyG)
    x = torch.ones(len(all_lineups), 1)  # Placeholder features
    
    # Create directed edges with score-based direction
    edge_index = []
    edge_attr = []
    
    for _, row in df.iterrows():
        home = row['home_lineup']
        away = row['away_lineup']
        home_idx = lineup2idx[home]
        away_idx = lineup2idx[away]
        
        # Determine edge direction based on normalized scores
        if row['normalized_home_score'] > row['normalized_away_score']:
            src, dst = home_idx, away_idx
        else:
            src, dst = away_idx, home_idx
            
        edge_index.append([src, dst])
        edge_attr.append(abs(row['net_score']))
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    
    return Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        num_nodes=len(all_lineups))

In [116]:
data = create_lineup_graph(df)

In [117]:
data

Data(x=[1973, 1], edge_index=[2, 2416], edge_attr=[2416, 1], num_nodes=1973)

# GNN Model

In [118]:
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import GINEConv
from torch_geometric.loader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

In [119]:
class LineupGINE(pl.LightningModule):
    def __init__(self, input_dim, edge_feature_dim, hidden_dim=32, lr=0.001, num_conv_layers=2, num_linear_layers=2, dropout=0.5):
        super(LineupGINE, self).__init__()
        self.save_hyperparameters()
        self.lr = lr

        # Convolutional layers
        self.conv_layers = nn.ModuleList()
        for i in range(num_conv_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            mlp = nn.Sequential(
                nn.Linear(in_channels, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.conv_layers.append(
                GINEConv(mlp, edge_dim=edge_feature_dim)
            )

        self.dropout = nn.Dropout(dropout)
            
        # Decoder linear layers
        self.decoder = self._build_decoder(hidden_dim * 2, num_linear_layers)

    def _build_decoder(self, decoder_input_dim, num_layers):
        layers = []
        in_dim = decoder_input_dim
        
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(in_dim, self.hparams.hidden_dim))
            layers.append(nn.ReLU())
            in_dim = self.hparams.hidden_dim
            
        layers.append(nn.Linear(in_dim, 1))
        return nn.Sequential(*layers)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        for conv in self.conv_layers:
            x = conv(x, edge_index, edge_attr)
            x = F.relu(x)
            x = self.dropout(x)
        
        return x
    
    def training_step(self, batch, batch_idx):
        embeddings = self(batch)

        src, tgt = batch.edge_index
        src_embeddings = embeddings[src]
        tgt_embeddings = embeddings[tgt]
        edge_feature_input = torch.cat([src_embeddings, tgt_embeddings], dim=1)

        pred_scores = self.decoder(edge_feature_input)

        print(f"Predicted scores: {pred_scores}")
        print(f"Actual scores: {batch.edge_attr}")

        loss = F.mse_loss(pred_scores, batch.edge_attr)
        self.log('train_loss', loss, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        embeddings = self(batch)

        src, tgt = batch.edge_index
        src_embeddings = embeddings[src]
        tgt_embeddings = embeddings[tgt]
        edge_feature_input = torch.cat([src_embeddings, tgt_embeddings], dim=1)
        pred_scores = self.decoder(edge_feature_input)
        loss = F.mse_loss(pred_scores, batch.edge_attr)
        self.log('val_loss', loss, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        embeddings = self(batch)

        src, tgt = batch.edge_index
        src_embeddings = embeddings[src]
        tgt_embeddings = embeddings[tgt]
        edge_feature_input = torch.cat([src_embeddings, tgt_embeddings], dim=1)
        pred_scores = self.decoder(edge_feature_input)
        loss = F.mse_loss(pred_scores, batch.edge_attr)
        self.log('test_loss', loss, prog_bar=True, logger=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [120]:
batch_size = 32
train_data, val_data, test_data = RandomLinkSplit(num_val=0.1, num_test=0.1)(data)
train_loader = DataLoader([train_data], batch_size=batch_size, shuffle=True)
val_loader = DataLoader([val_data], batch_size=batch_size, shuffle=False)
test_loader = DataLoader([test_data], batch_size=batch_size, shuffle=False)

In [121]:
sweep_config = {
    'method': 'bayes',  # bayes, grid, or random
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'   
    },
    'parameters': {
        'hidden_dim': {
            'values': [32, 64, 128]
        },
        'num_conv_layers': {
            'values': [2, 3, 4]
        },
        'num_linear_layers': {
            'values': [1, 2, 3]
        },
        'lr': {
            'distribution': 'log_uniform',
            'min': 1e-4,
            'max': 1e-2
        },
        'dropout': {
            'values': [0.0, 0.2, 0.4]
        }
    }
}

In [122]:
def train_sweep():
    with wandb.init() as run:
        config = wandb.config
        model = LineupGINE(
            input_dim=10,  # Number of features per lineup
            edge_feature_dim=1,  # Number of features per edge (e.g., score difference)
            hidden_dim=config.hidden_dim,
            lr=config.lr,
            num_conv_layers=config.num_conv_layers,
            num_linear_layers=config.num_linear_layers,
            dropout=config.dropout
        )

        trainer = pl.Trainer(
            max_epochs=50,
            logger=pl.loggers.WandbLogger(),
            callbacks=[
                pl.callbacks.EarlyStopping(monitor='val_loss', patience=10),
                pl.callbacks.ModelCheckpoint(monitor='val_loss')
            ]
        )
        trainer.fit(model, train_loader, val_loader)
        trainer.test(model, test_loader)
        wandb.finish()

In [123]:
model = LineupGINE(
    input_dim=10,  # Number of features per node
    edge_feature_dim=1,  # Number of features per edge
    hidden_dim=32,
    lr=0.001,
    num_conv_layers=2,
    num_linear_layers=2
)

trainer = pl.Trainer(
    max_epochs=10
)

trainer.fit(model, train_loader, val_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | conv_layers | ModuleList | 3.6 K  | train
1 | dropout     | Dropout    | 0      | train
2 | decoder     | Sequential | 2.1 K  | train
---------------------------------------------------
5.7 K     Trainable params
0         Non-trainable params
5.7 K     Total params
0.023     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


                                                                           

c:\Users\rokaa\egyetem\basketball_lineup_analysis\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
c:\Users\rokaa\egyetem\basketball_lineup_analysis\.venv\lib\site-packages\pytorch_lightning\utilities\data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1973. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
c:\Users\rokaa\egyetem\basketball_lineup_analysis\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
c:\Users\rokaa\egyetem\basketball_lineup_analysi

Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] Predicted scores: tensor([[ 0.1084],
        [-0.0343],
        [ 0.0463],
        ...,
        [ 0.0119],
        [-0.0467],
        [ 0.0262]], grad_fn=<AddmmBackward0>)
Actual scores: tensor([[0.0000],
        [0.0319],
        [0.1154],
        ...,
        [0.1500],
        [0.0786],
        [0.0036]])
Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s, v_num=18, train_loss=nan.0, val_loss=nan.0]        Predicted scores: tensor([[nan],
        [nan],
        [nan],
        ...,
        [nan],
        [nan],
        [nan]], grad_fn=<AddmmBackward0>)
Actual scores: tensor([[0.0000],
        [0.0319],
        [0.1154],
        ...,
        [0.1500],
        [0.0786],
        [0.0036]])
Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s, v_num=18, train_loss=nan.0, val_loss=nan.0]        Predicted scores: tensor([[nan],
        [nan],
        [nan],
        ...,
        [nan],
        [nan],
        [nan]], grad_fn=<AddmmBackward0>)
Actual scores: 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1/1 [00:00<00:00,  7.74it/s, v_num=18, train_loss=nan.0, val_loss=nan.0]


In [124]:
trainer.test(model, test_loader)

c:\Users\rokaa\egyetem\basketball_lineup_analysis\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 66.61it/s] 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss                   nan
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': nan}]

In [None]:
sweep_id = wandb.sweep(sweep_config, project="lineup_gine_sweep")
wandb.agent(sweep_id, train_sweep, count=10)