In [1]:
# ==========================
# üß∞ Standard Library Imports
# ==========================
import warnings

# ==========================
# üßÆ Core Data Science Stack
# ==========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, roc_auc_score, RocCurveDisplay
from sklearn.calibration import CalibrationDisplay

# ==========================
# üèÄ Sports Analytics APIs
# ==========================
from basketball_dataset.nba_tracking_data_15_16 import NbaTracking
from basketball_dataset.dataset_operations import *

# ==========================
# üìä Visualization Tools
# ==========================
from mplsoccer import Pitch
from matplotlib.patches import Arc, Rectangle, Circle
from matplotlib.animation import FuncAnimation

# ==========================
# üß† Machine Learning / Models
# ==========================
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GraphConv
from torch_geometric.utils import to_networkx

# ==========================
# üîó Graphs / Utilities
# ==========================
import networkx as nx
import tqdm

# ==========================
# ‚ö†Ô∏è Warning Filters
# ==========================
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


# ***IMPORTANT*** You still need to pull the larger data files

Please run in your terminal

git lfs install

git lfs pull

# Now we will get into utilizing tracking data (Basketball) to predict probability of conceding points off a turnover

In [61]:
from basketball_dataset.nba_tracking_data_15_16 import NbaTracking
import pandas as pd
import numpy as np

dataset = NbaTracking(config_name="tiny")
dataset.download_and_prepare("basketball_dataset/data")

In [62]:
nba_dataset = dataset.as_dataset()

In [63]:
nba_data = nba_dataset["train"].to_pandas()

In [64]:
nba_data.head()

Unnamed: 0,gameid,gamedate,event_info,primary_info,secondary_info,visitor,home,moments
0,21500333,2015-12-11,"{'id': '1', 'type': 10, 'possession_team_id': ...","{'team': 'home', 'player_id': 101133.0, 'team_...","{'team': 'away', 'player_id': 202355.0, 'team_...","{'name': 'Miami Heat', 'teamid': 1610612748, '...","{'name': 'Indiana Pacers', 'teamid': 161061275...","[{'quarter': 1, 'game_clock': 707.65, 'shot_cl..."
1,21500333,2015-12-11,"{'id': '2', 'type': 1, 'possession_team_id': 1...","{'team': 'home', 'player_id': 101133.0, 'team_...","{'team': 'home', 'player_id': 101145.0, 'team_...","{'name': 'Miami Heat', 'teamid': 1610612748, '...","{'name': 'Indiana Pacers', 'teamid': 161061275...","[{'quarter': 1, 'game_clock': 707.65, 'shot_cl..."
2,21500333,2015-12-11,"{'id': '3', 'type': 1, 'possession_team_id': 1...","{'team': 'away', 'player_id': 2547.0, 'team_id...","{'team': None, 'player_id': 0.0, 'team_id': nan}","{'name': 'Miami Heat', 'teamid': 1610612748, '...","{'name': 'Indiana Pacers', 'teamid': 161061275...","[{'quarter': 1, 'game_clock': 694.65, 'shot_cl..."
3,21500333,2015-12-11,"{'id': '4', 'type': 2, 'possession_team_id': 1...","{'team': 'home', 'player_id': 101145.0, 'team_...","{'team': None, 'player_id': 0.0, 'team_id': nan}","{'name': 'Miami Heat', 'teamid': 1610612748, '...","{'name': 'Indiana Pacers', 'teamid': 161061275...","[{'quarter': 1, 'game_clock': 675.65, 'shot_cl..."
4,21500333,2015-12-11,"{'id': '5', 'type': 4, 'possession_team_id': 1...","{'team': 'away', 'player_id': 202355.0, 'team_...","{'team': None, 'player_id': 0.0, 'team_id': nan}","{'name': 'Miami Heat', 'teamid': 1610612748, '...","{'name': 'Indiana Pacers', 'teamid': 161061275...","[{'quarter': 1, 'game_clock': 675.65, 'shot_cl..."


# To start we will have to find the moment the turnover occurs for each turnover event

