In [1]:
import sys
import os
import importlib
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# --- 1. Path Setup ---
# Absolute path to repo root (adjust if necessary)
repo_root = "/files/pixlball"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root) 

# --- 2. Project Module Imports ---
# Import all project modules using clean names
import src.config as config
import src.dataset as dataset
import src.train as train
import src.evaluate as evaluate
import src.data as data
import src.losses as losses
import src.model as model
import src.utils as utils

# --- 3. Module Reloading (CRITICAL for Notebook Development) ---
# Reload dependencies in order: Config/Utils -> Data/Losses/Model -> Train/Dataset/Evaluate
importlib.reload(config)
importlib.reload(utils)
importlib.reload(data)
importlib.reload(model)
importlib.reload(losses) 
importlib.reload(dataset)
importlib.reload(train)
importlib.reload(evaluate)

# --- 4. Direct Imports (For clean code in subsequent cells) ---
# Import essential classes and functions needed for the pipeline steps

# Configuration
from src.config import DEVICE 

# Data/Dataset Classes
from src.dataset import PitchDatasetMultiTask, TemporalPitchDataset, ContextPitchDatasetMultiTask, FusionPitchDataset

# Training Functions
from src.train import train_model_base_threat

# Evaluation/Helpers
from src.evaluate import evaluate_model_base_threat
from src.losses import get_model_criteria
from src.model import TinyCNN_MultiTask_Threat
from src.utils import get_sequence_lengths

# --- Final Check ---
print(f"Using device: {DEVICE}")

Using device: cpu


In [2]:
data_events = pd.read_parquet(os.path.join(repo_root, "data", "events_data.parquet"), engine="fastparquet")
data_360 = pd.read_parquet(os.path.join(repo_root, "data", "sb360_data.parquet"), engine="fastparquet")

In [3]:
admin_events = [
        'Starting XI', 'Half Start', 'Half End', 'Player On', 'Player Off',
        'Substitution', 'Tactical Shift', 'Referee Ball-Drop', 'Injury Stoppage',
        'Bad Behaviour', 'Shield'
    ]

cleaned_df = data.drop_events(data_events, rows_to_drop=admin_events)

1278 events.


In [4]:
# -----------------------------
# Example usage
# -----------------------------
columns_to_drop = ['clearance_body_part',
                   'clearance_head',
                   'clearance_left_foot',
                   'clearance_other',
                   'clearance_right_foot',
                   'shot_technique',
                   'substitution_replacement_id',
                   'substitution_replacement',
                   'substitution_outcome',
                   'shot_saved_off_target',
                   'pass_miscommunication',
                   'goalkeeper_shot_saved_off_target',
                   'goalkeeper_punched_out',
                   'shot_first_time',
                   'shot_first_time',
                   'shot_body_part',
                   'related_events',
                   'pass_shot_assist', 
                   'pass_straight', 
                   'pass_switch', 
                   'pass_technique', 
                   'pass_through_ball',
                   'goalkeeper_body_part',
                   'goalkeeper_end_location', 
                   'goalkeeper_outcome', 
                   'goalkeeper_position', 
                   'goalkeeper_technique', 
                   'goalkeeper_type', 
                   'goalkeeper_penalty_saved_to_post', 
                   'goalkeeper_shot_saved_to_post', 
                   'goalkeeper_lost_out', 
                   'goalkeeper_Clear', 
                   'goalkeeper_In Play Safe',
                   'shot_key_pass_id',
                   'shot_one_on_one',
                   'shot_end_location',
                   'shot_type',
                   'pass_angle',
                   'pass_body_part',
                   'pass_type',
                   'pass_length',
                   'pass_outswinging',
                   'pass_inswinging',
                   'pass_cross', 
                   'pass_cut_back', 
                   'pass_deflected', 
                   'pass_goal_assist', 
                   'pass_recipient', 
                   'pass_recipient_id', 
                   'pass_assisted_shot_id', 
                   'pass_no_touch', 
                   'pass_end_location', 
                   'pass_aerial_won',
                   'pass_height',
                   'substitution_outcome_id',
                   'tactics',
                   'block_deflection',
                   'dribble_no_touch',
                   'shot_open_goal', 
                   'shot_saved_to_post',
                   'shot_redirect', 
                   'shot_follows_dribble',
                   'period',
                   'injury_stoppage_in_chanin',
                   'block_save_block',
                   'ball recovery_offensive',


                   ]
