In [1]:
from nba_api.stats.static import teams
import pandas as pd
import sqlite3

In [87]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, confusion_matrix, f1_score
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.linear_model import LogisticRegression

In [90]:
from interpret import show
from interpret.data import ClassHistogram
from interpret.glassbox import ExplainableBoostingClassifier, ClassificationTree, DecisionListClassifier
from interpret.perf import ROC

In [2]:
team_data=teams.get_teams()

In [5]:
connection = sqlite3.connect('data/nba.db')

In [65]:
df=pd.read_sql("select * from ATL", connection)

In [23]:
for idx,t in enumerate(team_data):
    if idx==0:
        continue
    print(t['abbreviation'])
    df=df.append(pd.read_sql("select * from "+str(t['abbreviation']), connection))


BOS
CLE
NOP
CHI
DAL
DEN
GSW
HOU
LAC
LAL
MIA
MIL
MIN
BKN
NYK
ORL
IND
PHI
PHX
POR
SAC
SAS
OKC
TOR
UTA
MEM
WAS
DET
CHA


In [66]:
df

Unnamed: 0,index,SEASON_ID,TEAM_ID,TEAM_ABBREVIATION,TEAM_NAME,GAME_ID,GAME_DATE,MATCHUP,WL,MIN,...,OREB,DREB,REB,AST,STL,BLK,TOV,PF,PLUS_MINUS,opponent
0,0,22019,1610612737,ATL,Atlanta Hawks,0021900969,2020-03-11,ATL vs. NYK,L,265,...,15,38,53,26,6.0,3,17,25,0.0,NYK
1,1,22019,1610612737,ATL,Atlanta Hawks,0021900957,2020-03-09,ATL vs. CHA,W,290,...,12,41,53,33,1.0,5,15,26,5.0,CHA
2,2,22019,1610612737,ATL,Atlanta Hawks,0021900943,2020-03-07,ATL @ MEM,L,240,...,14,32,46,23,9.0,2,14,24,-17.0,MEM
3,3,22019,1610612737,ATL,Atlanta Hawks,0021900930,2020-03-06,ATL @ WAS,L,239,...,6,30,36,25,9.0,4,17,25,-6.0,WAS
4,4,22019,1610612737,ATL,Atlanta Hawks,0021900905,2020-03-02,ATL vs. MEM,L,239,...,16,27,43,20,6.0,8,17,21,-39.0,MEM
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3307,3307,21983,1610612737,ATL,Atlanta Hawks,0028300058,1983-11-06,ATL @ MIL,L,240,...,11,26,37,18,4.0,5,14,24,,MIL
3308,3308,21983,1610612737,ATL,Atlanta Hawks,0028300041,1983-11-04,ATL vs. CHI,W,240,...,19,27,46,31,14.0,13,18,27,,CHI
3309,3309,21983,1610612737,ATL,Atlanta Hawks,0028300027,1983-11-01,ATL vs. WAS,W,240,...,12,29,41,20,7.0,10,16,34,,WAS
3310,3310,21983,1610612737,ATL,Atlanta Hawks,0028300014,1983-10-29,ATL vs. DET,W,240,...,27,21,48,28,14.0,7,23,35,,DET


In [67]:
df['target']=df.WL.apply(lambda x: 1 if x=='W' else 0)

In [70]:
(df.target.value_counts()/len(df))[1]

0.49033816425120774

In [71]:
df=df.astype({'GAME_DATE':'datetime64'}).sort_values('GAME_DATE')

In [72]:
df=df.rename(columns={'MIN':'minutes'})

In [73]:
numerical_features=['minutes', 'PTS', 'FGM', 'FGA',
       'FG_PCT', 'FG3M', 'FG3A', 'FG3_PCT', 'FTM', 'FTA', 'FT_PCT', 'OREB',
       'DREB', 'REB', 'AST', 'STL', 'BLK', 'TOV', 'PF', 'PLUS_MINUS']

In [74]:
for f in numerical_features:
    df[f]=df[f].rolling(window=10).mean()

In [75]:
df