In [65]:
%matplotlib tk
import matplotlib.pyplot as plt
from matplotlib.patches import Arc, Rectangle, Circle
from matplotlib.animation import FuncAnimation
import numpy as np
from basketball_dataset.dataset_operations import *


# --- Load a single event ---
filtered_events = filter_candidate_events(nba_dataset["train"]) #### Filter to just turnover events
event_num = 3
for i in range(event_num):
    event = next(iter(filtered_events))  # just pick one example

print(event['gameid'])
print(event['event_info']['id'])
print(event['event_info']['desc_home'])
print(event['event_info']['desc_away'])
# --- Setup figure ---
fig, ax = plt.subplots(figsize=(10, 10))

# --- Draw court ---
def draw_full_court(ax, color='black', lw=1):
    # Upper half
    court_elements = [
        Rectangle((0, 0), 50, 94, lw=lw, color=color, fill=False),
        
        Circle((25, 89.25), .85, lw=lw, color=color, fill=False),
        Rectangle((17, 75), 16, 19, lw=lw, color=color, fill=False),
        Rectangle((19, 75), 12, 19, lw=lw, color=color, fill=False),
        Arc((25, 75), 12, 12, theta1=180, theta2=0, lw=lw, color=color, fill=False),
        Arc((25, 75), 12, 12, theta1=0, theta2=180, lw=lw, color=color, linestyle='dashed'),
        Rectangle((3, 80), 0, 14, lw=lw, color=color),
        Rectangle((47, 80), 0, 14, lw=lw, color=color),
        Arc((25, 89.25), 47.5, 47.5, theta1=202, theta2=337.5, lw=lw, color=color),
        # Lower half mirrored
        Circle((25, 4.75), .85, lw=lw, color=color, fill=False),
        Rectangle((17, 0), 16, 19, lw=lw, color=color, fill=False),
        Rectangle((19, 0), 12, 19, lw=lw, color=color, fill=False),
        Arc((25, 19), 12, 12, theta1=0, theta2=180, lw=lw, color=color, fill=False),
        Arc((25, 19), 12, 12, theta1=180, theta2=0, lw=lw, color=color, linestyle='dashed'),
        Rectangle((3, 0), 0, 14, lw=lw, color=color),
        Rectangle((47, 0), 0, 14, lw=lw, color=color),
        Arc((25, 4.75), 47.5, 47.5, theta1=22.5, theta2=157.5, lw=lw, color=color),
        Rectangle((0, 47), 50, 0, lw=lw, color=color),  # half court
        Circle((25, 47), 6, lw=lw, color=color, fill=False)
    ]
    for e in court_elements:
        ax.add_patch(e)
    ax.set_xlim(0, 50)
    ax.set_ylim(0, 94)
    ax.set_aspect('equal')
    ax.axis('off')

draw_full_court(ax)

# --- Extract data ---
moments = event["moments"]
event_team = event["primary_info"]["team_id"]
event_player_id = event["primary_info"]["player_id"]

# --- Create plot handles for animation ---
ball, = ax.plot([], [], 'o', color='k', markersize=8)
players_attack, = ax.plot([], [], 'o', color='green', markersize=10)
players_defend, = ax.plot([], [], 'o', color='red', markersize=10)
handler, = ax.plot([], [], 'o', color='purple', markersize=12)

# --- Prepare frame update function ---
def update(frame):
    moment = moments[frame]
    attack_x, attack_y, defend_x, defend_y = [], [], [], []
    hx, hy = None, None

    for p in moment['player_coordinates']:
        if p['playerid'] == event_player_id:
            hx, hy = p['x'], p['y']
        elif p['teamid'] == event_team:
            attack_x.append(p['x'])
            attack_y.append(p['y'])
        else:
            defend_x.append(p['x'])
            defend_y.append(p['y'])
    
    ball.set_data(
        [moment['ball_coordinates']['x']],
        [moment['ball_coordinates']['y']]
    )
    players_attack.set_data(attack_x, attack_y)
    players_defend.set_data(defend_x, defend_y)
    
    if hx is not None:
        handler.set_data([hx], [hy])
    else:
        handler.set_data([], [])
    
    return ball, players_attack, players_defend, handler

