In [1]:
import torch
import torch.nn as nn
torch.__version__


'2.1.0'

In [2]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence



import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

from datasets import PlayByPlayDataset

dataset = PlayByPlayDataset("../data/nfl-big-data-bowl-2024/play_by_play_features.pkl")


In [28]:
list(dataset[2].keys()), dataset[2]["play_features"]


(['offense_geometric',
  'defense_geometric',
  'offense_raw',
  'defense_raw',
  'ball_carrier_raw',
  'play_features',
  'game_id',
  'play_id',
  'player_tracking',
  'event_timeseries',
  'players_on_the_field',
  'tackle_successful',
  'yards_after_contact'],
 ballCarrierId                        47857
 ballCarrierDisplayName    Devin Singletary
 quarter                                  1
 down                                     2
 yardsToGo                                3
 possessionTeam                         BUF
 defensiveTeam                           LA
 yardlineSide                           BUF
 yardlineNumber                          45
 gameClock                            13:15
 preSnapHomeScore                         0
 preSnapVisitorScore                      0
 passResult                               C
 absoluteYardlineNumber                  65
 prePenaltyPlayResult                     6
 playResult                               6
 offenseFormation              

In [48]:
PAD_VALUE = 123456789.
TIME_SERIES_KEYS = ["offense_geometric", "offense_raw", "defense_geometric", "defense_raw", "ball_carrier_raw", "event_timeseries"]
TARGET_KEY = "yards_after_contact"
TREATMENT_KEY = "tackle_successful"
STATIC_KEYS = [] # future: play features and on-field player info

from collections import defaultdict

def collate_padded_play_data(batch):
    batchdict = defaultdict(list)
    for item in batch:
        for k, v in item.items():
            if k in TIME_SERIES_KEYS + STATIC_KEYS:
                batchdict[k].append(torch.from_numpy(v.to_numpy()))
            elif k in [TARGET_KEY, TREATMENT_KEY]:
                batchdict[k].append(v)
    X_padded = torch.cat([pad_sequence(batchdict[k], batch_first=True, padding_value=PAD_VALUE) for k in TIME_SERIES_KEYS], dim=2)
    return {"time_series_features": X_padded, "features": batchdict["play_features"], "target": batchdict["yards_after_contact"], "treatment": batchdict["tackle_successful"]}


dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_padded_play_data)
batch = next(iter(dataloader))


In [49]:
batch["time_series_features"].shape


torch.Size([8, 39, 248])