# Imports

In [1]:
from sportsreference.mlb.schedule import Schedule
from sportsreference.mlb.teams import Teams
from tqdm.notebook import tqdm

from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV

import pandas as pd
import numpy as np
import os

pd.set_option('display.max_columns', None)

# Load Data

We first need to load in the data. We can get the sports reference data by running `get_game_boxes.py` file in the SportsReference folder. Thne let's create one giant csv that has all of the games from all seasons.

In [2]:
SR_BOXES = './datasets/sports_reference_boxes/'

def getAllBoxes(boxes_dir=SR_BOXES):
    """Function that returns all of the game box scores"""
    
    dfs = []
    for year in os.listdir(boxes_dir):
        year_dir = os.path.join(boxes_dir, year)
        for file in os.listdir(year_dir):
            team_df = pd.read_csv(os.path.join(year_dir, file), index_col=0)
            team_df.index.name = 'game_id'
            dfs.append(team_df)

    return pd.concat(dfs).sort_values(by=['date', 'time']).drop_duplicates()

# Compute Moving Average

In [3]:
# Get all the boxscores in one dataframe
allBoxes = getAllBoxes()

# Get all the team abbreviations, this is how we will identify teams
TEAMS = set(allBoxes['winning_abbr']).union(set(allBoxes['losing_abbr']))

# Create an away column and a home column
allBoxes['away'] = np.where(allBoxes['winner'] == 'Away',
                            allBoxes['winning_abbr'],
                            allBoxes['losing_abbr'])

allBoxes['home'] = np.where(allBoxes['winner'] == 'Home',
                            allBoxes['winning_abbr'],
                            allBoxes['losing_abbr'])

# labels we are going to keep, we get rid of any overages because taking the average
# of an average is generally incorrect
keeping_labels = ['date',
                  'away',
                  'home',
                  'winner',
                  'away_at_bats',
                  'away_runs',
                  'away_hits',
                  'away_rbi',
                  'away_earned_runs',
                  'away_bases_on_balls',
                  'away_strikeouts',
                  'away_plate_appearances',
                  'away_pitches',
                  'away_strikes',
                  'away_base_out_runs_added',
                  'away_putouts',
                  'away_assists',
                  'away_innings_pitched',
                  'away_home_runs',
                  'away_strikes_by_contact',
                  'away_strikes_swinging',
                  'away_strikes_looking',
                  'away_grounded_balls',
                  'away_fly_balls',
                  'away_line_drives',
                  'away_game_score',
                  'away_base_out_runs_saved',
                  'home_at_bats',
                  'home_runs',
                  'home_hits',
                  'home_rbi',
                  'home_earned_runs',
                  'home_bases_on_balls',
                  'home_strikeouts',
                  'home_plate_appearances',
                  'home_pitches',
                  'home_strikes',
                  'home_base_out_runs_added',
                  'home_putouts',
                  'home_assists',
                  'home_innings_pitched',
                  'home_home_runs',
                  'home_strikes_by_contact',
                  'home_strikes_swinging',
                  'home_strikes_looking',
                  'home_grounded_balls',
                  'home_fly_balls',
                  'home_line_drives',
                  'home_game_score',
                  'home_base_out_runs_saved']

# Drop any columns that are null
allBoxes.dropna(axis=1, inplace=True)

# Keep only certain columns and reorder
allBoxes = allBoxes.filter(items=keeping_labels)
allBoxes = allBoxes[keeping_labels]
allBoxes.rename(columns={'away_home_runs': 'away_homeruns', 'home_home_runs': 'home_homeruns'}, inplace=True)

print(f'There are {allBoxes.shape[0]} examples')
allBoxes

There are 15475 examples