# --- Animate ---
anim = FuncAnimation(fig, update, frames=len(moments), interval=60, blit=True)
plt.show()

0021500333
6
Mahinmi STEAL (1 STL)
Whiteside Lost Ball Turnover (P1.T1)


# Find the turnover

In [66]:
handler_has_ball = False
lost_possession = False

for moment in moments:
    event_team = event["primary_info"]["team_id"]
    event_player = event["primary_info"]["player_id"]

    for player_coord in moment['player_coordinates']:
        if player_coord['playerid'] == event_player:
          handler_x =player_coord["x"]
          handler_y = player_coord["y"]

    ball_x = moment['ball_coordinates']['x']
    ball_y = moment['ball_coordinates']['y']
    ball_spped = moment["ball_coordinates"]["speed"]
    
    if not lost_possession:
      if handler_has_ball == False:
        if np.linalg.norm(np.array([ball_x,ball_y])-np.array([handler_x,handler_y]))<2:
          print("gained possesion")
          handler_has_ball = True
      else:
        if np.linalg.norm(np.array([ball_x,ball_y])-np.array([handler_x,handler_y]))>5:
          lost_possession = True
          print("lost_possession")
    else:
      if ball_spped<3:
        print("ball controlled")
        break

gained possesion
lost_possession
ball controlled


# Turnover Frame

In [67]:
fig, ax = plt.subplots(figsize=(20, 20))
draw_full_court(ax=ax)

for player_coord in moment['player_coordinates']:
    if player_coord['playerid'] == event_player:
        handler_x =player_coord["x"]
        handler_y = player_coord["y"]
        c = 'purple'

    elif player_coord["teamid"]==event_team:
        c="green"
    else:
        c="r"
    plt.plot(player_coord['x'], player_coord['y'], 'o', markersize=10, color=c)
    
ball_x = moment['ball_coordinates']['x']
ball_y = moment['ball_coordinates']['y']

plt.plot(ball_x, ball_y, 'o', markersize=10, color='k')
plt.show()

In [68]:
filtered_events = filter_candidate_events(nba_dataset["train"]) ### This cell takes about 1 min 30 sec to run
turnovers_and_shots = pd.DataFrame(filtered_events)

In [69]:
events = pd.json_normalize(turnovers_and_shots["event_info"])
events = pd.concat([events,turnovers_and_shots["primary_info"].apply(lambda prim_info: prim_info["team_id"])],axis=1).rename({"primary_info":"event_team"},axis=1)

In [70]:
game_turnovers = events[events["event_type"]=="turnover"]
for i, turnover in game_turnovers.iterrows():
    game_id = turnover["game_id"]
    quarter = turnover["quarter"]
    game_clock = turnover["game_clock"]
    shot_after = events[(events["game_id"]==game_id) & (events["quarter"]==quarter) & (events["event_type"]=="made shot") & (events["game_clock"]<game_clock) & (events["game_clock"]>=game_clock-10)] # Shot within 10 seconds of turnover
    if len(shot_after) >= 1:
        game_turnovers.loc[i,"made_shot_after"] = True
    else:
        game_turnovers.loc[i,"made_shot_after"] = False
game_turnovers.head()

