# Run all Tiny CNNs

This file runs all CNNs that are 2d, with and without context features

In [16]:
# 1. Load standard libraries FIRST
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import sys
import importlib
import inspect
import os

# 2. Verify standard libraries are healthy
print(f"Pandas version: {pd.__version__}")

# --- 1. Path Setup ---
# Try to locate the repository root by searching upward for a 'src' directory (or .git)
def find_repo_root(start_path=None, marker_dirs=('src', '.git')):
    p = os.path.abspath(start_path or os.getcwd())
    while True:
        if any(os.path.isdir(os.path.join(p, m)) for m in marker_dirs):
            return p
        parent = os.path.dirname(p)
        if parent == p:
            return None
        p = parent

repo_root = find_repo_root()
# Fallback to previous hardcoded path working on nuvolos
if repo_root is None:
    repo_root = "/files/pixlball"

if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
print(f"Using repo_root: {repo_root}")

import src.data as data
import src.model as model
import src.train as train
import src.config as config
import src.dataset as dataset
import src.evaluate as evaluate
import src.utils as utils

from src.config import DEVICE 


# 4. Force a clean reload of your specific logic
importlib.reload(data)
importlib.reload(train)
importlib.reload(evaluate)
importlib.reload(model)
importlib.reload(dataset)
importlib.reload(utils)


# 5. THE SMOKE TEST
print("Signature check:", inspect.signature(data.prepare_nn_dataset))

Pandas version: 2.3.3
Using repo_root: c:\Users\jonas\Desktop\repos\pixlball
Signature check: (events_df, nn_layers_df, target_cols=['nn_target'], id_col='id', context_cols=False, temporal_context=True, keep_context_ids=False)


In [2]:
data_events = pd.read_parquet(os.path.join(repo_root, "data", "events_data.parquet"), engine="fastparquet")
data_360 = pd.read_parquet(os.path.join(repo_root, "data", "sb360_data.parquet"), engine="fastparquet")

In [3]:
df_with_targets = data.event_data_loader(data_events)
df_with_targets = data.add_ball_trajectory_features(df_with_targets)

16316 events.
counts of each outcome nn_target
Keep Possession    58092
Lose Possession    26565
Shot                4638
Name: count, dtype: int64


## Prepare 360 Data

In [4]:
df_360 = data.assign_grid_cells(data_360)
nn_final = data.aggregate_nn_layers_vectorized(df_360)

## Finalize Df

In [5]:
nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model
nn_dataset = data.add_context_cols(nn_dataset)
nn_dataset = data.add_target_as_int(nn_dataset)
nn_dataset, vector_names = data.add_ball_coordinates(nn_dataset)

In [6]:
merged_dataset = pd.merge(nn_dataset, data_events, on = 'id', how = 'left')

In [7]:
merged_dataset

Unnamed: 0,id,ball_layer,teammates_layer,opponents_layer,nn_target,goal_flag,ball_trajectory_vector,match_id_x,possession_x,under_pressure_x,...,ball_recovery_offensive,dribble_no_touch,block_save_block,goalkeeper_penalty_saved_to_post,goalkeeper_shot_saved_to_post,shot_open_goal,shot_saved_to_post,shot_redirect,goalkeeper_lost_out,shot_follows_dribble
0,8b621ae4-ea81-415c-af41-9669db9bdd93,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[61.0, 40.1, 61.0, 40.1, 61.0, 40.1, 61.0, 40.1]",4020846,2,0.0,...,,,,,,,,,,
1,4706efbe-767c-45aa-9351-09528a77d135,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[26.4, 43.3, 61.0, 40.1, 26.4, 43.3, 26.4, 43.3]",4020846,2,0.0,...,,,,,,,,,,
2,084b9a88-4efa-4947-b94d-b89face472be,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[26.4, 43.3, 26.4, 43.3, 61.0, 40.1, 26.4, 43.3]",4020846,2,1.0,...,,,,,,,,,,
3,27fa7d4d-d637-4487-98e2-5c078ad600c7,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[28.5, 43.8, 26.4, 43.3, 26.4, 43.3, 61.0, 40.1]",4020846,2,1.0,...,,,,,,,,,,
4,764d437f-f799-4489-a38f-69fbb219a6fa,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[83.6, 59.0, 28.5, 43.8, 26.4, 43.3, 26.4, 43.3]",4020846,2,0.0,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
77907,4ae0db6f-063c-4947-a514-f24c8af42a12,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[70.8, 63.3, 70.8, 63.3, 55.3, 44.3, 54.5, 43.3]",3998841,182,1.0,...,,,,,,,,,,
77908,690d8b4c-7fa4-461b-a24d-ebe97fb2c9ae,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[50.3, 14.4, 70.8, 63.3, 70.8, 63.3, 55.3, 44.3]",3998841,182,0.0,...,,,,,,,,,,
77909,86456b25-4108-4725-b2fb-1bb263cf55c9,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[69.8, 65.7, 50.3, 14.4, 70.8, 63.3, 70.8, 63.3]",3998841,182,1.0,...,,,,,,,,,,
77910,1415a3dd-034f-46c4-b070-11180a26ed38,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[7.0, 46.2, 69.6, 67.3, 69.8, 65.7, 50.3, 14.4]",3998841,184,0.0,...,,,,,,,,,,


