In [1]:
import sys
import os
import importlib
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# --- 1. Path Setup ---
# Absolute path to repo root (adjust if necessary)
repo_root = "/files/pixlball"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root) 

# --- 2. Project Module Imports ---
# Import all project modules using clean names
import src.config as config
import src.dataset as dataset
import src.train as train
import src.evaluate as evaluate
import src.data as data
import src.losses as losses
import src.model as model
import src.utils as utils

# --- 3. Module Reloading (CRITICAL for Notebook Development) ---
# Reload dependencies in order: Config/Utils -> Data/Losses/Model -> Train/Dataset/Evaluate
importlib.reload(config)
importlib.reload(utils)
importlib.reload(data)
importlib.reload(model)
importlib.reload(losses) 
importlib.reload(dataset)
importlib.reload(train)
importlib.reload(evaluate)

# --- 4. Direct Imports (For clean code in subsequent cells) ---
# Import essential classes and functions needed for the pipeline steps

# Configuration
from src.config import DEVICE 

# Data/Dataset Classes
from src.dataset import PitchDatasetMultiTask, TemporalPitchDataset, ContextPitchDatasetMultiTask, FusionPitchDataset

# Training Functions
from src.train import train_model_base, train_model_lstm, train_model_context, train_model_lstm_fused

# Evaluation/Helpers
from src.evaluate import evaluate_model_base, evaluate_model_lstm, evaluate_model_context, evaluate_model_lstm_fused
from src.losses import get_model_criteria
from src.model import TinyCNN_MultiTask, HybridCNN_LSTM, TinyCNN_LSTM_Fused
from src.utils import get_sequence_lengths

# --- Final Check ---
print(f"Using device: {DEVICE}")

Using device: cpu


In [2]:
data_events = pd.read_parquet(os.path.join(repo_root, "data", "events_data.parquet"), engine="fastparquet")
data_360 = pd.read_parquet(os.path.join(repo_root, "data", "sb360_data.parquet"), engine="fastparquet")

In [3]:
admin_events = [
        'Starting XI', 'Half Start', 'Half End', 'Player On', 'Player Off',
        'Substitution', 'Tactical Shift', 'Referee Ball-Drop', 'Injury Stoppage',
        'Bad Behaviour', 'Shield'
    ]

cleaned_df = data.drop_events(data_events, rows_to_drop=admin_events)

1278 events.


In [4]:
# -----------------------------
# Example usage
# -----------------------------
columns_to_drop = ['clearance_body_part',
                   'clearance_head',
                   'clearance_left_foot',
                   'clearance_other',
                   'clearance_right_foot',
                   'shot_technique',
                   'substitution_replacement_id',
                   'substitution_replacement',
                   'substitution_outcome',
                   'shot_saved_off_target',
                   'pass_miscommunication',
                   'goalkeeper_shot_saved_off_target',
                   'goalkeeper_punched_out',
                   'shot_first_time',
                   'shot_first_time',
                   'shot_body_part',
                   'related_events',
                   'pass_shot_assist', 
                   'pass_straight', 
                   'pass_switch', 
                   'pass_technique', 
                   'pass_through_ball',
                   'goalkeeper_body_part',
                   'goalkeeper_end_location', 
                   'goalkeeper_outcome', 
                   'goalkeeper_position', 
                   'goalkeeper_technique', 
                   'goalkeeper_type', 
                   'goalkeeper_penalty_saved_to_post', 
                   'goalkeeper_shot_saved_to_post', 
                   'goalkeeper_lost_out', 
                   'goalkeeper_Clear', 
                   'goalkeeper_In Play Safe',
                   'shot_key_pass_id',
                   'shot_one_on_one',
                   'shot_end_location',
                   'shot_type',
                   'pass_angle',
                   'pass_body_part',
                   'pass_type',
                   'pass_length',
                   'pass_outswinging',
                   'pass_inswinging',
                   'pass_cross', 
                   'pass_cut_back', 
                   'pass_deflected', 
                   'pass_goal_assist', 
                   'pass_recipient', 
                   'pass_recipient_id', 
                   'pass_assisted_shot_id', 
                   'pass_no_touch', 
                   'pass_end_location', 
                   'pass_aerial_won',
                   'pass_height',
                   'substitution_outcome_id',
                   'tactics',
                   'block_deflection',
                   'dribble_no_touch',
                   'shot_open_goal', 
                   'shot_saved_to_post',
                   'shot_redirect', 
                   'shot_follows_dribble',
                   'period',
                   'injury_stoppage_in_chanin',
                   'block_save_block',
                   'ball recovery_offensive',


                   ]