Unnamed: 0,id,type,possession_team_id,desc_home,desc_away,direction,quarter,game_clock,shot_clock,event_type,game_id,event_moment.quarter,event_moment.game_clock,event_moment.shot_clock,event_moment.ball_coordinates.x,event_moment.ball_coordinates.y,event_moment.ball_coordinates.z,event_moment.ball_coordinates.speed,event_moment.ball_coordinates.dir_x,event_moment.ball_coordinates.dir_y,event_moment.player_coordinates,event_team,made_shot_after
2,6,5,1610612748.0,Mahinmi STEAL (1 STL),Whiteside Lost Ball Turnover (P1.T1),right,1,651.2,2.79,turnover,21500333,1.0,651.2,2.79,37.959,88.421,3.952,2.727,0.226,0.974,"[{'teamid': 1610612754, 'playerid': 201588, 'x...",1610612748.0,False
4,15,5,1610612748.0,Mahinmi STEAL (2 STL),Dragic Bad Pass Turnover (P1.T2),right,1,607.99,7.29,turnover,21500333,1.0,607.99,7.29,14.62,92.354,1.156,2.002,-0.776,0.631,"[{'teamid': 1610612754, 'playerid': 201588, 'x...",1610612748.0,False
8,35,5,1610612754.0,G. Hill Bad Pass Turnover (P1.T2),Bosh STEAL (1 STL),left,1,477.92,19.82,turnover,21500333,1.0,477.92,19.82,25.086,5.736,10.18,0.255,0.773,-0.634,"[{'teamid': 1610612754, 'playerid': 201588, 'x...",1610612754.0,False
15,64,5,1610612748.0,Ellis STEAL (1 STL),Winslow Bad Pass Turnover (P1.T4),right,1,278.72,23.59,turnover,21500333,1.0,278.72,23.59,24.621,67.785,6.527,2.981,-0.597,-0.802,"[{'teamid': 1610612754, 'playerid': 201588, 'x...",1610612748.0,False
16,76,5,1610612748.0,Hill STEAL (1 STL),Wade Bad Pass Turnover (P2.T5),right,1,218.53,19.23,turnover,21500333,1.0,218.53,19.23,2.153,56.939,4.523,6.109,0.125,0.992,"[{'teamid': 1610612754, 'playerid': 201155, 'x...",1610612748.0,False


# Now we'd do this for every game we have data for but I have already collected this for us

In [72]:
import pandas as pd
turnovers = pd.read_csv("turnovers.csv")

# Now we will build the features representing the turnover 

We will do this by representing the current situation after the turnover as a graph. We will then use the graph as input in a Graph Nueral Network to produce a logistic probability

In [73]:
import torch
from torch_geometric.data import Data
import ast  # for safely parsing stringified lists/dicts
import re

def build_graph_from_turnover(row):
    """
    Convert a single turnover event (dict) into a PyTorch Geometric Data object.

    Each player and the ball are nodes.
    - Players on the same team are fully connected.
    - Every player connects to the ball.
    """

    # --- Parse event info ---
    event_team = int(row["event_team"])  # the team responsible for turnover
    # Clean up the player coordinate string
    if type(row["event_moment.player_coordinates"]) == str:
        player_str = row["event_moment.player_coordinates"]

        # Replace patterns like "np.float64(1.2345)" ‚Üí "1.2345"
        player_str = re.sub(r"np\.float64\(([^)]+)\)", r"\1", player_str)

        # Safely evaluate the cleaned string
        players = ast.literal_eval(player_str)
    else:
        players = row["event_moment.player_coordinates"]
    ball = {
        "x": row["event_moment.ball_coordinates.x"],
        "y": row["event_moment.ball_coordinates.y"],
        "z": row["event_moment.ball_coordinates.z"],
        "speed": row["event_moment.ball_coordinates.speed"],
        "dir_x": row["event_moment.ball_coordinates.dir_x"],
        "dir_y": row["event_moment.ball_coordinates.dir_y"],
    }

    # Collect unique team IDs
    team_ids = list(set(p["teamid"] for p in players))
    if len(team_ids) != 2:
        raise ValueError(f"Expected 2 teams, got {team_ids}")

    teamA_id = event_team
    teamB_id = [tid for tid in team_ids if tid != teamA_id][0]

    # --- Node features ---
    node_features = []
    node_ids = []

    for p in players:
        x, y = p["x"], p["y"]
        speed, dir_x, dir_y = float(p["speed"]), float(p["dir_x"]), float(p["dir_y"])
        is_ball = 0
        is_team_A = 1 if p["teamid"] == teamA_id else 0
        node_features.append([x, y, speed, dir_x, dir_y, is_ball, is_team_A])
        node_ids.append(p["playerid"])

    # Add the ball node
    bx, by = ball["x"], ball["y"]
    node_features.append([bx, by, ball["speed"], ball["dir_x"], ball["dir_y"], 1, 0])
    node_ids.append("ball")

    x = torch.tensor(node_features, dtype=torch.float)

    # --- Build edges ---
    n = len(node_features)
    edges_src, edges_dst = [], []

    # (1) Connect all players on the same team
    for i, pi in enumerate(players):
        for j, pj in enumerate(players):
            if i != j and pi["teamid"] == pj["teamid"]:
                edges_src.append(i)
                edges_dst.append(j)

    # (2) Connect every player to the ball
    ball_idx = len(node_features) - 1
    for i in range(len(players)):
        edges_src.extend([i, ball_idx])
        edges_dst.extend([ball_idx, i])

    edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)

    # --- Optional edge features (distance) ---
    edge_attr = []
    for i, j in zip(edges_src, edges_dst):
        dist = ((x[i, 0] - x[j, 0]) ** 2 + (x[i, 1] - x[j, 1]) ** 2).sqrt()
        edge_attr.append([dist.item()])
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Let's visualize our features now