Unnamed: 0_level_0,date,away,home,winner,away_at_bats,away_runs,away_hits,away_rbi,away_earned_runs,away_bases_on_balls,away_strikeouts,away_plate_appearances,away_pitches,away_strikes,away_base_out_runs_added,away_putouts,away_assists,away_innings_pitched,away_homeruns,away_strikes_by_contact,away_strikes_swinging,away_strikes_looking,away_grounded_balls,away_fly_balls,away_line_drives,away_game_score,away_base_out_runs_saved,home_at_bats,home_runs,home_hits,home_rbi,home_earned_runs,home_bases_on_balls,home_strikeouts,home_plate_appearances,home_pitches,home_strikes,home_base_out_runs_added,home_putouts,home_assists,home_innings_pitched,home_homeruns,home_strikes_by_contact,home_strikes_swinging,home_strikes_looking,home_grounded_balls,home_fly_balls,home_line_drives,home_game_score,home_base_out_runs_saved
game_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1
COL/COL201504100,"Friday, April 10, 2015",CHC,COL,Home,30,1,5,1,1.0,4,7,37,150,76,-3.9,24,8,8.0,1,37,18,21,11,13,7,41,-0.6,34,5,11,5,5.62,3,7,37,140,92,0.6,27,12,9.0,0,56,13,23,9,18,10,51,3.9
TEX/TEX201504100,"Friday, April 10, 2015",HOU,TEX,Away,34,5,10,5,5.0,2,7,37,119,79,0.4,27,15,9.0,0,45,15,19,14,14,7,60,3.6,32,1,6,1,1.00,2,8,34,125,82,-3.6,27,10,9.0,2,41,16,25,15,9,7,45,-0.4
CHA/CHA201504100,"Friday, April 10, 2015",MIN,CHW,Away,36,6,10,5,5.0,9,9,46,112,112,1.8,27,10,9.0,0,60,26,26,13,14,7,80,4.2,28,0,3,0,0.00,2,7,30,190,74,-4.2,27,9,9.0,1,43,12,19,10,11,4,48,-1.8
BAL/BAL201504100,"Friday, April 10, 2015",TOR,BAL,Away,39,12,16,12,12.0,3,6,47,130,103,7.5,27,13,9.0,2,63,14,26,15,22,10,47,-0.5,36,5,13,4,5.00,2,2,38,162,79,0.5,27,10,9.0,0,51,9,19,13,21,12,13,-7.5
CLE/CLE201504100,"Friday, April 10, 2015",DET,CLE,Away,43,8,18,7,6.0,3,9,47,139,117,3.3,27,7,9.0,0,67,13,37,13,22,11,43,0.7,34,4,10,4,4.00,3,7,37,178,86,-0.7,27,9,9.0,1,50,5,31,8,19,9,20,-3.3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
PIT/PIT202009090,"Wednesday, September 9, 2020",CHW,PIT,Away,37,8,11,8,8.0,3,10,41,137,106,3.4,27,10,9.0,0,61,23,22,14,14,11,68,3.6,31,1,4,1,0.00,5,5,37,163,85,-3.6,27,9,9.0,2,45,12,28,13,13,6,29,-3.4
TEX/TEX202009090,"Wednesday, September 9, 2020",LAA,TEX,Home,35,3,7,2,2.0,4,8,39,155,109,-2.5,24,8,8.0,0,64,13,32,13,14,9,35,-2.1,32,7,9,6,7.88,3,10,37,163,97,2.1,27,9,9.0,1,52,17,28,13,10,3,47,2.5
ATL/ATL202009090,"Wednesday, September 9, 2020",MIA,ATL,Home,38,9,13,9,9.0,4,7,43,240,121,4.0,24,7,8.0,7,67,26,28,9,23,9,17,-24.6,47,29,23,28,30.38,9,9,58,197,141,24.6,27,8,9.0,2,83,26,32,14,25,16,14,-4.0
NYN/NYN202009090,"Wednesday, September 9, 2020",BAL,NYM,Home,40,6,14,6,6.0,3,7,46,137,96,1.6,24,11,8.0,4,55,21,20,13,20,13,37,-3.1,33,7,10,7,7.88,2,8,36,164,93,3.1,27,9,9.0,1,55,16,22,13,12,6,25,-1.6


In [4]:
new_labels = ['date',
                  'away',
                  'home',
                  'winner',
                  'at_bats',
                  'runs',
                  'hits',
                  'rbi',
                  'earned_runs',
                  'bases_on_balls',
                  'strikeouts',
                  'plate_appearances',
                  'pitches',
                  'strikes',
                  'base_out_runs_added',
                  'putouts',
                  'assists',
                  'innings_pitched',
                  'homeruns',
                  'strikes_by_contact',
                  'strikes_swinging',
                  'strikes_looking',
                  'grounded_balls',
                  'fly_balls',
                  'line_drives',
                  'game_score',
                  'base_out_runs_saved']