cleaned_df = data.drop_columns(cleaned_df, columns_to_drop)

# add lookahead outcome
df_with_targets = data.assign_lookahead_outcomes(cleaned_df, lookahead=6)


counts of each outcome nn_target
Keep Possession    71251
Lose Possession    28252
Shot                4830
Name: count, dtype: int64


# Prepare 360 Data

In [5]:
df_360 = data.assign_grid_cells(data_360)
nn_final = data.aggregate_nn_layers_vectorized(df_360)

# Finalize NN Df

In [6]:
nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model

# Neural Network final Data Prep

In [7]:
context_cols = [
    'under_pressure', 
    'counterpress', 
    'dribble_nutmeg'
]

# Impute NaN values with 0.0 (float)
# This assumes NaN means the event was NOT under pressure, NOT a counterpress, etc.
nn_dataset[context_cols] = nn_dataset[context_cols].fillna(0.0)


target_map = {"Keep Possession": 0, "Lose Possession": 1, "Shot": 2}

# Apply mapping
nn_dataset['nn_target_int'] = nn_dataset['nn_target'].map(target_map)

# Check
print(nn_dataset[['nn_target', 'nn_target_int']].head())

         nn_target  nn_target_int
0  Keep Possession              0
1  Keep Possession              0
2  Keep Possession              0
3  Keep Possession              0
4  Keep Possession              0


In [8]:
from sklearn.model_selection import train_test_split
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]

# Define the three arrays to split
X_features = nn_dataset[layer_columns] # Example feature set
event_targets = nn_dataset['nn_target_int'].values
goal_flags = nn_dataset['goal_flag'].values

VALIDATION_SIZE = 0.20
RANDOM_SEED = 42

# CRITICAL: Assign the 6 returned arrays to 6 descriptive variables
(
    X_train, 
    X_val, 
    y_event_train, 
    y_event_val, 
    y_goal_train, 
    y_goal_val
) = train_test_split(
    X_features, 
    event_targets, 
    goal_flags,
    test_size=VALIDATION_SIZE, 
    random_state=RANDOM_SEED,
    stratify=event_targets 
)

# -------------------------------------------------------------
# 2. Instantiate the two PitchDataset objects (using the 6 arrays)
# -------------------------------------------------------------

# Training Dataset (uses all three 'train' arrays)
train_dataset = PitchDatasetMultiTask(
    X_train, 
    y_event_train, 
    y_goal_train
)

# Validation Dataset (uses all three 'val' arrays)
validation_dataset = PitchDatasetMultiTask(
    X_val, 
    y_event_val, 
    y_goal_val
)

print(f"Total training samples: {len(train_dataset)}")
print(f"Total validation samples: {len(validation_dataset)}")

Total training samples: 72552
Total validation samples: 18138


# The Goal Multi Task CNN

In [9]:
# -----------------------------
# Check device
# -----------------------------
print(f"Using device: {DEVICE}")

# ------------------------------------
# 1. Define input columns & targets
# ------------------------------------
# This assumes nn_dataset is already loaded and processed in previous cells.

# Ensure labels are in the correct format
event_targets = nn_dataset['nn_target_int'].values   # 0=keep, 1=lose, 2=shot (int)
# CRITICAL: Goal flags must be float for BCEWithLogitsLoss
goal_flags = nn_dataset['goal_flag'].values.astype(np.float32) 

# -----------------------------
# 2. Prepare dataset (Static Input)
# -----------------------------
# PitchDatasetMultiTask correctly uses the 3 static layer columns.
# ------------------------------------
# 3. Compute class weights and positive weight
# ------------------------------------

