In [18]:
import sys
import os
import importlib
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# --- 1. Path Setup ---
# Absolute path to repo root (adjust if necessary)
repo_root = "/files/pixlball"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root) 

# --- 2. Project Module Imports ---
# Import all project modules using clean names
import src.config as config
import src.dataset as dataset
import src.train as train
import src.evaluate as evaluate
import src.data as data
import src.losses as losses
import src.model as model
import src.utils as utils

# --- 3. Module Reloading (CRITICAL for Notebook Development) ---
# Reload dependencies in order: Config/Utils -> Data/Losses/Model -> Train/Dataset/Evaluate
importlib.reload(config)
importlib.reload(utils)
importlib.reload(data)
importlib.reload(model)
importlib.reload(losses) 
importlib.reload(dataset)
importlib.reload(train)
importlib.reload(evaluate)

# --- 4. Direct Imports (For clean code in subsequent cells) ---
# Import essential classes and functions needed for the pipeline steps

# Configuration
from src.config import DEVICE 

# Data/Dataset Classes
from src.dataset import PitchDatasetMultiTask, TemporalPitchDataset, ContextPitchDatasetMultiTask, FusionPitchDataset

# Training Functions
from src.train import train_model_base_threat, train_model_context_threat

# Evaluation/Helpers
from src.evaluate import evaluate_model_base_threat, evaluate_model_context_threat
from src.losses import get_model_criteria, FocalLossThreat
from src.model import TinyCNN_MultiTask_Threat
from src.utils import get_sequence_lengths

# --- Final Check ---
print(f"Using device: {DEVICE}")

Using device: cpu


In [2]:
data_events = pd.read_parquet(os.path.join(repo_root, "data", "events_data.parquet"), engine="fastparquet")
data_360 = pd.read_parquet(os.path.join(repo_root, "data", "sb360_data.parquet"), engine="fastparquet")

In [3]:
admin_events = [
        'Starting XI', 'Half Start', 'Half End', 'Player On', 'Player Off',
        'Substitution', 'Tactical Shift', 'Referee Ball-Drop', 'Injury Stoppage',
        'Bad Behaviour', 'Shield', 'Goal Keeper'
    ]

cleaned_df = data.drop_events(data_events, rows_to_drop=admin_events)

2462 events.


In [4]:
# -----------------------------
# Example usage
# -----------------------------
columns_to_drop = ['clearance_body_part',
                   'clearance_head',
                   'clearance_left_foot',
                   'clearance_other',
                   'clearance_right_foot',
                   'shot_technique',
                   'substitution_replacement_id',
                   'substitution_replacement',
                   'substitution_outcome',
                   'shot_saved_off_target',
                   'pass_miscommunication',
                   'goalkeeper_shot_saved_off_target',
                   'goalkeeper_punched_out',
                   'shot_first_time',
                   'shot_first_time',
                   'shot_body_part',
                   'related_events',
                   'pass_shot_assist', 
                   'pass_straight', 
                   'pass_switch', 
                   'pass_technique', 
                   'pass_through_ball',
                   'goalkeeper_body_part',
                   'goalkeeper_end_location', 
                   'goalkeeper_outcome', 
                   'goalkeeper_position', 
                   'goalkeeper_technique', 
                   'goalkeeper_type', 
                   'goalkeeper_penalty_saved_to_post', 
                   'goalkeeper_shot_saved_to_post', 
                   'goalkeeper_lost_out', 
                   'goalkeeper_Clear', 
                   'goalkeeper_In Play Safe',
                   'shot_key_pass_id',
                   'shot_one_on_one',
                   'shot_end_location',
                   'shot_type',
                   'pass_angle',
                   'pass_body_part',
                   'pass_type',
                   'pass_length',
                   'pass_outswinging',
                   'pass_inswinging',
                   'pass_cross', 
                   'pass_cut_back', 
                   'pass_deflected', 
                   'pass_goal_assist', 
                   'pass_recipient', 
                   'pass_recipient_id', 
                   'pass_assisted_shot_id', 
                   'pass_no_touch', 
                   'pass_end_location', 
                   'pass_aerial_won',
                   'pass_height',
                   'substitution_outcome_id',
                   'tactics',
                   'block_deflection',
                   'dribble_no_touch',
                   'shot_open_goal', 
                   'shot_saved_to_post',
                   'shot_redirect', 
                   'shot_follows_dribble',
                   'period',
                   'injury_stoppage_in_chanin',
                   'block_save_block',
                   'ball recovery_offensive',


                   ]