def generateTeamStats(dataframe, teams=TEAMS):
    """Returns a dictionary, where each key is a team and each
    value is a dataframe for their stats"""
    
    team_stats_dict = dict()
    
    for team in tqdm(teams):
        away_games = dataframe.loc[(dataframe["away"] == team)].copy()
        home_games = dataframe.loc[(dataframe["home"] == team)].copy()        
        
        home_games.drop(home_games.filter(regex="away_").columns, axis=1, inplace=True)
        home_games.columns = new_labels
        
        away_games.drop(away_games.filter(regex="home_").columns, axis=1, inplace=True)
        away_games.columns = new_labels
        
        # Shape for both must match same columns
        assert away_games.shape[1] == home_games.shape[1]
        
        # Join the home games and away games, sort by date
        team_stats = pd.concat([home_games, away_games])
        team_stats.drop_duplicates(inplace=True)
        team_stats.dropna(inplace=True)
        team_stats.sort_values(by=["date"], inplace=True)
        team_stats_dict[team] = team_stats   
    
    return team_stats_dict

In [5]:
team_stats = generateTeamStats(allBoxes)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




In [6]:
team_stats['NYY']

Unnamed: 0_level_0,date,away,home,winner,at_bats,runs,hits,rbi,earned_runs,bases_on_balls,strikeouts,plate_appearances,pitches,strikes,base_out_runs_added,putouts,assists,innings_pitched,homeruns,strikes_by_contact,strikes_swinging,strikes_looking,grounded_balls,fly_balls,line_drives,game_score,base_out_runs_saved
game_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
NYA/NYA201504100,"Friday, April 10, 2015",BOS,NYY,Away,67,5,14,5,2.37,6,13,75,332,176,-4.3,57,19,19.0,1,97,16,63,30,25,7,40,3.3
NYA/NYA201404110,"Friday, April 11, 2014",BOS,NYY,Away,33,2,7,2,2.00,2,9,35,149,92,-2.2,27,8,9.0,2,53,10,29,14,10,4,56,0.2
NYA/NYA201904120,"Friday, April 12, 2019",CHW,NYY,Away,24,6,7,6,6.00,4,7,30,137,77,2.8,19,3,6.1,4,46,12,19,5,13,8,23,-5.5
DET/DET201804130,"Friday, April 13, 2018",NYY,DET,Away,38,8,11,7,7.00,4,8,44,161,90,3.3,27,5,9.0,2,48,12,30,17,14,6,54,-1.3
NYA/NYA201704140,"Friday, April 14, 2017",STL,NYY,Home,32,4,9,3,3.38,4,10,36,153,78,-0.4,27,10,9.0,1,43,18,17,7,15,6,54,1.9
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
NYA/NYA201909040,"Wednesday, September 4, 2019",TEX,NYY,Home,30,4,6,4,4.50,5,10,35,133,98,-0.3,27,6,9.0,1,56,19,23,9,11,4,55,3.9
OAK/OAK201809050,"Wednesday, September 5, 2018",NYY,OAK,Home,31,2,5,2,2.00,3,7,34,127,94,-2.5,24,6,8.0,0,51,12,31,12,12,3,26,-4.0
NYA/NYA201609070,"Wednesday, September 7, 2016",TOR,NYY,Home,33,2,9,2,2.25,2,12,35,143,84,-2.2,27,16,9.0,0,40,23,21,8,13,7,59,4.7
NYA/NYA201509090,"Wednesday, September 9, 2015",BAL,NYY,Away,31,3,4,3,3.00,0,9,32,162,82,-1.4,27,12,9.0,1,44,13,25,9,13,4,50,-0.6


In [7]:
def computeMA(span, team_stats):
    """Compute moving averages for each team"""
    
    team_ma = dict()
    
    for team, dataframe in tqdm(team_stats.items(), unit='teams'):
        
        df = dataframe.copy()

        # Compute the moving averages for the appropriate columns
        for col in df.columns:
            if col in {'date', 'away', 'home', 'winner'}:
                continue

            # Exponential moving average
            df[f"{col} EMA"] = df.loc[:, col].ewm(span=span, adjust=False).mean()
            df[f"{col} EMA"] = df[f"{col} EMA"].shift(1)
        
        df["batting_average EMA"] = df["hits EMA"] / df["at_bats EMA"]
        df["era EMA"] = (df["earned_runs EMA"] * 9) / df["innings_pitched EMA"]
        
        # Drop any rows with NULL values and save the dataframe
        df.dropna(inplace=True)
        df.drop_duplicates(inplace=True)
        
        team_ma[team] = df
    
    return team_ma

In [8]:
team_ma = computeMA(span=5, team_stats=team_stats)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




In [9]:
team_ma['NYY']