# A. Event Weights (Multi-Class) - For CrossEntropyLoss
event_counts = Counter(event_targets)
total_events = len(event_targets)

# Using inverse frequency: total / count
class_weights_event = torch.tensor(
    [total_events / event_counts.get(c, 1) for c in range(len(event_counts))],
    dtype=torch.float32
).to(DEVICE)

# B. Goal Positive Weight (Binary) - For BCEWithLogitsLoss
goal_counts = Counter(goal_flags)

STABLE_GOAL_POS_WEIGHT = 5.0
goal_pos_weight = torch.tensor(STABLE_GOAL_POS_WEIGHT, dtype=torch.float32).to(config.DEVICE)

print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")

# ------------------------------------
# 4. Train model (Using the dedicated base function)
# ------------------------------------
print("Starting training for Static CNN Baseline...")
baseline_model = train_model_base_threat(
    dataset=train_dataset, 
    event_class_weights=class_weights_event, 
    goal_pos_weight=goal_pos_weight
)
print("Training complete.")



Using device: cpu
Goal Positive Weight (0/1 ratio): 5.00
Starting training for Static CNN Baseline...


Base CNN Threat Epoch 1: 100%|██████████| 2268/2268 [00:13<00:00, 169.97it/s, event_loss=0.954, loss=0.954]
Base CNN Threat Epoch 2: 100%|██████████| 2268/2268 [00:12<00:00, 177.08it/s, event_loss=0.464, loss=1.15] 
Base CNN Threat Epoch 3: 100%|██████████| 2268/2268 [00:12<00:00, 178.20it/s, event_loss=0.628, loss=1.24] 
Base CNN Threat Epoch 4: 100%|██████████| 2268/2268 [00:12<00:00, 181.81it/s, event_loss=1.13, loss=1.13]  
Base CNN Threat Epoch 5: 100%|██████████| 2268/2268 [00:12<00:00, 175.41it/s, event_loss=0.607, loss=0.773]
Base CNN Threat Epoch 6: 100%|██████████| 2268/2268 [00:13<00:00, 163.37it/s, event_loss=1.12, loss=1.12]  
Base CNN Threat Epoch 7: 100%|██████████| 2268/2268 [00:13<00:00, 167.39it/s, event_loss=0.173, loss=0.596]
Base CNN Threat Epoch 8: 100%|██████████| 2268/2268 [00:12<00:00, 174.72it/s, event_loss=1.9, loss=5.14]   
Base CNN Threat Epoch 9: 100%|██████████| 2268/2268 [00:12<00:00, 177.71it/s, event_loss=1.03, loss=1.03]  
Base CNN Threat Epoch 10: 10

Training complete.





In [10]:
# -----------------------------
# 5. Evaluate model
# -----------------------------
metrics = evaluate_model_base_threat(baseline_model, validation_dataset)
# print(metrics)



--- Event Outcome Metrics ---
Event Accuracy: 0.5772962840445474
Event Balanced Accuracy: 0.5991002910321375
Event Confusion Matrix:
 [[7712 3398 1422]
 [1848 2105  764]
 [ 118  117  654]]
              precision    recall  f1-score   support

           0       0.80      0.62      0.69     12532
           1       0.37      0.45      0.41      4717
           2       0.23      0.74      0.35       889

    accuracy                           0.58     18138
   macro avg       0.47      0.60      0.48     18138
weighted avg       0.66      0.58      0.60     18138


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8312710911136107
Goal Balanced Accuracy: 0.7092140979331866
Goal AUC-ROC Score: 0.7490862846517194
Goal Confusion Matrix:
 [[673  94]
 [ 56  66]]
              precision    recall  f1-score   support

         0.0       0.92      0.88      0.90       767
         1.0       0.41      0.54      0.47       122

    accuracy                           0.83       889
   macro a

In [11]:
# Assuming the result of your evaluation is stored here:
print("--- Event Classification Probabilities (P_outcome) ---")
event_probs = metrics['event_probs']
print(f"Shape of Event Probabilities (N, 3): {event_probs.shape}")

