# Run all Tiny CNNs

This file runs all CNNs that are 2d, with and without context features

In [1]:
# 1. Load standard libraries FIRST
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import sys
import importlib
import inspect
import os

# 2. Verify standard libraries are healthy
print(f"Pandas version: {pd.__version__}")

# --- 1. Path Setup ---
# Try to locate the repository root by searching upward for a 'src' directory (or .git)
def find_repo_root(start_path=None, marker_dirs=('src', '.git')):
    p = os.path.abspath(start_path or os.getcwd())
    while True:
        if any(os.path.isdir(os.path.join(p, m)) for m in marker_dirs):
            return p
        parent = os.path.dirname(p)
        if parent == p:
            return None
        p = parent

repo_root = find_repo_root()
# Fallback to previous hardcoded path working on nuvolos
if repo_root is None:
    repo_root = "/files/pixlball"

if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
print(f"Using repo_root: {repo_root}")

import src.data as data
import src.model as model
import src.train as train
import src.config as config
import src.dataset as dataset
import src.evaluate as evaluate
import src.utils as utils

from src.config import DEVICE 


# 4. Force a clean reload of your specific logic
importlib.reload(data)
importlib.reload(train)
importlib.reload(evaluate)
importlib.reload(model)
importlib.reload(dataset)
importlib.reload(utils)


# 5. THE SMOKE TEST
print("Signature check:", inspect.signature(data.prepare_nn_dataset))

Pandas version: 2.0.3
Using repo_root: /files/pixlball
Signature check: (events_df, nn_layers_df, target_cols=['nn_target'], id_col='id', context_cols=False, temporal_context=True, keep_context_ids=False)


In [2]:
data_events = pd.read_parquet(os.path.join(repo_root, "data", "events_data.parquet"), engine="fastparquet")
data_360 = pd.read_parquet(os.path.join(repo_root, "data", "sb360_data.parquet"), engine="fastparquet")

In [3]:
df_with_targets = data.event_data_loader(data_events)
df_with_targets = data.add_ball_trajectory_features(df_with_targets)

16316 events.
counts of each outcome nn_target
Keep Possession    58092
Lose Possession    26565
Shot                4638
Name: count, dtype: int64


## Prepare 360 Data

In [4]:
df_360 = data.assign_grid_cells(data_360)
nn_final = data.aggregate_nn_layers_vectorized(df_360)

## Finalize Df

In [5]:
nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model
nn_dataset = data.add_context_cols(nn_dataset)
nn_dataset = data.add_target_as_int(nn_dataset)
nn_dataset, vector_names = data.add_ball_coordinates(nn_dataset)

In [6]:
nn_dataset