Unnamed: 0_level_0,date,away,home,winner,at_bats,runs,hits,rbi,earned_runs,bases_on_balls,strikeouts,plate_appearances,pitches,strikes,base_out_runs_added,putouts,assists,innings_pitched,homeruns,strikes_by_contact,strikes_swinging,strikes_looking,grounded_balls,fly_balls,line_drives,game_score,base_out_runs_saved,at_bats EMA,runs EMA,hits EMA,rbi EMA,earned_runs EMA,bases_on_balls EMA,strikeouts EMA,plate_appearances EMA,pitches EMA,strikes EMA,base_out_runs_added EMA,putouts EMA,assists EMA,innings_pitched EMA,homeruns EMA,strikes_by_contact EMA,strikes_swinging EMA,strikes_looking EMA,grounded_balls EMA,fly_balls EMA,line_drives EMA,game_score EMA,base_out_runs_saved EMA,batting_average EMA,era EMA
game_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1
NYA/NYA201404110,"Friday, April 11, 2014",BOS,NYY,Away,33,2,7,2,2.00,2,9,35,149,92,-2.2,27,8,9.0,2,53,10,29,14,10,4,56,0.2,67.000000,5.000000,14.000000,5.000000,2.370000,6.000000,13.000000,75.000000,332.000000,176.000000,-4.300000,57.000000,19.000000,19.000000,1.000000,97.000000,16.000000,63.000000,30.000000,25.000000,7.000000,40.000000,3.300000,0.208955,1.122632
NYA/NYA201904120,"Friday, April 12, 2019",CHW,NYY,Away,24,6,7,6,6.00,4,7,30,137,77,2.8,19,3,6.1,4,46,12,19,5,13,8,23,-5.5,55.666667,4.000000,11.666667,4.000000,2.246667,4.666667,11.666667,61.666667,271.000000,148.000000,-3.600000,47.000000,15.333333,15.666667,1.333333,82.333333,14.000000,51.666667,24.666667,20.000000,6.000000,45.333333,2.266667,0.209581,1.290638
DET/DET201804130,"Friday, April 13, 2018",NYY,DET,Away,38,8,11,7,7.00,4,8,44,161,90,3.3,27,5,9.0,2,48,12,30,17,14,6,54,-1.3,45.111111,4.666667,10.111111,4.666667,3.497778,4.444444,10.111111,51.111111,226.333333,124.333333,-1.466667,37.666667,11.222222,12.477778,2.222222,70.222222,13.333333,40.777778,18.111111,17.666667,6.666667,37.888889,-0.322222,0.224138,2.522885
NYA/NYA201704140,"Friday, April 14, 2017",STL,NYY,Home,32,4,9,3,3.38,4,10,36,153,78,-0.4,27,10,9.0,1,43,18,17,7,15,6,54,1.9,42.740741,5.777778,10.407407,5.444444,4.665185,4.296296,9.407407,48.740741,204.555556,112.888889,0.122222,34.111111,9.148148,11.318519,2.148148,62.814815,12.888889,37.185185,17.740741,16.444444,6.444444,43.259259,-0.648148,0.243501,3.709555
NYA/NYA201604150,"Friday, April 15, 2016",SEA,NYY,Away,33,1,6,1,1.00,7,10,40,152,93,-3.7,27,9,9.0,1,42,15,36,9,14,3,38,-2.3,39.160494,5.185185,9.938272,4.629630,4.236790,4.197531,9.604938,44.493827,187.370370,101.259259,-0.051852,31.740741,9.432099,10.545679,1.765432,56.209877,14.592593,30.456790,14.160494,15.962963,6.296296,46.839506,0.201235,0.253783,3.615804
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
NYA/NYA201909040,"Wednesday, September 4, 2019",TEX,NYY,Home,30,4,6,4,4.50,5,10,35,133,98,-0.3,27,6,9.0,1,56,19,23,9,11,4,55,3.9,35.055519,4.969086,9.237461,4.940335,4.893283,6.485995,6.716166,41.851855,161.304227,101.666179,0.332742,28.572878,8.796664,9.524293,1.363653,59.885060,13.073873,28.929487,12.204387,16.372799,5.699875,55.218911,-0.422413,0.263509,4.623917
OAK/OAK201809050,"Wednesday, September 5, 2018",NYY,OAK,Home,31,2,5,2,2.00,3,7,34,127,94,-2.5,24,6,8.0,0,51,12,31,12,12,3,26,-4.0,33.370346,4.646057,8.158307,4.626890,4.762188,5.990664,7.810777,39.567903,151.869485,100.444119,0.121828,28.048585,7.864443,9.349528,1.242435,58.590040,15.049249,26.952992,11.136258,14.581866,5.133250,55.145941,1.018392,0.244478,4.584156
NYA/NYA201609070,"Wednesday, September 7, 2016",TOR,NYY,Home,33,2,9,2,2.25,2,12,35,143,84,-2.2,27,16,9.0,0,40,23,21,8,13,7,59,4.7,32.580230,3.764038,7.105538,3.751260,3.841459,4.993776,7.540518,37.711936,143.579656,98.296079,-0.752115,26.699057,7.242962,8.899686,0.828290,56.060027,14.032832,28.301994,11.424172,13.721244,4.422167,45.430627,-0.654406,0.218094,3.884759
NYA/NYA201509090,"Wednesday, September 9, 2015",BAL,NYY,Away,31,3,4,3,3.00,0,9,32,162,82,-1.4,27,12,9.0,1,44,13,25,9,13,4,50,-0.6,32.720154,3.176026,7.737025,3.167507,3.310973,3.995850,9.027012,36.807957,143.386438,93.530720,-1.234743,26.799371,10.161975,8.933124,0.552193,50.706685,17.021888,25.867996,10.282781,13.480829,5.281444,49.953751,1.130396,0.236461,3.335760