cleaned_df = data.drop_columns(cleaned_df, columns_to_drop)

# add lookahead outcome
df_with_targets = data.assign_lookahead_outcomes(cleaned_df, lookahead=6)


counts of each outcome nn_target
Keep Possession    70920
Lose Possession    27465
Shot                4764
Name: count, dtype: int64


In [20]:
from src.data import add_ball_trajectory_features
df_with_targets_2 = add_ball_trajectory_features(df_with_targets)

In [21]:
df_with_targets_2

Unnamed: 0,50_50,ball_receipt_outcome,ball_recovery_recovery_failure,carry_end_location,clearance_aerial_won,counterpress,dribble_nutmeg,dribble_outcome,duel_outcome,duel_type,...,shot_deflected,bad_behaviour_card,block_offensive,foul_committed_offensive,foul_committed_penalty,foul_won_penalty,ball_recovery_offensive,nn_target,goal_flag,ball_trajectory_vector
0,,,,,,,,,,,...,,,,,,,,Lose Possession,0,"[61.0, 40.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
1,,,,,,,,,,,...,,,,,,,,Lose Possession,0,"[46.6, 41.5, 61.0, 40.1, 0.0, 0.0, 0.0, 0.0]"
2,,,,"[46.4, 41.6]",,,,,,,...,,,,,,,,Lose Possession,0,"[46.6, 41.5, 46.6, 41.5, 61.0, 40.1, 0.0, 0.0]"
3,,,,,,,,,,,...,,,,,,,,Lose Possession,0,"[72.9, 40.9, 46.6, 41.5, 46.6, 41.5, 61.0, 40.1]"
4,,,,,,,,,,,...,,,,,,,,Lose Possession,0,"[46.4, 41.6, 72.9, 40.9, 46.6, 41.5, 46.6, 41.5]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
103144,,,,,,,,,,,...,,,,,,,,Shot,1,"[108.1, 40.1, 107.7, 40.1, 108.1, 40.1, 108.1,..."
103145,,,,,,,,,,,...,,,,,,,,Shot,0,"[107.8, 40.1, 108.1, 40.1, 107.7, 40.1, 108.1,..."
103146,,,,,,,,,,,...,,,,,,,,Shot,0,"[107.9, 40.1, 107.8, 40.1, 108.1, 40.1, 107.7,..."
103147,,,,,,,,,,,...,,,,,,,,Shot,0,"[107.9, 40.1, 107.9, 40.1, 107.8, 40.1, 108.1,..."


# Prepare 360 Data

In [5]:
df_360 = data.assign_grid_cells(data_360)
nn_final = data.aggregate_nn_layers_vectorized(df_360)

# Finalize NN Df

In [6]:
nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model

# Neural Network final Data Prep

In [7]:
context_cols = [
    'under_pressure', 
    'counterpress', 
    'dribble_nutmeg'
]

# Impute NaN values with 0.0 (float)
# This assumes NaN means the event was NOT under pressure, NOT a counterpress, etc.
nn_dataset[context_cols] = nn_dataset[context_cols].fillna(0.0)


target_map = {"Keep Possession": 0, "Lose Possession": 1, "Shot": 2}

# Apply mapping
nn_dataset['nn_target_int'] = nn_dataset['nn_target'].map(target_map)



# Add context Columns

FINAL_CONTEXTUAL_FEATURES = [
    'under_pressure', 
    'counterpress', 
    'dribble_nutmeg'
]

context_df = nn_dataset[FINAL_CONTEXTUAL_FEATURES].copy().fillna(0.0) 



# Check
print(nn_dataset.head())


                                     id  \
0  8b621ae4-ea81-415c-af41-9669db9bdd93   
1  4706efbe-767c-45aa-9351-09528a77d135   
2  084b9a88-4efa-4947-b94d-b89face472be   
3  27fa7d4d-d637-4487-98e2-5c078ad600c7   
4  764d437f-f799-4489-a38f-69fbb219a6fa   

                                          ball_layer  \
0  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
1  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
2  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
3  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
4  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   

                                     teammates_layer  \
0  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
1  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
2  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
3  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   
4  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....   

                                     opponents_layer        nn_targ

# The Goal Multi Task CNN

In [8]:
# ------------------------------------
# 1. Define input columns & targets
# ------------------------------------
# This assumes nn_dataset is already loaded and processed in previous cells.
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]

# Ensure labels are in the correct format
event_targets = nn_dataset['nn_target_int'].values   # 0=keep, 1=lose, 2=shot (int)
# CRITICAL: Goal flags must be float for BCEWithLogitsLoss
goal_flags = nn_dataset['goal_flag'].values.astype(np.float32) 

