In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        continue
        # print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# functions

In [None]:
import os

# üß© Only the essential PyG wheels (no internet required)
core_wheels = [
    "pyg_lib-0.4.0+pt26cu124-cp311-cp311-linux_x86_64.whl",
    "torch_scatter-2.1.2+pt26cu124-cp311-cp311-linux_x86_64.whl",
    "torch_sparse-0.6.18+pt26cu124-cp311-cp311-linux_x86_64.whl",
    "torch_cluster-1.6.3+pt26cu124-cp311-cp311-linux_x86_64.whl",
    "torch_spline_conv-1.2.2+pt26cu124-cp311-cp311-linux_x86_64.whl",
    "torch_geometric-2.6.1-py3-none-any.whl"
]

for whl in core_wheels:
    path = f"/kaggle/usr/lib/torch_geometric/pyg_wheels/{whl}"  # adjust folder name if needed
    print("Installing:", os.path.basename(path))
    os.system(f"pip install --no-deps {path} -q")

import sys
sys.path = [p for p in sys.path if "usr/lib/torch_geometric" not in p]

import torch, torch_geometric
from torch_geometric.nn import GATv2Conv

In [None]:
# ============================================================
# üöÄ FULL PIPELINE: From raw tracking ‚Üí temporal embeddings
# ============================================================
import pandas as pd
import numpy as np
import glob
from tqdm import tqdm
import torch
import torch.nn as nn