In [10]:
def mergeMAWithSchedule(schedule, team_ma_dict):
    
    schedule = schedule.filter(items=['date', 'away', 'home', 'winner'])
    
    aways = []
    homes = []
    
    for team, df in tqdm(team_ma_dict.items()):
        
        # Merge in visiting team stats
        away = schedule.join(df.loc[df["away"] == team].drop(columns=['date', 'away', 'home', 'winner']))
        away.dropna(inplace=True)
        away.drop(away.loc[:, 'at_bats':'base_out_runs_saved'].columns, axis=1, inplace=True)
        away_cols = {col: f'away_{col}' for col in away.columns[4:]}
        away.rename(columns=away_cols, inplace=True) 
        
        # Merge in home team stats
        home = schedule.join(df.loc[df["home"] == team].drop(columns=['date', 'away', 'home', 'winner']))
        home.dropna(inplace=True)
        home.drop(home.loc[:, 'at_bats':'base_out_runs_saved'].columns, axis=1, inplace=True)
        home_cols = {col: f'home_{col}' for col in home.columns[4:]}
        home.rename(columns=home_cols, inplace=True) 
        
        homes.append(home)
        aways.append(away)
        
        assert home.shape[1] == away.shape[1]
        

    return pd.merge(pd.concat(homes),
                    pd.concat(aways).drop(columns=['date', 'away', 'home', 'winner']),
                    left_index=True,
                    right_index=True)


In [11]:
mergeMAWithSchedule(allBoxes, team_ma)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