Unnamed: 0,index,SEASON_ID,TEAM_ID,TEAM_ABBREVIATION,TEAM_NAME,GAME_ID,GAME_DATE,MATCHUP,WL,minutes,...,DREB,REB,AST,STL,BLK,TOV,PF,PLUS_MINUS,opponent,target
3311,3311,21983,1610612737,ATL,Atlanta Hawks,0028300005,1983-10-28,ATL @ NJN,L,,...,,,,,,,,,NJN,0
3310,3310,21983,1610612737,ATL,Atlanta Hawks,0028300014,1983-10-29,ATL vs. DET,W,,...,,,,,,,,,DET,1
3309,3309,21983,1610612737,ATL,Atlanta Hawks,0028300027,1983-11-01,ATL vs. WAS,W,,...,,,,,,,,,WAS,1
3308,3308,21983,1610612737,ATL,Atlanta Hawks,0028300041,1983-11-04,ATL vs. CHI,W,,...,,,,,,,,,CHI,1
3307,3307,21983,1610612737,ATL,Atlanta Hawks,0028300058,1983-11-06,ATL @ MIL,L,,...,,,,,,,,,MIL,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,4,22019,1610612737,ATL,Atlanta Hawks,0021900905,2020-03-02,ATL vs. MEM,L,244.8,...,32.8,43.0,27.0,7.3,5.2,13.4,21.8,-4.8,MEM,0
3,3,22019,1610612737,ATL,Atlanta Hawks,0021900930,2020-03-06,ATL @ WAS,L,239.5,...,31.9,41.6,26.6,7.3,4.8,13.7,21.9,-5.9,WAS,0
2,2,22019,1610612737,ATL,Atlanta Hawks,0021900943,2020-03-07,ATL @ MEM,L,239.6,...,32.4,43.1,25.6,7.7,4.5,13.9,22.5,-6.7,MEM,0
1,1,22019,1610612737,ATL,Atlanta Hawks,0021900957,2020-03-09,ATL vs. CHA,W,244.8,...,33.6,44.6,26.5,6.7,4.5,13.7,22.5,-4.0,CHA,1


In [76]:
df.dropna(inplace=True)

In [77]:
df

Unnamed: 0,index,SEASON_ID,TEAM_ID,TEAM_ABBREVIATION,TEAM_NAME,GAME_ID,GAME_DATE,MATCHUP,WL,minutes,...,DREB,REB,AST,STL,BLK,TOV,PF,PLUS_MINUS,opponent,target
2164,2164,21996,1610612737,ATL,Atlanta Hawks,0029600120,1996-11-16,ATL @ CHI,L,240.4,...,27.4,39.8,16.3,8.1,5.1,16.6,19.0,-2.2,CHI,0
2163,2163,21996,1610612737,ATL,Atlanta Hawks,0029600132,1996-11-19,ATL @ CLE,L,240.3,...,27.5,39.9,16.6,7.8,5.1,15.9,19.0,-1.9,CLE,0
2162,2162,21996,1610612737,ATL,Atlanta Hawks,0029600150,1996-11-21,ATL @ MIL,W,240.3,...,27.9,40.1,16.4,8.1,5.6,15.4,19.4,0.1,MIL,1
2161,2161,21996,1610612737,ATL,Atlanta Hawks,0029600162,1996-11-23,ATL @ TOR,W,240.3,...,27.7,39.7,16.5,7.6,5.2,16.0,19.4,-1.4,TOR,1
2160,2160,21996,1610612737,ATL,Atlanta Hawks,0029600181,1996-11-26,ATL vs. VAN,W,240.0,...,28.9,41.1,16.8,7.3,5.7,16.0,18.7,-1.5,VAN,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,4,22019,1610612737,ATL,Atlanta Hawks,0021900905,2020-03-02,ATL vs. MEM,L,244.8,...,32.8,43.0,27.0,7.3,5.2,13.4,21.8,-4.8,MEM,0
3,3,22019,1610612737,ATL,Atlanta Hawks,0021900930,2020-03-06,ATL @ WAS,L,239.5,...,31.9,41.6,26.6,7.3,4.8,13.7,21.9,-5.9,WAS,0
2,2,22019,1610612737,ATL,Atlanta Hawks,0021900943,2020-03-07,ATL @ MEM,L,239.6,...,32.4,43.1,25.6,7.7,4.5,13.9,22.5,-6.7,MEM,0
1,1,22019,1610612737,ATL,Atlanta Hawks,0021900957,2020-03-09,ATL vs. CHA,W,244.8,...,33.6,44.6,26.5,6.7,4.5,13.7,22.5,-4.0,CHA,1


In [78]:
pd.get_dummies(df.opponent)