# ============================================================
# 1Ô∏è‚É£ Normalization
# ============================================================
def normalize_field_direction(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    mask_left = df["play_direction"].str.lower() == "left"

    # Flip x
    for col in ["x", "ball_land_x"]:
        if col in df.columns:
            df.loc[mask_left, col] = 120 - df.loc[mask_left, col]

    # Flip angles
    for ang_col in ["o", "dir"]:
        if ang_col in df.columns:
            df.loc[mask_left, ang_col] = (df.loc[mask_left, ang_col] + 180) % 360

    # Center y around midline (26.65)
    for col in ["y", "ball_land_y"]:
        if col in df.columns:
            df[col] = df[col] - 26.65

    return df

# ============================================================
# 4Ô∏è‚É£ Temporal Encoder (GRU / Transformer)
# ============================================================
class TemporalTransformer(nn.Module):
    def __init__(self, in_dim=8, d_model=128, n_heads=4, n_layers=2, dropout=0.1):
        super().__init__()

        self.input_proj = nn.Linear(in_dim, d_model)

        self.pos_emb = nn.Parameter(torch.randn(1, 60, d_model))   # K up to 60

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model*4,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

    def forward(self, x):
        # x : (P, K, F)
        P,K,F = x.shape

        x = self.input_proj(x)

        # add positional embeddings: truncate or expand
        pos = self.pos_emb[:, :K, :]

        x = x + pos

        out = self.encoder(x)       # (P,K,D)

        # take last token
        return out[:, -1, :]        # (P,D)

In [None]:
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

def build_interaction_graphs(df_input: pd.DataFrame, K: int = 6):
    """
    Build K-NN interaction graphs per play at the throw frame.
    Each graph connects every player to K nearest neighbors
    with edge features [dx, dy, dvx, dvy, ally_flag].

    Returns
    -------
    dict[(game_id, play_id)] = {
        "nodes": pd.DataFrame of player features,
        "edges": pd.DataFrame of edge features
    }
    """

    df = df_input.copy()

    # ------------------------------------------------------------------
    # üß≠ 1Ô∏è‚É£ Ensure velocity columns exist
    # ------------------------------------------------------------------
    if "vx" not in df.columns or "vy" not in df.columns:
        df = df.sort_values(["game_id", "play_id", "nfl_id", "frame_id"]).copy()
        df["vx"] = df.groupby(["game_id", "play_id", "nfl_id"])["x"].diff().fillna(0)
        df["vy"] = df.groupby(["game_id", "play_id", "nfl_id"])["y"].diff().fillna(0)

    # ------------------------------------------------------------------
    # üïê 2Ô∏è‚É£ Extract throw frame (last frame of input for each player)
    # ------------------------------------------------------------------
    throw_frame = (
        df.groupby(["game_id", "play_id", "nfl_id"], group_keys=False)
          .apply(lambda g: g.tail(1))
          .reset_index(drop=True)
    )

    # ------------------------------------------------------------------
    # üß© 3Ô∏è‚É£ Build graph per play
    # ------------------------------------------------------------------
    graphs = {}
    plays = throw_frame.groupby(["game_id", "play_id"])

    for (gid, pid), play_df in tqdm(plays, desc="Building KNN graphs per play"):

        # node features (one per player)
        nodes = play_df[
            ["nfl_id", "x", "y", "vx", "vy", "player_side", "player_role"]
        ].reset_index(drop=True)

        coords = nodes[["x", "y"]].values

        if len(nodes) < 2:
            continue  # skip incomplete plays

        # Fit KNN (K+1 to include self, drop self edge later)
        nbrs = NearestNeighbors(
            n_neighbors=min(K + 1, len(nodes)),
            algorithm="ball_tree"
        ).fit(coords)
        distances, indices = nbrs.kneighbors(coords)

        edge_records = []
        for i, nbr_idxs in enumerate(indices):
            for j in nbr_idxs[1:]:  # skip self
                src = nodes.iloc[i]
                dst = nodes.iloc[j]
                dx  = dst["x"]  - src["x"]
                dy  = dst["y"]  - src["y"]
                dvx = dst["vx"] - src["vx"]
                dvy = dst["vy"] - src["vy"]

                # NEW FEATURES
                dist = np.sqrt(dx*dx + dy*dy + 1e-6)        # distance magnitude
                dv   = np.sqrt(dvx*dvx + dvy*dvy + 1e-6)    # relative speed magnitude
                bearing = np.arctan2(dy, dx)
                cos_bear = np.cos(bearing)
                sin_bear = np.sin(bearing)

                edge_records.append({
                    "src_id": src["nfl_id"],
                    "dst_id": dst["nfl_id"],
                    "dx": dx,
                    "dy": dy,
                    "dvx": dvx,
                    "dvy": dvy,

                    # NEW
                    "dist": dist,
                    "dv": dv,
                    "cos_bear": cos_bear,
                    "sin_bear": sin_bear,

                    "ally_flag": 1 if src["player_side"] == dst["player_side"] else 0
                })


        edges = pd.DataFrame(edge_records)
        graphs[(gid, pid)] = {"nodes": nodes, "edges": edges}

    return graphs

In [None]:
class SpatialTransformer(nn.Module):
    def __init__(self, d_model=128, n_heads=4, n_layers=2, dropout=0.1):
        super().__init__()

        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model*4,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)

        # project edge_attr ‚Üí attention bias
        self.edge_proj = nn.Linear(9, d_model)

    def forward(self, h_nodes, edge_index, edge_attr):
        """
        h_nodes    : (P,D)
        edge_index : (2,E)
        edge_attr  : (E,9)
        We convert edges ‚Üí full (P,P,d_model) bias matrix.
        """

        P = h_nodes.size(0)
        device = h_nodes.device

        # build full pairwise bias matrix
        bias = torch.zeros(P, P, h_nodes.size(1), device=device)

        src, dst = edge_index
        e = self.edge_proj(edge_attr)         # (E,D)
        bias[src, dst] = e                    # direct fill-in

        # convert to additive attention bias
        # flatten into (1,P,P,D)
        bias = bias.unsqueeze(0)

        # Transformer encoder supports "src_mask" but not arbitrary bias.
        # So we fold bias into embeddings:
        h = h_nodes.unsqueeze(0)              # (1,P,D)
        h = h + bias.mean(dim=2)              # aggregate bias per node

        h = self.encoder(h)                   # (1,P,D)

        return h.squeeze(0)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ============================================================
# üß© 3.3 Role-Specific Adapter Module
# ============================================================

