In [None]:
import pandas as pd
import numpy as np
import math
import json
import time
import pickle

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from sklearn.manifold import TSNE

In [None]:

from models import trans_single_frame_13_6_nine, trans_single_frame_13_6_hw
from utils import return_pos_embeddings, return_pos_df
from dataset import SeqContextDataset, SingleFrameContextDataset

In [None]:
torch.manual_seed(1)

In [None]:
MAX_FRAME_ID = 164
TRAIN_SPLIT = 0.8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5

DATA_DIR = f"seq_clipped_sorted_data"

NORM=True

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device('cpu')
print(f"Using device = {DEVICE}")

In [None]:
SAVE_DIR = f"trans_single_frame_13_6_hw"

In [None]:
new_embeddings = return_pos_df()

In [None]:
max = new_embeddings.loc[:,['0', '1']].max(axis=0).values
min = new_embeddings.loc[:,['0', '1']].min(axis=0).values
mean = new_embeddings.loc[:,['0', '1']].mean(axis=0).values
new_embeddings.loc[:,['0','1']] = (new_embeddings.loc[:,['0','1']] - mean)/(max-min)

In [None]:
model = trans_single_frame_13_6_hw(pos_df=new_embeddings, feature_embed_size=128, dropout=0.35, num_encoder_layers=4, num_att_heads=32)


In [None]:
#model.load_state_dict(torch.load("./saved_models/trans_single_frame_12_12_pos/weights/trans_all_players.pt"))


In [None]:

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
#lr_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=10, max_iters=2000)

from torch.optim.lr_scheduler import ExponentialLR
lr_scheduler = ExponentialLR(optimizer, gamma=0.5)


In [None]:
with open(f'cleaned_data/{DATA_DIR}/game_play_id.json') as f:
    list_IDS = json.load(f)

np.random.seed(1)
np.random.shuffle(list_IDS)

# break ids into train-val-test sets
val_percent = int(len(list_IDS)* (TRAIN_SPLIT+((1-TRAIN_SPLIT)/2)) )
train_IDS, val_IDS, test_IDS = np.split(list_IDS, [ int(len(list_IDS)*TRAIN_SPLIT), val_percent ])

In [None]:
params = {'batch_size': 256,
          'shuffle': True,
          'num_workers': 0}

# Generators
seq_training_set = SeqContextDataset(train_IDS, data_dir=DATA_DIR)
seq_training_generator = torch.utils.data.DataLoader(seq_training_set, **params)

val_params = {'batch_size': 4,
          'shuffle': True,
          'num_workers': 0}
seq_validation_set = SeqContextDataset(val_IDS, data_dir=DATA_DIR)
seq_validation_generator = torch.utils.data.DataLoader(seq_validation_set, **val_params)

In [None]:
train_IDS[0:5]

In [None]:
norm_list = []
if NORM==True:
    training_mean, training_var, normalization_mask, min_max_diff = seq_training_set.get_normalize()
    norm_list = [training_mean, training_var, normalization_mask, min_max_diff]

In [None]:
print(f"training_mean shape = {training_mean.shape}")
print(f"normalization_mask shape = {normalization_mask.shape}")
print(f"min_max_diff shape = {min_max_diff.shape}")


In [None]:
with open(f"saved_models/{SAVE_DIR}/weights/norm_list.pickle", 'wb') as fp:
   pickle.dump(norm_list, fp)


In [None]:
all_train_features, all_train_labels, all_train_context_vectors, all_train_ids = seq_training_set.get_all_features_and_labels()
all_validation_features, all_validation_labels, all_val_context_vectors, all_val_ids = seq_validation_set.get_all_features_and_labels()

In [None]:
# create new generator with single frames

full_training_set = SingleFrameContextDataset(all_train_features, all_train_labels, all_train_context_vectors, all_train_ids)
full_train_gen = torch.utils.data.DataLoader(full_training_set, **params)

full_validation_set = SingleFrameContextDataset(all_validation_features, all_validation_labels, all_val_context_vectors, all_val_ids)
full_validation_gen = torch.utils.data.DataLoader(full_validation_set, **params)

In [None]:
local_batch, local_labels, local_context, local_lengths, local_player_ids, local_ids = full_training_set.__getitem__(full_training_set.__len__()-1)

In [None]:
local_player_ids

In [None]:
print(f"Len training set = {len(full_training_set)}")
print(f"Len validation set = {len(full_validation_set)}")

print(f"Number of train batches = {len(full_train_gen)}")
print(f"Number of val batches = {len(full_validation_gen)}")

In [None]:
del seq_training_set
#del training_generator
del seq_validation_set
#del validation_generator

del all_train_features
del all_train_labels
del all_validation_features
del all_validation_labels

