In [51]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.patches as patches

In [52]:
import tensorflow as tf
import tensorflow.keras.layers as tfl
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [53]:
from IPython.display import HTML

In [None]:
from animate import AnimateFeature
from models import createLSTMModel

In [54]:
# import lstm architecture
model = createLSTMModel()

In [55]:
# load model weights
model_string = f"./rnn_model_unnorm/weights/weights_epochs4"
model.load_weights(model_string)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x12c9d75b0>

In [56]:
MAX_PLAY_LENGTH = 203

x_val = np.load("./seq_unnorm_data/x_val.npy")
y_val = np.load("./seq_unnorm_data/y_val.npy")


In [59]:
x_val_input = x_val.reshape(-1, MAX_PLAY_LENGTH, 23, 11)[:,:,:,4:].reshape(-1,MAX_PLAY_LENGTH,23*7)

In [63]:
preds = model.predict(x_val_input)



In [69]:
four_dim = x_val.reshape(x_val.shape[0], MAX_PLAY_LENGTH, 23, -1)

In [70]:
four_dim[:,:,:,6] += (53.3/2)

In [73]:
test_preds = np.tile(preds, (1,1,23)).reshape(1300, 203, 23, 3)

In [78]:
val_with_preds = np.concatenate([four_dim, test_preds], axis=3)

val_with_preds.shape

(1300, 203, 23, 14)

In [82]:
col_names = ['gameId', 'playId', 'frameId', 'nflId', 'team_indicator', 'adj_x', 'adj_y', 's', 'a', 'adj_o', 'adj_dir', 'no_sack_prob', 'sack_prob', 'pred_time']

val_with_preds_df = pd.DataFrame(val_with_preds.reshape(-1, val_with_preds.shape[-1]), columns=col_names)
val_with_preds_df.gameId.astype('int64', copy=False)


0                   0
1                   0
2                   0
3                   0
4                   0
              ...    
6069695    2021102408
6069696    2021102408
6069697    2021102408
6069698    2021102408
6069699    2021102408
Name: gameId, Length: 6069700, dtype: int64

In [85]:
val_with_preds_df['true_sack'] = np.tile(y_val, (1,1,23)).reshape(-1,3)[:,1]

In [86]:
val_with_preds_df.head()

Unnamed: 0,gameId,playId,frameId,nflId,team_indicator,adj_x,adj_y,s,a,adj_o,adj_dir,no_sack_prob,sack_prob,pred_time,true_sack
0,0.0,0.0,0.0,0.0,0.0,0.0,26.65,0.0,0.0,0.0,0.0,0.500505,0.499495,-0.005946,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,26.65,0.0,0.0,0.0,0.0,0.500505,0.499495,-0.005946,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,26.65,0.0,0.0,0.0,0.0,0.500505,0.499495,-0.005946,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,26.65,0.0,0.0,0.0,0.0,0.500505,0.499495,-0.005946,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,26.65,0.0,0.0,0.0,0.0,0.500505,0.499495,-0.005946,0.0


In [87]:
val_with_preds_no_padding_df = val_with_preds_df[val_with_preds_df.gameId != 0]
val_with_preds_no_padding_df.gameId = val_with_preds_no_padding_df.gameId.astype('int64')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_with_preds_no_padding_df.gameId = val_with_preds_no_padding_df.gameId.astype('int64')


In [88]:
val_with_preds_no_padding_df.head()

Unnamed: 0,gameId,playId,frameId,nflId,team_indicator,adj_x,adj_y,s,a,adj_o,adj_dir,no_sack_prob,sack_prob,pred_time,true_sack
4140,2021101704,2327.0,6.0,43290.0,1.0,83.53,29.44,0.01,0.01,357.97,247.63,0.525297,0.474703,-0.937201,0.0
4141,2021101704,2327.0,6.0,43299.0,2.0,76.4,44.71,1.93,1.42,235.91,21.0,0.525297,0.474703,-0.937201,0.0
4142,2021101704,2327.0,6.0,43350.0,2.0,67.91,35.25,0.32,0.69,170.93,150.43,0.525297,0.474703,-0.937201,0.0
4143,2021101704,2327.0,6.0,43453.0,1.0,80.54,30.92,0.09,0.54,26.79,291.61,0.525297,0.474703,-0.937201,0.0
4144,2021101704,2327.0,6.0,43455.0,2.0,78.7,30.61,0.04,0.45,206.57,204.82,0.525297,0.474703,-0.937201,0.0


In [89]:
true_sack_df = val_with_preds_df.query("true_sack == 1")
true_sack_df.gameId = true_sack_df.gameId.astype('int64')

true_sack_df.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  true_sack_df.gameId = true_sack_df.gameId.astype('int64')


Unnamed: 0,gameId,playId,frameId,nflId,team_indicator,adj_x,adj_y,s,a,adj_o,adj_dir,no_sack_prob,sack_prob,pred_time,true_sack
54855,2021101704,3012.0,6.0,38544.0,1.0,13.21,27.53,0.37,1.87,151.07,190.98,0.340334,0.659666,0.78158,1.0
54856,2021101704,3012.0,6.0,38553.0,2.0,15.27,32.76,0.31,1.42,345.71,104.15,0.340334,0.659666,0.78158,1.0
54857,2021101704,3012.0,6.0,40171.0,1.0,13.36,32.21,0.24,1.33,150.47,195.5,0.340334,0.659666,0.78158,1.0
54858,2021101704,3012.0,6.0,41619.0,2.0,14.6,29.53,0.19,0.72,335.83,157.79,0.340334,0.659666,0.78158,1.0
54859,2021101704,3012.0,6.0,42500.0,2.0,15.76,24.78,0.0,0.0,358.96,308.54,0.340334,0.659666,0.78158,1.0