class RoleSpecificAdapters(nn.Module):
    """
    Each player_role (e.g., Targeted WR, Coverage DB, Passer, Other)
    gets its own small MLP adapter that reshapes the 128-D context
    embedding into a role-specialized representation.

    Input : h_context  ‚Üí (N_players, embed_dim)
            role_ids   ‚Üí (N_players,)  integers 0..N_roles-1
    Output: h_adapted  ‚Üí (N_players, embed_dim)
    """
    def __init__(self, embed_dim=128, hidden_dim=128, role_names=None):
        super().__init__()
        if role_names is None:
            role_names = ["Targeted Receiver", "Defensive Coverage", "Passer", "Other Route Runner"]
        self.role_names = role_names
        self.n_roles = len(role_names)

        # small MLP adapter per role
        self.adapters = nn.ModuleDict({
            name: nn.Sequential(
                nn.Linear(embed_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, embed_dim),
                nn.LayerNorm(embed_dim)
            )
            for name in role_names
        })

        # shared fallback (for unseen / undefined roles)
        self.default_adapter = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.LayerNorm(embed_dim)
        )

    def forward(self, h_context, role_ids, role_mapping):
        """
        h_context : (N, D)
        role_ids  : list/series of textual roles matching role_mapping keys
        role_mapping : {role_name : idx}
        """
        outputs = []
        for i, r in enumerate(role_ids):
            role_name = None
            # reverse-map index to string
            if isinstance(r, (int, np.integer)):
                # find key by index
                for k,v in role_mapping.items():
                    if v == r:
                        role_name = k
                        break
            else:
                role_name = r

            if role_name in self.adapters:
                out = self.adapters[role_name](h_context[i])
            else:
                out = self.default_adapter(h_context[i])
            outputs.append(out)

        return torch.stack(outputs, dim=0)

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# Helpers: time embeddings
# ----------------------------
class TimeEmbedding(nn.Module):
    """
    Sinusoidal + learned projection for tau in [0,1].
    """
    def __init__(self, emb_dim=64, n_freq=8):
        super().__init__()
        self.n_freq = n_freq
        self.proj = nn.Linear(2*n_freq, emb_dim)

    def forward(self, tau):  # tau: (P, T) in [0,1]
        P, T = tau.shape
        device = tau.device
        # [P, T, 2*n_freq]
        freqs = torch.arange(self.n_freq, device=device).float()  # 0..n-1
        ang = tau.unsqueeze(-1) * (2.0 * np.pi * (freqs + 1.0))   # avoid 0 freq
        sin = torch.sin(ang)
        cos = torch.cos(ang)
        feats = torch.cat([sin, cos], dim=-1)                     # (P, T, 2*n_freq)
        return self.proj(feats)                                   # (P, T, emb_dim)


# ----------------------------
# Two-stream decoder
# ----------------------------
class TwoStreamDecoder(nn.Module):
    """
    Stream A (Goal-drift): drives motion toward (ball_land_x, ball_land_y).
    Stream B (Interaction correction): local evasive/collision adjustments from context.

    Inputs per play:
      h_role: (P, D)     role-specific embeddings from 3.3
      goal_feat: (P, G)  per-player goal features [dx0, dy0, dist0, ux, uy]
      tau_seq: (P, T)    time-to-land values in [0,1] per player
      horizon: (P,) long per-player num_frames_output (1..N_max)
    Outputs:
      dxy: (P, T, 2)     residuals Œîx,Œîy relative to last input frame
      mask: (P, T)       1 within horizon, 0 after
    """
    def __init__(self, d_model=128, time_dim=64, goal_dim=5, hidden=256, N_max=30):
        super().__init__()
        self.goal_dim = goal_dim
        self.N_max = N_max
        self.time_emb = TimeEmbedding(emb_dim=time_dim, n_freq=8)

        inA = d_model + time_dim + goal_dim
        inB = d_model + time_dim

        # Stream A: smooth drift toward goal
        self.streamA = nn.Sequential(
            nn.Linear(inA, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 2)  # Œîx, Œîy per step
        )

        # Stream B: local interaction correction
        self.streamB = nn.Sequential(
            nn.Linear(inB, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 2)
        )

    def forward(self, h_role, goal_feat, tau_seq, horizon):
        """
        h_role:   (P, D)
        goal_feat:(P, 5)  = [dx0, dy0, dist0, ux, uy]
        tau_seq:  (P, T)
        horizon:  (P,)
        """
        P, D = h_role.shape
        T = tau_seq.shape[1]

        # ‚úÖ Allow variable horizon lengths (for curriculum training)
        if T != self.N_max:
            # ensure consistent device, dtype, and leading dimension P
            pad_len = max(self.N_max - T, 0)
            if pad_len > 0:
                pad = torch.ones(
                    (P, pad_len),
                    dtype=tau_seq.dtype,
                    device=tau_seq.device
                )
                tau_seq = torch.cat([tau_seq, pad], dim=1)
            elif T > self.N_max:
                tau_seq = tau_seq[:, :self.N_max]
            T = self.N_max


        # time embedding
        t_emb = self.time_emb(tau_seq)             # (P, T, time_dim)

        # expand static inputs across time
        h_rep = h_role.unsqueeze(1).expand(P, T, D)        # (P, T, D)
        g_rep = goal_feat.unsqueeze(1).expand(P, T, goal_feat.size(1))  # (P, T, 5)

        # stream A: goal drift
        a_in = torch.cat([h_rep, t_emb, g_rep], dim=-1)    # (P, T, D+time_dim+5)
        dA = self.streamA(a_in)                            # (P, T, 2)

        # stream B: interaction correction
        b_in = torch.cat([h_rep, t_emb], dim=-1)           # (P, T, D+time_dim)
        dB = self.streamB(b_in)                            # (P, T, 2)

        dxy = dA + dB                                      # (P, T, 2)

        # horizon mask: 1..H_i active
        device = h_role.device
        t_idx = torch.arange(T, device=device).unsqueeze(0).expand(P, T)  # 0..T-1
        mask = (t_idx < horizon.unsqueeze(1)).float()                     # (P, T)

        return dxy, mask