In [74]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

def visualize_graph(data):
    # Convert PyG graph to NetworkX
    G = to_networkx(data, to_undirected=True)

    # Extract node positions (x, y) from your features
    pos = {i: (data.x[i,0].item(), data.x[i,1].item()) for i in range(data.num_nodes)}

    # Color players by team, ball as red
    colors = ['red' if data.x[i,5] == 1 else  # is_ball
              ('blue' if data.x[i,6] == 1 else 'green')  # is_team_A
              for i in range(data.num_nodes)]

    plt.figure(figsize=(6,6))
    nx.draw(G, pos, node_color=colors, with_labels=True, node_size=400, edge_color='gray')
    plt.title("Graph Connectivity (Players + Ball)")
    plt.show()

In [75]:
graph = build_graph_from_turnover(turnovers.iloc[0])
visualize_graph(graph)

# Now we want to build our turnover model. We can use a Graph Nueral Network to predict the probability our turnover situation (A Graph) will lead to a goal

üèÄ Turnover ‚Üí Shot Made Prediction Pipeline
Overview

This pipeline builds a Graph Neural Network (GNN) to estimate the probability that a turnover event in a basketball game will eventually lead to a made shot. Each turnover is represented as a graph snapshot of player and ball positions, velocities, and directions at the moment of the turnover.

üîß 1. Data Preparation

Each row in the turnover dataset represents a unique game event. For every event:

The player coordinates and ball coordinates are extracted and parsed.

Speeds and directional unit vectors are precomputed from the tracking data.

Each player and the ball are represented as nodes in a graph.

We then connect nodes to form edges:

Players are connected to their teammates (intra-team edges).

Every player is connected to the ball (player‚Äìball edges).

Optionally, edges may be created for players within a fixed spatial radius of one another.

üß© 2. Graph Construction

For each turnover, a PyTorch Geometric Data object is created containing:

Node features:

ùë•
,
ùë¶
,
ùë†
ùëù
ùëí
ùëí
ùëë
,
ùëë
ùëñ
ùëü
_
ùë•
,
ùëë
ùëñ
ùëü
_
ùë¶
,
ùëñ
ùë†
_
ùëè
ùëé
ùëô
ùëô
,
ùëñ
ùë†
_
ùë°
ùëí
ùëé
ùëö
ùê¥
x,y,speed,dir_x,dir_y,is_ball,is_teamA

Edge index:
Defines which nodes are connected.

Edge attributes:
Typically the Euclidean distance between nodes.

These graphs are then batched into a dataset ready for GNN training.

üß† 3. Model Training

We define a Graph Neural Network (GNN) model, such as a GCNConv or GraphSAGE network, to learn from the turnover graphs.

Each graph is passed through the GNN to produce a graph-level embedding, which is fed into a linear classifier predicting:

ùëÉ
(
goal after turnover
)
‚àà
[
0
,
1
]
P(goal after turnover)‚àà[0,1]

Training optimizes this probability using a binary classification loss (e.g. Binary Cross Entropy).

üìà 4. Evaluation & Inference

During evaluation, each turnover graph is fed into the trained model to obtain:

Probability of Shot Made
=
sigmoid
(
ùëì
ùúÉ
(
ùê∫
)
)
Probability of Shot Made=sigmoid(f
Œ∏
	‚Äã

(G))

The model outputs the likelihood that the turnover will result in a made shot in the following possession.

üöÄ Summary

Run Pipeline Steps

Load and preprocess turnover data.

Build a graph for each event (players + ball).

Batch graphs into a PyTorch Geometric dataset.

Train a GNN to predict post-turnover goal probability.

Evaluate model accuracy and interpret learned spatial‚Äìtemporal patterns.

In [76]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GraphConv
import numpy as np
from sklearn.model_selection import train_test_split
import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example wrapper: turn a dataframe row into a PyG Data object and label
def row_to_data(row):
    """
    Wraps your existing build_graph_from_turnover to add a label.
    Label here is graph-level binary target 'made_shot_after' (bool -> 0/1).
    """
    graph = build_graph_from_turnover(row)  # returns torch_geometric.data.Data
    # label: convert to int (0/1)
    lbl = int(bool(row.get("made_shot_after", False)))
    graph.y = torch.tensor([lbl], dtype=torch.long)  # graph-level label
    # optionally add an id
    graph.event_id = row.get("id", None)
    return graph

# --- Create dataset (list of Data objects) ---
def build_dataset_from_df(df, row_to_data_fn, max_items=None):
    graphs = []
    for i, row in df.iterrows():
        try:
            g = row_to_data_fn(row)
            graphs.append(g)
        except Exception as e:
            # skip problematic rows but print error for debugging
            print(f"Skipping row {i} due to {e}")
        if max_items and len(graphs) >= max_items:
            break
    return graphs