Unnamed: 0,BKN,BOS,CHA,CHH,CHI,CLE,DAL,DEN,DET,DLS,...,PHI,PHX,POR,SAC,SAS,SEA,TOR,UTA,VAN,WAS
2164,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2163,0,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2162,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2161,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,0
2160,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [79]:
t=df[numerical_features].shift(1)

In [80]:
t['target']=df.target

In [81]:
t.dropna(inplace=True)

In [82]:
t

Unnamed: 0,minutes,PTS,FGM,FGA,FG_PCT,FG3M,FG3A,FG3_PCT,FTM,FTA,...,OREB,DREB,REB,AST,STL,BLK,TOV,PF,PLUS_MINUS,target
2163,240.4,87.2,30.2,74.2,0.4061,7.8,21.1,0.3560,19.0,25.6,...,12.4,27.4,39.8,16.3,8.1,5.1,16.6,19.0,-2.2,0
2162,240.3,85.4,30.0,74.8,0.4000,8.1,21.8,0.3614,17.3,23.9,...,12.4,27.5,39.9,16.6,7.8,5.1,15.9,19.0,-1.9,1
2161,240.3,84.9,29.9,74.5,0.4003,7.4,20.5,0.3372,17.7,24.8,...,12.2,27.9,40.1,16.4,8.1,5.6,15.4,19.4,0.1,1
2160,240.3,84.6,29.9,73.3,0.4087,6.9,20.3,0.3125,17.9,25.1,...,12.0,27.7,39.7,16.5,7.6,5.2,16.0,19.4,-1.4,1
2159,240.0,83.0,29.7,73.7,0.4038,6.1,19.3,0.2994,17.5,25.1,...,12.2,28.9,41.1,16.8,7.3,5.7,16.0,18.7,-1.5,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,244.9,122.0,42.9,92.1,0.4667,13.8,37.7,0.3658,22.4,26.4,...,9.8,33.5,43.3,27.1,7.4,5.2,13.1,21.8,-1.4,0
3,244.8,120.1,42.6,93.4,0.4583,13.4,37.3,0.3582,21.5,25.4,...,10.2,32.8,43.0,27.0,7.3,5.2,13.4,21.8,-4.8,0
2,239.5,117.3,42.3,91.2,0.4664,13.2,36.5,0.3607,19.5,23.5,...,9.7,31.9,41.6,26.6,7.3,4.8,13.7,21.9,-5.9,0
1,239.6,114.8,40.8,91.9,0.4467,12.9,37.7,0.3440,20.3,24.6,...,10.7,32.4,43.1,25.6,7.7,4.5,13.9,22.5,-6.7,1


In [85]:
for idx,team in enumerate(team_data):
    if idx==0:
        continue
    
    df=pd.read_sql("select * from "+str(team['abbreviation']), connection)
    df=df.astype({'GAME_DATE':'datetime64'}).sort_values('GAME_DATE')
    df=df.rename(columns={'MIN':'minutes'})
    df['target']=df.WL.apply(lambda x: 1 if x=='W' else 0)
    print(team['abbreviation'],(df.target.value_counts()/len(df))[1])
    df.dropna(inplace=True)
    temp=df[numerical_features].shift(1)
    temp['target']=df.target
    temp.dropna(inplace=True)
    t=t.append(temp)



BOS 0.5579399141630901
CLE 0.494026284348865
NOP 0.4589082183563287
CHI 0.5310850439882698
DAL 0.5120285120285121
DEN 0.4848851269649335
GSW 0.47944377267230953
HOU 0.5703399765533411
LAC 0.4206471494607088
LAL 0.5974208017942249
MIA 0.5229846768820786
MIL 0.4797460701330109
MIN 0.4035608308605341
BKN 0.42110091743119266
NYK 0.4639423076923077
ORL 0.4813473379210431
IND 0.5173951828724354
PHI 0.4612146722790138
PHX 0.5304295942720764
POR 0.5503256364712847
SAC 0.42383900928792567
SAS 0.6132533561839475
OKC 0.5553254437869822
TOR 0.48802786242925555
UTA 0.5728070175438597
MEM 0.42679127725856697
WAS 0.42014210688909487
DET 0.5322959483264826
CHA 0.43735676088617265


In [86]:
t