# ----------------------------
# Per-play tensor builder
# ----------------------------
def prepare_play_decoder_inputs(play_graph_nodes: pd.DataFrame,
                                df_in_norm: pd.DataFrame,
                                game_id: int, play_id: int,
                                N_max: int = 30,
                                device: str = "cpu"):
    """
    Builds inputs for TwoStreamDecoder from one play.

    Returns:
      x0y0:      (P, 2) last input frame positions (for later reconstruction)
      h0_goal:   (P, 5) [dx0, dy0, dist0, ux, uy]
      tau_seq:   (P, N_max) tau = t / num_frames_output (clipped to 1)
      horizon:   (P,)  long
      (Plus convenience dict with ball_land per play)
    """
    nodes = play_graph_nodes.reset_index(drop=True).copy()  # needs columns: nfl_id,x,y,player_role,etc.

    # last input position per player is already in nodes['x','y'] from 1.3 throw-frame snapshot
    x0y0 = torch.tensor(nodes[["x", "y"]].to_numpy(), dtype=torch.float32, device=device)  # (P,2)

    # ball landing (per play) from df_in_norm (any row of this play has same ball_land)
    play_rows = df_in_norm[(df_in_norm.game_id == game_id) & (df_in_norm.play_id == play_id)]
    bx = float(play_rows["ball_land_x"].iloc[-1])
    by = float(play_rows["ball_land_y"].iloc[-1])

    # goal vector at throw
    dx0 = torch.tensor((bx - nodes["x"]).to_numpy(), dtype=torch.float32, device=device)
    dy0 = torch.tensor((by - nodes["y"]).to_numpy(), dtype=torch.float32, device=device)
    dist0 = torch.sqrt(dx0**2 + dy0**2) + 1e-6
    ux = dx0 / dist0
    uy = dy0 / dist0
    h0_goal = torch.stack([dx0, dy0, dist0, ux, uy], dim=-1)  # (P,5)

    # per-player horizon from input table
    # num_frames_output is per (game,play,nfl). Take the last input row per player.
    horizon_np = (
        play_rows.sort_values(["nfl_id","frame_id"])
                 .groupby("nfl_id")["num_frames_output"]
                 .last()
                 .reindex(nodes["nfl_id"])
                 .fillna(0).to_numpy(dtype=np.int64)
    )
    horizon = torch.tensor(np.minimum(horizon_np, N_max), dtype=torch.long, device=device)  # (P,)

    # tau sequence per player (P, T)
    T = N_max
    t_grid = torch.arange(1, T+1, device=device).float().unsqueeze(0).expand(len(nodes), T)  # 1..T
    denom = torch.clamp(horizon.unsqueeze(1).float(), min=1.0)
    tau_seq = torch.clamp(t_grid / denom, max=1.0)  # (P,T) in [0,1]

    meta = dict(ball_land=(bx, by))
    return x0y0, h0_goal, tau_seq, horizon, meta

