In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import typing

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

In [19]:
exploded_test_df = pd.read_csv('gs://pitch-sequencing/sequence_data/full_sequence_data/exploded/large_cur_test.csv')
arsenal_df = pd.read_csv("gs://pitch-sequencing/arsenal_data/pitch_arsenal_data.csv")
exploded_test_df.head(5)

Unnamed: 0,pitch_sequence,count_sequence,zone_sequence,p_throws,stand,pitcher_id,batter_id,at_bat_number
0,"CH,SI","0-0,1-0",116,R,R,112526,572039,44
1,"CH,SI,SI","0-0,1-0,1-1",11611,R,R,112526,572039,44
2,"CH,CH","0-0,1-0",144,R,L,543037,624512,37
3,"FF,FF","0-0,0-1",111,R,L,642121,544369,84
4,"FF,FF,FF","0-0,0-1,0-2",11111,R,L,642121,544369,84


In [33]:
from pitch_sequencing.ml.data.pitch_arsenal import PitchArsenalSequenceDataset, collate_interleaved_and_target
from pitch_sequencing.ml.tokenizers.pitch_arsenal import ArsenalSequenceTokenizer, PitchArsenalLookupTable
from pitch_sequencing.ml.models.last_pitch import LastPitchTransformerModel

arsenal_lookup_table = PitchArsenalLookupTable(arsenal_df)

# Hardcode 63 for now.
tokenizer = ArsenalSequenceTokenizer(arsenal_lookup_table.max_arsenal_size, max_pitch_count_seq_len=63)
exploded_test_dataset = PitchArsenalSequenceDataset(exploded_test_df, tokenizer, arsenal_lookup_table)
model = LastPitchTransformerModel(tokenizer.vocab_size(), d_model=64, nhead=4, num_layers=2)
collate_fn = collate_interleaved_and_target
loss = nn.CrossEntropyLoss()

In [34]:
tokenized_seq, padding_mask = tokenizer.tokenize("CB,FF", "0-0,0-1,1-1", 'CB,FF,SL,SI')
print(tokenized_seq)
print(padding_mask)

print(len(tokenized_seq))
print(len(padding_mask))

model(torch.tensor([tokenized_seq]), torch.tensor([padding_mask]))

[1, 2, 4, 9, 10, 12, 0, 0, 0, 0, 0, 3, 14, 4, 15, 9, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[False, False, False, False, False, False, True, True, True, True, True, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
75
75


tensor([[ 0.3231,  0.1977,  0.6200,  0.2015,  0.7211,  0.0818, -0.4977, -0.5869,
         -0.2521,  0.2730, -0.4603, -0.7118,  1.2122, -0.0908, -0.2656, -0.0957,
         -0.1569,  0.1959, -0.0886, -0.0035, -0.4848, -0.1200, -0.4860,  0.1505,
          0.0718, -1.0243]], grad_fn=<AddmmBackward0>)

In [38]:
from dataclasses import asdict 
input, target = exploded_test_dataset[0]
input = input.unsqueeze(0)
target = target.unsqueeze(0)
model(**asdict(input))

tensor([[ 0.2185,  0.3749,  0.1458,  0.1055,  0.8463,  0.1516, -0.3850, -0.5745,
         -0.2320,  0.3702, -0.5755, -0.6878,  1.1523, -0.0229, -0.3789, -0.0253,
         -0.1036,  0.2726, -0.1879,  0.1412, -0.3743, -0.1512, -0.4431,  0.1153,
          0.2696, -0.9759]], grad_fn=<AddmmBackward0>)

In [37]:
from dataclasses import asdict 

test_dataloader = DataLoader(exploded_test_dataset, batch_size=4, collate_fn=collate_fn)

for batch in test_dataloader:
    input, target = batch
    print(model(**asdict(input)))
    break

tensor([[ 0.2167,  0.2715,  0.4210,  0.0115,  0.8873,  0.2128, -0.3073, -0.5924,
         -0.3030,  0.2981, -0.6125, -0.7087,  1.0777,  0.0853, -0.4824, -0.1237,
         -0.0903,  0.4137, -0.2437,  0.0896, -0.2602, -0.1749, -0.5908,  0.1689,
          0.1270, -0.9866],
        [ 0.1948,  0.3281,  0.4273,  0.1103,  0.9630,  0.0342, -0.4297, -0.4885,
         -0.1130,  0.2928, -0.6178, -0.6262,  1.1434,  0.2131, -0.3882,  0.0547,
         -0.1792,  0.3132, -0.1230, -0.1384, -0.3694, -0.1657, -0.3893,  0.2374,
          0.2843, -0.9909],
        [ 0.3303,  0.2975,  0.2686,  0.2670,  0.8783, -0.0797, -0.4729, -0.4840,
         -0.4328,  0.6406, -0.6992, -0.6830,  1.2117, -0.0776, -0.3082,  0.0043,
         -0.1419,  0.4666, -0.2790,  0.0514, -0.4386, -0.0709, -0.4661,  0.1286,
          0.2087, -1.0578],
        [ 0.2955,  0.5127,  0.5720,  0.3418,  0.7210,  0.3020, -0.2938, -0.5810,
         -0.2580,  0.2720, -0.7301, -0.7970,  1.1582, -0.0690, -0.4738, -0.1278,
         -0.2445,  0.2857