Unnamed: 0_level_0,date,away,home,winner,home_at_bats EMA,home_runs EMA,home_hits EMA,home_rbi EMA,home_earned_runs EMA,home_bases_on_balls EMA,home_strikeouts EMA,home_plate_appearances EMA,home_pitches EMA,home_strikes EMA,home_base_out_runs_added EMA,home_putouts EMA,home_assists EMA,home_innings_pitched EMA,home_homeruns EMA,home_strikes_by_contact EMA,home_strikes_swinging EMA,home_strikes_looking EMA,home_grounded_balls EMA,home_fly_balls EMA,home_line_drives EMA,home_game_score EMA,home_base_out_runs_saved EMA,home_batting_average EMA,home_era EMA,away_at_bats EMA,away_runs EMA,away_hits EMA,away_rbi EMA,away_earned_runs EMA,away_bases_on_balls EMA,away_strikeouts EMA,away_plate_appearances EMA,away_pitches EMA,away_strikes EMA,away_base_out_runs_added EMA,away_putouts EMA,away_assists EMA,away_innings_pitched EMA,away_homeruns EMA,away_strikes_by_contact EMA,away_strikes_swinging EMA,away_strikes_looking EMA,away_grounded_balls EMA,away_fly_balls EMA,away_line_drives EMA,away_game_score EMA,away_base_out_runs_saved EMA,away_batting_average EMA,away_era EMA
game_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1
WAS/WAS201904120,"Friday, April 12, 2019",PIT,WSN,Away,35.000000,2.666667,8.666667,2.666667,1.566667,2.000000,8.666667,37.333333,140.333333,90.333333,-1.433333,25.666667,14.666667,8.400000,0.666667,48.666667,19.333333,23.000000,12.333333,14.333333,6.333333,45.333333,-1.233333,0.247619,1.678571,32.000000,4.666667,8.000000,4.666667,4.333333,2.000000,10.666667,38.333333,131.666667,94.333333,0.500000,26.000000,11.666667,8.666667,0.666667,52.333333,19.333333,22.666667,13.333333,11.000000,4.333333,51.666667,1.333333,0.250000,4.500000
WAS/WAS201804130,"Friday, April 13, 2018",COL,WSN,Away,35.333333,2.777778,8.111111,2.777778,1.944444,2.333333,7.777778,38.555556,149.222222,93.222222,-1.822222,27.111111,13.444444,8.933333,1.111111,52.777778,17.222222,23.666667,10.888889,17.222222,7.555556,55.222222,-0.955556,0.229560,1.958955,43.555556,4.000000,11.000000,3.444444,3.720000,2.777778,13.333333,47.222222,172.222222,122.222222,-1.766667,34.666667,11.666667,11.477778,0.222222,68.888889,21.444444,32.555556,13.222222,17.888889,9.222222,46.777778,2.622222,0.252551,2.916941
WAS/WAS201704140,"Friday, April 14, 2017",PHI,WSN,Home,33.555556,2.185185,6.740741,2.185185,1.629630,2.888889,8.518519,37.370370,144.814815,96.148148,-2.414815,27.074074,10.962963,8.955556,1.074074,55.518519,17.481481,23.444444,10.592593,15.148148,6.703704,57.481481,0.229630,0.200883,1.637717,32.481481,4.740741,9.000000,4.592593,5.000000,3.888889,8.629630,37.629630,137.666667,90.851852,0.640741,27.000000,9.518519,9.000000,0.444444,45.259259,15.666667,29.925926,11.259259,13.037037,5.333333,63.259259,2.981481,0.277081,5.000000
WAS/WAS201504170,"Friday, April 17, 2015",PHI,WSN,Home,37.024691,4.637860,10.329218,4.304527,4.057613,3.172840,9.119342,40.831276,140.251029,102.176955,0.193416,27.699588,10.094650,9.202469,1.032922,57.341564,19.769547,25.197531,12.263374,15.954733,9.312757,65.325103,1.990947,0.278982,3.968339,32.658436,2.884774,6.666667,2.818930,2.955556,3.395062,8.279835,36.613169,153.851852,90.934156,-1.604115,27.000000,9.786008,9.000000,0.864198,48.559671,16.518519,25.855967,12.004115,12.572016,5.592593,49.115226,0.058436,0.204133,2.955556
WAS/WAS201404180,"Friday, April 18, 2014",STL,WSN,Home,36.349794,5.425240,10.219479,4.869684,4.955075,3.115226,8.079561,40.220850,132.500686,97.451303,1.195610,27.466392,9.063100,9.134979,0.688615,54.561043,16.846365,26.131687,13.508916,14.969822,7.875171,69.883402,2.093964,0.281143,4.881859,34.086420,6.773663,9.172840,6.707819,6.499671,3.135802,8.839506,38.440329,129.037037,89.193416,2.637449,26.555556,9.098765,8.851852,1.279835,50.728395,13.213992,25.251029,11.218107,14.827160,6.905350,58.032922,1.331687,0.269105,6.608452
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SLN/SLN201709270,"Wednesday, September 27, 2017",CHC,STL,Away,31.872484,3.474334,6.537918,3.100035,3.408303,2.840585,9.739400,36.186176,155.864936,85.362927,-0.959445,25.361290,9.845772,8.449211,0.858718,43.617068,16.217363,25.529557,10.791618,11.984430,7.704555,46.278743,-1.015426,0.205127,3.630484,33.501098,4.079908,8.337577,3.996872,3.674166,3.857448,7.016306,38.703838,147.974153,90.501062,-0.375983,26.962301,11.543967,8.987434,0.995711,50.762169,18.076645,21.662247,14.266465,13.315560,7.046046,57.689823,0.373626,0.248875,3.679303
SLN/SLN201609280,"Wednesday, September 28, 2016",CIN,STL,Away,31.581656,2.649556,6.025279,2.400023,2.605535,2.893723,8.492933,35.457451,141.576624,85.241951,-1.839630,25.907526,9.230515,8.632807,0.905812,43.078046,15.478242,26.686371,10.194412,13.322953,7.136370,45.852496,-0.810284,0.190784,2.716361,32.983020,3.236586,6.592777,3.235216,3.188611,2.780439,7.569478,36.307062,150.498683,89.504881,-1.291852,26.763660,10.710498,8.921220,1.030214,51.190428,16.566250,21.748712,9.862105,16.034506,7.934178,48.199650,0.284110,0.199884,3.216768
SLN/SLN201409030,"Wednesday, September 3, 2014",PIT,STL,Home,31.721104,2.099704,6.350186,1.600016,2.070357,2.929149,7.328622,35.638300,149.717749,80.494634,-2.359753,26.271684,10.153676,8.755205,0.603875,44.052030,12.652162,23.790914,10.796275,13.881969,7.090913,48.901664,0.259811,0.200188,2.128244,33.860350,5.890401,10.309897,5.444331,6.321962,2.611901,7.617456,37.619345,141.446007,93.250475,1.715959,27.082433,11.541484,8.958342,0.799707,53.985930,14.995077,24.269464,13.983415,12.726015,7.167146,48.972720,0.368803,0.304483,6.351360
SLN/SLN201909040,"Wednesday, September 4, 2019",SFG,STL,Away,32.991438,3.881394,7.955611,3.511116,3.878624,2.645674,6.245518,37.300237,141.138592,87.628040,-0.236223,25.784203,10.786275,8.594135,0.845592,51.978379,11.230270,24.419530,13.680378,13.779843,7.434345,51.748641,-0.074871,0.241142,4.061796,30.976109,2.946392,7.572282,2.932194,3.151349,2.446402,7.577137,33.731358,130.373307,83.404936,-1.244214,25.984421,9.532099,8.632359,1.207933,49.226898,15.969191,19.177455,10.841997,12.857046,6.817660,56.829946,0.610403,0.244456,3.285561