In [None]:
def normalize_output_like_input(df_out: pd.DataFrame, df_in: pd.DataFrame) -> pd.DataFrame:
    """Flip output x for left plays and center y, using play_direction from input."""
    df_out = df_out.copy()
    dir_map = (df_in[["game_id","play_id","play_direction"]]
               .drop_duplicates()
               .assign(is_left=lambda d: d["play_direction"].str.lower()=="left")
               .drop(columns="play_direction"))
    df_out = df_out.merge(dir_map, on=["game_id","play_id"], how="left")
    df_out.loc[df_out["is_left"]==True, "x"] = 120 - df_out.loc[df_out["is_left"]==True, "x"]
    df_out["y"] = df_out["y"] - 26.65
    return df_out.drop(columns=["is_left"])

In [None]:
def precompute_graph_tensors(graphs, df_in_norm, N_max=30):
    graphs_fast = {}

    for (gid, pid), g in graphs.items():

        nodes = g["nodes"].sort_values("nfl_id").reset_index(drop=True)
        edges = g["edges"]

        player_ids = nodes["nfl_id"].tolist()
        P = len(player_ids)

        # -------- node xy --------
        node_xy = torch.tensor(nodes[["x","y"]].values, dtype=torch.float32)

        # -------- roles ----------
        roles = nodes["player_role"].tolist()

        # -------- edges ----------
        id_new = {nid:i for i,nid in enumerate(player_ids)}
        edges = edges[edges["src_id"].isin(id_new) & edges["dst_id"].isin(id_new)]

        src_idx = edges["src_id"].map(id_new).to_numpy()
        dst_idx = edges["dst_id"].map(id_new).to_numpy()

        edge_index = torch.tensor([src_idx, dst_idx], dtype=torch.long)
        edge_attr  = torch.tensor(
            edges[["dx","dy","dvx","dvy","dist","dv","cos_bear","sin_bear","ally_flag"]]
            .to_numpy(np.float32)
        )

        # -------- global context --------
        global_ctx = torch.tensor([
            nodes["x"].mean(),
            nodes["y"].mean(),
            0.42
        ], dtype=torch.float32)

        # -------- PRECOMPUTE DECODER INPUTS (goal_feat, tau_seq, horizon) --------
        x0y0, goal_feat, tau_seq, horizon, meta = prepare_play_decoder_inputs(
            nodes, df_in_norm, gid, pid, N_max=N_max, device="cpu"
        )
        # x0y0 should match node_xy; we keep node_xy as-is

        graphs_fast[(gid,pid)] = {
            "player_ids": player_ids,
            "node_xy": node_xy,              # (P,2)
            "roles": roles,                  # list len P
            "edge_index": edge_index,        # (2,E)
            "edge_attr": edge_attr,          # (E,9)
            "global_ctx": global_ctx,        # (3,)
            "goal_feat": goal_feat,          # (P,5)
            "tau_seq": tau_seq,              # (P,N_max)
            "horizon": horizon,              # (P,)
        }

    return graphs_fast
    
# graphs_fast = precompute_graph_tensors(graphs, df_out_norm, df_in_norm, N_max=30)

In [None]:
# =========================
# üîß Utilities / Seeding
# =========================
import os, random, math
import numpy as np
import torch
import time
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from contextlib import nullcontext

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =========================================================
# üß† End-to-end model wrapper (enc + GAT + role adapters + dec)
# =========================================================
class End2EndModel(nn.Module):
    def __init__(self, in_dim=8, embed_dim=128, d_model=128, time_dim=64, goal_dim=5,
                 hidden=256, N_max=30, gat_edge_dim=9, gat_heads=4, gat_layers=2,
                 role_names=("Targeted Receiver", "Defensive Coverage", "Passer", "Other Route Runner")):
        super().__init__()
        # Keep param count ~3‚Äì5M by modest dims
        self.encoder = TemporalTransformer(in_dim=in_dim, d_model=embed_dim)
        self.spatial = SpatialTransformer(d_model=embed_dim)
        self.adapters = RoleSpecificAdapters(embed_dim=embed_dim, hidden_dim=embed_dim,
                                             role_names=list(role_names))
        self.decoder = TwoStreamDecoder(d_model=embed_dim, time_dim=time_dim,
                                        goal_dim=goal_dim, hidden=hidden, N_max=N_max)
        self.role_names = list(role_names)
        self.role_map = {r:i for i,r in enumerate(self.role_names)}
        self.N_max = N_max

    def forward_batch(self, batch, N_max_curr):
        """
        batch: dict from play_collate
        Processes all players from all plays in one shot.
        """
        # unpack & move to device
        x_hist   = batch["x_hist"].to(DEVICE).float()          # (P_tot,K,F)
        node_xy  = batch["node_xy"].to(DEVICE).float()         # (P_tot,2)
        edge_index = batch["edge_index"].to(DEVICE).long()     # (2,E_tot)
        edge_attr  = batch["edge_attr"].to(DEVICE).float()     # (E_tot,9)
        goal_feat  = batch["goal_feat"].to(DEVICE).float()     # (P_tot,5)
        tau_seq    = batch["tau_seq"][:, :N_max_curr].to(DEVICE).float()  # (P_tot,T)
        horizon    = torch.clamp(batch["horizon"], max=N_max_curr).to(DEVICE).long()  # (P_tot,)
        global_ctx = batch["global_ctx"].to(DEVICE).float()    # (3,)
        roles      = batch["roles"]                            # list len P_tot

        h_nodes = self.encoder(x_hist)            # (P,D)
        h_ctx = self.spatial(h_nodes, edge_index, edge_attr)

        # role adapters
        h_role = self.adapters(h_ctx, roles, self.role_map)    # (P_tot,D)

        # decoder
        dxy, mask = self.decoder(h_role, goal_feat, tau_seq, horizon)  # (P_tot,T,2),(P_tot,T)

        xy_pred = node_xy.unsqueeze(1) + torch.cumsum(dxy, dim=1)
        xy_pred = xy_pred * mask.unsqueeze(-1)

        return xy_pred, mask