## Weights

In [8]:
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]
class_weights_event, goal_pos_weight = utils.get_multitask_loss_weights(nn_dataset, DEVICE)

print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")

Goal Positive Weight (0/1 ratio): 5.00


# Prepare Datasets for CNNs

In [9]:
import numpy as np
from sklearn.model_selection import train_test_split


# 1. Define your split parameters
VALIDATION_SIZE = 0.20
RANDOM_SEED = 42
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]

# 2. Split the entire DataFrame first
# This keeps features, event targets, and goal flags bundled together
train_df, val_df = train_test_split(
    nn_dataset, 
    test_size=VALIDATION_SIZE, 
    random_state=RANDOM_SEED, 
    stratify=nn_dataset['nn_target_int']
)

## Run the Baseline Model

In [28]:

# Training Dataset extraction - Pass only the values in the correct order
train_dataset = dataset.PitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values         # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset = dataset.PitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values
)

print(f"Total training samples: {len(train_dataset)}")
print(f"Total validation samples: {len(validation_dataset)}")
print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")


Total training samples: 62329
Total validation samples: 15583
Goal Positive Weight (0/1 ratio): 5.00


In [29]:

# ------------------------------------
# 4. Train model (Using the dedicated base function)
# ------------------------------------
print("Starting training for Static CNN Baseline...")
baseline_model = train.train_model_base_threat(
    dataset=train_dataset, 
    event_class_weights=class_weights_event, 
    goal_pos_weight=goal_pos_weight
)
print("Training complete.")

Starting training for Static CNN Baseline...


Base CNN Threat Epoch 1: 100%|██████████| 1948/1948 [00:30<00:00, 64.17it/s, ev_loss=1.1297, loss=1.1297] 

Training complete.





In [30]:
metrics = evaluate.evaluate_model_base_threat(baseline_model, validation_dataset)


--- Event Outcome Metrics ---
Event Accuracy: 0.5218507347750754
Event Balanced Accuracy: 0.5893328239053665
Event Confusion Matrix:
 [[5313 3710 1202]
 [1558 2163  783]
 [  92  106  656]]
              precision    recall  f1-score   support

           0       0.76      0.52      0.62     10225
           1       0.36      0.48      0.41      4504
           2       0.25      0.77      0.38       854

    accuracy                           0.52     15583
   macro avg       0.46      0.59      0.47     15583
weighted avg       0.62      0.52      0.55     15583


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8711943793911007
Goal Balanced Accuracy: 0.5
Goal AUC-ROC Score: 0.6419354838709677
Goal Confusion Matrix:
 [[744   0]
 [110   0]]
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93       744
         1.0       0.00      0.00      0.00       110

    accuracy                           0.87       854
   macro avg       0.44  

## Run the Context Model

In [31]:
context_features = ['under_pressure', 'counterpress', 'dribble_nutmeg']

# Training Dataset extraction - Pass only the values in the correct order
train_dataset_context = dataset.ContextPitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values,
    train_df[context_features]        # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset_context = dataset.ContextPitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values,
    val_df[context_features]  
)

print(f"Total training samples: {len(train_dataset_context)}")
print(f"Total validation samples: {len(validation_dataset_context)}")