# Training  a Random Forest Classifier

In [12]:
data = mergeMAWithSchedule(allBoxes, computeMA(14, generateTeamStats(allBoxes)))

data['target'] = np.where(data['winner'] == 'Home', 1, 0)

data.drop(columns=['date', 'away', 'home', 'winner'], inplace=True)
data.reset_index(drop=True, inplace=True)

data.head()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




Unnamed: 0,home_at_bats EMA,home_runs EMA,home_hits EMA,home_rbi EMA,home_earned_runs EMA,home_bases_on_balls EMA,home_strikeouts EMA,home_plate_appearances EMA,home_pitches EMA,home_strikes EMA,home_base_out_runs_added EMA,home_putouts EMA,home_assists EMA,home_innings_pitched EMA,home_homeruns EMA,home_strikes_by_contact EMA,home_strikes_swinging EMA,home_strikes_looking EMA,home_grounded_balls EMA,home_fly_balls EMA,home_line_drives EMA,home_game_score EMA,home_base_out_runs_saved EMA,home_batting_average EMA,home_era EMA,away_at_bats EMA,away_runs EMA,away_hits EMA,away_rbi EMA,away_earned_runs EMA,away_bases_on_balls EMA,away_strikeouts EMA,away_plate_appearances EMA,away_pitches EMA,away_strikes EMA,away_base_out_runs_added EMA,away_putouts EMA,away_assists EMA,away_innings_pitched EMA,away_homeruns EMA,away_strikes_by_contact EMA,away_strikes_swinging EMA,away_strikes_looking EMA,away_grounded_balls EMA,away_fly_balls EMA,away_line_drives EMA,away_game_score EMA,away_base_out_runs_saved EMA,away_batting_average EMA,away_era EMA,target
0,33.2,1.666667,7.066667,1.666667,1.226667,1.4,8.866667,34.733333,129.333333,88.933333,-2.373333,24.666667,16.066667,8.16,0.266667,46.466667,20.933333,21.8,11.533333,12.933333,5.533333,48.733333,-0.733333,0.212851,1.352941,32.6,5.466667,9.2,5.466667,5.333333,2.0,11.466667,39.933333,134.266667,96.533333,1.28,26.6,11.866667,8.866667,0.266667,52.133333,21.533333,22.866667,12.533333,12.2,4.733333,51.266667,1.853333,0.282209,5.413534,0
1,33.573333,1.844444,7.057778,1.844444,1.423111,1.613333,8.484444,35.568889,134.355556,90.275556,-2.403556,25.377778,15.391111,8.405333,0.497778,48.404444,19.875556,22.226667,11.062222,14.275556,6.128889,52.235556,-0.688889,0.21022,1.523794,37.848889,4.6,11.0,4.351111,4.8168,2.884444,9.613333,41.115556,152.835556,104.275556,-0.330667,29.986667,11.706667,9.964444,0.115556,61.182222,16.431111,26.928889,10.795556,17.822222,9.555556,48.884444,3.215556,0.290629,4.350589,0
2,33.096889,1.731852,6.650074,1.731852,1.366696,1.931556,8.686519,35.493037,134.574815,91.838815,-2.563081,25.594074,14.138963,8.484622,0.564741,50.083852,19.625481,22.329778,10.920593,13.838815,5.97837,53.537481,-0.25037,0.200927,1.449713,29.702815,4.511407,8.109333,4.411259,4.912,4.051556,6.61363,35.922963,135.437333,86.649185,0.687407,27.0,10.435852,9.0,0.751111,48.869926,11.005333,26.773926,12.699259,11.141037,5.282667,62.472593,2.924148,0.273016,4.912,1
3,34.637219,2.84748,8.168278,2.714147,2.439874,2.313035,8.897874,37.41477,135.302861,95.452265,-1.450492,26.29066,13.055488,8.72845,0.673072,52.045204,20.100917,23.456589,11.651467,14.350043,7.183754,57.857042,0.658166,0.235824,2.515781,30.50567,3.753013,7.317677,3.67779,4.030791,3.789835,7.038682,35.924359,143.035153,87.892055,-0.329903,27.0,10.407373,9.0,0.830835,49.497855,12.595117,25.799082,12.738555,11.292601,5.363425,56.843858,1.721693,0.239879,4.030791,1
4,34.68559,3.401149,8.412507,3.15226,3.014558,2.40463,8.511491,37.626134,132.862479,94.45863,-0.830427,26.385239,12.248089,8.764656,0.583329,51.639177,18.887462,24.062377,12.231272,14.170037,6.892587,60.676103,0.877077,0.242536,3.095503,35.286479,5.333647,9.580792,5.258424,5.18383,3.221242,10.137948,39.266056,118.907828,92.421844,1.094862,25.684053,8.211038,8.561351,1.589094,48.782997,15.770784,27.868064,13.99658,11.569063,5.613514,53.059336,0.031643,0.271515,5.449428,1