cleaned_df = data.drop_columns(cleaned_df, columns_to_drop)

# add lookahead outcome
df_with_targets = data.assign_lookahead_outcomes(cleaned_df, lookahead=6)


counts of each outcome nn_target
Keep Possession    71251
Lose Possession    28252
Shot                4830
Name: count, dtype: int64


# Prepare 360 Data

In [5]:
df_360 = data.assign_grid_cells(data_360)
nn_final = data.aggregate_nn_layers_vectorized(df_360)

# Finalize NN Df

In [6]:
nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model

# Neural Network final Data Prep

In [7]:
context_cols = [
    'under_pressure', 
    'counterpress', 
    'dribble_nutmeg'
]

# Impute NaN values with 0.0 (float)
# This assumes NaN means the event was NOT under pressure, NOT a counterpress, etc.
nn_dataset[context_cols] = nn_dataset[context_cols].fillna(0.0)


target_map = {"Keep Possession": 0, "Lose Possession": 1, "Shot": 2}

# Apply mapping
nn_dataset['nn_target_int'] = nn_dataset['nn_target'].map(target_map)

# Check
print(nn_dataset[['nn_target', 'nn_target_int']].head())

         nn_target  nn_target_int
0  Keep Possession              0
1  Keep Possession              0
2  Keep Possession              0
3  Keep Possession              0
4  Keep Possession              0


# The Goal Multi Task CNN

In [8]:
# -----------------------------
# Check device
# -----------------------------
print(f"Using device: {DEVICE}")

# ------------------------------------
# 1. Define input columns & targets
# ------------------------------------
# This assumes nn_dataset is already loaded and processed in previous cells.
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]

# Ensure labels are in the correct format
event_targets = nn_dataset['nn_target_int'].values   # 0=keep, 1=lose, 2=shot (int)
# CRITICAL: Goal flags must be float for BCEWithLogitsLoss
goal_flags = nn_dataset['goal_flag'].values.astype(np.float32) 

# -----------------------------
# 2. Prepare dataset (Static Input)
# -----------------------------
# PitchDatasetMultiTask correctly uses the 3 static layer columns.
train_dataset = PitchDatasetMultiTask(nn_dataset[layer_columns], event_targets, goal_flags)

# ------------------------------------
# 3. Compute class weights and positive weight
# ------------------------------------

# A. Event Weights (Multi-Class) - For CrossEntropyLoss
event_counts = Counter(event_targets)
total_events = len(event_targets)

# Using inverse frequency: total / count
class_weights_event = torch.tensor(
    [total_events / event_counts.get(c, 1) for c in range(len(event_counts))],
    dtype=torch.float32
).to(DEVICE)

# B. Goal Positive Weight (Binary) - For BCEWithLogitsLoss
goal_counts = Counter(goal_flags)
# CRITICAL: pos_weight = (Number of Negative Samples) / (Number of Positive Samples)
# Here: pos_weight = (Number of No Goals) / (Number of Goals)
goal_pos_weight = torch.tensor(
    goal_counts.get(0.0, 1) / goal_counts.get(1.0, 1),
    dtype=torch.float32
).to(DEVICE)

print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")

# ------------------------------------
# 4. Train model (Using the dedicated base function)
# ------------------------------------
print("Starting training for Static CNN Baseline...")
baseline_model = train_model_base(
    dataset=train_dataset, 
    event_class_weights=class_weights_event, 
    goal_pos_weight=goal_pos_weight
)
print("Training complete.")



Using device: cpu


Base Epoch 1:   0%|          | 0/9069 [00:00<?, ?it/s]

Goal Positive Weight (0/1 ratio): 168.51
Starting training for Static CNN Baseline...