Unnamed: 0,id,ball_layer,teammates_layer,opponents_layer,nn_target,goal_flag,ball_trajectory_vector,match_id,possession,under_pressure,...,dribble_nutmeg,nn_target_int,x1,y1,x2,y2,x3,y3,x4,y4
0,8b621ae4-ea81-415c-af41-9669db9bdd93,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[61.0, 40.1, 61.0, 40.1, 61.0, 40.1, 61.0, 40.1]",4020846,2,0.0,...,0.0,0,0.508333,0.50125,0.508333,0.50125,0.508333,0.50125,0.508333,0.50125
1,4706efbe-767c-45aa-9351-09528a77d135,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[26.4, 43.3, 61.0, 40.1, 26.4, 43.3, 26.4, 43.3]",4020846,2,0.0,...,0.0,0,0.220000,0.54125,0.508333,0.50125,0.220000,0.54125,0.220000,0.54125
2,084b9a88-4efa-4947-b94d-b89face472be,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[26.4, 43.3, 26.4, 43.3, 61.0, 40.1, 26.4, 43.3]",4020846,2,1.0,...,0.0,0,0.220000,0.54125,0.220000,0.54125,0.508333,0.50125,0.220000,0.54125
3,27fa7d4d-d637-4487-98e2-5c078ad600c7,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[28.5, 43.8, 26.4, 43.3, 26.4, 43.3, 61.0, 40.1]",4020846,2,1.0,...,0.0,0,0.237500,0.54750,0.220000,0.54125,0.220000,0.54125,0.508333,0.50125
4,764d437f-f799-4489-a38f-69fbb219a6fa,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,"[83.6, 59.0, 28.5, 43.8, 26.4, 43.3, 26.4, 43.3]",4020846,2,0.0,...,0.0,0,0.696667,0.73750,0.237500,0.54750,0.220000,0.54125,0.220000,0.54125
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
77907,4ae0db6f-063c-4947-a514-f24c8af42a12,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[70.8, 63.3, 70.8, 63.3, 55.3, 44.3, 54.5, 43.3]",3998841,182,1.0,...,0.0,1,0.590000,0.79125,0.590000,0.79125,0.460833,0.55375,0.454167,0.54125
77908,690d8b4c-7fa4-461b-a24d-ebe97fb2c9ae,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[50.3, 14.4, 70.8, 63.3, 70.8, 63.3, 55.3, 44.3]",3998841,182,0.0,...,0.0,1,0.419167,0.18000,0.590000,0.79125,0.590000,0.79125,0.460833,0.55375
77909,86456b25-4108-4725-b2fb-1bb263cf55c9,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[69.8, 65.7, 50.3, 14.4, 70.8, 63.3, 70.8, 63.3]",3998841,182,1.0,...,0.0,1,0.581667,0.82125,0.419167,0.18000,0.590000,0.79125,0.590000,0.79125
77910,1415a3dd-034f-46c4-b070-11180a26ed38,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,"[7.0, 46.2, 69.6, 67.3, 69.8, 65.7, 50.3, 14.4]",3998841,184,0.0,...,0.0,1,0.058333,0.57750,0.580000,0.84125,0.581667,0.82125,0.419167,0.18000


## Weights

In [7]:
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]
class_weights_event, goal_pos_weight = utils.get_multitask_loss_weights(nn_dataset, DEVICE)

print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")

Goal Positive Weight (0/1 ratio): 5.00


# Prepare Datasets for CNNs

In [8]:
import numpy as np
from sklearn.model_selection import train_test_split


# 1. Define your split parameters
VALIDATION_SIZE = 0.20
RANDOM_SEED = 42
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]

# 2. Split the entire DataFrame first
# This keeps features, event targets, and goal flags bundled together
train_df, val_df = train_test_split(
    nn_dataset, 
    test_size=VALIDATION_SIZE, 
    random_state=RANDOM_SEED, 
    stratify=nn_dataset['nn_target_int']
)

## Run the Baseline Model

In [9]:

# Training Dataset extraction - Pass only the values in the correct order
train_dataset = dataset.PitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values         # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset = dataset.PitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values
)

print(f"Total training samples: {len(train_dataset)}")
print(f"Total validation samples: {len(validation_dataset)}")
print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")


Total training samples: 62329
Total validation samples: 15583
Goal Positive Weight (0/1 ratio): 5.00


In [10]:

# ------------------------------------
# 4. Train model (Using the dedicated base function)
# ------------------------------------
print("Starting training for Static CNN Baseline...")
baseline_model = train.train_model_base_threat(
    dataset=train_dataset, 
    event_class_weights=class_weights_event, 
    goal_pos_weight=goal_pos_weight
)
print("Training complete.")

Base CNN Threat Epoch 1:   1%|▏         | 28/1948 [00:00<00:06, 277.24it/s, ev_loss=1.0353, loss=1.2662]

Starting training for Static CNN Baseline...