In [13]:
X_train, X_test, y_train, y_test = train_test_split(data.drop(columns=['target']),
                                                    data.filter(items=['target']).pop('target'),
                                                    test_size = .2)

In [None]:
# Number of trees in random forest
n_estimators = [i for i in range(100, 1100, 100)]

# Number of features to consider at every split
max_features = ['auto', 'sqrt']

# Maximum number of levels in tree
max_depth = [i for i in range(1, 21)]
max_depth.append(None)

# Minimum number of samples required to split a node
min_samples_split = [2, 5, 10]

# Minimum number of samples required at each leaf node
min_samples_leaf = [1, 2, 4]

# Method of selecting samples for training each tree
bootstrap = [True, False]

# Create the random grid
random_param_grid = {'n_estimators': n_estimators,
                      'max_features': max_features,
                      'max_depth': max_depth,
                      'min_samples_split': min_samples_split,
                      'min_samples_leaf': min_samples_leaf,
                      'bootstrap': bootstrap}

# Use the random grid to find the best hyperparameters
# First create the base model to tune
rf = RandomForestClassifier()

# Random search of parameters, using 3 fold cross validation, 
# search across 100 different combinations, and use all available cores
rf_random = RandomizedSearchCV(estimator = rf,
                               param_distributions = random_param_grid,
                               n_iter = 50,
                               cv = 3,
                               verbose = 2,
                               random_state = 42,
                               n_jobs = -1)

# Fit the random search model
rf_random.fit(X_train, y_train)

In [None]:
rf_random.best_params_

In [14]:
best_params = {'n_estimators': 100,
               'min_samples_split': 5,
               'min_samples_leaf': 2,
               'max_features': 'sqrt',
               'max_depth': 1,
               'bootstrap': True}

In [15]:
rf = RandomForestClassifier(**best_params)
rf.fit(X_train, y_train)
rf.score(X_train, y_train)

0.5329883570504528

In [16]:
rf.score(X_test, y_test)

0.5384864165588615