Base Epoch 1: 100%|██████████| 9069/9069 [00:35<00:00, 253.56it/s, event_loss=1.07, loss=1.07, shot_loss=0]    
Base Epoch 2: 100%|██████████| 9069/9069 [00:36<00:00, 248.05it/s, event_loss=0.907, loss=0.907, shot_loss=0]   

Training complete.





In [9]:

# -----------------------------
# 5. Evaluate model
# -----------------------------
metrics = evaluate_model_base(baseline_model, train_dataset)
# print(metrics)


Event Accuracy: 0.6381629727643621
Event Balanced Accuracy: 0.5365020106336923
Event Confusion Matrix:
 [[54592     0  8067]
 [18597     0  4987]
 [ 1164     0  3283]]
              precision    recall  f1-score   support

           0       0.73      0.87      0.80     62659
           1       0.00      0.00      0.00     23584
           2       0.20      0.74      0.32      4447

    accuracy                           0.64     90690
   macro avg       0.31      0.54      0.37     90690
weighted avg       0.52      0.64      0.57     90690

Goal Accuracy: 0.12030582415111311
Goal Balanced Accuracy: 0.5
Goal Confusion Matrix:
 [[   0 3912]
 [   0  535]]
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00      3912
         1.0       0.12      1.00      0.21       535

    accuracy                           0.12      4447
   macro avg       0.06      0.50      0.11      4447
weighted avg       0.01      0.12      0.03      4447



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# The LSTM

In [10]:
# Get your existing targets
event_targets = nn_dataset['nn_target_int'].values
goal_flags = nn_dataset['goal_flag'].values.astype(np.float32) # Ensure targets are float

# Using inverse frequency: total / count
class_weights_event = torch.tensor(
    [total_events / event_counts.get(c, 1) for c in range(len(event_counts))],
    dtype=torch.float32
).to(DEVICE)

# B. Goal Positive Weight (Binary) - For BCEWithLogitsLoss
goal_counts = Counter(goal_flags)
# CRITICAL: pos_weight = (Number of Negative Samples) / (Number of Positive Samples)
# Here: pos_weight = (Number of No Goals) / (Number of Goals)
goal_pos_weight = torch.tensor(
    goal_counts.get(0.0, 1) / goal_counts.get(1.0, 1),
    dtype=torch.float32
).to(DEVICE)


In [11]:
# -----------------------------
# Prepare Sequential Dataset
# -----------------------------
print("Preparing 5D Sequential Dataset for LSTM...")

nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag', 'possession']) # adjust cols depending on model

target_map = {"Keep Possession": 0, "Lose Possession": 1, "Shot": 2}

# Apply mapping
nn_dataset['nn_target_int'] = nn_dataset['nn_target'].map(target_map)

# Check
print(nn_dataset[['nn_target', 'nn_target_int']].head())

windows = data.build_temporal_windows_with_mask(nn_dataset)

# Assuming 'windows' variable holds the output of data.build_temporal_windows_with_mask()
# Shape: (Num_Events, T, 4, H, W)
temporal_dataset = TemporalPitchDataset(
    windows=windows, 
    event_labels=event_targets, 
    goal_flags=goal_flags
)



Preparing 5D Sequential Dataset for LSTM...
         nn_target  nn_target_int
0  Keep Possession              0
1  Keep Possession              0
2  Keep Possession              0
3  Keep Possession              0
4  Keep Possession              0


In [12]:
# -----------------------------
# Train Hybrid CNN-LSTM Model
# -----------------------------
print("Starting training for Hybrid CNN-LSTM...")
lstm_model = train_model_lstm(
    dataset=temporal_dataset, 
    event_class_weights=class_weights_event, 
    goal_pos_weight=goal_pos_weight
)
print("Hybrid CNN-LSTM Training complete.")



LSTM Epoch 1:   0%|          | 14/9069 [00:00<01:05, 137.40it/s, event_loss=1.25, loss=2.11, shot_loss=0.172]

Starting training for Hybrid CNN-LSTM...