Base CNN Threat Epoch 1: 100%|██████████| 1948/1948 [00:06<00:00, 288.17it/s, ev_loss=0.9304, loss=0.9304]
Base CNN Threat Epoch 2: 100%|██████████| 1948/1948 [00:06<00:00, 281.16it/s, ev_loss=0.4944, loss=0.8110]
Base CNN Threat Epoch 3: 100%|██████████| 1948/1948 [00:07<00:00, 270.64it/s, ev_loss=0.8479, loss=0.8479]
Base CNN Threat Epoch 4: 100%|██████████| 1948/1948 [00:06<00:00, 287.95it/s, ev_loss=0.5173, loss=0.7546]
Base CNN Threat Epoch 5: 100%|██████████| 1948/1948 [00:06<00:00, 280.84it/s, ev_loss=0.6150, loss=0.6150]
Base CNN Threat Epoch 6: 100%|██████████| 1948/1948 [00:06<00:00, 289.02it/s, ev_loss=0.6079, loss=0.8093]
Base CNN Threat Epoch 7: 100%|██████████| 1948/1948 [00:06<00:00, 289.39it/s, ev_loss=0.4902, loss=0.8407]
Base CNN Threat Epoch 8: 100%|██████████| 1948/1948 [00:06<00:00, 287.51it/s, ev_loss=0.6557, loss=0.8509]
Base CNN Threat Epoch 9: 100%|██████████| 1948/1948 [00:06<00:00, 281.54it/s, ev_loss=0.8057, loss=0.8057]
Base CNN Threat Epoch 10: 100%|██████

Training complete.





In [11]:
metrics = evaluate.evaluate_model_base_threat(baseline_model, validation_dataset)


--- Event Outcome Metrics ---
Event Accuracy: 0.5754989411538215
Event Balanced Accuracy: 0.6461207906504299
Event Confusion Matrix:
 [[5642 3714  869]
 [1310 2643  551]
 [  56  115  683]]
              precision    recall  f1-score   support

           0       0.81      0.55      0.65     10225
           1       0.41      0.59      0.48      4504
           2       0.32      0.80      0.46       854

    accuracy                           0.58     15583
   macro avg       0.51      0.65      0.53     15583
weighted avg       0.66      0.58      0.59     15583


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8501170960187353
Goal Balanced Accuracy: 0.7319281524926686
Goal AUC-ROC Score: 0.7937316715542522
Goal Confusion Matrix:
 [[663  81]
 [ 47  63]]
              precision    recall  f1-score   support

         0.0       0.93      0.89      0.91       744
         1.0       0.44      0.57      0.50       110

    accuracy                           0.85       854
   macro a

## Run the Context Model

In [12]:
context_features = ['under_pressure', 'counterpress', 'dribble_nutmeg']

# Training Dataset extraction - Pass only the values in the correct order
train_dataset_context = dataset.ContextPitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values,
    train_df[context_features]        # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset_context = dataset.ContextPitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values,
    val_df[context_features]  
)

print(f"Total training samples: {len(train_dataset_context)}")
print(f"Total validation samples: {len(validation_dataset_context)}")

Total training samples: 62329
Total validation samples: 15583


In [13]:
# Assuming event_class_weights and goal_pos_weight are defined from previous cells
NUM_CONTEXT_FEATURES = 3 

print("Starting training for Contextual CNN Baseline...")

# Modified the Function in Loss to take correct loss function -> needs to be changed for baseline model again

context_baseline_model = train.train_model_context_threat(
    dataset=train_dataset_context, 
    event_class_weights=class_weights_event, # Use your calculated weights
    goal_pos_weight=goal_pos_weight,         # Use your calculated pos_weight
    num_context_features=NUM_CONTEXT_FEATURES
)

print("\nContextual CNN Training complete.")

Context CNN Epoch 1:   1%|▏         | 29/1948 [00:00<00:06, 281.56it/s, ev_loss=1.0662, loss=1.7442, sh_loss=1.3558]

Starting training for Contextual CNN Baseline...