In [None]:
'''
Processes batch from generator for a single frame
norm_list: if supplied, contains [training_mean, training_var, normalization_mask]
'''
def process_single_frame_no_seq_batch(model, device, batch, labels, context, player_ids, norm_list:list=[]):

    local_batch = batch.to(device)
    index_labels = torch.argmax(labels, -1).to(device) # (batch_size)
    context = context.to(torch.float32).to(device)

    # normalize, either mean/std or min_max
    if len(norm_list) != 0:
        training_mean, _, normalization_mask, min_max_diff = norm_list
        training_mean = training_mean.to(device)
        normalization_mask = normalization_mask.to(device)
        min_max_diff = min_max_diff.to(device)
        local_batch[:,normalization_mask] = (local_batch[:,normalization_mask] - training_mean[normalization_mask].reshape(1,-1))/torch.where(min_max_diff[:,normalization_mask]==0,1,min_max_diff[:,normalization_mask])

    local_batch = local_batch.reshape(-1,23,11)

    pos_embeddings = return_pos_embeddings(model.pos_df, player_ids).to(device)    # [batch_size, 23, embed_dim]

    output = model(local_batch, pos_embeddings, context).to(device)  # (batch_size, 23)

    batch_loss = F.nll_loss(output, index_labels, reduction='mean')
    return batch_loss

