# Preprocess the data:

## Imports:

In [1]:
import py7zr
from datasets import load_dataset, load_dataset_builder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os

In [3]:
import numpy as np

## Constants:

In [26]:
NUM_PROC = os.cpu_count()
D = 33

## Look at dataset:

In [76]:
ds_builder = load_dataset_builder("dcayton/nba_tracking_data_15_16", 'full')

In [77]:
ds_builder.info.description

'This dataset is designed to give further easy access to tracking data.\nBy merging all .7z files into one large .json file, access is easier to retrieve all information at once.\n'

In [78]:
ds_builder.info.features

{'gameid': Value(dtype='string', id=None),
 'gamedate': Value(dtype='string', id=None),
 'event_info': {'id': Value(dtype='string', id=None),
  'type': Value(dtype='int64', id=None),
  'possession_team_id': Value(dtype='float64', id=None),
  'desc_home': Value(dtype='string', id=None),
  'desc_away': Value(dtype='string', id=None)},
 'primary_info': {'team': Value(dtype='string', id=None),
  'player_id': Value(dtype='float64', id=None),
  'team_id': Value(dtype='float64', id=None)},
 'secondary_info': {'team': Value(dtype='string', id=None),
  'player_id': Value(dtype='float64', id=None),
  'team_id': Value(dtype='float64', id=None)},
 'visitor': {'name': Value(dtype='string', id=None),
  'teamid': Value(dtype='int64', id=None),
  'abbreviation': Value(dtype='string', id=None),
  'players': [{'lastname': Value(dtype='string', id=None),
    'firstname': Value(dtype='string', id=None),
    'playerid': Value(dtype='int64', id=None),
    'jersey': Value(dtype='string', id=None),
    'posit

In [80]:
for split_name, split_info in ds_builder.info.splits.items():
    print(f"{split_name}: {split_info.num_examples} examples")

AttributeError: 'NoneType' object has no attribute 'items'

In [93]:
ds_builder.info.num_examples

AttributeError: 'DatasetInfo' object has no attribute 'num_examples'

## Load Dataset:

In [5]:
# Looks like there is only a train split (loading without split yields a dictionary with only "train": data not split already)
dataset = load_dataset("dcayton/nba_tracking_data_15_16", name = "tiny", split="train", trust_remote_code=True)

In [4]:
dataset[0]['moments'][0]

{'quarter': 1,
 'game_clock': 707.65,
 'shot_clock': 24.0,
 'ball_coordinates': {'x': 5.48747, 'y': 24.25562, 'z': 4.88724},
 'player_coordinates': [{'teamid': 1610612754,
   'playerid': 201588,
   'x': 22.84608,
   'y': 43.37048,
   'z': 0.0},
  {'teamid': 1610612754,
   'playerid': 101133,
   'x': 6.9984,
   'y': 27.64791,
   'z': 0.0},
  {'teamid': 1610612754,
   'playerid': 101145,
   'x': 15.35528,
   'y': 7.98835,
   'z': 0.0},
  {'teamid': 1610612754,
   'playerid': 202730,
   'x': 28.86235,
   'y': 19.02946,
   'z': 0.0},
  {'teamid': 1610612754,
   'playerid': 202331,
   'x': 9.36919,
   'y': 46.68742,
   'z': 0.0},
  {'teamid': 1610612748,
   'playerid': 2548,
   'x': 10.42233,
   'y': 12.393,
   'z': 0.0},
  {'teamid': 1610612748,
   'playerid': 2547,
   'x': 18.48075,
   'y': 22.48179,
   'z': 0.0},
  {'teamid': 1610612748,
   'playerid': 2736,
   'x': 6.93975,
   'y': 37.22873,
   'z': 0.0},
  {'teamid': 1610612748,
   'playerid': 201609,
   'x': 10.11883,
   'y': 34.4858,

In [5]:
type(dataset[0]['moments'])

list

In [None]:
len(dataset[0]['moments'])

150

## Downsample dataset:

In [33]:
example = dataset[0]

In [None]:
list(map(lambda x : x.pop('lastname', None), x.pop('firstname', None), x, example['home']['players']))

[(None, None, {'playerid': 101133, 'jersey': '28', 'position': 'C'}),
 (None, None, {'playerid': 101139, 'jersey': '0', 'position': 'F-G'}),
 (None, None, {'playerid': 101145, 'jersey': '11', 'position': 'G'}),
 (None, None, {'playerid': 201155, 'jersey': '2', 'position': 'G'}),
 (None, None, {'playerid': 201588, 'jersey': '3', 'position': 'G'}),
 (None, None, {'playerid': 201941, 'jersey': '27', 'position': 'C'}),
 (None, None, {'playerid': 201978, 'jersey': '10', 'position': 'F'}),
 (None, None, {'playerid': 202331, 'jersey': '13', 'position': 'F'}),
 (None, None, {'playerid': 202730, 'jersey': '5', 'position': 'F-C'}),
 (None, None, {'playerid': 203524, 'jersey': '44', 'position': 'F'}),
 (None, None, {'playerid': 203922, 'jersey': '40', 'position': 'G-F'}),
 (None, None, {'playerid': 1626167, 'jersey': '33', 'position': 'F-C'}),
 (None, None, {'playerid': 1626202, 'jersey': '1', 'position': 'G'})]

In [15]:
example['gameid']=0

In [17]:
dataset[0]['gameid']

'0021500333'

In [23]:
del example['event_info']['desc_away']

In [None]:
    # del example['gamedate']
    # del example['event_info']['desc_home']
    # del example['event_info']['desc_away']
    # del example['primary_info']['team']
    # del example['primary_info']['team_id']
    # del example['secondary_info']['team']
    # del example['secondary_info']['team_id']
    # del example['visitor']['name']
    # del example['visitor']['teamid']
    # del example['visitor']['abbreviation']
    # del example['visitor']['players'][:]['lastname']
    # del example['visitor']['players'][:]['firstname']
    # del example['visitor']['players'][:]['number']
    # del example['home']['name']
    # del example['home']['teamid']
    # del example['home']['abbreviation']
    # del example['home']['players']
    # del example['home']['players'][:]['firstname']
    # del example['home']['players'][:]['number']
    # del example['moments']['quarter']
    # del example['moments']['game_clock']
    # del example['moments']['shot_clock']
    # del example['moments']['player_coordinates']['teamid']

In [6]:
def downsample(example):
    example['moments'] = example['moments'][::10]
    return example

In [7]:
def prune_example(example):
    return {
        # top‐level ID
        "n_gameid": example["gameid"],

        # only the fields we care about, renaming `id` → `eventid`
        "n_event_info": {
            "id": example["event_info"]["id"],
            "type": example["event_info"]["type"],
            "possession_team_id": example["event_info"]["possession_team_id"],
        },

        # only player_id from primary and secondary
        "n_primary_info": {
            "player_id": example["primary_info"]["player_id"]
        },
        "n_secondary_info": {
            "player_id": example["secondary_info"]["player_id"]
        },

        # visitor: keep only each player’s position
        "n_visitor": {
            "players": [
                {"position": p["position"]}
                for p in example["visitor"]["players"]
            ]
        },

        # home: keep only playerid and position
        "n_home": {
            "players": [
                {"playerid": p["playerid"], "position": p["position"]}
                for p in example["home"]["players"]
            ]
        },

        # downsampled moments: only ball_coords and each player’s id+xyz
        "n_moments": [
            {
                "ball_coordinates": {
                    "x": m["ball_coordinates"]["x"],
                    "y": m["ball_coordinates"]["y"],
                    "z": m["ball_coordinates"]["z"],
                },
                "player_coordinates": [
                    {
                        "playerid": pc["playerid"],
                        "x": pc["x"],
                        "y": pc["y"],
                        "z": pc["z"],
                    }
                    for pc in m["player_coordinates"]
                ],
            }
            for m in example["moments"]
        ],
    }


In [8]:
type(downsample(dataset[0]))

dict

In [9]:
downsample(dataset[0])['moments_ds']

[{'quarter': 1,
  'game_clock': 707.65,
  'shot_clock': 24.0,
  'ball_coordinates': {'x': 5.48747, 'y': 24.25562, 'z': 4.88724},
  'player_coordinates': [{'teamid': 1610612754,
    'playerid': 201588,
    'x': 22.84608,
    'y': 43.37048,
    'z': 0.0},
   {'teamid': 1610612754,
    'playerid': 101133,
    'x': 6.9984,
    'y': 27.64791,
    'z': 0.0},
   {'teamid': 1610612754,
    'playerid': 101145,
    'x': 15.35528,
    'y': 7.98835,
    'z': 0.0},
   {'teamid': 1610612754,
    'playerid': 202730,
    'x': 28.86235,
    'y': 19.02946,
    'z': 0.0},
   {'teamid': 1610612754,
    'playerid': 202331,
    'x': 9.36919,
    'y': 46.68742,
    'z': 0.0},
   {'teamid': 1610612748,
    'playerid': 2548,
    'x': 10.42233,
    'y': 12.393,
    'z': 0.0},
   {'teamid': 1610612748,
    'playerid': 2547,
    'x': 18.48075,
    'y': 22.48179,
    'z': 0.0},
   {'teamid': 1610612748,
    'playerid': 2736,
    'x': 6.93975,
    'y': 37.22873,
    'z': 0.0},
   {'teamid': 1610612748,
    'playeri

In [10]:
len(downsample(dataset[0])['moments_ds'])

15

In [8]:
%%time
dataset_dsed = dataset.map(downsample, num_proc=NUM_PROC)

Map (num_proc=16): 100%|██████████| 2219/2219 [00:18<00:00, 120.96 examples/s]

CPU times: user 1.38 s, sys: 797 ms, total: 2.18 s
Wall time: 18.9 s





In [57]:
len(dataset_dsed[0]['moments'])

15

In [98]:
dataset_dsed.column_names

['gameid',
 'gamedate',
 'event_info',
 'primary_info',
 'secondary_info',
 'visitor',
 'home',
 'moments']

In [9]:
dataset_prcsd  = dataset_dsed.map(prune_example, num_proc=NUM_PROC, remove_columns=dataset_dsed.column_names)

Map (num_proc=16): 100%|██████████| 2219/2219 [00:03<00:00, 711.73 examples/s]


In [10]:
dataset_prcsd = dataset_prcsd.rename_columns({'n_gameid':'gameid',
                              'n_event_info':'event_info',
                              'n_primary_info':'primary_info',
                              'n_secondary_info':'secondary_info',
                              'n_visitor':'visitor',
                              'n_home':'home',
                              'n_moments':'moments'})

In [111]:
dataset_prcsd.features

{'gameid': Value(dtype='string', id=None),
 'event_info': {'id': Value(dtype='string', id=None),
  'possession_team_id': Value(dtype='float64', id=None),
  'type': Value(dtype='int64', id=None)},
 'primary_info': {'player_id': Value(dtype='float64', id=None)},
 'secondary_info': {'player_id': Value(dtype='float64', id=None)},
 'visitor': {'players': [{'position': Value(dtype='string', id=None)}]},
 'home': {'players': [{'playerid': Value(dtype='int64', id=None),
    'position': Value(dtype='string', id=None)}]},
 'moments': [{'ball_coordinates': {'x': Value(dtype='float64', id=None),
    'y': Value(dtype='float64', id=None),
    'z': Value(dtype='float64', id=None)},
   'player_coordinates': [{'playerid': Value(dtype='int64', id=None),
     'x': Value(dtype='float64', id=None),
     'y': Value(dtype='float64', id=None),
     'z': Value(dtype='float64', id=None)}]}]}

In [12]:
dataset[0]['visitor']

{'name': 'Miami Heat',
 'teamid': 1610612748,
 'abbreviation': 'MIA',
 'players': [{'lastname': 'Andersen',
   'firstname': 'Chris',
   'playerid': 2365,
   'jersey': '11',
   'position': 'F-C'},
  {'lastname': 'Stoudemire',
   'firstname': "Amar'e",
   'playerid': 2405,
   'jersey': '5',
   'position': 'F-C'},
  {'lastname': 'Bosh',
   'firstname': 'Chris',
   'playerid': 2547,
   'jersey': '1',
   'position': 'F'},
  {'lastname': 'Wade',
   'firstname': 'Dwyane',
   'playerid': 2548,
   'jersey': '3',
   'position': 'G'},
  {'lastname': 'Haslem',
   'firstname': 'Udonis',
   'playerid': 2617,
   'jersey': '40',
   'position': 'F'},
  {'lastname': 'Deng',
   'firstname': 'Luol',
   'playerid': 2736,
   'jersey': '9',
   'position': 'F'},
  {'lastname': 'Udrih',
   'firstname': 'Beno',
   'playerid': 2757,
   'jersey': '19',
   'position': 'G'},
  {'lastname': 'Green',
   'firstname': 'Gerald',
   'playerid': 101123,
   'jersey': '14',
   'position': 'G'},
  {'lastname': 'Dragic',
   '

In [10]:
dataset.features

{'gameid': Value(dtype='string', id=None),
 'gamedate': Value(dtype='string', id=None),
 'event_info': {'id': Value(dtype='string', id=None),
  'type': Value(dtype='int64', id=None),
  'possession_team_id': Value(dtype='float64', id=None),
  'desc_home': Value(dtype='string', id=None),
  'desc_away': Value(dtype='string', id=None)},
 'primary_info': {'team': Value(dtype='string', id=None),
  'player_id': Value(dtype='float64', id=None),
  'team_id': Value(dtype='float64', id=None)},
 'secondary_info': {'team': Value(dtype='string', id=None),
  'player_id': Value(dtype='float64', id=None),
  'team_id': Value(dtype='float64', id=None)},
 'visitor': {'name': Value(dtype='string', id=None),
  'teamid': Value(dtype='int64', id=None),
  'abbreviation': Value(dtype='string', id=None),
  'players': [{'lastname': Value(dtype='string', id=None),
    'firstname': Value(dtype='string', id=None),
    'playerid': Value(dtype='int64', id=None),
    'jersey': Value(dtype='string', id=None),
    'posit

In [108]:
# Build a list of lengths
lengths = [ len(moments_list) for moments_list in dataset_prcsd['moments'] ]

# Compute mean and (sample) standard deviation
mean_length = np.mean(lengths)
sd_length   = np.std(lengths, ddof=1)

In [109]:
mean_length

np.float64(45.27129337539432)

In [110]:
sd_length

np.float64(22.597157044348098)

In [None]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

ImportError: cannot import name 'DataCollatorWithPadding' from 'torch.utils.data' (/home/konstantin/miniconda3/lib/python3.13/site-packages/torch/utils/data/__init__.py)

In [89]:
train_split = int(0.9 * len(datased_prcsd))

In [90]:
train_split

1997

In [92]:
train_dataloader = DataLoader(
    datased_prcsd[:train_split], shuffle=True, batch_size=8
)
test_dataloader = DataLoader(
    datased_prcsd[train_split:], batch_size=8
)

## Building the embeddings together with ChatGPT:

In [11]:
def flatten_example(ex):
    # ex['moments'] is a list of dicts: { 'ball_coordinates':…, 'player_coordinates': […] }
    seq = []
    for m in ex['moments']:
        # 1) ball coords
        bx, by, bz = m['ball_coordinates']['x'], m['ball_coordinates']['y'], m['ball_coordinates']['z']
        # 2) players: sort by playerid to keep order consistent
        pcs = sorted(m['player_coordinates'], key=lambda p: p['playerid'])
        coords = []
        for p in pcs:
            coords.extend([p['x'], p['y'], p['z']])
        seq.append([bx, by, bz] + coords)   # length = 3 + 3*#players
    return {'sequence': seq}

In [113]:
flatten_example(dataset_prcsd[0])

{'sequence': [[5.48747,
   24.25562,
   4.88724,
   18.48075,
   22.48179,
   0.0,
   10.42233,
   12.393,
   0.0,
   6.93975,
   37.22873,
   0.0,
   6.9984,
   27.64791,
   0.0,
   15.35528,
   7.98835,
   0.0,
   22.84608,
   43.37048,
   0.0,
   10.11883,
   34.4858,
   0.0,
   9.36919,
   46.68742,
   0.0,
   5.65447,
   25.04955,
   0.0,
   28.86235,
   19.02946,
   0.0],
  [3.91592,
   25.58263,
   4.07934,
   18.94992,
   22.25667,
   0.0,
   11.54876,
   12.57746,
   0.0,
   7.64769,
   37.71717,
   0.0,
   8.59038,
   29.12755,
   0.0,
   20.27971,
   9.1354,
   0.0,
   25.13357,
   42.09495,
   0.0,
   9.71629,
   34.73978,
   0.0,
   12.5674,
   45.28627,
   0.0,
   4.6167,
   25.41067,
   0.0,
   31.30909,
   19.18481,
   0.0],
  [2.17345,
   27.25701,
   4.59159,
   20.05491,
   21.77041,
   0.0,
   13.06183,
   12.89166,
   0.0,
   9.01332,
   39.07315,
   0.0,
   11.73506,
   30.23825,
   0.0,
   25.35531,
   10.16031,
   0.0,
   27.78324,
   40.36676,
   0.0,
   8.7961

In [49]:
dataset_fl = dataset_prcsd.map(flatten_example, num_proc=NUM_PROC, remove_columns=dataset_prcsd.column_names)

In [17]:
dataset_fl.features

{'sequence': Sequence(feature=Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None), length=-1, id=None)}

In [60]:
dataset_fl.set_format(type='torch', columns=['sequence'])

In [72]:
for x in dataset_fl['sequence']:
    print(x)

tensor([[ 5.4875e+00,  2.4256e+01,  4.8872e+00,  1.8481e+01,  2.2482e+01,
          0.0000e+00,  1.0422e+01,  1.2393e+01,  0.0000e+00,  6.9398e+00,
          3.7229e+01,  0.0000e+00,  6.9984e+00,  2.7648e+01,  0.0000e+00,
          1.5355e+01,  7.9883e+00,  0.0000e+00,  2.2846e+01,  4.3370e+01,
          0.0000e+00,  1.0119e+01,  3.4486e+01,  0.0000e+00,  9.3692e+00,
          4.6687e+01,  0.0000e+00,  5.6545e+00,  2.5050e+01,  0.0000e+00,
          2.8862e+01,  1.9029e+01,  0.0000e+00],
        [ 3.9159e+00,  2.5583e+01,  4.0793e+00,  1.8950e+01,  2.2257e+01,
          0.0000e+00,  1.1549e+01,  1.2577e+01,  0.0000e+00,  7.6477e+00,
          3.7717e+01,  0.0000e+00,  8.5904e+00,  2.9128e+01,  0.0000e+00,
          2.0280e+01,  9.1354e+00,  0.0000e+00,  2.5134e+01,  4.2095e+01,
          0.0000e+00,  9.7163e+00,  3.4740e+01,  0.0000e+00,  1.2567e+01,
          4.5286e+01,  0.0000e+00,  4.6167e+00,  2.5411e+01,  0.0000e+00,
          3.1309e+01,  1.9185e+01,  0.0000e+00],
        [ 2.17

KeyboardInterrupt: 

In [50]:
dataset_fl = dataset_fl.filter(lambda ex: len(ex["sequence"]) > 1)

Filter: 100%|██████████| 2219/2219 [00:01<00:00, 1254.61 examples/s]


In [51]:
dataset_fl.set_format(type="numpy", columns=["sequence"])

In [14]:
from torch.nn.utils.rnn import pad_sequence

In [None]:
def collate_fn(batch):
    # batch: list of {"sequence": Tensor|list, "mask": list|Tensor?}
    seqs = []
    masks = []
    for item in batch['sequence']:
        # 1) get the sequence, convert if needed
        # if not isinstance(item, torch.Tensor):
        #     item = torch.tensor(item, dtype=torch.float32)
        seqs.append(item)

        # 2) build the mask from lengths if the dataset didn't give one
        if "mask" in item:
            m = item["mask"]
            if not isinstance(m, torch.Tensor):
                m = torch.tensor(m, dtype=torch.bool)
        else:
            m = torch.ones(seq.size(0), dtype=torch.bool)
        masks.append(m)

    # 3) pad the sequences & masks
    padded_seqs = pad_sequence(seqs, batch_first=True, padding_value=0.0)   # [B, T_max, D]
    padded_masks = pad_sequence(masks, batch_first=True, padding_value=False)  # [B, T_max]

    return {"sequence": padded_seqs, "mask": padded_masks}

In [92]:
def collate_fn(batch):
    padded_seqs = pad_sequence(batch['sequence'], batch_first=True, padding_value=0.0)   # [B, T_max, D]
    return {"sequence": padded_seqs}

In [82]:
batches = dataset_fl.batch(batch_size=8)

In [103]:
batches

Dataset({
    features: ['sequence'],
    num_rows: 260
})

In [84]:
type(batches[0])

dict

In [91]:
len(batches[0]['sequence'][1])

48

In [95]:
batch_cltd = collate_fn(batches[0])

In [96]:
len(batch_cltd['sequence'])

8

In [97]:
for seq in batch_cltd['sequence']:
    print(len(seq))

88
88
88
88
88
88
88
88


In [38]:
def collate_fn(batch):
    # 1) Convert every seq to a FloatTensor
    seqs = [
        torch.as_tensor(item["sequence"], dtype=torch.float32)
        for item in batch
    ]
    # 2) Record original lengths
    lengths = [s.size(0) for s in seqs]

    # 3) Pad to the max length in this batch
    padded_seqs = pad_sequence(seqs, batch_first=True, padding_value=0.0)  # [B, T_max, D]

    # 4) Build a mask of real vs. pad
    #    True where t < original length
    mask = (
        torch.arange(padded_seqs.size(1))[None, :]
        < torch.tensor(lengths)[:, None]
    )  # [B, T_max]

    return {"sequence": padded_seqs, "mask": mask}

In [15]:
def collate_fn(batch):
    # batch is a list of {"sequence": Tensor[T_i, D]}
    seqs = [item['sequence'] for item in batch]
    # pad to the longest in the batch
    padded = pad_sequence(seqs, batch_first=True)  # -> [B, T_max, D]
    # optionally build an attention mask:
    lengths = torch.tensor([s.size(0) for s in seqs], dtype=torch.long)
    mask = (torch.arange(padded.size(1))[None, :] < lengths[:, None])
    return {'sequence': padded, 'mask': mask}

In [41]:
def collate_fn(batch):
    # 1) Convert every sequence to a FloatTensor
    seqs = [ torch.as_tensor(item["sequence"]).float() for item in batch ]
    
    # 2) Remember each original length
    lengths = [s.size(0) for s in seqs]
    
    # 3) Pad to the max length in the batch
    padded_seqs = pad_sequence(seqs, batch_first=True, padding_value=0.0)  # [B, T_max, D]
    
    # 4) Build a mask (True = real frame, False = pad)
    #    Use the same device as padded_seqs
    device = padded_seqs.device
    lengths_tensor = torch.tensor(lengths, device=device)
    mask = torch.arange(padded_seqs.size(1), device=device)[None, :] < lengths_tensor[:, None]
    
    return { "sequence": padded_seqs, "mask": mask }

In [46]:
def collate_fn(batch):
    seqs = []
    lengths = []
    for item in batch:
        raw = item["sequence"]              # either a NumPy array or a nested list
        arr = np.asarray(raw, dtype=np.float32)  
        # arr.shape == [T_i, D]
        tensor = torch.from_numpy(arr)      # float32 by default
        seqs.append(tensor)
        lengths.append(tensor.size(0))

    # pad to the longest in this batch
    padded = pad_sequence(seqs, batch_first=True, padding_value=0.0)  # [B, T_max, D]

    # build mask
    device = padded.device
    lengths = torch.tensor(lengths, device=device)
    mask = torch.arange(padded.size(1), device=device)[None, :] < lengths[:, None]  # [B, T_max]

    return {"sequence": padded, "mask": mask}

In [54]:
def collate_fn(batch):
    seqs = []
    lengths = []
    for item in batch:
        raw = item["sequence"]
        # 1) If HF gave us a numpy object array, convert to a list of lists
        if isinstance(raw, np.ndarray) and raw.dtype == object:
            raw = raw.tolist()
        # 2) Now raw should be a nested list: List[T_i][D]
        arr = np.array(raw, dtype=np.float32)   # shape = (T_i, D)
        # 3) Wrap as Tensor
        tensor = torch.from_numpy(arr)          # torch.Size([T_i, D])
        seqs.append(tensor)
        lengths.append(tensor.size(0))

    # 4) Pad all sequences to the batch max length
    padded = pad_sequence(seqs, batch_first=True, padding_value=0.0)  # [B, T_max, D]

    # 5) Build a boolean mask of real vs. padded frames
    device = padded.device
    lengths = torch.tensor(lengths, device=device)
    mask = torch.arange(padded.size(1), device=device)[None, :] < lengths[:, None]  # [B, T_max]

    return {"sequence": padded, "mask": mask}

In [55]:
from torch.utils.data import DataLoader
loader = DataLoader(dataset_fl, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [98]:
loader = DataLoader(
    dataset_fl,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)

In [101]:
loader.batch_sampler

<torch.utils.data.sampler.BatchSampler at 0x7f3814611410>

In [105]:
for batch in loader:
    print(batch)

KeyboardInterrupt: 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f383bf06660>
Traceback (most recent call last):
  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
 

    Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7f383bf06660>^
^^Traceback (most recent call last):
^  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
^    ^self._shutdown_workers()^
^  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
^    ^if w.is_alive():^
^  
   File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process'
    ^ ^^  ^ ^ ^^  ^ ^ ^^^^^
^  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^^ ^ ^ ^ ^ ^  ^ ^ ^^  ^^^^^^^^^^^^^^^^^^^^^^^^^^
^AssertionError^: ^^can only test a child process
^^Excep

KeyboardInterrupt: 

KeyboardInterrupt: 

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [19]:
# 1) Encoder: projects raw features and runs the Transformer stack
class PlayEncoder(nn.Module):
    def __init__(self, in_dim, emb_dim, nhead=4, nlayers=3):
        super().__init__()
        self.input_proj = nn.Linear(in_dim, emb_dim)
        layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=nhead,
            dim_feedforward=emb_dim*4,
            batch_first=False
        )
        self.transformer = nn.TransformerEncoder(
            layer,
            num_layers=nlayers,
            norm=nn.LayerNorm(emb_dim)
        )

    def forward(self, x, mask):
        """
        x:    [B, T, D] raw sequence
        mask: [B, T]    True=real, False=pad
        → returns H [B, T, E]
        """
        # project to model dim
        x = self.input_proj(x)           # [B, T, E]
        x = x.transpose(0, 1)            # [T, B, E]
        pad_mask = ~mask                  # [B, T]
        out = self.transformer(x, src_key_padding_mask=pad_mask)
        return out.transpose(0, 1)       # [B, T, E]

In [28]:
# 2) Decoder: from hidden states to next-frame features
class NextMomentDecoder(nn.Module):
    def __init__(self, emb_dim, out_dim):
        super().__init__()
        self.output_proj = nn.Linear(emb_dim, out_dim)

    def forward(self, h):
        """
        h:    [B, T, E] hidden states
        → pred [B, T, D] next-frame predictions
        """
        return self.output_proj(h)

In [29]:
# 3) Assemble end-to-end
class NextMomentModel(nn.Module):
    def __init__(self, in_dim, emb_dim, nhead=4, nlayers=3):
        super().__init__()
        self.encoder = PlayEncoder(in_dim, emb_dim, nhead, nlayers)
        self.decoder = NextMomentDecoder(emb_dim, in_dim)

    def forward(self, x, mask):
        # Encode then decode
        h    = self.encoder(x, mask)     # [B, T, E]
        pred = self.decoder(h)           # [B, T, D]
        return pred

In [30]:
# 4) Training loop
def train(model, loader, epochs=5, lr=1e-4, device="cpu"):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0

        for batch in loader:
            x    = batch['sequence'].to(device)  # [B, T, D]
            mask = batch['mask'].to(device)      # [B, T]

            # prepare inputs (all but last) and targets (all but first)
            x_in  = x[:, :-1, :]    # [B, T-1, D]
            x_tgt = x[:,  1:, :]    # [B, T-1, D]
            m     = mask[:, :-1]    # [B, T-1]

            # forward
            pred = model(x_in, m)   # [B, T-1, D]

            # compute MSE only on real frames
            loss_mat = F.mse_loss(pred, x_tgt, reduction='none')  # [B, T-1, D]
            loss_seq = loss_mat.sum(-1) * m.float()                # [B, T-1]
            loss = loss_seq.sum() / m.sum()                        # scalar

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch}/{epochs} — avg MSE: {avg_loss:.4f}")

In [99]:
# 5) Usage
# Assume `loader` is your DataLoader with collate_fn from before,
# and D = feature dim of your sequence (e.g. seq.shape[-1]).
model = NextMomentModel(in_dim=33, emb_dim=256, nhead=4, nlayers=3)
train(model, loader, epochs=10, lr=1e-4, device="cuda" if torch.cuda.is_available() else "cpu")



TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/konstantin/miniconda3/envs/torch311/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_130931/3831389929.py", line 2, in collate_fn
    padded_seqs = pad_sequence(batch['sequence'], batch_first=True, padding_value=0.0)   # [B, T_max, D]
                               ~~~~~^^^^^^^^^^^^
TypeError: list indices must be integers or slices, not str