# ------------------------------------
# 3. Compute class weights and positive weight
# ------------------------------------

# A. Event Weights (Multi-Class) - For CrossEntropyLoss
event_counts = Counter(event_targets)
total_events = len(event_targets)

# Using inverse frequency: total / count
class_weights_event = torch.tensor(
    [total_events / event_counts.get(c, 1) for c in range(len(event_counts))],
    dtype=torch.float32
).to(DEVICE)

# B. Goal Positive Weight (Binary) - For BCEWithLogitsLoss
goal_counts = Counter(goal_flags)

STABLE_GOAL_POS_WEIGHT = 3.0
goal_pos_weight = torch.tensor(STABLE_GOAL_POS_WEIGHT, dtype=torch.float32).to(config.DEVICE)

print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")


Goal Positive Weight (0/1 ratio): 3.00


# Preparing the Context CNN

In [9]:
import numpy as np
from sklearn.model_selection import train_test_split

layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]
VALIDATION_SIZE = 0.20
RANDOM_SEED = 42

# --- 1. Define ALL inputs and targets ---
# Input 1: The Grid Layers (X_features)
X_features = nn_dataset[layer_columns].reset_index(drop=True)

# Input 2: The Contextual 1D Features (X_context)
# CRITICAL: Ensure this DataFrame is aligned with X_features
X_context = context_df.reset_index(drop=True)

# Targets
event_targets = nn_dataset['nn_target_int'].values
goal_flags = nn_dataset['goal_flag'].values.astype(np.float32)



# CRITICAL: Assign the 8 returned arrays/DataFrames to 8 descriptive variables
(
    X_feat_train,      # 1. Grid Layers (Train)
    X_feat_val,        # 2. Grid Layers (Validation)
    X_ctx_train,       # 3. Context Features (Train)
    X_ctx_val,         # 4. Context Features (Validation)
    y_event_train,     # 5. Event Targets (Train)
    y_event_val,       # 6. Event Targets (Validation)
    y_goal_train,      # 7. Goal Targets (Train)
    y_goal_val         # 8. Goal Targets (Validation)
) = train_test_split(
    X_features,        # Input 1
    X_context,         # Input 2 (NEW)
    event_targets,     # Input 3
    goal_flags,        # Input 4
    test_size=VALIDATION_SIZE, 
    random_state=RANDOM_SEED,
    stratify=event_targets # Stratify only on the multi-class target
)

# --- 3. Instantiate the two Contextual Dataset objects ---

# Training Dataset (uses four 'train' splits)
train_dataset_context = ContextPitchDatasetMultiTask(
    nn_layers_df=X_feat_train,          # Grid Layers (Train)
    event_targets=y_event_train,        # Event Targets (Train)
    goal_flags=y_goal_train,            # Goal Targets (Train)
    contextual_features_df=X_ctx_train  # Context Features (Train)
)

# Validation Dataset (uses four 'val' splits)
validation_dataset_context = ContextPitchDatasetMultiTask(
    nn_layers_df=X_feat_val,            # Grid Layers (Validation)
    event_targets=y_event_val,          # Event Targets (Validation)
    goal_flags=y_goal_val,              # Goal Targets (Validation)
    contextual_features_df=X_ctx_val    # Context Features (Validation)
)

print(f"Total training samples: {len(train_dataset_context)}")
print(f"Total validation samples: {len(validation_dataset_context)}")

Total training samples: 72117
Total validation samples: 18030


In [14]:
# Assuming event_class_weights and goal_pos_weight are defined from previous cells
NUM_CONTEXT_FEATURES = 3 

print("Starting training for Contextual CNN Baseline...")

# Modified the Function in Loss to take correct loss function -> needs to be changed for baseline model again

context_baseline_model = train_model_context_threat(
    dataset=train_dataset_context, 
    event_class_weights=class_weights_event, # Use your calculated weights
    goal_pos_weight=goal_pos_weight,         # Use your calculated pos_weight
    num_context_features=NUM_CONTEXT_FEATURES
)

print("\nContextual CNN Training complete.")

Context CNN Epoch 1:   0%|          | 0/2254 [00:00<?, ?it/s]

Starting training for Contextual CNN Baseline...