# --- Simple GNN model for graph classification ---
class GraphClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden=64, num_layers=3, dropout=0.2):
        super().__init__()
        # Use GCNConv; you can swap to GATConv or GraphConv easily
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))
        self.lin1 = torch.nn.Linear(hidden, hidden // 2)
        self.lin2 = torch.nn.Linear(hidden // 2, 2)  # binary classification (2 classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        edge_attr = getattr(data, "edge_attr", None)
        # message passing
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # pool to graph-level representation
        # assume data.batch is provided by DataLoader
        x = global_mean_pool(x, data.batch)  # (batch_size, hidden)

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        out = self.lin2(x)  # logits
        return out

# --- Training / evaluation helpers ---
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        logits = model(batch)  # shape [batch_size, 2]
        labels = batch.y.view(-1).to(device)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        labels = batch.y.view(-1).to(device)
        loss = F.cross_entropy(logits, labels, reduction='sum')
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())
    return total_loss / total, correct / total, all_preds, all_labels

# --- Putting it together ---
def run_pipeline(turnovers_df, batch_size=16, max_items=None, epochs=20, lr=1e-3):
    # 1) build graphs
    graphs = build_dataset_from_df(turnovers_df, row_to_data, max_items=max_items)
    if len(graphs) == 0:
        raise RuntimeError("No graphs were created. Check build_graph_from_turnover and rows.")
    print(f"Built {len(graphs)} graphs")

    # 2) split
    train_idx, test_idx = train_test_split(np.arange(len(graphs)), test_size=0.2, random_state=42)
    train_graphs = [graphs[i] for i in train_idx]
    test_graphs = [graphs[i] for i in test_idx]

    # 3) DataLoaders
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

    # 4) model
    in_ch = graphs[0].x.shape[1]
    model = GraphClassifier(in_ch, hidden=64, num_layers=3).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    # 5) train loop
    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc, _, _ = evaluate(model, test_loader, device)
        print(f"Epoch {epoch:02d} | Train loss {train_loss:.4f}, acc {train_acc:.3f} | Val loss {val_loss:.4f}, acc {val_acc:.3f}")

    return model, train_loader, test_loader

In [77]:
model,train_loader,test_loader = run_pipeline(turnovers)

Skipping row 851 due to malformed node or string on line 1: <ast.Name object at 0x0000017568BEBF50>
Skipping row 1157 due to malformed node or string on line 1: <ast.Name object at 0x0000017568614E10>
Skipping row 1291 due to malformed node or string on line 1: <ast.Name object at 0x00000175685BBE50>
Built 1409 graphs
Epoch 01 | Train loss 0.5675, acc 0.769 | Val loss 0.5180, acc 0.791
Epoch 02 | Train loss 0.5327, acc 0.789 | Val loss 0.5194, acc 0.791
Epoch 03 | Train loss 0.5391, acc 0.789 | Val loss 0.5379, acc 0.791
Epoch 04 | Train loss 0.5356, acc 0.789 | Val loss 0.5165, acc 0.791
Epoch 05 | Train loss 0.5250, acc 0.789 | Val loss 0.5235, acc 0.791
Epoch 06 | Train loss 0.5297, acc 0.789 | Val loss 0.5220, acc 0.791
Epoch 07 | Train loss 0.5201, acc 0.789 | Val loss 0.5289, acc 0.791
Epoch 08 | Train loss 0.5341, acc 0.789 | Val loss 0.5154, acc 0.791
Epoch 09 | Train loss 0.5216, acc 0.789 | Val loss 0.5168, acc 0.791
Epoch 10 | Train loss 0.5210, acc 0.789 | Val loss 0.5169, 

# Now let's predict the probability of the turnover we viewed earlier leading to a made shot

In [78]:
model.eval()
with torch.no_grad():
    data = graph.to(device)
    logits = model(data)
    prob = torch.nn.functional.softmax(logits, dim=1)[0][1]
    print("Probability of leading to made shot:", prob.cpu().numpy())

Probability of leading to made shot: 0.23363149


# Let's play around with different turnovers

In [79]:
TURNOVER_NUM = 1 #### PICK A NUMBER BETWEEN 0 and 65 


test_turnover_event = turnovers_and_shots[turnovers_and_shots["event_info"].apply(lambda event: event["type"]==5)].iloc[TURNOVER_NUM]
test_turnover_moments = test_turnover_event["moments"]
test_turnover_df = game_turnovers.iloc[TURNOVER_NUM]
test_graph = build_graph_from_turnover(test_turnover_df)
model.eval()
with torch.no_grad():
    data = test_graph.to(device)
    logits = model(data)
    test_prob = torch.nn.functional.softmax(logits, dim=1)[0][1]

# --- Extract data ---
event_team = test_turnover_event["primary_info"]["team_id"]
event_player_id = test_turnover_event["primary_info"]["player_id"]

fig, ax = plt.subplots(figsize=(10, 10))
draw_full_court(ax)

# --- Create plot handles ---
test_ball, = ax.plot([], [], 'o', color='k', markersize=8)
test_players_attack, = ax.plot([], [], 'o', color='green', markersize=10)
test_players_defend, = ax.plot([], [], 'o', color='red', markersize=10)
test_handler, = ax.plot([], [], 'o', color='purple', markersize=12)

ax.set_title(f"Turnover Num {TURNOVER_NUM} | Prob(Shot): {test_prob:.3f}")

# --- Frame update function ---
def test_update(frame):
    moment = test_turnover_moments[frame]
    attack_x, attack_y, defend_x, defend_y = [], [], [], []
    hx, hy = None, None

    for p in moment['player_coordinates']:
        if p['playerid'] == event_player_id:
            hx, hy = p['x'], p['y']
        elif p['teamid'] == event_team:
            attack_x.append(p['x'])
            attack_y.append(p['y'])
        else:
            defend_x.append(p['x'])
            defend_y.append(p['y'])
    
    # Update positions
    test_ball.set_data(
        [moment['ball_coordinates']['x']],
        [moment['ball_coordinates']['y']]
    )
    test_players_attack.set_data(attack_x, attack_y)
    test_players_defend.set_data(defend_x, defend_y)

    if hx is not None:
        test_handler.set_data([hx], [hy])
    else:
        test_handler.set_data([], [])

    return test_ball, test_players_attack, test_players_defend, test_handler

# --- Animate ---
anim = FuncAnimation(fig, test_update, frames=len(test_turnover_moments), interval=60, blit=True)
plt.show()