In [None]:
def process_single_frame_no_seq_val_batch(model, device, batch, labels, context, lengths, player_ids, norm_list:list=[]):

    local_batch = batch.to(device)
    index_labels = torch.argmax(labels, -1).to(device) # (batch_size)
    local_lengths = lengths.to(device)    # [batch_size, 23, embed_dim]
    local_context = context.to(torch.float32).to(device)

    batch_seq_loss = 0  # avg avg seq loss (e.g, on expected sequence, average loss)
    val_metrics_dict = {"quarter_output_pred":0,
                   "halfway_output_pred":0,
                   "three_quarter_output_pred":0,
                   "final_output_pred":0,
                   "correct_tackler_identified_w_highest_prob_anytime":0,
                   "correct_tackler_had_highest_average_prob":0,
                   "correct_tackler_average_prob":0}

    # for each full sequence
    for i in range(0, local_batch.shape[0]):

        single_target = index_labels[i, local_lengths[i]-1].to(device)                                 # ([]), the correct class
        batch = local_batch[i, :local_lengths[i]-1, :].reshape(-1,23,11).to(device)    # (seq_length, 23, 11)
        context = local_context[i,:local_lengths[i]-1,:].to(device)
        pos_embeddings = return_pos_embeddings(model.pos_df, player_ids[i:i+1]).repeat(local_lengths[i]-1,1,1).to(device)

        if len(norm_list) != 0:
            training_mean, _, normalization_mask, min_max_diff = norm_list
            training_mean = training_mean.to(device)
            normalization_mask = normalization_mask.to(device)
            min_max_diff = min_max_diff.to(device)

            batch = batch.reshape(-1,23*11)
            batch[:,normalization_mask] = (batch[:,normalization_mask] - training_mean[normalization_mask].reshape(1,-1))/torch.where(min_max_diff[:,normalization_mask]==0,1,min_max_diff[:,normalization_mask])
            batch = batch.reshape(-1,23,11)

        with torch.no_grad():
            output = model(batch, pos_embeddings, context).to(device)

        batch_seq_loss += F.nll_loss(output, index_labels[i, :local_lengths[i]-1], reduction='mean')

        val_metrics_dict['correct_tackler_identified_w_highest_prob_anytime'] += ((output.argmax(dim=1) == single_target).sum() > 0).float().item()    #output.argmax(dim=1) = (seq_length)

        avg_tackle_probs_over_seq = torch.exp(output.mean(dim=0))                                      # (24)
        val_metrics_dict['correct_tackler_average_prob'] += avg_tackle_probs_over_seq[single_target].item()
        val_metrics_dict['correct_tackler_had_highest_average_prob'] += (avg_tackle_probs_over_seq.argmax()==single_target).float().item()

        # shape ([]), 1 or 0 if classified correctly at that point in time
        val_metrics_dict['quarter_output_pred'] += (output[output.shape[0]//4, :].argmax() == single_target).float().item()
        val_metrics_dict['halfway_output_pred'] += (output[output.shape[0]//2, :].argmax() == single_target).float().item()
        val_metrics_dict['three_quarter_output_pred'] += (output[(output.shape[0]//4)*3, :].argmax() == single_target).float().item()
        val_metrics_dict['final_output_pred'] += (output[-1, :].argmax() == single_target).float().item()

    return batch_seq_loss, val_metrics_dict

In [None]:

'''
Runs epoch w/ diff val 
'''
def run_epoch(model, device, train_generator, val_generator, optimizer, num_val_batches=5, lr_scheduler=None, norm_list=[]):
    model.to(device)
    # model.train()
    avg_total_loss = 0
    train_loss_hist = []
    val_loss_hist = []
    val_metrics_dict = {"quarter_output_pred":0,
                   "halfway_output_pred":0,
                   "three_quarter_output_pred":0,
                   "final_output_pred":0,
                   "correct_tackler_identified_w_highest_prob_anytime":0,
                   "correct_tackler_had_highest_average_prob":0,
                   "correct_tackler_average_prob":0}
    # train
    for batch_index, (local_batch, local_labels, local_context, local_lengths, local_player_ids, local_ids) in enumerate(train_generator):
        optimizer.zero_grad()
        batch_train_loss = process_single_frame_no_seq_batch(model, device, local_batch, local_labels, local_context, local_player_ids, norm_list)
        
        train_loss_hist.append(batch_train_loss.item())
        avg_total_loss += batch_train_loss.item()

        batch_train_loss.backward()
        optimizer.step()

        if batch_index % 250 == 0:
            print(f"batch {batch_index} train loss = {batch_train_loss}")

    avg_total_loss /= len(train_generator)
    if lr_scheduler != None:
        lr_scheduler.step()

    # check val
    model.eval()
    avg_val_loss = 0  # avg avg seq loss (e.g, on expected sequence, average loss)

    gen = iter(val_generator)
    for batch_index in range(num_val_batches):
        with torch.no_grad():
            local_batch, local_labels, local_context, local_lengths, local_player_ids, local_ids = next(gen)
            # val_metrics[0] = val loss
            batch_seq_loss, val_batch_metrics_dict = process_single_frame_no_seq_val_batch(model, device, local_batch, local_labels, local_context, local_lengths, local_player_ids, norm_list)
            val_loss_hist.append(batch_seq_loss.item())

            avg_val_loss += batch_seq_loss.item()
            val_metrics_dict['quarter_output_pred'] += val_batch_metrics_dict['quarter_output_pred']
            val_metrics_dict['halfway_output_pred'] += val_batch_metrics_dict['halfway_output_pred']
            val_metrics_dict['three_quarter_output_pred'] += val_batch_metrics_dict['three_quarter_output_pred']
            val_metrics_dict['final_output_pred'] += val_batch_metrics_dict['final_output_pred']
            val_metrics_dict['correct_tackler_identified_w_highest_prob_anytime'] += val_batch_metrics_dict['correct_tackler_identified_w_highest_prob_anytime']
            val_metrics_dict['correct_tackler_had_highest_average_prob'] += val_batch_metrics_dict['correct_tackler_had_highest_average_prob']
            val_metrics_dict['correct_tackler_average_prob'] += val_batch_metrics_dict['correct_tackler_average_prob']

    avg_val_loss /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['quarter_output_pred'] /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['halfway_output_pred'] /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['three_quarter_output_pred'] /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['final_output_pred'] /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['correct_tackler_identified_w_highest_prob_anytime'] /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['correct_tackler_had_highest_average_prob'] /= (num_val_batches*val_generator.batch_size)
    val_metrics_dict['correct_tackler_average_prob'] /= (num_val_batches*val_generator.batch_size)

    return avg_total_loss, avg_val_loss, train_loss_hist, val_loss_hist, val_metrics_dict

In [None]:
# with open(f"saved_models/{SAVE_DIR}/train_arrs/trans_all_players_val_batch_loss_hist.pickle", 'rb') as fp:
#     val_batch_loss_hist = pickle.load(fp)
# with open(f"saved_models/{SAVE_DIR}/train_arrs/trans_all_players_train_batch_loss_hist.pickle", 'rb') as fp:
#     train_batch_loss_hist = pickle.load(fp)


In [None]:
print(f"Starting training...")
train_batch_loss_hist=[]
val_batch_loss_hist=[]


In [None]:
total_start_time = time.time()

for epoch_index in range(NUM_EPOCHS):
    train_loss, val_loss, tr_hist, val_hist, val_metrics_dict = run_epoch(model, DEVICE, full_train_gen, seq_validation_generator, \
                                    optimizer, lr_scheduler=lr_scheduler, num_val_batches=14, norm_list=norm_list)

    train_batch_loss_hist += tr_hist
    val_batch_loss_hist += val_hist
    print(f"Epoch={epoch_index}: train_loss={train_loss}, val_loss={val_loss}. LR={lr_scheduler.get_last_lr()}")
    print(f"val metrics dict = ")
    print(f"{list(val_metrics_dict.keys())}")
    print(f"{np.array(list(val_metrics_dict.values())).round(3)}")
    print(f"#######################")

    torch.save({
        'epoch': epoch_index,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }, f"./saved_models/{SAVE_DIR}/weights/checks/model_checkpoint_{epoch_index}.pt")
    
    
total_end_time = time.time()
print(f"Finished training {NUM_EPOCHS} epochs in {round((total_end_time - total_start_time)/60, 3)} min")

In [None]:
torch.save(model.state_dict(), f"./saved_models/{SAVE_DIR}/weights/checks/trans_all_players.pt")
with open(f"saved_models/{SAVE_DIR}/train_arrs/trans_all_players_train_batch_loss_hist.pickle", 'wb') as fp:
    pickle.dump(train_batch_loss_hist, fp)
with open(f"saved_models/{SAVE_DIR}/train_arrs/trans_all_players_val_batch_loss_hist.pickle", 'wb') as fp:
    pickle.dump(val_batch_loss_hist, fp)


In [None]:
test_params = {'batch_size': 4,
          'shuffle': True,
          'num_workers': 0}
seq_test_set = SeqContextDataset(test_IDS, data_dir=DATA_DIR)
seq_test_generator = torch.utils.data.DataLoader(seq_test_set, **val_params)
all_test_features, all_test_labels, all_test_context_vectors, all_test_ids = seq_test_set.get_all_features_and_labels()
full_test_set = SingleFrameContextDataset(all_test_features, all_test_labels, all_test_context_vectors, all_test_ids)
full_test_gen = torch.utils.data.DataLoader(full_test_set, **test_params)

In [None]:
np.save("saved_models/trans_single_frame_13_6_nine/test_IDS.npy", test_IDS)

In [None]:

def run_test_epoch(model, device, test_generator, norm_list=[]):
    model.to(device)
    model.eval()
    avg_val_loss = 0  # avg avg seq loss (e.g, on expected sequence, average loss)
    val_loss_hist = []
    gen = iter(test_generator)

    num_batches = len(test_generator)
    for batch_index in range(num_batches):
        with torch.no_grad():
            local_batch, local_labels, local_context, local_lengths, local_player_ids, local_ids = next(gen)
            # val_metrics[0] = val loss
            batch_seq_loss, val_batch_metrics_dict = process_single_frame_no_seq_val_batch(model, device, local_batch, local_labels, local_context, local_lengths, local_player_ids, norm_list)
            val_loss_hist.append(batch_seq_loss.item())

            avg_val_loss += batch_seq_loss.item()
            val_metrics_dict['quarter_output_pred'] += val_batch_metrics_dict['quarter_output_pred']
            val_metrics_dict['halfway_output_pred'] += val_batch_metrics_dict['halfway_output_pred']
            val_metrics_dict['three_quarter_output_pred'] += val_batch_metrics_dict['three_quarter_output_pred']
            val_metrics_dict['final_output_pred'] += val_batch_metrics_dict['final_output_pred']
            val_metrics_dict['correct_tackler_identified_w_highest_prob_anytime'] += val_batch_metrics_dict['correct_tackler_identified_w_highest_prob_anytime']
            val_metrics_dict['correct_tackler_had_highest_average_prob'] += val_batch_metrics_dict['correct_tackler_had_highest_average_prob']
            val_metrics_dict['correct_tackler_average_prob'] += val_batch_metrics_dict['correct_tackler_average_prob']

    avg_val_loss /= (num_batches*test_generator.batch_size)
    val_metrics_dict['quarter_output_pred'] /= (num_batches*test_generator.batch_size)
    val_metrics_dict['halfway_output_pred'] /= (num_batches*test_generator.batch_size)
    val_metrics_dict['three_quarter_output_pred'] /= (num_batches*test_generator.batch_size)
    val_metrics_dict['final_output_pred'] /= (num_batches*test_generator.batch_size)
    val_metrics_dict['correct_tackler_identified_w_highest_prob_anytime'] /= (num_batches*test_generator.batch_size)
    val_metrics_dict['correct_tackler_had_highest_average_prob'] /= (num_batches*test_generator.batch_size)
    val_metrics_dict['correct_tackler_average_prob'] /= (num_batches*test_generator.batch_size)

    return avg_val_loss, val_loss_hist, val_metrics_dict

In [None]:

test_loss, test_hist, test_metrics_dict = run_test_epoch(model, DEVICE, seq_test_generator, norm_list=norm_list)

print(f"Epoch={epoch_index}: test loss={test_loss}")
print(f"test metrics dict = ")
print(f"{list(test_metrics_dict.keys())}")
print(f"{np.array(list(test_metrics_dict.values())).round(3)}")
print(f"#######################")