Context CNN Epoch 1: 100%|██████████| 1948/1948 [00:07<00:00, 259.33it/s, ev_loss=0.2388, loss=0.5261, sh_loss=0.5746]
Context CNN Epoch 2: 100%|██████████| 1948/1948 [00:07<00:00, 256.06it/s, ev_loss=0.1685, loss=0.4538, sh_loss=0.5706]
Context CNN Epoch 3: 100%|██████████| 1948/1948 [00:07<00:00, 257.23it/s, ev_loss=0.2583, loss=0.6675, sh_loss=0.8183]
Context CNN Epoch 4: 100%|██████████| 1948/1948 [00:07<00:00, 260.86it/s, ev_loss=0.2049, loss=0.3845, sh_loss=0.3592]
Context CNN Epoch 5: 100%|██████████| 1948/1948 [00:07<00:00, 258.39it/s, ev_loss=0.6288, loss=0.9339, sh_loss=0.6101] 
Context CNN Epoch 6: 100%|██████████| 1948/1948 [00:07<00:00, 257.87it/s, ev_loss=0.1364, loss=0.4177, sh_loss=0.5628] 
Context CNN Epoch 7: 100%|██████████| 1948/1948 [00:07<00:00, 259.70it/s, ev_loss=0.1262, loss=0.4124, sh_loss=0.5724] 
Context CNN Epoch 8: 100%|██████████| 1948/1948 [00:07<00:00, 261.16it/s, ev_loss=0.5675, loss=0.9270, sh_loss=0.7191] 
Context CNN Epoch 9: 100%|██████████| 1948/1


Contextual CNN Training complete.





In [14]:
# Assuming evaluate_model_context is imported and available

print("\nEvaluating Contextual CNN Model...")

metrics = evaluate.evaluate_model_context_threat(
    model=context_baseline_model, 
    dataset=validation_dataset_context # Evaluate on the contextual dataset
)



Evaluating Contextual CNN Model...

--- Event Outcome Metrics ---
Event Accuracy: 0.3361355323108516
Event Balanced Accuracy: 0.5738914305762289
Event Confusion Matrix:
 [[ 842 8067 1316]
 [ 107 3697  700]
 [   4  151  699]]
              precision    recall  f1-score   support

           0       0.88      0.08      0.15     10225
           1       0.31      0.82      0.45      4504
           2       0.26      0.82      0.39       854

    accuracy                           0.34     15583
   macro avg       0.48      0.57      0.33     15583
weighted avg       0.68      0.34      0.25     15583


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8793911007025761
Goal Balanced Accuracy: 0.7758431085043989
Goal AUC-ROC Score: 0.8258431085043989
Goal Confusion Matrix:
 [[681  63]
 [ 40  70]]
              precision    recall  f1-score   support

         0.0       0.94      0.92      0.93       744
         1.0       0.53      0.64      0.58       110

    accuracy                

## Run the Ball Vector Context Model

In [15]:
# -------------------------------------------------------------
# 3. Extract the arrays and Instantiate the Datasets (FIXED)
# -------------------------------------------------------------
context_features = ['under_pressure', 'counterpress', 'dribble_nutmeg']

# Training Dataset extraction - Pass only the values in the correct order
train_dataset_temporal_context = dataset.ContextBallVectorPitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values,
    train_df[vector_names]        # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset_temporal_context = dataset.ContextBallVectorPitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values,
    val_df[vector_names]  
)

print(f"Total training samples: {len(train_dataset_temporal_context)}")
print(f"Total validation samples: {len(validation_dataset_temporal_context)}")

Total training samples: 62329
Total validation samples: 15583


In [16]:
# Assuming event_class_weights and goal_pos_weight are defined from previous cells
NUM_CONTEXT_FEATURES = 8 

print("Starting training for Contextual CNN Baseline...")

# Modified the Function in Loss to take correct loss function -> needs to be changed for baseline model again

context_temporal_model = train.train_model_context_threat(
    dataset=train_dataset_temporal_context, 
    event_class_weights=class_weights_event, # Use your calculated weights
    goal_pos_weight=goal_pos_weight,         # Use your calculated pos_weight
    num_context_features=NUM_CONTEXT_FEATURES
)

print("\nContextual CNN Training complete.")

Context CNN Epoch 1:   1%|          | 24/1948 [00:00<00:08, 236.52it/s, ev_loss=0.4578, loss=1.5891, sh_loss=2.2627]