LSTM Epoch 1: 100%|██████████| 9069/9069 [01:11<00:00, 127.53it/s, event_loss=1.17, loss=19.1, shot_loss=3.59]
LSTM Epoch 2: 100%|██████████| 9069/9069 [01:11<00:00, 127.53it/s, event_loss=0.95, loss=0.95, shot_loss=0]   

Hybrid CNN-LSTM Training complete.





In [13]:
# -----------------------------
# Evaluate LSTM Model
# -----------------------------
print("\nEvaluating LSTM Model on training data...")
lstm_metrics = evaluate_model_lstm(lstm_model, temporal_dataset)
# print(lstm_metrics)


Evaluating LSTM Model on training data...
Event Accuracy: 0.6909141029882016
Event Balanced Accuracy: 0.3333333333333333
Event Confusion Matrix:
 [[62659     0     0]
 [23584     0     0]
 [ 4447     0     0]]
              precision    recall  f1-score   support

           0       0.69      1.00      0.82     62659
           1       0.00      0.00      0.00     23584
           2       0.00      0.00      0.00      4447

    accuracy                           0.69     90690
   macro avg       0.23      0.33      0.27     90690
weighted avg       0.48      0.69      0.56     90690

Goal Accuracy: 0.12030582415111311
Goal Balanced Accuracy: 0.5
Goal Confusion Matrix:
 [[   0 3912]
 [   0  535]]
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00      3912
         1.0       0.12      1.00      0.21       535

    accuracy                           0.12      4447
   macro avg       0.06      0.50      0.11      4447
weighted avg       0.0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Contextual CNN

In [14]:
nn_dataset

Unnamed: 0,id,ball_layer,teammates_layer,opponents_layer,nn_target,goal_flag,possession,nn_target_int
0,8b621ae4-ea81-415c-af41-9669db9bdd93,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,2,0
1,4706efbe-767c-45aa-9351-09528a77d135,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,2,0
2,084b9a88-4efa-4947-b94d-b89face472be,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,2,0
3,27fa7d4d-d637-4487-98e2-5c078ad600c7,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,2,0
4,764d437f-f799-4489-a38f-69fbb219a6fa,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,2,0
...,...,...,...,...,...,...,...,...
90685,a3f71026-727b-40ea-a4cc-e639174912aa,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,182,1
90686,690d8b4c-7fa4-461b-a24d-ebe97fb2c9ae,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,182,1
90687,86456b25-4108-4725-b2fb-1bb263cf55c9,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,182,1
90688,1415a3dd-034f-46c4-b070-11180a26ed38,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,184,1


In [15]:
# --- 1. Define Context Features ---
# Based on your previous steps, you used 3 binary features.
NUM_CONTEXT_FEATURES = 3 

FINAL_CONTEXTUAL_FEATURES = [
    'under_pressure', 
    'counterpress', 
    'dribble_nutmeg'
]

nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model

target_map = {"Keep Possession": 0, "Lose Possession": 1, "Shot": 2}

# Apply mapping
nn_dataset['nn_target_int'] = nn_dataset['nn_target'].map(target_map)

# Ensure the context DataFrame is prepared (imputed NaN with 0.0)
context_df = nn_dataset[FINAL_CONTEXTUAL_FEATURES].copy().fillna(0.0) 

# Re-check the number of features just to be safe
print(f"Number of Contextual Features being used: {context_df.shape[1]}") 
print(f"Contextual Features Head:\n{context_df.head()}")

# --- 2. Define Inputs and Targets (Same as before) ---
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]
event_targets = nn_dataset['nn_target_int'].values
goal_flags = nn_dataset['goal_flag'].values.astype(np.float32) 

# --- 3. Create Contextual Dataset Instance ---
# CRITICAL: PitchDatasetMultiTask must be updated to accept the 4th input
train_dataset_context = ContextPitchDatasetMultiTask(
    nn_layers_df=nn_dataset[layer_columns], 
    event_targets=event_targets, 
    goal_flags=goal_flags,
    contextual_features_df=context_df # NEW ARGUMENT
)

print(f"Contextual Dataset Size: {len(train_dataset_context)}")
print("Contextual Dataset ready.")

