In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import typing

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from collections import defaultdict

from tqdm import tqdm
tqdm.pandas()  # This enables pandas integrationin_test_split
from collections import defaultdict

In [20]:
exploded_test_df = pd.read_csv("gs://pitch-sequencing/sequence_data/full_sequence_data/v2/kitchensink/exploded/large_cur_test.csv")
arsenal_df = pd.read_csv("gs://pitch-sequencing/arsenal_data/pitch_arsenal_data.csv")

In [21]:
import pitch_sequencing.ml.tokenizers.vocab as vocab
import pitch_sequencing.ml.data.generators as gen

from pitch_sequencing.ml.data.sequences import PitchSequenceDataset, CSVSequenceDataGenPlan, collate_interleaved_and_target, CSVSequenceInput
from pitch_sequencing.ml.tokenizers.pitch_arsenal import PitchArsenalLookupTable
from pitch_sequencing.ml.tokenizers.pitch_sequence import PitchSequenceTokenizer, SequenceInfo, SequenceID
from pitch_sequencing.ml.models.last_pitch import LastPitchTransformerModel
from pitch_sequencing.io.join import join_paths
from pitch_sequencing.io.gcs import save_model_to_gcs

arsenal_lookup_table = PitchArsenalLookupTable(arsenal_df)

sequential_sequence_infos = [
    SequenceInfo(SequenceID.ARSENAL, arsenal_lookup_table.max_arsenal_size, vocab_ids=[vocab.VocabID.PITCHES]),
    SequenceInfo(SequenceID.HANDEDNESS, 2, vocab_ids=[vocab.VocabID.HANDEDNESS]),
    SequenceInfo(SequenceID.ON_BASE, 3, vocab_ids=[vocab.VocabID.BOOLEAN]),
]
sequential_sequence_gen_plans = [
    CSVSequenceDataGenPlan(SequenceID.ARSENAL, gen.ArsenalCSVGenerator(arsenal_lookup_table)),
    CSVSequenceDataGenPlan(SequenceID.HANDEDNESS, gen.HandednessCSVGenerator()),
    CSVSequenceDataGenPlan(SequenceID.ON_BASE, gen.OnBaseCSVGenerator()),
]

# Hardcode 63 for now.
interleaved_sequence_infos = SequenceInfo(SequenceID.INTERLEAVED, 63, vocab_ids=[vocab.VocabID.PITCHES, vocab.VocabID.COUNTS])
interleaved_sequence_gen_plans = [
    CSVSequenceDataGenPlan(SequenceID.INTERLEAVED, gen.DirectCSVLookupGenerator('count_sequence')),
    CSVSequenceDataGenPlan(SequenceID.PITCHES, gen.DirectCSVLookupGenerator('input_pitch_sequence')),
]

tokenizer = PitchSequenceTokenizer(sequential_sequence_infos, interleaved_sequence_infos, [vocab.PITCH_VOCAB, vocab.HANDEDNESS_VOCAB, vocab.BOOLEAN_VOCAB, vocab.COUNT_VOCAB])
test_dataset = PitchSequenceDataset(exploded_test_df, tokenizer, sequential_sequence_gen_plans, interleaved_sequence_gen_plans, target_df_key='target_pitch')
collate_fn = collate_interleaved_and_target
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [22]:
from pitch_sequencing.ml.models.last_pitch import LastPitchTransformerModel
import gcsfs 

trained_model = LastPitchTransformerModel(tokenizer.vocab_size(), d_model=64, nhead=4, num_layers=2)

fs = gcsfs.GCSFileSystem()

model_path = "gs://pitch-sequencing/training_runs/kitchensink_training_job_20241017232358_hands_on_base/final/model.pth"
with fs.open(model_path, "rb") as f:
   trained_model.load_state_dict(torch.load(f, map_location=torch.device('cpu') ))