Starting training for Contextual CNN Baseline...


Context CNN Epoch 1: 100%|██████████| 1948/1948 [00:08<00:00, 235.45it/s, ev_loss=0.2425, loss=0.5247, sh_loss=0.5645]
Context CNN Epoch 2: 100%|██████████| 1948/1948 [00:08<00:00, 231.61it/s, ev_loss=0.5180, loss=1.2169, sh_loss=1.3979]
Context CNN Epoch 3: 100%|██████████| 1948/1948 [00:08<00:00, 233.63it/s, ev_loss=0.1588, loss=0.1588, sh_loss=0.0000]
Context CNN Epoch 4: 100%|██████████| 1948/1948 [00:08<00:00, 230.31it/s, ev_loss=0.5370, loss=0.6523, sh_loss=0.2307] 
Context CNN Epoch 5: 100%|██████████| 1948/1948 [00:08<00:00, 235.81it/s, ev_loss=0.2491, loss=0.2491, sh_loss=0.0000] 
Context CNN Epoch 6: 100%|██████████| 1948/1948 [00:08<00:00, 233.28it/s, ev_loss=0.2535, loss=0.2971, sh_loss=0.0872]
Context CNN Epoch 7: 100%|██████████| 1948/1948 [00:08<00:00, 235.45it/s, ev_loss=0.1888, loss=0.2136, sh_loss=0.0496] 
Context CNN Epoch 8: 100%|██████████| 1948/1948 [00:08<00:00, 235.58it/s, ev_loss=0.1566, loss=0.1566, sh_loss=0.0000] 
Context CNN Epoch 9: 100%|██████████| 1948/1


Contextual CNN Training complete.





In [17]:
# Assuming evaluate_model_context is imported and available

print("\nEvaluating Contextual CNN Model...")

metrics = evaluate.evaluate_model_context_threat(
    model=context_temporal_model, 
    dataset=validation_dataset_temporal_context # Evaluate on the contextual dataset
)


Evaluating Contextual CNN Model...

--- Event Outcome Metrics ---
Event Accuracy: 0.3512802412885837
Event Balanced Accuracy: 0.5801356962909571
Event Confusion Matrix:
 [[1113 7845 1267]
 [ 152 3662  690]
 [   7  148  699]]
              precision    recall  f1-score   support

           0       0.88      0.11      0.19     10225
           1       0.31      0.81      0.45      4504
           2       0.26      0.82      0.40       854

    accuracy                           0.35     15583
   macro avg       0.48      0.58      0.35     15583
weighted avg       0.68      0.35      0.28     15583


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8548009367681498
Goal Balanced Accuracy: 0.7462365591397849
Goal AUC-ROC Score: 0.8013318670576736
Goal Confusion Matrix:
 [[664  80]
 [ 44  66]]
              precision    recall  f1-score   support

         0.0       0.94      0.89      0.91       744
         1.0       0.45      0.60      0.52       110

    accuracy                

# Run 3D CNN

In [18]:
voxels_list = data.generate_temporal_voxels(nn_dataset, lookback=3)

In [19]:
# 1. Generate the 4D Voxels (Channels, Time, Height, Width)
# Lookback 3 = 4 frames total (t, t-1, t-2, t-3)
voxels_list = data.generate_temporal_voxels(nn_dataset, lookback=3)

# 2. Add as a column
nn_dataset['temporal_voxel'] = voxels_list

# 3. Create your Train/Test Split
# We stratify on the event target to keep class balance
train_df, test_df = train_test_split(
    nn_dataset, 
    test_size=0.2, 
    random_state=42, 
    stratify=nn_dataset['nn_target_int']
)

In [20]:
train_dataset_3d = dataset.VoxelPitchDataset(train_df)
test_dataset_3d = dataset.VoxelPitchDataset(test_df)

# SMOKE TEST: Check the shape of the first item
voxel, event, goal = train_dataset_3d[0]
print(f"Voxel Shape: {voxel.shape}") # MUST be [3, 4, 12, 8]