Number of Contextual Features being used: 3
Contextual Features Head:
   under_pressure  counterpress  dribble_nutmeg
0             0.0           0.0             0.0
1             0.0           0.0             0.0
2             1.0           0.0             0.0
3             1.0           0.0             0.0
4             0.0           0.0             0.0
Contextual Dataset Size: 90690
Contextual Dataset ready.


In [16]:
nn_dataset

Unnamed: 0,id,ball_layer,teammates_layer,opponents_layer,nn_target,goal_flag,match_id,possession,under_pressure,counterpress,dribble_nutmeg,nn_target_int
0,8b621ae4-ea81-415c-af41-9669db9bdd93,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,4020846,2,,,,0
1,4706efbe-767c-45aa-9351-09528a77d135,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,4020846,2,,,,0
2,084b9a88-4efa-4947-b94d-b89face472be,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,4020846,2,1.0,,,0
3,27fa7d4d-d637-4487-98e2-5c078ad600c7,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,4020846,2,1.0,,,0
4,764d437f-f799-4489-a38f-69fbb219a6fa,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Keep Possession,0,4020846,2,,,,0
...,...,...,...,...,...,...,...,...,...,...,...,...
90685,a3f71026-727b-40ea-a4cc-e639174912aa,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,3998841,182,,1.0,,1
90686,690d8b4c-7fa4-461b-a24d-ebe97fb2c9ae,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,3998841,182,,1.0,,1
90687,86456b25-4108-4725-b2fb-1bb263cf55c9,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,3998841,182,1.0,,,1
90688,1415a3dd-034f-46c4-b070-11180a26ed38,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0....",Lose Possession,0,3998841,184,,,,1


In [17]:
# Assuming event_class_weights and goal_pos_weight are defined from previous cells

print("Starting training for Contextual CNN Baseline...")

context_baseline_model = train_model_context(
    dataset=train_dataset_context, 
    event_class_weights=class_weights_event, # Use your calculated weights
    goal_pos_weight=goal_pos_weight,         # Use your calculated pos_weight
    num_context_features=NUM_CONTEXT_FEATURES
)

print("\nContextual CNN Training complete.")

Context CNN Epoch 1:   0%|          | 0/9069 [00:00<?, ?it/s, event_loss=1.26, loss=2.51, shot_loss=0.25] 

Starting training for Contextual CNN Baseline...


Context CNN Epoch 1: 100%|██████████| 9069/9069 [00:39<00:00, 231.88it/s, event_loss=0.867, loss=19.2, shot_loss=3.67]
Context CNN Epoch 2: 100%|██████████| 9069/9069 [00:39<00:00, 230.11it/s, event_loss=0.394, loss=17.9, shot_loss=3.49] 


Contextual CNN Training complete.





In [18]:
# Assuming evaluate_model_context is imported and available

print("\nEvaluating Contextual CNN Model...")

context_metrics = evaluate_model_context(
    model=context_baseline_model, 
    dataset=train_dataset_context # Evaluate on the contextual dataset
)

# You can now compare 'context_metrics' with your 'baseline_metrics'
# print(context_metrics)


Evaluating Contextual CNN Model...
Event Accuracy: 0.5433013562686073
Event Balanced Accuracy: 0.5640790903902526
Event Confusion Matrix:
 [[36198 18667  7794]
 [ 9485 10004  4095]
 [  580   797  3070]]
              precision    recall  f1-score   support

           0       0.78      0.58      0.66     62659
           1       0.34      0.42      0.38     23584
           2       0.21      0.69      0.32      4447

    accuracy                           0.54     90690
   macro avg       0.44      0.56      0.45     90690
weighted avg       0.64      0.54      0.57     90690

Goal Accuracy: 0.12030582415111311
Goal Balanced Accuracy: 0.5
Goal Confusion Matrix:
 [[   0 3912]
 [   0  535]]
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00      3912
         1.0       0.12      1.00      0.21       535

    accuracy                           0.12      4447
   macro avg       0.06      0.50      0.11      4447
weighted avg       0.01      

# LSTM CNN with Context

In [19]:
import numpy as np # Ensure NumPy is available