Total training samples: 62329
Total validation samples: 15583


In [32]:
# Assuming event_class_weights and goal_pos_weight are defined from previous cells
NUM_CONTEXT_FEATURES = 3 

print("Starting training for Contextual CNN Baseline...")

# Modified the Function in Loss to take correct loss function -> needs to be changed for baseline model again

context_baseline_model = train.train_model_context_threat(
    dataset=train_dataset_context, 
    event_class_weights=class_weights_event, # Use your calculated weights
    goal_pos_weight=goal_pos_weight,         # Use your calculated pos_weight
    num_context_features=NUM_CONTEXT_FEATURES
)

print("\nContextual CNN Training complete.")

Starting training for Contextual CNN Baseline...


Context CNN Epoch 1: 100%|██████████| 1948/1948 [00:35<00:00, 54.13it/s, ev_loss=0.1727, loss=0.1727, sh_loss=0.0000] 


Contextual CNN Training complete.





In [33]:
# Assuming evaluate_model_context is imported and available

print("\nEvaluating Contextual CNN Model...")

metrics = evaluate.evaluate_model_context_threat(
    model=context_baseline_model, 
    dataset=validation_dataset_context # Evaluate on the contextual dataset
)



Evaluating Contextual CNN Model...

--- Event Outcome Metrics ---
Event Accuracy: 0.27164217416415326
Event Balanced Accuracy: 0.5245733517470116
Event Confusion Matrix:
 [[   0 8408 1817]
 [   0 3565  939]
 [   0  186  668]]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00     10225
           1       0.29      0.79      0.43      4504
           2       0.20      0.78      0.31       854

    accuracy                           0.27     15583
   macro avg       0.16      0.52      0.25     15583
weighted avg       0.10      0.27      0.14     15583


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8711943793911007
Goal Balanced Accuracy: 0.5
Goal AUC-ROC Score: 0.6087732160312805
Goal Confusion Matrix:
 [[744   0]
 [110   0]]
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93       744
         1.0       0.00      0.00      0.00       110

    accuracy                           0.8

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


## Run the Ball Vector Context Model

In [34]:
# -------------------------------------------------------------
# 3. Extract the arrays and Instantiate the Datasets (FIXED)
# -------------------------------------------------------------
context_features = ['under_pressure', 'counterpress', 'dribble_nutmeg']

# Training Dataset extraction - Pass only the values in the correct order
train_dataset_temporal_context = dataset.ContextBallVectorPitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values,
    train_df[vector_names]        # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset_temporal_context = dataset.ContextBallVectorPitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values,
    val_df[vector_names]  
)

print(f"Total training samples: {len(train_dataset_temporal_context)}")
print(f"Total validation samples: {len(validation_dataset_temporal_context)}")

Total training samples: 62329
Total validation samples: 15583


In [35]:
# Assuming event_class_weights and goal_pos_weight are defined from previous cells
NUM_CONTEXT_FEATURES = 8 

print("Starting training for Contextual CNN Baseline...")

# Modified the Function in Loss to take correct loss function -> needs to be changed for baseline model again

context_temporal_model = train.train_model_context_threat(
    dataset=train_dataset_temporal_context, 
    event_class_weights=class_weights_event, # Use your calculated weights
    goal_pos_weight=goal_pos_weight,         # Use your calculated pos_weight
    num_context_features=NUM_CONTEXT_FEATURES
)

print("\nContextual CNN Training complete.")

Starting training for Contextual CNN Baseline...


Context CNN Epoch 1: 100%|██████████| 1948/1948 [00:38<00:00, 50.08it/s, ev_loss=0.1783, loss=0.1783, sh_loss=0.0000]  


Contextual CNN Training complete.





In [36]:
# Assuming evaluate_model_context is imported and available

print("\nEvaluating Contextual CNN Model...")

metrics = evaluate.evaluate_model_context_threat(
    model=context_temporal_model, 
    dataset=validation_dataset_temporal_context # Evaluate on the contextual dataset
)


Evaluating Contextual CNN Model...