Voxel Shape: torch.Size([3, 4, 12, 8])


In [21]:
model_3d = train.train_3d_model(
    dataset=train_dataset_3d,
    event_class_weights=class_weights_event, # Ensure these are pre-calculated tensors
    goal_pos_weight=goal_pos_weight,
    epochs=15,
    batch_size=32 
)

3D CNN Epoch 1: 100%|██████████| 1948/1948 [00:09<00:00, 195.23it/s, ev_loss=0.3014, loss=0.3014]
3D CNN Epoch 2: 100%|██████████| 1948/1948 [00:10<00:00, 192.82it/s, ev_loss=0.7456, loss=1.0923]
3D CNN Epoch 3: 100%|██████████| 1948/1948 [00:10<00:00, 191.44it/s, ev_loss=0.8550, loss=1.2016]
3D CNN Epoch 4: 100%|██████████| 1948/1948 [00:10<00:00, 191.42it/s, ev_loss=0.3000, loss=0.3000]
3D CNN Epoch 5: 100%|██████████| 1948/1948 [00:10<00:00, 192.91it/s, ev_loss=0.2419, loss=0.2419]
3D CNN Epoch 6: 100%|██████████| 1948/1948 [00:10<00:00, 193.15it/s, ev_loss=0.2275, loss=0.5741]
3D CNN Epoch 7: 100%|██████████| 1948/1948 [00:10<00:00, 188.95it/s, ev_loss=0.1918, loss=0.5383]
3D CNN Epoch 8: 100%|██████████| 1948/1948 [00:10<00:00, 194.30it/s, ev_loss=0.4331, loss=0.7796]
3D CNN Epoch 9: 100%|██████████| 1948/1948 [00:10<00:00, 189.99it/s, ev_loss=0.2688, loss=0.6153]
3D CNN Epoch 10: 100%|██████████| 1948/1948 [00:10<00:00, 193.16it/s, ev_loss=0.4048, loss=0.4048]
3D CNN Epoch 11: 10

In [22]:
y_ev_true, y_ev_pred, y_goal_true, y_goal_probs = evaluate.get_3d_predictions(model_3d, test_dataset_3d, DEVICE)

In [23]:
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, roc_auc_score

print("--- 3D CNN EVENT CONFUSION MATRIX ---")
ev_cm = confusion_matrix(y_ev_true, y_ev_pred)
print(ev_cm)

print("\n--- CLASSIFICATION REPORT ---")
target_names = ['Keep Possession', 'Lose Possession', 'Shot']
print(classification_report(y_ev_true, y_ev_pred, target_names=target_names))

print(f"Balanced Accuracy: {balanced_accuracy_score(y_ev_true, y_ev_pred):.4f}")

--- 3D CNN EVENT CONFUSION MATRIX ---
[[ 219 8412 1594]
 [  80 3791  633]
 [  12  445  397]]

--- CLASSIFICATION REPORT ---
                 precision    recall  f1-score   support

Keep Possession       0.70      0.02      0.04     10225
Lose Possession       0.30      0.84      0.44      4504
           Shot       0.15      0.46      0.23       854

       accuracy                           0.28     15583
      macro avg       0.39      0.44      0.24     15583
   weighted avg       0.56      0.28      0.17     15583

Balanced Accuracy: 0.4427


In [24]:
print("--- 3D CNN GOAL CONFUSION MATRIX ---")
# convert continuous probabilities to binary predictions (threshold = 0.5)
y_goal_pred = (y_goal_probs >= 0.5).astype(int)
g_cm = confusion_matrix(y_goal_true.astype(int), y_goal_pred)
print(g_cm)

--- 3D CNN GOAL CONFUSION MATRIX ---
[[15473     0]
 [  110     0]]


In [25]:
auc_score = roc_auc_score(y_goal_true, y_goal_probs)
print(f"3D CNN Goal AUC-ROC: {auc_score:.4f}")

3D CNN Goal AUC-ROC: 0.5054