# Average predicted probability for each class (overall confidence)
avg_P_keep = np.mean(event_probs[:, 0])
avg_P_lose = np.mean(event_probs[:, 1])
avg_P_shot = np.mean(event_probs[:, 2])

print(f"Average Predicted P(Keep Possession): {avg_P_keep:.4f}")
print(f"Average Predicted P(Lose Possession): {avg_P_lose:.4f}")
print(f"Average Predicted P(Shot): {avg_P_shot:.4f}")

print("\n--- Goal Prediction Probabilities (xG) ---")
goal_probs = metrics['goal_probs']
print(f"Number of Shots Evaluated: {goal_probs.shape[0]}")

# Average Predicted xG
avg_xg = np.mean(goal_probs)
print(f"Average Predicted xG per Shot: {avg_xg:.4f}")

# Total Predicted xG (sum of all probabilities for the shot events)
total_xg = np.sum(goal_probs)
print(f"Total Predicted xG for all Shots: {total_xg:.2f}")

# The actual number of goals scored in the test set (True Goals)
true_goals = np.sum(metrics['goal_labels'])
print(f"Actual Goals Scored (True Goals): {true_goals:.2f}")

# Print AUC Score (Should be in your metrics dictionary now)
print(f"Goal Prediction AUC-ROC Score: {metrics.get('goal_auc', 'N/A')}")

--- Event Classification Probabilities (P_outcome) ---
Shape of Event Probabilities (N, 3): (18138, 3)
Average Predicted P(Keep Possession): 0.4683
Average Predicted P(Lose Possession): 0.3879
Average Predicted P(Shot): 0.1438

--- Goal Prediction Probabilities (xG) ---
Number of Shots Evaluated: 889
Average Predicted xG per Shot: 0.1902
Total Predicted xG for all Shots: 169.10
Actual Goals Scored (True Goals): 122.00
Goal Prediction AUC-ROC Score: 0.7490862846517194


In [12]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm

# Define the columns that cause the Parquet error
LAYER_COLUMNS_TO_DROP = ["ball_layer", "teammates_layer", "opponents_layer"]

def predict_and_save_probabilities(
    model, 
    full_dataset, 
    original_df: pd.DataFrame, 
    output_filepath: str,
    device: str = 'cpu',
    batch_size: int = 1024
) -> pd.DataFrame:
    """
    Runs the model, computes probabilities, assigns them to the original DataFrame,
    drops the problematic pitch layer columns, and saves the result as a Parquet file.
    """
    print(f"Starting prediction on {len(full_dataset)} samples...")

    model.eval() 
    model.to(device)
    pred_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False)
    
    event_probs_list = []
    goal_probs_list = []
    
    with torch.no_grad():
        for X, _, _ in tqdm(pred_loader, desc="Predicting Probabilities"):
            # If using Contextual Model, adjust the unpacking: for X, ctx, _, _ in ...
            X = X.to(device)
            event_logits, goal_logits = model(X)
            
            event_probs = F.softmax(event_logits, dim=1) 
            event_probs_list.append(event_probs.cpu().numpy())
            
            goal_probs = torch.sigmoid(goal_logits)
            goal_probs_list.append(goal_probs.cpu().numpy())

    all_event_probs = np.concatenate(event_probs_list, axis=0)
    all_goal_probs = np.concatenate(goal_probs_list, axis=0).flatten()

    # 1. Assign new columns to the original DataFrame
    result_df = original_df.copy()
    
    result_df.loc[:, 'P_Lose'] = all_event_probs[:, 1]
    result_df.loc[:, 'P_Keep'] = all_event_probs[:, 0]
    result_df.loc[:, 'P_Shot'] = all_event_probs[:, 2]
    result_df.loc[:, 'xG'] = all_goal_probs
    
    # 2. CRITICAL FIX: Drop the complex object columns before saving!
    # These columns contain lists-of-lists (the pitch layers) which Parquet cannot serialize.
    columns_to_keep = [col for col in result_df.columns if col not in LAYER_COLUMNS_TO_DROP]
    final_df_to_save = result_df[columns_to_keep]

    # 3. Save the enriched DataFrame to Parquet
    final_df_to_save.to_parquet(output_filepath, index=False)
    
    print(f"\n✅ Prediction complete. Data saved to: {output_filepath}")
    return final_df_to_save