In [None]:
def unnormalize_field_direction(preds: pd.DataFrame, raw_input: pd.DataFrame) -> pd.DataFrame:
    # raw_input has original play_direction
    dir_map = (
        raw_input[["game_id","play_id","play_direction"]]
        .drop_duplicates()
        .assign(is_left=lambda d: d["play_direction"].str.lower()=="left")
        .drop(columns="play_direction")
    )
    out = preds.merge(dir_map, on=["game_id","play_id"], how="left")
    out["y"] = out["y"] + 26.65
    mask_left = out["is_left"] == True
    out.loc[mask_left, "x"] = 120 - out.loc[mask_left, "x"]

    return out.drop(columns=["is_left"])

In [None]:
import os
import pandas as pd
import polars as pl
import torch
import numpy as np
from sklearn.neighbors import NearestNeighbors

from torch import nn

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = End2EndModel(
    in_dim=8, embed_dim=128, d_model=128, time_dim=64, goal_dim=5, hidden=256, N_max=30
).to(DEVICE)

# Load weights from dataset you attached
model.load_state_dict(torch.load("/kaggle/input/final-model-seed3-stt/pytorch/default/1/final_model_seed3_stt.pt", map_location=DEVICE))
model.eval()

In [None]:
def predict(test: pl.DataFrame, test_input: pl.DataFrame) -> pd.DataFrame:
    """Kaggle will call this repeatedly per timestep"""

    df = test_input.to_pandas()
    df = normalize_field_direction(df)

    df = df.sort_values(["game_id","play_id","nfl_id","frame_id"])
    df["vx"] = df.groupby(["game_id","play_id","nfl_id"])["x"].diff().fillna(0)
    df["vy"] = df.groupby(["game_id","play_id","nfl_id"])["y"].diff().fillna(0)

    # # accelerations
    # df["ax"] = df.groupby(["game_id","play_id","nfl_id"])["vx"].diff().fillna(0)
    # df["ay"] = df.groupby(["game_id","play_id","nfl_id"])["vy"].diff().fillna(0)
    
    # # delta speed
    # df["ds"] = df.groupby(["game_id","play_id","nfl_id"])["s"].diff().fillna(0)
    
    # # cos/sin of direction
    # df["cos_dir"] = np.cos(np.radians(df["dir"]))
    # df["sin_dir"] = np.sin(np.radians(df["dir"]))
    
    # # normalized frame index
    # df["frame_norm"] = (
    #     df.groupby(["game_id","play_id","nfl_id"])["frame_id"]
    #       .transform(lambda x: x / x.max())
    # )

    # Build graph for this play
    graphs = build_interaction_graphs(df, K=6)
    (gid, pid), graph = next(iter(graphs.items()))
    nodes = graph["nodes"].copy()
    
    graphs_fast = precompute_graph_tensors(
        {(gid, pid): graph},
        df_in_norm=df,
        N_max=30
    )
    G = graphs_fast[(gid, pid)]

    # Fill missing columns
    for c in ["vx","vy","s","a","dir","o"]:
        if c not in nodes.columns:
            nodes[c] = 0.0

    # Build REAL history from test_input
    # ======================================================
    K_hist = 10
    features = ["x","y","vx","vy","s","a","dir","o"]
    
    # compute vx/vy from df
    nodes = nodes.sort_values("nfl_id").reset_index(drop=True)
    group = df[df["nfl_id"].isin(nodes["nfl_id"])]
    
    # extract last K frames per player
    hist_list = []
    for nid in nodes["nfl_id"]:
        g = group[group["nfl_id"] == nid].sort_values("frame_id")
        g["vx"] = g["x"].diff().fillna(0)
        g["vy"] = g["y"].diff().fillna(0)
        # pad if fewer than K frames
        tail = g[features].to_numpy(np.float32)[-K_hist:]
        if tail.shape[0] < K_hist:
            pad = np.repeat(tail[:1], K_hist - tail.shape[0], axis=0)
            tail = np.vstack([pad, tail])
        hist_list.append(tail)
        
    x_hist = torch.tensor(np.stack(hist_list, axis=0), device=DEVICE)

    # Play + player info
    gid = int(df["game_id"].iloc[0])
    pid = int(df["play_id"].iloc[0])
    player_ids = G["player_ids"]

    # ----------------------------------------------------------
    # 4. Build batch for forward_batch() (NOT forward_one_play)
    # ----------------------------------------------------------
    batch = {
        "x_hist":    x_hist,                       # (P,K,F)
        "node_xy":   G["node_xy"].to(DEVICE),      # (P,2)
        "edge_index":G["edge_index"].to(DEVICE),   # (2,E)
        "edge_attr": G["edge_attr"].to(DEVICE),    # (E,9)
        "global_ctx":G["global_ctx"].to(DEVICE),   # (3,)
        "goal_feat": G["goal_feat"].to(DEVICE),    # (P,5)
        "tau_seq":   G["tau_seq"].to(DEVICE),      # (P,30)
        "horizon":   G["horizon"].to(DEVICE),      # (P,)
        "roles":     G["roles"],                   # list len P
    }

    # ----------------------------------------------------------
    # 5. Run model forward
    # ----------------------------------------------------------
    with torch.no_grad():
        xy_pred, mask = model.forward_batch(batch, N_max_curr=30)

    xy_pred = xy_pred.cpu().numpy()

    # ================================
    # Ensemble: average predictions
    # ================================
    # preds_list = []
    
    # with torch.no_grad():
    #     for m in models:
    #         out, mask = m.forward_batch(batch, N_max_curr=30)
    #         preds_list.append(out.cpu())
    
    # # Stack ‚Üí (3, P, 30, 2)
    # preds_tensor = torch.stack(preds_list, dim=0)
    
    # # Mean over models ‚Üí (P, 30, 2)
    # xy_pred = preds_tensor.mean(dim=0).numpy()

    preds = pd.DataFrame({
        "nfl_id": player_ids,
        "x": xy_pred[:, 0, 0],  # next-frame prediction
        "y": xy_pred[:, 0, 1],
    })

    # Align to `test` shape
    test_pd = test.to_pandas()
    raw_input_pd = test_input.to_pandas()
    
    # Attach original IDs for inversion
    preds = preds.merge(
        raw_input_pd[["game_id","play_id","nfl_id"]].drop_duplicates(),
        on="nfl_id",
        how="left"
    )
    
    # Undo field normalization using original, unnormalized features
    preds = unnormalize_field_direction(preds, raw_input_pd)
    
    merged = test_pd.merge(preds, on="nfl_id", how="left")

    # Fill missing predictions safely (e.g., players filtered out)
    merged["x"] = merged["x"].fillna(60.0)
    merged["y"] = merged["y"].fillna(0.0)
    
    # ‚úÖ Diagnostic print for local gateway runs
    print(f"‚úÖ Play {gid}-{pid}: expected {len(test_pd)} rows, predicted {len(merged)}")

    return pl.DataFrame(merged[["x", "y"]])

In [None]:
import kaggle_evaluation.nfl_inference_server

inference_server = kaggle_evaluation.nfl_inference_server.NFLInferenceServer(predict)

if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    inference_server.serve()
else:
    inference_server.run_local_gateway(('/kaggle/input/nfl-big-data-bowl-2026-prediction/',))