In [1]:
%load_ext autoreload
%autoreload 2

In [31]:
import typing

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from collections import defaultdict

from tqdm import tqdm
tqdm.pandas()  # This enables pandas integration

In [25]:
exploded_test_df = pd.read_csv('gs://pitch-sequencing/sequence_data/full_sequence_data/exploded/large_cur_test.csv')
arsenal_df = pd.read_csv("gs://pitch-sequencing/arsenal_data/pitch_arsenal_data.csv")

In [5]:
from pitch_sequencing.ml.data.pitch_arsenal import PitchArsenalSequenceDataset, collate_interleaved_and_target
from pitch_sequencing.ml.tokenizers.pitch_arsenal import ArsenalSequenceTokenizer, PitchArsenalLookupTable

arsenal_lookup_table = PitchArsenalLookupTable(arsenal_df)

# Hardcode 63 for now.
tokenizer = ArsenalSequenceTokenizer(arsenal_lookup_table.max_arsenal_size, max_pitch_count_seq_len=63)
exploded_test_dataset = PitchArsenalSequenceDataset(exploded_test_df, tokenizer, arsenal_lookup_table)
collate_fn = collate_interleaved_and_target

In [6]:
from pitch_sequencing.ml.models.last_pitch import LastPitchTransformerModel
import gcsfs 

trained_model = LastPitchTransformerModel(tokenizer.vocab_size(), d_model=64, nhead=4, num_layers=2)

fs = gcsfs.GCSFileSystem()

model_path = "gs://pitch-sequencing/training_runs/pitcharsenal_training_job_20241015230328/final/model.pth"
with fs.open(model_path, "rb") as f:
   trained_model.load_state_dict(torch.load(f, map_location=torch.device('cpu') ))