In [90]:
true_sack_df.drop_duplicates(['gameId', 'playId'])

Unnamed: 0,gameId,playId,frameId,nflId,team_indicator,adj_x,adj_y,s,a,adj_o,adj_dir,no_sack_prob,sack_prob,pred_time,true_sack
54855,2021101704,3012.0,6.0,38544.0,1.0,13.21,27.53,0.37,1.87,151.07,190.98,0.340334,0.659666,0.781580,1.0
190394,2021101705,206.0,6.0,38622.0,1.0,72.54,19.49,0.00,0.00,30.20,158.84,0.922996,0.077004,1.524419,1.0
218684,2021101705,585.0,6.0,39971.0,1.0,46.43,10.10,1.59,0.28,156.49,338.23,0.528918,0.471082,-0.782943,1.0
325956,2021101705,2035.0,6.0,39971.0,1.0,62.91,9.78,0.17,0.84,156.73,110.77,0.169863,0.830137,-0.473978,1.0
334604,2021101705,2194.0,6.0,38622.0,1.0,76.70,34.33,0.00,0.00,353.46,290.30,0.642858,0.357142,-0.859047,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5970731,2021102407,3881.0,6.0,37266.0,2.0,71.50,29.39,0.40,0.96,313.34,188.42,0.752078,0.247922,0.252005,1.0
5984899,2021102408,131.0,6.0,37077.0,1.0,71.18,45.30,0.00,0.00,8.63,11.26,0.623120,0.376880,-0.663405,1.0
5998699,2021102408,273.0,6.0,33566.0,2.0,82.22,18.73,0.00,0.14,6.54,2.99,0.876439,0.123561,-0.368974,1.0
6002563,2021102408,353.0,6.0,37077.0,1.0,78.86,39.31,0.08,0.77,350.92,312.09,0.665748,0.334252,0.853871,1.0


In [98]:
# oscillatory probabilities example
test_game_id = 2021101705
test_play_id = 206.0

In [99]:
# success example

# test_game_id = 2021102407
# test_play_id = 2614.0

In [100]:

play_df = val_with_preds_no_padding_df.query("gameId == @test_game_id and playId == @test_play_id")
print(f"play_df shape = {play_df.shape}")

play_df.head()

play_df shape = (1035, 15)


Unnamed: 0,gameId,playId,frameId,nflId,team_indicator,adj_x,adj_y,s,a,adj_o,adj_dir,no_sack_prob,sack_prob,pred_time,true_sack
190394,2021101705,206.0,6.0,38622.0,1.0,72.54,19.49,0.0,0.0,30.2,158.84,0.922996,0.077004,1.524419,1.0
190395,2021101705,206.0,6.0,39947.0,1.0,72.19,26.23,0.09,0.71,27.68,105.23,0.922996,0.077004,1.524419,1.0
190396,2021101705,206.0,6.0,41300.0,2.0,67.02,26.76,0.0,0.0,285.87,293.55,0.922996,0.077004,1.524419,1.0
190397,2021101705,206.0,6.0,41308.0,2.0,70.25,36.13,0.02,0.02,204.25,134.83,0.922996,0.077004,1.524419,1.0
190398,2021101705,206.0,6.0,41483.0,2.0,66.17,39.24,0.23,0.22,233.61,16.9,0.922996,0.077004,1.524419,1.0


In [101]:
play_df.tail()

Unnamed: 0,gameId,playId,frameId,nflId,team_indicator,adj_x,adj_y,s,a,adj_o,adj_dir,no_sack_prob,sack_prob,pred_time,true_sack
191424,2021101705,206.0,50.0,52449.0,1.0,70.36,14.77,7.56,1.96,199.97,283.86,0.776602,0.223398,-0.505819,1.0
191425,2021101705,206.0,50.0,52498.0,2.0,79.18,21.88,5.79,4.15,285.03,292.96,0.776602,0.223398,-0.505819,1.0
191426,2021101705,206.0,50.0,53556.0,1.0,57.45,39.98,7.21,2.99,95.67,68.05,0.776602,0.223398,-0.505819,1.0
191427,2021101705,206.0,50.0,53624.0,2.0,71.49,26.91,1.53,0.24,180.14,277.28,0.776602,0.223398,-0.505819,1.0
191428,2021101705,206.0,50.0,0.0,3.0,78.97,21.27,5.08,2.82,0.0,0.0,0.776602,0.223398,-0.505819,1.0


In [102]:
animated_play = AnimateFeature(play_df)
HTML(animated_play.ani.to_jshtml())

In [48]:
# from importlib import reload

# def importOrReload(module_name, *names):
#     import sys

#     if module_name in sys.modules:
#         reload(sys.modules[module_name])
#     else:
#         __import__(module_name, fromlist=names)

#     for name in names:
#         globals()[name] = getattr(sys.modules[module_name], name)
        
# importOrReload("temp", "AnimatePlay")
