In [1]:
import sys
import os
import importlib
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# --- 1. Path Setup ---
# Try to locate the repository root by searching upward for a 'src' directory (or .git)
def find_repo_root(start_path=None, marker_dirs=('src', '.git')):
    p = os.path.abspath(start_path or os.getcwd())
    while True:
        if any(os.path.isdir(os.path.join(p, m)) for m in marker_dirs):
            return p
        parent = os.path.dirname(p)
        if parent == p:
            return None
        p = parent

repo_root = find_repo_root()
# Fallback to previous hardcoded path working on nuvolos
if repo_root is None:
    repo_root = "/files/pixlball"

if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
print(f"Using repo_root: {repo_root}")

# --- 2. Project Module Imports ---
# Import all project modules using clean names
import src.config as config
import src.dataset as dataset
import src.train as train
import src.evaluate as evaluate
import src.data as data
import src.losses as losses
import src.model as model
import src.utils as utils

# --- 3. Module Reloading (CRITICAL for Notebook Development) ---
# Reload dependencies in order: Config/Utils -> Data/Losses/Model -> Train/Dataset/Evaluate
importlib.reload(config)
importlib.reload(utils)
importlib.reload(data)
importlib.reload(model)
importlib.reload(losses) 
importlib.reload(dataset)
importlib.reload(train)
importlib.reload(evaluate)

# --- 4. Direct Imports (For clean code in subsequent cells) ---
# Import essential classes and functions needed for the pipeline steps

# Configuration
from src.config import DEVICE 

# Data/Dataset Classes
from src.dataset import PitchDatasetMultiTask, ContextPitchDatasetMultiTask

# Training Functions
from src.train import train_model_base_threat

# Evaluation/Helpers
from src.evaluate import evaluate_model_base_threat
from src.losses import get_model_criteria
from src.model import TinyCNN_MultiTask_Threat
from src.utils import get_sequence_lengths

# --- Final Check ---
print(f"Using device: {DEVICE}")

Using repo_root: c:\Users\jonas\Desktop\repos\pixlball
Using device: cpu


In [2]:
data_events = pd.read_parquet(os.path.join(repo_root, "data", "events_data.parquet"), engine="fastparquet")
data_360 = pd.read_parquet(os.path.join(repo_root, "data", "sb360_data.parquet"), engine="fastparquet")

## Prepare Event Data

In [3]:
df_with_targets = data.event_data_loader(data_events)

2462 events.
counts of each outcome nn_target
Keep Possession    70920
Lose Possession    27465
Shot                4764
Name: count, dtype: int64


## Prepare 360 Data

In [4]:
df_360 = data.assign_grid_cells(data_360)
nn_final = data.aggregate_nn_layers_vectorized(df_360)

# Finalize NN Df

In [5]:
nn_dataset = data.prepare_nn_dataset(df_with_targets, nn_final, target_cols=['nn_target', 'goal_flag'], context_cols = True, keep_context_ids = True ) # adjust cols depending on model

# Neural Network final Data Prep

In [6]:
importlib.reload(data)

nn_dataset = data.add_context_cols(nn_dataset)
nn_dataset = data.add_target_as_int(nn_dataset)

In [7]:
from sklearn.model_selection import train_test_split

# 1. Define your split parameters
VALIDATION_SIZE = 0.20
RANDOM_SEED = 42
layer_columns = ["ball_layer", "teammates_layer", "opponents_layer"]

# 2. Split the entire DataFrame first
# This keeps features, event targets, and goal flags bundled together
train_df, val_df = train_test_split(
    nn_dataset, 
    test_size=VALIDATION_SIZE, 
    random_state=RANDOM_SEED, 
    stratify=nn_dataset['nn_target_int']
)

# -------------------------------------------------------------
# 3. Extract the arrays and Instantiate the Datasets (FIXED)
# -------------------------------------------------------------

# Training Dataset extraction - Pass only the values in the correct order
train_dataset = PitchDatasetMultiTask(
    train_df[layer_columns],             # This maps to the 1st argument (features)
    train_df['nn_target_int'].values,    # This maps to the 2nd argument (events)
    train_df['goal_flag'].values         # This maps to the 3rd argument (goals)
)

# Validation Dataset extraction
validation_dataset = PitchDatasetMultiTask(
    val_df[layer_columns], 
    val_df['nn_target_int'].values, 
    val_df['goal_flag'].values
)

print(f"Total training samples: {len(train_dataset)}")
print(f"Total validation samples: {len(validation_dataset)}")

Total training samples: 72117
Total validation samples: 18030


# The Goal Multi Task CNN

In [8]:
class_weights_event, goal_pos_weight = utils.get_multitask_loss_weights(nn_dataset, DEVICE)

In [9]:

print(f"Goal Positive Weight (0/1 ratio): {goal_pos_weight.item():.2f}")

# ------------------------------------
# 4. Train model (Using the dedicated base function)
# ------------------------------------
print("Starting training for Static CNN Baseline...")
baseline_model = train_model_base_threat(
    dataset=train_dataset, 
    event_class_weights=class_weights_event, 
    goal_pos_weight=goal_pos_weight
)
print("Training complete.")



Goal Positive Weight (0/1 ratio): 5.00
Starting training for Static CNN Baseline...


Base CNN Threat Epoch 1: 100%|██████████| 2254/2254 [00:37<00:00, 60.70it/s, ev_loss=1.0336, loss=1.7891] 

Training complete.





In [10]:
# -----------------------------
# 5. Evaluate model
# -----------------------------
metrics = evaluate_model_base_threat(baseline_model, validation_dataset)
# print(metrics)



--- Event Outcome Metrics ---
Event Accuracy: 0.5128674431503051
Event Balanced Accuracy: 0.5821633311450194
Event Confusion Matrix:
 [[6520 4194 1774]
 [1687 2029  941]
 [ 105   82  698]]
              precision    recall  f1-score   support

           0       0.78      0.52      0.63     12488
           1       0.32      0.44      0.37      4657
           2       0.20      0.79      0.32       885

    accuracy                           0.51     18030
   macro avg       0.44      0.58      0.44     18030
weighted avg       0.64      0.51      0.55     18030


--- Goal Prediction (xG) Metrics ---
Goal Accuracy: 0.8723163841807909
Goal Balanced Accuracy: 0.5
Goal AUC-ROC Score: 0.6118001742399926
Goal Confusion Matrix:
 [[772   0]
 [113   0]]
              precision    recall  f1-score   support

         0.0       0.87      1.00      0.93       772
         1.0       0.00      0.00      0.00       113

    accuracy                           0.87       885
   macro avg       0.44  