In [7]:
class LastPitchPredictorWithArsenal:
    def __init__(self, model: LastPitchTransformerModel, tokenizer: ArsenalSequenceTokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def get_next_pitch_probs_ids(self, id_seq, attn_mask):
        self.model.eval()  # Ensure the model is in evaluation mode
        with torch.no_grad():
            logits = self.model(id_seq, src_mask=attn_mask)
            # The output shape should be [1, vocab_size]
            logits = logits.squeeze(0)  # Remove batch dimension if present
            if logits.dim() > 1:
                logits = logits[-1]  # Take the last prediction if multiple outputs

            probabilities = torch.softmax(logits, dim=0)

            return probabilities
    
    def get_next_pitch_probs(self, pitch_sequence, count_sequence, arsenal_sequence):
        encoded_seq, attn_mask = self.tokenizer.tokenize(pitch_sequence, count_sequence, arsenal_sequence)
        input_seq = torch.tensor(encoded_seq, dtype=torch.long).unsqueeze(0)
        attn_mask = torch.tensor(attn_mask, dtype=torch.bool).unsqueeze(0)
        
        
        return self.get_next_pitch_probs_ids(input_seq, attn_mask)

    def predict_next_pitch_ids(self, id_sequence, attn_mask):
        probabilities = self.get_next_pitch_probs_ids(id_sequence, attn_mask)
        predicted_idx = torch.argmax(probabilities).item()
        return predicted_idx

    def predict_next_pitch(self, pitch_sequence, count_sequence, arsenal_sequence):
        probabilities = self.get_next_pitch_probs(pitch_sequence, count_sequence, arsenal_sequence)
        predicted_idx = torch.argmax(probabilities).item()
        return self.tokenizer.get_pitch_for_id(predicted_idx)

In [9]:
predictor = LastPitchPredictorWithArsenal(trained_model, tokenizer)

In [18]:
from pitch_sequencing.ml.tokenizers.pitch_sequence import ORDERED_PITCHES

probs = predictor.get_next_pitch_probs("KN,KN,KN", "0-0,0-1,1-1,1-2", arsenal_sequence='KN,FF,SI')

for pitch in ORDERED_PITCHES:
    id = tokenizer.get_id_for_pitch(pitch)
    print(f"{pitch}: {probs[id]:.4f}")

CB: 0.0034
KN: 0.9425
FC: 0.0005
FS: 0.0058
CH: 0.0020
FF: 0.0362
SL: 0.0000
PO: 0.0001
SI: 0.0080
ST: 0.0016


In [26]:
exploded_test_df['target_pitch'] = exploded_test_df['pitch_sequence'].apply(lambda x: x.split(',')[-1])
exploded_test_df['setup_count'] = exploded_test_df['count_sequence'].apply(lambda x: x.split(',')[-1])
exploded_test_df['input_pitch_sequence'] = exploded_test_df['pitch_sequence'].apply(lambda x: ','.join(x.split(',')[:-1]))

In [27]:
exploded_test_df = pd.merge(exploded_test_df, arsenal_df, left_on='pitcher_id', right_on='pitcher', how='left')
exploded_test_df.head(10)
exploded_test_df.columns

Index(['pitch_sequence', 'count_sequence', 'zone_sequence', 'p_throws',
       'stand', 'pitcher_id', 'batter_id', 'at_bat_number', 'target_pitch',
       'setup_count', 'input_pitch_sequence', 'pitch_counts', 'pitcher',
       'pitch_arsenal', 'pitch_arsenal_csv', 'arsenal_size'],
      dtype='object')

In [28]:
exploded_test_df.head(10)

Unnamed: 0,pitch_sequence,count_sequence,zone_sequence,p_throws,stand,pitcher_id,batter_id,at_bat_number,target_pitch,setup_count,input_pitch_sequence,pitch_counts,pitcher,pitch_arsenal,pitch_arsenal_csv,arsenal_size
0,"CH,SI","0-0,1-0",116,R,R,112526,572039,44,SI,1-0,CH,"{'CB': 0, 'CH': 470, 'FC': 69, 'FF': 815, 'FS'...",112526,"['SI', 'CH', 'FF', 'SL', 'FC']","SI,CH,FF,SL,FC",5
1,"CH,SI,SI","0-0,1-0,1-1",11611,R,R,112526,572039,44,SI,1-1,"CH,SI","{'CB': 0, 'CH': 470, 'FC': 69, 'FF': 815, 'FS'...",112526,"['SI', 'CH', 'FF', 'SL', 'FC']","SI,CH,FF,SL,FC",5
2,"CH,CH","0-0,1-0",144,R,L,543037,624512,37,CH,1-0,CH,"{'CB': 3300, 'CH': 1885, 'FC': 461, 'FF': 1136...",543037,"['FF', 'SL', 'CB', 'CH', 'SI', 'PO', 'FC']","FF,SL,CB,CH,SI,PO,FC",7
3,"FF,FF","0-0,0-1",111,R,L,642121,544369,84,FF,0-1,FF,"{'CB': 33, 'CH': 757, 'FC': 2, 'FF': 1226, 'FS...",642121,"['FF', 'SL', 'CH', 'CB', 'FC']","FF,SL,CH,CB,FC",5
4,"FF,FF,FF","0-0,0-1,0-2",11111,R,L,642121,544369,84,FF,0-2,"FF,FF","{'CB': 33, 'CH': 757, 'FC': 2, 'FF': 1226, 'FS...",642121,"['FF', 'SL', 'CH', 'CB', 'FC']","FF,SL,CH,CB,FC",5
5,"FF,FF,FF,CH","0-0,0-1,0-2,1-2",1111113,R,L,642121,544369,84,CH,1-2,"FF,FF,FF","{'CB': 33, 'CH': 757, 'FC': 2, 'FF': 1226, 'FS...",642121,"['FF', 'SL', 'CH', 'CB', 'FC']","FF,SL,CH,CB,FC",5
6,"FF,FC","0-0,1-0",1212,R,R,608379,575929,28,FC,1-0,FF,"{'CB': 1287, 'CH': 3507, 'FC': 2385, 'FF': 534...",608379,"['FC', 'SI', 'CH', 'FF', 'CB', 'PO']","FC,SI,CH,FF,CB,PO",6
7,"FF,FC,FC","0-0,1-0,1-1",12129,R,R,608379,575929,28,FC,1-1,"FF,FC","{'CB': 1287, 'CH': 3507, 'FC': 2385, 'FF': 534...",608379,"['FC', 'SI', 'CH', 'FF', 'CB', 'PO']","FC,SI,CH,FF,CB,PO",6
8,"FF,FC,FC,CB","0-0,1-0,1-1,1-2",121296,R,R,608379,575929,28,CB,1-2,"FF,FC,FC","{'CB': 1287, 'CH': 3507, 'FC': 2385, 'FF': 534...",608379,"['FC', 'SI', 'CH', 'FF', 'CB', 'PO']","FC,SI,CH,FF,CB,PO",6
9,"FF,FC,FC,CB,FF","0-0,1-0,1-1,1-2,1-2",12129613,R,R,608379,575929,28,FF,1-2,"FF,FC,FC,CB","{'CB': 1287, 'CH': 3507, 'FC': 2385, 'FF': 534...",608379,"['FC', 'SI', 'CH', 'FF', 'CB', 'PO']","FC,SI,CH,FF,CB,PO",6


In [29]:
exploded_test_df['predicted_pitch'] = exploded_test_df.progress_apply(lambda row: predictor.predict_next_pitch(row['input_pitch_sequence'], row['count_sequence'], row['pitch_arsenal_csv']), axis=1)

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
100%|██████████| 352058/352058 [08:54<00:00, 658.78it/s]


In [32]:
print(classification_report(exploded_test_df['target_pitch'], exploded_test_df['predicted_pitch'], zero_division=1))

              precision    recall  f1-score   support

          CB       0.39      0.21      0.27     35034
          CH       0.38      0.14      0.21     42632
          FC       0.45      0.29      0.35     24014
          FF       0.48      0.72      0.58    119519
          FS       0.42      0.37      0.39      6892
          KN       0.74      1.00      0.85       293
          PO       1.00      0.00      0.00        23
          SI       0.49      0.51      0.50     56519
          SL       0.43      0.35      0.39     59898
          ST       0.43      0.52      0.47      7234

    accuracy                           0.46    352058
   macro avg       0.52      0.41      0.40    352058
weighted avg       0.45      0.46      0.44    352058



In [33]:
exploded_test_df['predicted_pitch_in_arsenal'] = exploded_test_df.progress_apply(lambda x: x['predicted_pitch'] in x['pitch_arsenal_csv'], axis=1)
exploded_test_df['predicted_pitch_in_sequence'] = exploded_test_df.progress_apply(lambda x: x['predicted_pitch'] in x['input_pitch_sequence'], axis=1)
exploded_test_df['target_pitch_in_sequence'] = exploded_test_df.progress_apply(lambda x: x['target_pitch'] in x['input_pitch_sequence'], axis=1)


100%|██████████| 352058/352058 [00:02<00:00, 135267.71it/s]
100%|██████████| 352058/352058 [00:02<00:00, 143629.84it/s]
100%|██████████| 352058/352058 [00:02<00:00, 141414.32it/s]


In [34]:
print(f"Target Pitch In Sequence:               {len(exploded_test_df[exploded_test_df['target_pitch_in_sequence']])/len(exploded_test_df):.4f}")
print(f"Predicted Pitch Seen In Input Sequence: {len(exploded_test_df[exploded_test_df['predicted_pitch_in_sequence']])/len(exploded_test_df):.4f}")
print(f"Predicted Pitch In Arsenal              {len(exploded_test_df[exploded_test_df['predicted_pitch_in_arsenal']])/len(exploded_test_df):.4f}")

Target Pitch In Sequence:               0.6137
Predicted Pitch Seen In Input Sequence: 0.8295
Predicted Pitch In Arsenal              1.0000


In [37]:
target_pitch_not_seen_df = exploded_test_df[~exploded_test_df['target_pitch_in_sequence']]
print(classification_report(target_pitch_not_seen_df['target_pitch'], target_pitch_not_seen_df['predicted_pitch'], zero_division=0))

              precision    recall  f1-score   support

          CB       0.24      0.11      0.15     18053
          CH       0.11      0.03      0.05     24069
          FC       0.11      0.06      0.08     10801
          FF       0.21      0.44      0.28     32543
          FS       0.22      0.15      0.18      3577
          KN       0.41      0.97      0.58        35
          PO       0.00      0.00      0.00        23
          SI       0.11      0.14      0.13     18868
          SL       0.13      0.09      0.10     25065
          ST       0.16      0.17      0.16      2953

    accuracy                           0.17    135987
   macro avg       0.17      0.22      0.17    135987
weighted avg       0.16      0.17      0.15    135987