Context CNN Epoch 1: 100%|██████████| 2254/2254 [00:10<00:00, 213.82it/s, event_loss=1.87, loss=4.41, shot_loss=1.69]  
Context CNN Epoch 2: 100%|██████████| 2254/2254 [00:09<00:00, 226.28it/s, event_loss=3.92, loss=5.87, shot_loss=1.3]   
Context CNN Epoch 3: 100%|██████████| 2254/2254 [00:10<00:00, 223.12it/s, event_loss=1.2, loss=1.2, shot_loss=0]       
Context CNN Epoch 4: 100%|██████████| 2254/2254 [00:10<00:00, 215.20it/s, event_loss=1.57, loss=1.57, shot_loss=0]      
Context CNN Epoch 5: 100%|██████████| 2254/2254 [00:10<00:00, 219.74it/s, event_loss=1.23, loss=1.63, shot_loss=0.266]  
Context CNN Epoch 6: 100%|██████████| 2254/2254 [00:10<00:00, 210.07it/s, event_loss=1.79, loss=1.97, shot_loss=0.122]   
Context CNN Epoch 7: 100%|██████████| 2254/2254 [00:10<00:00, 211.53it/s, event_loss=1.75, loss=1.78, shot_loss=0.0156]  
Context CNN Epoch 8: 100%|██████████| 2254/2254 [00:10<00:00, 224.72it/s, event_loss=1.28, loss=1.28, shot_loss=0]       
Context CNN Epoch 9: 100%|██████


Contextual CNN Training complete.





In [15]:
# Assuming evaluate_model_context is imported and available

print("\nEvaluating Contextual CNN Model...")

metrics = evaluate_model_context_threat(
    model=context_baseline_model, 
    dataset=validation_dataset_context # Evaluate on the contextual dataset
)




Evaluating Contextual CNN Model...

--- Event Outcome Metrics ---
Event Accuracy: 0.5253466444814199
Event Balanced Accuracy: 0.6189850992521696
Event Confusion Matrix:
 [[6140 4751 1597]
 [1297 2622  738]
 [  68  107  710]]
              precision    recall  f1-score   support

           0       0.82      0.49      0.61     12488
           1       0.35      0.56      0.43      4657
           2       0.23      0.80      0.36       885

    accuracy                           0.53     18030
   macro avg       0.47      0.62      0.47     18030
weighted avg       0.67      0.53      0.55     18030


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8598870056497175
Goal Balanced Accuracy: 0.776158925214361
Goal AUC-ROC Score: 0.8352515016736209
Goal Confusion Matrix:
 [[686  86]
 [ 38  75]]
              precision    recall  f1-score   support

         0.0       0.95      0.89      0.92       772
         1.0       0.47      0.66      0.55       113

    accuracy                 

In [None]:
import numpy as np
# Assuming metrics contains the result from evaluate_model_context_threat

event_probs = metrics['event_probs']

print("P(Keep) | P(Lose) | P(Shot)")
print("-------------------------------")
print(event_probs[:5])

# You can look at the average predicted probability for the Shot class across all events:
avg_p_shot = np.mean(event_probs[:, 2])
print(f"\nAverage Predicted P(Shot) across all events: {avg_p_shot:.4f}")

In [None]:
import numpy as np
import pandas as pd
# Assuming metrics contains the result from evaluate_model_context_threat

print("--- Goal Prediction Probabilities (xG) Analysis ---")

goal_probs = metrics['goal_probs']
goal_labels = metrics['goal_labels'] # Actual outcome (0=No Goal, 1=Goal)

print(f"Number of Shots Evaluated: {len(goal_probs)}")

# 1. Total xG vs. Actual Goals
total_predicted_xg = np.sum(goal_probs)
total_true_goals = np.sum(goal_labels)
avg_xg_per_shot = np.mean(goal_probs)

print(f"\nTotal Predicted xG: {total_predicted_xg:.2f}")
print(f"Total True Goals Scored: {total_true_goals:.2f}")
print(f"Average Predicted xG per Shot: {avg_xg_per_shot:.4f}")

# 2. Calibration Check (Optional but helpful)
# Compare the average predicted xG for shots that were goals vs. shots that were misses.

# Create a DataFrame for easy slicing
xg_df = pd.DataFrame({'xg': goal_probs, 'goal': goal_labels})

avg_xg_goal = xg_df[xg_df['goal'] == 1]['xg'].mean()
avg_xg_miss = xg_df[xg_df['goal'] == 0]['xg'].mean()

print("\n-- Calibration Check --")
print(f"Average xG for True Goals (should be high): {avg_xg_goal:.4f}")
print(f"Average xG for Missed Shots (should be low): {avg_xg_miss:.4f}")