In [107]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder
from nba_api.stats.endpoints import leaguegamefinder
from nba_api.stats.static import teams
import shap
import matplotlib.pyplot as plt
import seaborn as sns

# Setting a style for seaborn, optional
sns.set_style('whitegrid')

In [108]:
# Get all NBA teams info
nba_teams = teams.get_teams()

# Create a dictionary mapping team abbreviations to their team IDs
team_abbr_to_id = {team['abbreviation']: team['id'] for team in nba_teams}

# Initialize an empty DataFrame to store all game data
all_games = pd.DataFrame()

# Loop through all teams and fetch their games
for team in nba_teams:
    team_id = team['id']
    gamefinder = leaguegamefinder.LeagueGameFinder(team_id_nullable=team_id)
    games = gamefinder.get_data_frames()[0]
    all_games = pd.concat([all_games, games], ignore_index=True)

print(all_games.columns)

Index(['SEASON_ID', 'TEAM_ID', 'TEAM_ABBREVIATION', 'TEAM_NAME', 'GAME_ID',
       'GAME_DATE', 'MATCHUP', 'WL', 'MIN', 'PTS', 'FGM', 'FGA', 'FG_PCT',
       'FG3M', 'FG3A', 'FG3_PCT', 'FTM', 'FTA', 'FT_PCT', 'OREB', 'DREB',
       'REB', 'AST', 'STL', 'BLK', 'TOV', 'PF', 'PLUS_MINUS'],
      dtype='object')


In [116]:
all_games

Unnamed: 0,SEASON_ID,TEAM_ID,TEAM_ABBREVIATION,TEAM_NAME,GAME_ID,GAME_DATE,MATCHUP,WL,MIN,PTS,...,STL,BLK,TOV,PF,PLUS_MINUS,WIN,Points_Per_Game,OPPONENT_TEAM_ID,HOME_GAME,LAST_GAME_RESULT
0,22024,0,ATL,Atlanta Hawks,0022401229,2024-12-14,ATL @ MIL,L,241,102.0,...,3.0,6,13,20,-8.0,0,101.475348,12,0,0.0
1,22024,0,ATL,Atlanta Hawks,0022401202,2024-12-11,ATL @ NYK,W,239,108.0,...,5.0,7,10,13,8.0,1,101.475348,15,0,0.0
2,22024,0,ATL,Atlanta Hawks,0022400350,2024-12-08,ATL vs. DEN,L,239,111.0,...,12.0,2,12,22,-30.0,0,101.475348,6,1,1.0
3,22024,0,ATL,Atlanta Hawks,0022400334,2024-12-06,ATL vs. LAL,W,265,134.0,...,10.0,5,17,24,2.0,1,101.475348,10,1,0.0
4,22024,0,ATL,Atlanta Hawks,0022400323,2024-12-04,ATL @ MIL,W,240,119.0,...,16.0,3,11,25,15.0,1,101.475348,12,0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
106909,21988,29,CHH,Charlotte Hornets,0028800062,1988-11-12,CHH @ ATL,L,238,111.0,...,12.0,2,19,26,,0,99.927599,0,0,0.0
106910,21988,29,CHH,Charlotte Hornets,0028800052,1988-11-11,CHH @ WAS,L,240,87.0,...,9.0,1,23,26,,0,99.927599,27,0,0.0
106911,21988,29,CHH,Charlotte Hornets,0028800024,1988-11-08,CHH vs. LAC,W,240,117.0,...,9.0,1,17,31,,1,99.927599,9,1,0.0
106912,21988,29,CHH,Charlotte Hornets,0028800015,1988-11-05,CHH @ DET,L,240,85.0,...,8.0,6,11,21,,0,99.927599,28,0,1.0


In [109]:
# Convert GAME_DATE to datetime
all_games['GAME_DATE'] = pd.to_datetime(all_games['GAME_DATE'])

# Create a win column: 1 if W, else 0
all_games['WIN'] = all_games['WL'].apply(lambda x: 1 if x == 'W' else 0)

# Convert PTS to float
all_games['PTS'] = all_games['PTS'].astype(float)

# Calculate average points per game by team
all_games['Points_Per_Game'] = all_games.groupby('TEAM_ID')['PTS'].transform('mean')

def get_opponent_team_id(matchup, team_abbr_to_id, team_id):
    # Matchup looks like "LAL vs. BOS" or "LAL @ BOS"
    if '@' in matchup:
        opponent_abbr = matchup.split(' @ ')[-1]
    else:
        opponent_abbr = matchup.split(' vs. ')[-1]
    return team_abbr_to_id.get(opponent_abbr, team_id)

# Create OPPONENT_TEAM_ID column
all_games['OPPONENT_TEAM_ID'] = all_games.apply(
    lambda row: get_opponent_team_id(row['MATCHUP'], team_abbr_to_id, row['TEAM_ID']), axis=1
)

# HOME_GAME column: 1 if 'vs.' in matchup, else 0
all_games['HOME_GAME'] = all_games['MATCHUP'].apply(lambda x: 1 if 'vs.' in x else 0)

# LAST_GAME_RESULT: previous game's WIN result for each team
all_games['LAST_GAME_RESULT'] = all_games.groupby('TEAM_ID')['WIN'].shift(1).fillna(0)


In [110]:
le = LabelEncoder()
all_games['TEAM_ID'] = le.fit_transform(all_games['TEAM_ID'])
all_games['OPPONENT_TEAM_ID'] = le.fit_transform(all_games['OPPONENT_TEAM_ID'])


In [111]:
X = all_games[['TEAM_ID', 'OPPONENT_TEAM_ID', 'Points_Per_Game', 'HOME_GAME', 'LAST_GAME_RESULT']]
y = all_games['WIN']

# 80% training, 20% testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


In [112]:
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)


In [113]:
y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))


Accuracy: 0.5859327503156714
              precision    recall  f1-score   support

           0       0.58      0.59      0.59     10692
           1       0.59      0.58      0.58     10691

    accuracy                           0.59     21383
   macro avg       0.59      0.59      0.59     21383
weighted avg       0.59      0.59      0.59     21383



In [114]:
feature_importances = pd.DataFrame(model.feature_importances_,
                                   index=X_train.columns,
                                   columns=['importance']).sort_values('importance', ascending=False)
print("Feature Importances:\n", feature_importances)


Feature Importances:
                   importance
OPPONENT_TEAM_ID    0.481924
HOME_GAME           0.292751
Points_Per_Game     0.096894
TEAM_ID             0.086160
LAST_GAME_RESULT    0.042272


In [115]:
team_abbr = 'LAL'
opponent_abbr = 'BOS'
average_points_per_game = 110.5

new_data = pd.DataFrame({
    'TEAM_ID': [le.transform([team_abbr_to_id[team_abbr]])[0]],
    'OPPONENT_TEAM_ID': [le.transform([team_abbr_to_id[opponent_abbr]])[0]],
    'Points_Per_Game': [average_points_per_game],
    'HOME_GAME': [1],            # Assuming a home game for Lakers
    'LAST_GAME_RESULT': [1]      # Let's say they won their last game
})

predictions = model.predict(new_data)
prediction_probabilities = model.predict_proba(new_data)

print("Predictions: ", predictions)
print("Prediction Probabilities: ", prediction_probabilities)


Predictions:  [0]
Prediction Probabilities:  [[0.53105066 0.46894934]]