# Assuming you previously defined these:
NUM_CONTEXT_FEATURES = 3 
FINAL_CONTEXTUAL_FEATURES = ['under_pressure', 'counterpress', 'dribble_nutmeg']

# 1. Prepare 1D Context Features (from the static context_df)
# Assuming 'context_df' is defined and NaN-imputed in a previous cell
context_df = nn_dataset[FINAL_CONTEXTUAL_FEATURES].copy().fillna(0.0)
context_features_array = context_df.values # Shape: (N_events, 3)

# 2. Determine the Sequence Length (T)
# CRITICAL FIX: The list 'windows' contains 'N' items. The length of the first item is T.
if isinstance(windows, list) and len(windows) > 0:
    # Use array conversion to safely get the shape of the first sequence item
    T = np.array(windows[0]).shape[0] 
else:
    # If windows is already a NumPy array, use its shape
    T = windows.shape[1] 

print(f"Sequence Length (T): {T}")

# 3. Create the Context Sequence (T-frames for each event)
# Repeat the static context features (N, 3) across the T dimension: (N, T, 3)
# context_features_array is (N, 3). We insert a T-dimension and repeat the feature vector T times.
context_sequence = np.repeat(context_features_array[:, np.newaxis, :], T, axis=1)

print(f"Context Sequence Shape: {context_sequence.shape} (N, T, F)")

# --- 4. Create Fused Dataset Instance (CRITICAL: Reload and Import) ---
# ... rest of your import and dataset creation code ...

fused_dataset = FusionPitchDataset(
    windows=windows, # The FusionPitchDataset class will handle the list to array conversion internally
    contextual_features=context_sequence, 
    event_labels=event_targets, 
    goal_flags=goal_flags
)

print(f"Fused Dataset Size: {len(fused_dataset)}")

Sequence Length (T): 4
Context Sequence Shape: (90690, 4, 3) (N, T, F)
Fused Dataset Size: 90690


In [20]:
# Assuming class_weights_event and goal_pos_weight are defined

print("Starting training for TinyCNN_LSTM_Fused Model...")

fused_lstm_model = train_model_lstm_fused(
    dataset=fused_dataset, 
    event_class_weights=class_weights_event,
    goal_pos_weight=goal_pos_weight,
    num_context_features=NUM_CONTEXT_FEATURES # Pass the number of 1D features
)

print("\nTinyCNN_LSTM_Fused Training complete.")

Fused LSTM Epoch 1:   0%|          | 15/9069 [00:00<01:04, 139.51it/s, event_loss=1.08, loss=1.08, shot_loss=0]    

Starting training for TinyCNN_LSTM_Fused Model...


Fused LSTM Epoch 1: 100%|██████████| 9069/9069 [01:08<00:00, 131.63it/s, event_loss=0.955, loss=0.955, shot_loss=0] 
Fused LSTM Epoch 2: 100%|██████████| 9069/9069 [01:15<00:00, 120.43it/s, event_loss=1.01, loss=1.01, shot_loss=0]   


TinyCNN_LSTM_Fused Training complete.





In [21]:
print("\nEvaluating Fused LSTM Model on training data...")

fused_metrics = evaluate_model_lstm_fused(
    model=fused_lstm_model, 
    dataset=fused_dataset
)


Evaluating Fused LSTM Model on training data...
Event Accuracy: 0.6178409968022935
Event Balanced Accuracy: 0.3699842548543841
Event Confusion Matrix:
 [[47874 14785     0]
 [15426  8158     0]
 [ 3170  1277     0]]
              precision    recall  f1-score   support

           0       0.72      0.76      0.74     62659
           1       0.34      0.35      0.34     23584
           2       0.00      0.00      0.00      4447

    accuracy                           0.62     90690
   macro avg       0.35      0.37      0.36     90690
weighted avg       0.59      0.62      0.60     90690

Goal Accuracy: 0.12030582415111311
Goal Balanced Accuracy: 0.5
Goal Confusion Matrix:
 [[   0 3912]
 [   0  535]]
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00      3912
         1.0       0.12      1.00      0.21       535

    accuracy                           0.12      4447
   macro avg       0.06      0.50      0.11      4447
weighted avg    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