# Example: Run the function again
# final_df_with_probs = predict_and_save_probabilities(...) # using the fixed function

# --- Example Usage (Requires your context setup) ---
# NOTE: You will need to create the 'full_dataset' object here:
full_dataset = PitchDatasetMultiTask(nn_dataset[layer_columns], event_targets, goal_flags) 

final_df_with_probs = predict_and_save_probabilities(
    model=baseline_model,
    full_dataset=full_dataset,
    original_df=nn_dataset.copy(),
    output_filepath='baseline_cnn_predictions.parquet',
    device=DEVICE,
    batch_size=1024
)

Starting prediction on 90690 samples...


Predicting Probabilities: 100%|██████████| 89/89 [00:01<00:00, 62.03it/s]



✅ Prediction complete. Data saved to: baseline_cnn_predictions.parquet


In [13]:
final_df_with_probs

COLS_TO_DROP = ['match_id', 'possession', 'under_pressure', 'counter_press', 'dribble_nutmeg'] # Assuming these were also duplicated

# 2. Create a clean version of the predictions DF
df_preds_clean = final_df_with_probs.drop(columns=COLS_TO_DROP, errors='ignore')

In [14]:
df_merged = pd.merge(data_events, df_preds_clean, on='id', how='inner')

In [15]:
df_merged

Unnamed: 0,50_50,ball_receipt_outcome,ball_recovery_recovery_failure,carry_end_location,clearance_aerial_won,clearance_body_part,clearance_head,clearance_left_foot,clearance_other,clearance_right_foot,...,goalkeeper_lost_out,shot_follows_dribble,nn_target,goal_flag,counterpress_y,nn_target_int,P_Lose,P_Keep,P_Shot,xG
0,,,,,,,,,,,...,,,Keep Possession,0,0.0,0,0.448441,0.549075,0.002484,0.000005
1,,,,,,,,,,,...,,,Keep Possession,0,0.0,0,0.208276,0.791583,0.000140,0.000005
2,,,,,,,,,,,...,,,Keep Possession,0,0.0,0,0.445620,0.388160,0.166219,0.653905
3,,,,,,,,,,,...,,,Lose Possession,0,0.0,1,0.481475,0.506040,0.012485,0.004109
4,,,,,,,,,,,...,,,Keep Possession,0,0.0,0,0.590819,0.378169,0.031012,0.000002
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90685,"{'outcome': {'id': 1, 'name': 'Lost'}}",,,,,,,,,,...,,,Lose Possession,0,0.0,1,0.466086,0.533402,0.000512,0.033038
90686,"{'outcome': {'id': 1, 'name': 'Lost'}}",,,,,,,,,,...,,,Keep Possession,0,1.0,0,0.316368,0.365292,0.318340,0.000060
90687,"{'outcome': {'id': 4, 'name': 'Won'}}",,,,,,,,,,...,,,Keep Possession,0,0.0,0,0.457215,0.444297,0.098488,0.001303
90688,,,,,,,,,,,...,,,Lose Possession,0,0.0,1,0.558619,0.441226,0.000154,0.002911


In [16]:
# 1. Define the team name
TEAM_NAME = "Switzerland Women's"

# 2. Filter the DataFrame where Switzerland Women's is either the home or away team
switzerland_matches = df_merged[
    (df_merged['team'] == TEAM_NAME)]

# 3. Get all unique match IDs
unique_match_ids = switzerland_matches['match_id'].unique()

# 4. Print the result
print(unique_match_ids)

[4018356 3998852 3998844 3998837]


