<div class="alert alert-danger">
    <h4 style="font-weight: bold; font-size: 28px;">LSTM with Box Score Features</h4>
    <h5 style="font-weight: bold; font-size: 24px;">Test Set using Rolling Window</h5>
    <p style="font-size: 20px;">NBA API Seasons 2021-22 to 2023-24</p>
</div>

<a name="Models"></a>

# Setup

[Return to top](#Models)

In [1]:
import sys
from pathlib import Path
# get current working directory
cwd = %pwd
# add shared_code directory to Python sys.path
sys.path.append(str(Path(cwd).parent / "shared_code"))
# import all libraries in shared_code directory 'imports.py' file
from imports import *
%matplotlib inline

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [2]:
import statsmodels.api as sm
import statsmodels.tsa.api as smt
import statsmodels.graphics.api as smg
from statsmodels.tsa.stattools import acf, pacf
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.graphics.gofplots import qqplot
from scipy.signal import periodogram

import pmdarima as pm
from pmdarima.model_selection import train_test_split as pm_train_test_split
from pmdarima.model_selection import RollingForecastCV
from pmdarima.metrics import smape 
from pmdarima.utils import tsdisplay
from pmdarima.preprocessing import BoxCoxEndogTransformer
from arch import arch_model

from sklearn import neighbors, decomposition, metrics, preprocessing
from sklearn import model_selection
from sklearn.model_selection import train_test_split, cross_val_score 
from sklearn import metrics

# Data

[Return to top](#Models)

Data splits:

- Define NBA Season 2021-22 as the TRAINING set: regular season is 2021-10-19 to 2022-04-10. 
- Define NBA Season 2022-23 as the VALIDATION set: regular season is 2022-10-18 to 2023-04-09.
- Define NBA Season 2023-24 as the TESTING set: regular season is 2023-10-24 to 2024-04-14.

In [39]:
# load, filter (by time) and scale data
pts_scaled_df, pm_scaled_df, res_scaled_df, test_set_obs = utl.load_and_scale_data(
    input_data='../../data/processed/nba_team_matchups_rolling_box_scores_2022_2024_r05.csv',
    seasons_to_keep=['2021-22', '2022-23', '2023-24'], 
    training_season='2021-22',
    feature_prefixes=['ROLL_'],
    scaler_type='minmax', 
    scale_target=False
)

Season 2021-22: 1186 games
Season 2022-23: 1181 games
Season 2023-24: 692 games
Total number of games across sampled seasons: 3059 games


In [40]:
# define number of games in seasons
season_22_ngames = 1186
season_23_ngames = 1181

In [4]:
pts_scaled_df.head()

Unnamed: 0_level_0,ROLL_HOME_PTS,ROLL_HOME_FGM,ROLL_HOME_FGA,ROLL_HOME_FG_PCT,ROLL_HOME_FG3M,ROLL_HOME_FG3A,ROLL_HOME_FG3_PCT,ROLL_HOME_FTM,ROLL_HOME_FTA,ROLL_HOME_FT_PCT,ROLL_HOME_OREB,ROLL_HOME_DREB,ROLL_HOME_REB,ROLL_HOME_AST,ROLL_HOME_STL,ROLL_HOME_BLK,ROLL_HOME_TOV,ROLL_HOME_PF,ROLL_AWAY_PTS,ROLL_AWAY_FGM,ROLL_AWAY_FGA,ROLL_AWAY_FG_PCT,ROLL_AWAY_FG3M,ROLL_AWAY_FG3A,ROLL_AWAY_FG3_PCT,ROLL_AWAY_FTM,ROLL_AWAY_FTA,ROLL_AWAY_FT_PCT,ROLL_AWAY_OREB,ROLL_AWAY_DREB,ROLL_AWAY_REB,ROLL_AWAY_AST,ROLL_AWAY_STL,ROLL_AWAY_BLK,ROLL_AWAY_TOV,ROLL_AWAY_PF,TOTAL_PTS
GAME_DATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
2021-10-23,124.0,42.0,87.0,0.483,16.0,38.0,0.421,24.0,31.0,0.774,13.0,33.0,46.0,26.0,18.0,12.0,18.0,22.0,112.0,42.0,87.0,0.483,15.0,29.0,0.517,13.0,16.0,0.813,9.0,33.0,42.0,26.0,7.0,5.0,16.0,20.0,185
2021-10-23,83.0,30.0,97.0,0.309,7.0,34.0,0.206,16.0,22.0,0.727,19.0,35.0,54.0,14.0,10.0,4.0,19.0,21.0,87.0,31.0,93.0,0.333,13.0,43.0,0.302,12.0,13.0,0.923,10.0,40.0,50.0,16.0,7.0,3.0,15.0,21.0,198
2021-10-23,121.0,45.0,93.0,0.484,12.0,35.0,0.343,19.0,22.0,0.864,9.0,40.0,49.0,25.0,5.0,5.0,12.0,22.0,115.0,42.0,86.0,0.488,10.0,32.0,0.313,21.0,28.0,0.75,7.0,40.0,47.0,31.0,8.0,2.0,11.0,22.0,239
2021-10-23,123.0,49.0,98.0,0.5,13.0,30.0,0.433,12.0,18.0,0.667,13.0,30.0,43.0,32.0,8.0,3.0,8.0,22.0,95.0,32.0,84.0,0.381,12.0,42.0,0.286,19.0,29.0,0.655,5.0,33.0,38.0,19.0,6.0,0.0,15.0,26.0,232
2021-10-24,124.0,48.0,95.0,0.505,17.0,38.0,0.447,11.0,14.0,0.786,10.0,44.0,54.0,29.0,12.0,10.0,17.0,18.0,134.0,48.0,117.0,0.41,21.0,57.0,0.368,17.0,23.0,0.739,15.0,41.0,56.0,34.0,13.0,9.0,18.0,24.0,204


In [5]:
pm_scaled_df.head()

Unnamed: 0_level_0,ROLL_HOME_PTS,ROLL_HOME_FGM,ROLL_HOME_FGA,ROLL_HOME_FG_PCT,ROLL_HOME_FG3M,ROLL_HOME_FG3A,ROLL_HOME_FG3_PCT,ROLL_HOME_FTM,ROLL_HOME_FTA,ROLL_HOME_FT_PCT,ROLL_HOME_OREB,ROLL_HOME_DREB,ROLL_HOME_REB,ROLL_HOME_AST,ROLL_HOME_STL,ROLL_HOME_BLK,ROLL_HOME_TOV,ROLL_HOME_PF,ROLL_AWAY_PTS,ROLL_AWAY_FGM,ROLL_AWAY_FGA,ROLL_AWAY_FG_PCT,ROLL_AWAY_FG3M,ROLL_AWAY_FG3A,ROLL_AWAY_FG3_PCT,ROLL_AWAY_FTM,ROLL_AWAY_FTA,ROLL_AWAY_FT_PCT,ROLL_AWAY_OREB,ROLL_AWAY_DREB,ROLL_AWAY_REB,ROLL_AWAY_AST,ROLL_AWAY_STL,ROLL_AWAY_BLK,ROLL_AWAY_TOV,ROLL_AWAY_PF,PLUS_MINUS
GAME_DATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
2021-10-23,124.0,42.0,87.0,0.483,16.0,38.0,0.421,24.0,31.0,0.774,13.0,33.0,46.0,26.0,18.0,12.0,18.0,22.0,112.0,42.0,87.0,0.483,15.0,29.0,0.517,13.0,16.0,0.813,9.0,33.0,42.0,26.0,7.0,5.0,16.0,20.0,7.0
2021-10-23,83.0,30.0,97.0,0.309,7.0,34.0,0.206,16.0,22.0,0.727,19.0,35.0,54.0,14.0,10.0,4.0,19.0,21.0,87.0,31.0,93.0,0.333,13.0,43.0,0.302,12.0,13.0,0.923,10.0,40.0,50.0,16.0,7.0,3.0,15.0,21.0,-8.0
2021-10-23,121.0,45.0,93.0,0.484,12.0,35.0,0.343,19.0,22.0,0.864,9.0,40.0,49.0,25.0,5.0,5.0,12.0,22.0,115.0,42.0,86.0,0.488,10.0,32.0,0.313,21.0,28.0,0.75,7.0,40.0,47.0,31.0,8.0,2.0,11.0,22.0,29.0
2021-10-23,123.0,49.0,98.0,0.5,13.0,30.0,0.433,12.0,18.0,0.667,13.0,30.0,43.0,32.0,8.0,3.0,8.0,22.0,95.0,32.0,84.0,0.381,12.0,42.0,0.286,19.0,29.0,0.655,5.0,33.0,38.0,19.0,6.0,0.0,15.0,26.0,-10.0
2021-10-24,124.0,48.0,95.0,0.505,17.0,38.0,0.447,11.0,14.0,0.786,10.0,44.0,54.0,29.0,12.0,10.0,17.0,18.0,134.0,48.0,117.0,0.41,21.0,57.0,0.368,17.0,23.0,0.739,15.0,41.0,56.0,34.0,13.0,9.0,18.0,24.0,-10.0


In [6]:
res_scaled_df.head()

Unnamed: 0_level_0,ROLL_HOME_PTS,ROLL_HOME_FGM,ROLL_HOME_FGA,ROLL_HOME_FG_PCT,ROLL_HOME_FG3M,ROLL_HOME_FG3A,ROLL_HOME_FG3_PCT,ROLL_HOME_FTM,ROLL_HOME_FTA,ROLL_HOME_FT_PCT,ROLL_HOME_OREB,ROLL_HOME_DREB,ROLL_HOME_REB,ROLL_HOME_AST,ROLL_HOME_STL,ROLL_HOME_BLK,ROLL_HOME_TOV,ROLL_HOME_PF,ROLL_AWAY_PTS,ROLL_AWAY_FGM,ROLL_AWAY_FGA,ROLL_AWAY_FG_PCT,ROLL_AWAY_FG3M,ROLL_AWAY_FG3A,ROLL_AWAY_FG3_PCT,ROLL_AWAY_FTM,ROLL_AWAY_FTA,ROLL_AWAY_FT_PCT,ROLL_AWAY_OREB,ROLL_AWAY_DREB,ROLL_AWAY_REB,ROLL_AWAY_AST,ROLL_AWAY_STL,ROLL_AWAY_BLK,ROLL_AWAY_TOV,ROLL_AWAY_PF,GAME_RESULT
GAME_DATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1
2021-10-23,124.0,42.0,87.0,0.483,16.0,38.0,0.421,24.0,31.0,0.774,13.0,33.0,46.0,26.0,18.0,12.0,18.0,22.0,112.0,42.0,87.0,0.483,15.0,29.0,0.517,13.0,16.0,0.813,9.0,33.0,42.0,26.0,7.0,5.0,16.0,20.0,1
2021-10-23,83.0,30.0,97.0,0.309,7.0,34.0,0.206,16.0,22.0,0.727,19.0,35.0,54.0,14.0,10.0,4.0,19.0,21.0,87.0,31.0,93.0,0.333,13.0,43.0,0.302,12.0,13.0,0.923,10.0,40.0,50.0,16.0,7.0,3.0,15.0,21.0,0
2021-10-23,121.0,45.0,93.0,0.484,12.0,35.0,0.343,19.0,22.0,0.864,9.0,40.0,49.0,25.0,5.0,5.0,12.0,22.0,115.0,42.0,86.0,0.488,10.0,32.0,0.313,21.0,28.0,0.75,7.0,40.0,47.0,31.0,8.0,2.0,11.0,22.0,1
2021-10-23,123.0,49.0,98.0,0.5,13.0,30.0,0.433,12.0,18.0,0.667,13.0,30.0,43.0,32.0,8.0,3.0,8.0,22.0,95.0,32.0,84.0,0.381,12.0,42.0,0.286,19.0,29.0,0.655,5.0,33.0,38.0,19.0,6.0,0.0,15.0,26.0,0
2021-10-24,124.0,48.0,95.0,0.505,17.0,38.0,0.447,11.0,14.0,0.786,10.0,44.0,54.0,29.0,12.0,10.0,17.0,18.0,134.0,48.0,117.0,0.41,21.0,57.0,0.368,17.0,23.0,0.739,15.0,41.0,56.0,34.0,13.0,9.0,18.0,24.0,0


# Rolling Average

In [None]:
pts_scaled_df.reset_index(drop=True, inplace=True)

# n-game rolling average
ngame = 10

plt.figure(figsize=(20, 6))
plt.plot(pts_scaled_df.index, pts_scaled_df['TOTAL_PTS'], label='Total Points')
# calculate n-game rolling mean
rolling_mean = pts_scaled_df['TOTAL_PTS'].rolling(window=ngame).mean()
plt.plot(rolling_mean, color='red', label=f'{ngame}-Game Moving Average')
plt.title('Total Points by Game Number')
plt.xlabel('Game Number')
plt.ylabel('Total Points')
plt.legend()  
plt.tight_layout()
plt.show()

# LSTM

- Need to get the raw data (no rolling averages) but have it normalized.
- Need to have model informed about team indentity
  - through dummies?
  - through separate team embeddings?
  - through separate team sequences? 
- Need to set up the right dimensions for training and prediction.
- Need to ugrade Python to 3.9 or 3.10 and install tensorflow into environment.
- Need to look at previous LSTM code to see if the below can be upgraded.
- Need to functionalize things as much as possible
- Also need to do ARIMAX for other outcomes. But tricky if not impossible for game winner.

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, LSTM, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

# create a toy dataset
np.random.seed(42)
data = np.random.randn(100, 1)  # 100 sequential data points

# function to create sequences
def create_sequences(data, n_steps):
    X, y = [], []
    for i in range(len(data) - n_steps):
        X.append(data[i:i+n_steps])
        y.append(data[i+n_steps])
    return np.array(X), np.array(y)

# prepare data
n_steps = 5
X, y = create_sequences(data, n_steps)
X = np.asarray(X).astype('float32')
y = np.asarray(y).astype('float32')

# define model
inputs = Input(shape=(n_steps, 1))
lstm_layer = LSTM(50, activation='relu')(inputs)
outputs = Dense(1)(lstm_layer)
model = Model(inputs=inputs, outputs=outputs)

# compile and train
model.compile(optimizer=Adam(learning_rate=0.01), loss='mean_squared_error')
model.fit(X, y, epochs=200, verbose=0)

In [None]:
# predict the next 10 points
n_future = 10
future = X[-1:].reshape((1, n_steps, 1))
predictions = []

for _ in range(n_future):
    pred = model.predict(future, verbose=0)
    predictions.append(pred[0, 0])
    future = np.append(future[:, 1:, :], [[pred]], axis=1)

In [None]:
# true future values, assuming we know them for comparison
true_future = data[-n_future:]

# plot predictions vs true values
plt.figure(figsize=(10, 6))
plt.plot(range(len(data)-n_future, len(data)), true_future.flatten(), marker='o', label='True Values')
plt.plot(range(len(data)-n_future, len(data)), predictions, marker='x', label='Predictions')
plt.title('LSTM Predictions vs True Values')
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.legend()
plt.show()