--- Event Outcome Metrics ---
Event Accuracy: 0.28864788551626774
Event Balanced Accuracy: 0.5293049582291408
Event Confusion Matrix:
 [[ 367 8179 1679]
 [  89 3462  953]
 [   4  181  669]]
              precision    recall  f1-score   support

           0       0.80      0.04      0.07     10225
           1       0.29      0.77      0.42      4504
           2       0.20      0.78      0.32       854

    accuracy                           0.29     15583
   macro avg       0.43      0.53      0.27     15583
weighted avg       0.62      0.29      0.19     15583


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.7985948477751756
Goal Balanced Accuracy: 0.5667888563049853
Goal AUC-ROC Score: 0.5882697947214076
Goal Confusion Matrix:
 [[654  90]
 [ 82  28]]
              precision    recall  f1-score   support

         0.0       0.89      0.88      0.88       744
         1.0       0.24      0.25      0.25       110

    accuracy               

# Run 3D CNN

In [10]:
voxels_list = data.generate_temporal_voxels(nn_dataset, lookback=3)

In [11]:
# 1. Generate the 4D Voxels (Channels, Time, Height, Width)
# Lookback 3 = 4 frames total (t, t-1, t-2, t-3)
voxels_list = data.generate_temporal_voxels(nn_dataset, lookback=3)

# 2. Add as a column
nn_dataset['temporal_voxel'] = voxels_list

# 3. Create your Train/Test Split
# We stratify on the event target to keep class balance
train_df, test_df = train_test_split(
    nn_dataset, 
    test_size=0.2, 
    random_state=42, 
    stratify=nn_dataset['nn_target_int']
)

In [12]:
train_dataset_3d = dataset.VoxelPitchDataset(train_df)
test_dataset_3d = dataset.VoxelPitchDataset(test_df)

# SMOKE TEST: Check the shape of the first item
voxel, event, goal = train_dataset_3d[0]
print(f"Voxel Shape: {voxel.shape}") # MUST be [3, 4, 12, 8]

Voxel Shape: torch.Size([3, 4, 12, 8])


In [17]:
model_3d = train.train_3d_model(
    dataset=train_dataset_3d,
    event_class_weights=class_weights_event, # Ensure these are pre-calculated tensors
    goal_pos_weight=goal_pos_weight,
    epochs=1,
    batch_size=32 
)

3D CNN Epoch 1: 100%|██████████| 1948/1948 [00:48<00:00, 40.14it/s, ev_loss=0.5032, loss=1.5441]


In [18]:
y_ev_true, y_ev_pred, y_goal_true, y_goal_probs = evaluate.get_3d_predictions(model_3d, test_dataset_3d, DEVICE)

In [19]:
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, roc_auc_score

print("--- 3D CNN EVENT CONFUSION MATRIX ---")
ev_cm = confusion_matrix(y_ev_true, y_ev_pred)
print(ev_cm)

print("\n--- CLASSIFICATION REPORT ---")
target_names = ['Keep Possession', 'Lose Possession', 'Shot']
print(classification_report(y_ev_true, y_ev_pred, target_names=target_names))

print(f"Balanced Accuracy: {balanced_accuracy_score(y_ev_true, y_ev_pred):.4f}")

--- 3D CNN EVENT CONFUSION MATRIX ---
[[    0 10225     0]
 [    0  4504     0]
 [    0   854     0]]

--- CLASSIFICATION REPORT ---
                 precision    recall  f1-score   support

Keep Possession       0.00      0.00      0.00     10225
Lose Possession       0.29      1.00      0.45      4504
           Shot       0.00      0.00      0.00       854

       accuracy                           0.29     15583
      macro avg       0.10      0.33      0.15     15583
   weighted avg       0.08      0.29      0.13     15583

Balanced Accuracy: 0.3333


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [20]:
print("--- 3D CNN GOAL CONFUSION MATRIX ---")
# convert continuous probabilities to binary predictions (threshold = 0.5)
y_goal_pred = (y_goal_probs >= 0.5).astype(int)
g_cm = confusion_matrix(y_goal_true.astype(int), y_goal_pred)
print(g_cm)

--- 3D CNN GOAL CONFUSION MATRIX ---
[[15473     0]
 [  110     0]]


In [21]:
auc_score = roc_auc_score(y_goal_true, y_goal_probs)
print(f"3D CNN Goal AUC-ROC: {auc_score:.4f}")

3D CNN Goal AUC-ROC: 0.4955