In [17]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def get_possession_sequence(
    df_merged: pd.DataFrame, 
    match_id: int, 
    possession_id: int
) -> pd.DataFrame:
    """
    Filters the merged DataFrame (raw event data + model predictions)
    to return a single, chronologically sorted possession sequence.

    Args:
        df_merged (pd.DataFrame): The pre-loaded DataFrame containing all merged data.
        match_id (int): The match identifier to filter on.
        possession_id (int): The possession identifier to filter on.

    Returns:
        pd.DataFrame: A filtered and sorted sequence DataFrame.
    """
    
    # 1. Filter for the Specific Possession Sequence
    # This filters the full DataFrame down to just the events of interest
    df_sequence = df_merged[
        (df_merged['match_id'] == match_id) & 
        (df_merged['possession_id'] == possession_id)
    ].copy()
    
    if df_sequence.empty:
        print(f"⚠️ Warning: Sequence not found for Match ID {match_id}, Possession ID {possession_id}.")
        return pd.DataFrame()

    # 2. Sort the sequence chronologically
    # Assumes a column like 'event_sequence_in_possession' or a reliable timestamp exists.
    # If not, use df_sequence.sort_values(by='timestamp', inplace=True)
    if 'event_sequence_in_possession' in df_sequence.columns:
        df_sequence.sort_values(by='event_sequence_in_possession', inplace=True)
    
    # 3. Create a clean chronological index for plotting (X-axis)
    df_sequence['seq_index'] = np.arange(len(df_sequence))

    print(f"Sequence extracted with {len(df_sequence)} events, ready for plotting.")
    return df_sequence

# --- Next Step: Visualization Function ---

def plot_possession_threat_stack(df_sequence: pd.DataFrame, title_suffix: str = ""):
    """
    Generates a Stacked Area Chart for the Event Head probabilities (P_outcome).
    """
    if df_sequence.empty:
        print("Cannot plot: DataFrame is empty.")
        return

    events = df_sequence['seq_index']
    
    # Ensure probabilities are present and in the correct order for stacking (Lose at the bottom)
    # The stackplot inherently calculates the cumulative lines you requested.
    y_lose = df_sequence['P_Lose'].values
    y_keep = df_sequence['P_Keep'].values
    y_shot = df_sequence['P_Shot'].values
    
    fig, ax = plt.subplots(figsize=(12, 6))

    ax.stackplot(
        events,
        y_lose,
        y_keep,
        y_shot,
        labels=['P(Lose Possession)', 'P(Keep Possession)', 'P(Shot)'],
        colors=['#ff7f0e', '#1f77b4', '#2ca02c'], # Orange, Blue, Green
        alpha=0.8
    )

    # Add xG values as a secondary line plot for context
    ax.plot(events, df_sequence['xG'].values, color='red', linestyle='--', linewidth=2, label='xG (P(Goal) | Shot)')

    # --- Add Labels and Title ---
    match_id = df_sequence['match_id'].iloc[0]
    possession_id = df_sequence['possession_id'].iloc[0]
    
    ax.set_xlabel(f"Event Index (Relative to Possession Start) | Total Events: {len(df_sequence)}", fontsize=12)
    ax.set_ylabel("Probability / Risk Profile")
    ax.set_title(f"Threat Model Output (P_outcome) for Match {match_id}, Possession {possession_id} {title_suffix}", fontsize=14)
    
    ax.legend(loc='upper right', frameon=True)
    ax.set_ylim(0, 1.0)
    ax.grid(True, linestyle='--', alpha=0.6)
    
    # Set X-ticks clearly for every 5th event, or just the start/end if the possession is very long
    if len(events) < 30:
        ax.set_xticks(events[::2])
    
    plt.tight_layout()
    plt.show()

# --- Example Usage (How you would run this in your notebook) ---

# 1. ASSUME df_merged IS AVAILABLE
# 2. Define your target sequence
# TARGET_MATCH = 12345
# TARGET_POSSESSION = 50

# 3. Get the sequence data
# sequence_data = get_possession_sequence(
#     df_merged=df_merged,
#     match_id=TARGET_MATCH,
#     possession_id=TARGET_POSSESSION
# )

# 4. Plot the results
# if not sequence_data.empty:
#     plot_possession_threat_stack(sequence_data)