Unnamed: 0,minutes,PTS,FGM,FGA,FG_PCT,FG3M,FG3A,FG3_PCT,FTM,FTA,...,OREB,DREB,REB,AST,STL,BLK,TOV,PF,PLUS_MINUS,target
2163,240.4,87.2,30.2,74.2,0.4061,7.8,21.1,0.3560,19.0,25.6,...,12.4,27.4,39.8,16.3,8.1,5.1,16.6,19.0,-2.2,0
2162,240.3,85.4,30.0,74.8,0.4000,8.1,21.8,0.3614,17.3,23.9,...,12.4,27.5,39.9,16.6,7.8,5.1,15.9,19.0,-1.9,1
2161,240.3,84.9,29.9,74.5,0.4003,7.4,20.5,0.3372,17.7,24.8,...,12.2,27.9,40.1,16.4,8.1,5.6,15.4,19.4,0.1,1
2160,240.3,84.6,29.9,73.3,0.4087,6.9,20.3,0.3125,17.9,25.1,...,12.0,27.7,39.7,16.5,7.6,5.2,16.0,19.4,-1.4,1
2159,240.0,83.0,29.7,73.7,0.4038,6.1,19.3,0.2994,17.5,25.1,...,12.2,28.9,41.1,16.8,7.3,5.7,16.0,18.7,-1.5,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,239.0,85.0,34.0,96.0,0.3540,8.0,35.0,0.2290,9.0,10.0,...,12.0,35.0,47.0,22.0,7.0,8.0,10.0,18.0,-8.0,0
3,240.0,103.0,39.0,82.0,0.4760,12.0,29.0,0.4140,13.0,18.0,...,14.0,32.0,46.0,25.0,4.0,2.0,17.0,16.0,-1.0,0
2,241.0,112.0,38.0,77.0,0.4940,13.0,33.0,0.3940,23.0,28.0,...,9.0,25.0,34.0,24.0,7.0,2.0,13.0,23.0,-2.0,1
1,240.0,108.0,37.0,70.0,0.5290,15.0,32.0,0.4690,19.0,25.0,...,6.0,24.0,30.0,29.0,7.0,1.0,17.0,15.0,9.0,0


In [88]:
X=t[numerical_features]

In [89]:
y=t.target

In [91]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=117)

In [92]:

ebm = ExplainableBoostingClassifier(feature_names=None,
    feature_types=None,
    max_bins=100,
    max_interaction_bins=32,
    binning='uniform',
    mains='all',
    interactions=0,
    outer_bags=16,
    inner_bags=0,
    learning_rate=0.1,
    validation_size=0.1,
    early_stopping_rounds=150,
    early_stopping_tolerance=0.0001,
    max_rounds=5000,
    max_leaves=3,
    min_samples_leaf=2,
    n_jobs=-2,
    random_state=117)
ebm.fit(X_train, y_train)   

ExplainableBoostingClassifier(binning='uniform', early_stopping_rounds=150,
                              feature_names=['minutes', 'PTS', 'FGM', 'FGA',
                                             'FG_PCT', 'FG3M', 'FG3A',
                                             'FG3_PCT', 'FTM', 'FTA', 'FT_PCT',
                                             'OREB', 'DREB', 'REB', 'AST',
                                             'STL', 'BLK', 'TOV', 'PF',
                                             'PLUS_MINUS'],
                              feature_types=['continuous', 'continuous',
                                             'continuous', 'continuous',
                                             'continuous', 'continuous',
                                             'continuous', 'continuous',
                                             'continuous', 'continuous',
                                             'continuous', 'continuous',
                                             'conti

In [93]:
ebm_perf = ROC(ebm.predict_proba).explain_perf(X_test, y_test, name='EBM')
show(ebm_perf)

In [96]:
ebm_global = ebm.explain_global()
show(ebm_global)

In [94]:
y_pred_ebm=ebm.predict(X_test)
y_pred_proba_ebm = ebm.predict_proba(X_test)[::,1]

In [95]:
print('Test Accuracy: '+str(accuracy_score(y_test, y_pred_ebm)))
print('Train Accuracy: '+str(accuracy_score(y_train, y_pred=ebm.predict(X_train))))

Test Accuracy: 0.5453421906380349
Train Accuracy: 0.5505194855355142


In [97]:
baseline_model=LogisticRegression().fit(X_train,y_train)

In [98]:
y_pred=baseline_model.predict(X_test)
y_pred_proba = baseline_model.predict_proba(X_test)[::,1]

In [99]:
print('Test Accuracy: '+str(accuracy_score(y_test, y_pred)))
print('Train Accuracy: '+str(accuracy_score(y_train, y_pred=baseline_model.predict(X_train))))

Test Accuracy: 0.5421751892476441
Train Accuracy: 0.5443976671430226