In [23]:
class LastPitchPredictor:
    def __init__(self, model: LastPitchTransformerModel, tokenizer: PitchSequenceTokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def get_next_pitch_probs_ids(self, id_seq, attn_mask):
        self.model.eval()  # Ensure the model is in evaluation mode
        with torch.no_grad():
            logits = self.model(id_seq, src_mask=attn_mask)
            # The output shape should be [1, vocab_size]
            logits = logits.squeeze(0)  # Remove batch dimension if present
            if logits.dim() > 1:
                logits = logits[-1]  # Take the last prediction if multiple outputs

            probabilities = torch.softmax(logits, dim=0)

            return probabilities
    
    def get_next_pitch_probs(self, pitch_sequence, count_sequence, arsenal_sequence, handedness_sequence, on_base_info_sequence):
        encoded_seq, attn_mask = self.tokenizer.tokenize(
            [CSVSequenceInput(SequenceID.ARSENAL, arsenal_sequence), CSVSequenceInput(SequenceID.HANDEDNESS, handedness_sequence), CSVSequenceInput(SequenceID.ON_BASE, on_base_info_sequence)],
            [count_sequence, pitch_sequence]
        )
        input_seq = torch.tensor(encoded_seq, dtype=torch.long).unsqueeze(0)
        attn_mask = torch.tensor(attn_mask, dtype=torch.bool).unsqueeze(0)
        
        
        return self.get_next_pitch_probs_ids(input_seq, attn_mask)

    def predict_next_pitch_ids(self, id_sequence, attn_mask):
        probabilities = self.get_next_pitch_probs_ids(id_sequence, attn_mask)
        predicted_idx = torch.argmax(probabilities).item()
        return predicted_idx

    def predict_next_pitch(self, pitch_sequence, count_sequence, arsenal_sequence, handedness_sequence, on_base_info_sequence):
        probabilities = self.get_next_pitch_probs(pitch_sequence, count_sequence, arsenal_sequence, handedness_sequence, on_base_info_sequence)
        predicted_idx = torch.argmax(probabilities).item()
        return self.tokenizer.get_token_for_id(predicted_idx)

In [24]:
predictor = LastPitchPredictor(trained_model, tokenizer)

In [25]:
from pitch_sequencing.ml.tokenizers.pitch_sequence import ORDERED_PITCHES

probs = predictor.get_next_pitch_probs("KN,KN,KN", "0-0,0-1,1-1,1-2", arsenal_sequence='KN,FF,SI', handedness_sequence="L,R", on_base_info_sequence="F,F,F")

for pitch in ORDERED_PITCHES:
    id = tokenizer.get_id_for_token(pitch)
    print(f"{pitch}: {probs[id]:.4f}")

CB: 0.0051
KN: 0.9639
FC: 0.0001
FS: 0.0013
CH: 0.0033
FF: 0.0191
SL: 0.0000
PO: 0.0003
SI: 0.0061
ST: 0.0009


In [36]:
exploded_test_with_arsenals_df = pd.merge(exploded_test_df, arsenal_df, left_on='pitcher_id', right_on='pitcher', how='left')
print(len(exploded_test_with_arsenals_df))
print(len(exploded_test_df))
exploded_test_with_arsenals_df.head(10)
exploded_test_with_arsenals_df.columns

339233
339233


Index(['pitch_sequence', 'count_sequence', 'zone_sequence', 'p_throws',
       'stand', 'pitcher_id', 'batter_id', 'at_bat_number', 'game_date',
       'at_bat_pitch_number', 'target_pitch', 'setup_count',
       'input_pitch_sequence', 'events', 'zone', 'outs_when_up', 'type',
       'bb_type', 'on_3b', 'on_2b', 'on_1b', 'pitch_counts', 'pitcher',
       'pitch_arsenal', 'pitch_arsenal_csv', 'arsenal_size'],
      dtype='object')

In [41]:
exploded_test_with_arsenals_df.head(10)

Unnamed: 0,pitch_sequence,count_sequence,zone_sequence,p_throws,stand,pitcher_id,batter_id,at_bat_number,game_date,at_bat_pitch_number,...,type,bb_type,on_3b,on_2b,on_1b,pitch_counts,pitcher,pitch_arsenal,pitch_arsenal_csv,arsenal_size
0,"CB,SL","0-0,0-1",913,L,R,666129,621028,35,2023-05-14,2,...,S,,,571976.0,,"{'CB': 517, 'CH': 459, 'FC': 448, 'FF': 864, '...",666129,"['CB', 'FF', 'CH', 'SL', 'SI', 'FC']","CB,FF,CH,SL,SI,FC",6
1,"CB,SL,SI","0-0,0-1,0-2",9132,L,R,666129,621028,35,2023-05-14,3,...,S,,,571976.0,,"{'CB': 517, 'CH': 459, 'FC': 448, 'FF': 864, '...",666129,"['CB', 'FF', 'CH', 'SL', 'SI', 'FC']","CB,FF,CH,SL,SI,FC",6
2,"CB,SL,SI,SL","0-0,0-1,0-2,0-2",913214,L,R,666129,621028,35,2023-05-14,4,...,B,,,571976.0,,"{'CB': 517, 'CH': 459, 'FC': 448, 'FF': 864, '...",666129,"['CB', 'FF', 'CH', 'SL', 'SI', 'FC']","CB,FF,CH,SL,SI,FC",6
3,"CB,SL,SI,SL,SI","0-0,0-1,0-2,0-2,1-2",9132144,L,R,666129,621028,35,2023-05-14,5,...,X,ground_ball,,571976.0,,"{'CB': 517, 'CH': 459, 'FC': 448, 'FF': 864, '...",666129,"['CB', 'FF', 'CH', 'SL', 'SI', 'FC']","CB,FF,CH,SL,SI,FC",6
4,"SI,CH","0-0,0-1",411,R,L,453562,520471,13,2018-08-12,2,...,B,,,,595978.0,"{'CB': 1322, 'CH': 1196, 'FC': 0, 'FF': 134, '...",453562,"['SI', 'CB', 'SL', 'CH', 'FF']","SI,CB,SL,CH,FF",5
5,"SI,CH,SI","0-0,0-1,1-1",4113,R,L,453562,520471,13,2018-08-12,3,...,S,,,,595978.0,"{'CB': 1322, 'CH': 1196, 'FC': 0, 'FF': 134, '...",453562,"['SI', 'CB', 'SL', 'CH', 'FF']","SI,CB,SL,CH,FF",5
6,"SI,CH,SI,CB","0-0,0-1,1-1,1-2",41138,R,L,453562,520471,13,2018-08-12,4,...,X,fly_ball,,,595978.0,"{'CB': 1322, 'CH': 1196, 'FC': 0, 'FF': 134, '...",453562,"['SI', 'CB', 'SL', 'CH', 'FF']","SI,CB,SL,CH,FF",5
7,"FF,CB","0-0,0-1",914,R,L,656546,606466,19,2019-06-20,2,...,S,,,,,"{'CB': 893, 'CH': 126, 'FC': 0, 'FF': 3636, 'F...",656546,"['FF', 'CB', 'SL', 'CH', 'SI', 'FS']","FF,CB,SL,CH,SI,FS",6
8,"FF,CB,FF","0-0,0-1,0-2",91414,R,L,656546,606466,19,2019-06-20,3,...,B,,,,,"{'CB': 893, 'CH': 126, 'FC': 0, 'FF': 3636, 'F...",656546,"['FF', 'CB', 'SL', 'CH', 'SI', 'FS']","FF,CB,SL,CH,SI,FS",6
9,"FF,CB,FF,CB","0-0,0-1,0-2,1-2",9141411,R,L,656546,606466,19,2019-06-20,4,...,S,,,,,"{'CB': 893, 'CH': 126, 'FC': 0, 'FF': 3636, 'F...",656546,"['FF', 'CB', 'SL', 'CH', 'SI', 'FS']","FF,CB,SL,CH,SI,FS",6


In [64]:
arsenal_seq_gen = gen.ArsenalCSVGenerator(arsenal_lookup_table)
handedness_seq_gen = gen.HandednessCSVGenerator()
on_base_seq_gen = gen.OnBaseCSVGenerator()
count_seq_gen = gen.DirectCSVLookupGenerator('count_sequence')
pitch_seq_gen = gen.DirectCSVLookupGenerator('input_pitch_sequence')


exploded_test_with_arsenals_df['predicted_pitch'] = exploded_test_with_arsenals_df.progress_apply(lambda row: predictor.predict_next_pitch(
    pitch_seq_gen.generate_csv_sequence_from_df_row(row), 
    count_seq_gen.generate_csv_sequence_from_df_row(row),
    arsenal_seq_gen.generate_csv_sequence_from_df_row(row),
    handedness_seq_gen.generate_csv_sequence_from_df_row(row),
    on_base_seq_gen.generate_csv_sequence_from_df_row(row)
    ),
    axis=1)


100%|██████████| 339233/339233 [10:32<00:00, 536.05it/s]


In [65]:
print(classification_report(exploded_test_with_arsenals_df['target_pitch'], exploded_test_with_arsenals_df['predicted_pitch'], zero_division=0))

              precision    recall  f1-score   support

          CB       0.40      0.19      0.26     33460
          CH       0.36      0.20      0.25     40972
          FC       0.48      0.23      0.31     22519
          FF       0.49      0.73      0.58    115034
          FS       0.44      0.34      0.39      6747
          KN       0.80      0.96      0.87       274
          PO       0.00      0.00      0.00        22
          SI       0.49      0.55      0.51     54468
          SL       0.46      0.36      0.40     58441
          ST       0.49      0.34      0.40      7296

    accuracy                           0.47    339233
   macro avg       0.44      0.39      0.40    339233
weighted avg       0.46      0.47      0.44    339233



In [66]:
exploded_test_with_arsenals_df['predicted_pitch_in_arsenal'] = exploded_test_with_arsenals_df.progress_apply(lambda x: x['predicted_pitch'] in x['pitch_arsenal_csv'], axis=1)
exploded_test_with_arsenals_df['predicted_pitch_in_sequence'] = exploded_test_with_arsenals_df.progress_apply(lambda x: x['predicted_pitch'] in x['input_pitch_sequence'], axis=1)
exploded_test_with_arsenals_df['target_pitch_in_sequence'] = exploded_test_with_arsenals_df.progress_apply(lambda x: x['target_pitch'] in x['input_pitch_sequence'], axis=1)


100%|██████████| 339233/339233 [00:03<00:00, 108095.14it/s]
100%|██████████| 339233/339233 [00:02<00:00, 124466.75it/s]
100%|██████████| 339233/339233 [00:02<00:00, 137809.41it/s]


In [67]:
print(f"Target Pitch In Sequence:               {len(exploded_test_with_arsenals_df[exploded_test_with_arsenals_df['target_pitch_in_sequence']])/len(exploded_test_with_arsenals_df):.4f}")
print(f"Predicted Pitch Seen In Input Sequence: {len(exploded_test_with_arsenals_df[exploded_test_with_arsenals_df['predicted_pitch_in_sequence']])/len(exploded_test_with_arsenals_df):.4f}")
print(f"Predicted Pitch In Arsenal              {len(exploded_test_with_arsenals_df[exploded_test_with_arsenals_df['predicted_pitch_in_arsenal']])/len(exploded_test_with_arsenals_df):.4f}")

Target Pitch In Sequence:               0.6091
Predicted Pitch Seen In Input Sequence: 0.8100
Predicted Pitch In Arsenal              1.0000


In [68]:
target_pitch_not_seen_df = exploded_test_with_arsenals_df[~exploded_test_with_arsenals_df['target_pitch_in_sequence']]
print(classification_report(target_pitch_not_seen_df['target_pitch'], target_pitch_not_seen_df['predicted_pitch'], zero_division=0))

              precision    recall  f1-score   support

          CB       0.24      0.09      0.13     17446
          CH       0.24      0.11      0.15     23379
          FC       0.15      0.06      0.09     10258
          FF       0.22      0.45      0.29     32020
          FS       0.30      0.17      0.22      3552
          KN       0.30      0.76      0.43        21
          PO       0.00      0.00      0.00        22
          SI       0.13      0.17      0.15     18233
          SL       0.15      0.10      0.12     24646
          ST       0.18      0.11      0.14      3019

    accuracy                           0.19    132596
   macro avg       0.19      0.20      0.17    132596
weighted avg       0.20      0.19      0.17    132596



## Setup Count Metrics

In [69]:
for setup, group in exploded_test_df.groupby('setup_count'):
    print(f"Metrics for Setup Count: {setup}")
    print(classification_report(group['target_pitch'], group['predicted_pitch'], zero_division=0))
    print("\n")

Metrics for Setup Count: 0-1
              precision    recall  f1-score   support

          CB       0.35      0.15      0.21      6829
          CH       0.33      0.22      0.26      7602
          FC       0.45      0.19      0.26      4223
          FF       0.44      0.72      0.54     18421
          FS       0.44      0.30      0.36      1252
          KN       0.87      0.96      0.91        54
          PO       0.00      0.00      0.00         6
          SI       0.43      0.48      0.45      9316
          SL       0.46      0.33      0.38     10844
          ST       0.46      0.31      0.37      1366

    accuracy                           0.43     59913
   macro avg       0.42      0.36      0.38     59913
weighted avg       0.42      0.43      0.40     59913



Metrics for Setup Count: 0-2
              precision    recall  f1-score   support

          CB       0.45      0.29      0.35      4283
          CH       0.40      0.23      0.29      3306
          FC      

# At Bat Pitch Number Metrics

In [70]:
for pitch_number, group in exploded_test_df.groupby('at_bat_pitch_number'):
    print(f"Metrics for Pitch number: {pitch_number}")
    print(classification_report(group['target_pitch'], group['predicted_pitch'], zero_division=0))
    print("\n")

Metrics for Pitch number: 2
              precision    recall  f1-score   support

          CB       0.36      0.14      0.20     10338
          CH       0.35      0.20      0.25     13472
          FC       0.47      0.22      0.30      7624
          FF       0.46      0.74      0.57     34464
          FS       0.41      0.26      0.32      1876
          KN       0.84      0.97      0.90       101
          PO       0.00      0.00      0.00         7
          SI       0.48      0.55      0.51     18506
          SL       0.44      0.29      0.35     17744
          ST       0.45      0.29      0.36      2099

    accuracy                           0.45    106231
   macro avg       0.43      0.37      0.38    106231
weighted avg       0.44      0.45      0.42    106231



Metrics for Pitch number: 3
              precision    recall  f1-score   support

          CB       0.41      0.19      0.26      8856
          CH       0.37      0.21      0.27     11083
          FC       0