In [2]:
import os

import numpy as np
import pandas as pd
import pyspark.sql.functions as F
import regex as re

from IPython.display import display
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import *
from pyspark.ml.classification import GBTClassifier, LinearSVC, LogisticRegression, RandomForestClassifier
from typing import *

In [3]:
spark = SparkSession \
    .builder \
    .appName('group2nba') \
    .getOrCreate()

In [4]:
path_main = '/project/ds5559/group2nba'

T = TypeVar('T')

In [5]:
class ML_CV():
    __slots__: List[str] = [
          'Model'
        , 'HyperParameters'
    ]
    
    def __init__(self, model: T, hyper_params: Dict[str, List[T]]): # todo: touch up
        self.Model = model
        self.HyperParameters = hyper_params

In [6]:
FIELDS: Dict[str, T] = {
      'Url': StringType
    , 'GameType': StringType
    , 'Location': StringType
    , 'Date': StringType
    , 'Time': StringType
    , 'WinningTeam': StringType
    , 'Quarter': IntegerType
    , 'SecLeft': IntegerType
    , 'AwayTeam': StringType
    , 'AwayPlay': StringType
    , 'AwayScore': IntegerType
    , 'HomeTeam': StringType
    , 'HomePlay': StringType
    , 'HomeScore': IntegerType
    , 'Shooter': StringType
    , 'ShotType': StringType
    , 'ShotOutcome': IntegerType
    , 'ShotDist': IntegerType
    , 'Assister': StringType
    , 'Blocker': StringType
    , 'FoulType': StringType
    , 'Fouler': StringType
    , 'Fouled': StringType
    , 'Rebounder': StringType
    , 'ReboundType': IntegerType
    , 'ViolationPlayer': StringType
    , 'ViolationType': StringType
    , 'TimeoutTeam': StringType
    , 'FreeThrowShooter': StringType
    , 'FreeThrowOutcome': IntegerType
    , 'FreeThrowNum': StringType
    , 'EnterGame': StringType
    , 'LeaveGame': StringType
    , 'TurnoverPlayer': StringType
    , 'TurnoverType': StringType
    , 'TurnoverCause': StringType
    , 'TurnoverCauser': StringType
    , 'JumpballAwayPlayer': StringType
    , 'JumpballHomePlayer': StringType
    , 'JumpballPoss': StringType
}

POSSESSION_PLAYS: List[str] = [
      'Shooter'
    , 'Assister'
    , 'Fouled'
    , 'Rebounder'
    , 'ViolationPlayer'
    , 'FreeThrowShooter'
    , 'TurnOverPlayer'
    , 'JumpballPoss'
]

MODEL_FIELDS: List[str] = [
      'Date'
    , 'HomeTeam'
    , 'AwayTeam'
    , 'Team'
    , 'Year'
    , 'Won'
    , 'ScoreDiff'
    , 'Quarter'
    , 'SecLeftTotal'
    , 'LogSecLeftTotal'
    , 'SecLeftTotalInverse'
    , 'HasPossession'
]

DICT_ML: Dict[str, ML_CV] = {
      'GradientBoost': ML_CV(GBTClassifier, {
          'featureSubsetStrategy': ['all', 'sqrt', 'onethird', 'log2']
        , 'maxBins': [2, 3]
        , 'maxDepth': [100, 500, 1000]
        , 'weightCol': ['ScoreDiff', 'SecLeftTotal', 'HasPossession']
    })
    , 'LinearSVC': ML_CV(RandomForestClassifier, {
          'aggregationDepth': [10, 20, 50, 100]
        , 'maxIter': [10, 20, 50, 100]
        , 'weightCol': ['ScoreDiff', 'SecLeftTotal', 'HasPossession']
    })
    , 'LogisticRegression': ML_CV(LogisticRegression, {
          'maxIter': [10, 20, 50, 100]
        , 'regParam': [0.1, 0.5, 0.01, 0.05]
        , 'weightCol': ['ScoreDiff', 'SecLeftTotal', 'HasPossession']
    })
    , 'RandomForest': ML_CV(RandomForestClassifier, {
          'maxBins': [2, 3]
        , 'numTrees': [100, 500, 1000, 2000]
        , 'maxDepth': [100, 500, 1000]
        , 'weightCol': ['ScoreDiff', 'SecLeftTotal', 'HasPossession']
    })
}

In [7]:
@F.udf(IntegerType())
def get_has_possession(team: str, plays: List[str]) -> int:
    '''Get: whether the current team has possession of the ball'''
    return int(bool(team) and any([bool(x) for x in plays]))


@F.udf(IntegerType())
def get_score_diff(score1: int, score2: int) -> int:
    '''Get: score differential relative to the specified team'''
    return score1 - score2


@F.udf(IntegerType())
def get_secleft_total(quarter: int, maxquarter: int, sec: int) -> int:
    '''Get: SecLeft by Quarter, accounting for OT'''
    if quarter < 5:
        return ((maxquarter - 4) * 300) + ((4 - quarter) * 720) + sec
    else:
        return ((maxquarter - quarter) * 300) + sec
    
@F.udf(StringType())
def get_team(team: str) -> str:
    '''Get: the specified team'''
    return team
    
    
@F.udf(IntegerType())
def get_won(winner: str, team: str) -> int:
    '''Get: whether the specified team won'''
    return int(winner == team)
    
    
@F.udf(IntegerType())
def get_year(date: str) -> int:
    '''Get: Year of game took place in'''
    return int(re.match(r'[A-Z][a-z]+ \d+ (\d{4})', date).groups()[0])

@F.udf(IntegerType())
def get_year(date: str) -> int:
    '''Get: Year of game took place in'''
    return int(re.match(r'[A-Z][a-z]+ \d+ (\d{4})', date).groups()[0])

@F.udf(IntegerType())
def get_assist(assister: str, team: str, winner: str) -> int:
    '''Get: Boolean for assist metric'''
    return int(len(assister)>0 & winner==team)



In [8]:
MODEL_FIELDS: List[str] = [
      'Date'
    , 'HomeTeam'
    , 'AwayTeam'
    , 'Team'
    , 'Year'
    , 'Won'
    , 'ScoreDiff'
    , 'Quarter'
    , 'SecLeftTotal'
    , 'LogSecLeftTotal'
    , 'SecLeftTotalInverse'
    , 'HasPossession'
]
    
dict_bool_cols = {'Assister':'assist', 
                  'TurnoverPlayer':'turnover', 
                  'Blocker':'block', 
                  'Fouler':'foul', 
                  'Rebounder':'rebound', 
                  'ShotOutcome':'shotOnGoal', 
                  'FreeThrowOutcome':'freeThrow'
                 }

for value in dict_bool_cols.values():
    MODEL_FIELDS.append(f'{value}_team_cnt')
    MODEL_FIELDS.append(f'{value}_opponent_cnt')
#     MODEL_FIELDS.append(f'{value}_diff)
    
print(MODEL_FIELDS)

['Date', 'HomeTeam', 'AwayTeam', 'Team', 'Year', 'Won', 'ScoreDiff', 'Quarter', 'SecLeftTotal', 'LogSecLeftTotal', 'SecLeftTotalInverse', 'HasPossession', 'assist_team_cnt', 'assist_opponent_cnt', 'turnover_team_cnt', 'turnover_opponent_cnt', 'block_team_cnt', 'block_opponent_cnt', 'foul_team_cnt', 'foul_opponent_cnt', 'rebound_team_cnt', 'rebound_opponent_cnt', 'shotOnGoal_team_cnt', 'shotOnGoal_opponent_cnt', 'freeThrow_team_cnt', 'freeThrow_opponent_cnt']


In [9]:
def build_model_win_percent(df: DataFrame) -> DataFrame:
    '''Constucts a model for predicting the win percent on a play by play basis'''
    from pyspark.sql import Window

    # Build Cumulative Features
    altered = df.withColumn('team', F.when(F.length(F.col("HomePlay")) >0,'home').otherwise('away')) \
                .withColumn('date_location', F.concat(F.col("Date"),F.col("Location")))

    for key in dict_bool_cols:
        altered = build_boolean_features(altered, key, dict_bool_cols[key])

    windowVal = (Window.partitionBy('date_location')
                       .orderBy(df_train.Quarter.asc(),df_train.SecLeft.desc())
                       .rangeBetween(Window.unboundedPreceding, 0)
                )

    for value in dict_bool_cols.values():
        altered = altered.withColumn(f'{value}_home_cnt', F.sum(F.col(value+'_home')).over(windowVal))
        altered = altered.withColumn(f'{value}_away_cnt', F.sum(F.col(value+'_away')).over(windowVal))
    
    # Adjustments for Both Sides
    altered = (altered
        .groupBy(['Date', 'HomeTeam', 'AwayTeam'])
        .agg(F.max('Quarter').alias('MaxQuarter'))
        .join(altered, ['Date', 'HomeTeam', 'AwayTeam'])
        .withColumn('SecLeftTotal', get_secleft_total('Quarter', 'MaxQuarter', 'SecLeft'))
        .withColumn('LogSecLeftTotal', F.log(F.col('SecLeftTotal') + 1))
        .withColumn('SecLeftTotalInverse', 1/(F.col('SecLeftTotal') + 1))
        .withColumn('Year', get_year('Date'))
    )

    # Build Home
    home = altered \
        .withColumn('Won', get_won('WinningTeam', 'HomeTeam')) \
        .withColumn('ScoreDiff', get_score_diff('HomeScore', 'AwayScore')) \
        .withColumn('HasPossession', get_has_possession('HomePlay', F.array(*POSSESSION_PLAYS))) \
        .withColumn('Team', get_team('HomeTeam'))
    
    for value in dict_bool_cols.values():
        home = home.withColumn(f'{value}_team_cnt', F.col(f'{value}_home_cnt')) \
                   .withColumn(f'{value}_opponent_cnt', F.col(f'{value}_away_cnt'))
        
    home = home.select(MODEL_FIELDS)
    
    # Build Away
    away = altered \
        .withColumn('Won', get_won('WinningTeam', 'AwayTeam')) \
        .withColumn('ScoreDiff', get_score_diff('AwayScore', 'HomeScore')) \
        .withColumn('HasPossession', get_has_possession('AwayPlay', F.array(*POSSESSION_PLAYS))) \
        .withColumn('Team', get_team('AwayTeam'))
    
    for value in dict_bool_cols.values():
        away = away.withColumn(f'{value}_team_cnt', F.col(f'{value}_away_cnt')) \
                   .withColumn(f'{value}_opponent_cnt', F.col(f'{value}_home_cnt'))
        
    away = away.select(MODEL_FIELDS)
    
    final = home.union(away) \
                .withColumn('SecLeftTotalInverseTimesScoreDiff', F.col('SecLeftTotalInverse')*F.col('ScoreDiff'))

    # Add 'Diff' values (Team - Opponent)
    for value in dict_bool_cols.values():
        final = final.withColumn(f'{value}_diff', get_score_diff(f'{value}_team_cnt', f'{value}_opponent_cnt'))
    
    return final

In [10]:
def build_boolean_features(df: DataFrame, columnName, columnNewName) -> DataFrame:
    bool_away = F.when((F.length(F.col(columnName)) > 0) & (F.col("team") == 'away'), 1).otherwise(0)
    bool_home = F.when((F.length(F.col(columnName)) > 0) & (F.col("team") == 'home'), 1).otherwise(0)
    
    altered = df.withColumn(f'{columnNewName}_away', bool_away) \
                .withColumn(f'{columnNewName}_home', bool_home)
    
    return altered

def build_cumulative_features(df: DataFrame, colName, windowVal) -> DataFrame:
    return df.withColumn(f'{colName}_cnt', F.sum(F.col(colName)).over(windowVal))

In [41]:
def cross_validate(df: DataFrame, ml_method: str, features: List[str], k_folds: int = 10) -> DataFrame:
    '''...'''
    method = DICT_ML[ml_method]
    
    pipeline = Pipeline(stages = [
          VectorAssembler(inputCols = features, outputCol = 'features')
        , method.Model(featuresCol = 'features', labelCol = 'Won') # todo: make response a constant
    ])
    
    param_grid = ParamGridBuilder()
    for attr, params in method.HyperParameters.items():
        param_grid = param_grid.addGrid(getattr(pipeline.stages[1], attr), params)
    param_grid.build()
    
    cv_model = CrossValidator(
          estimator = pipeline
        , estimatorParamMaps = param_grid
        , evaluator = BinaryClassificationEvaluator()
        , numFolds = k_folds
    ).setParallelism(4).fit(df)
    
    # todo: return results as pandas dataframe...
    # write hyperparams to json


In [11]:
schema = StructType([StructField(k, v()) for k, v in FIELDS.items()])

df_train = spark.read \
    .format('csv') \
    .option('header', True) \
    .schema(schema) \
    .load(f'{path_main}/clean_no_overtime_data/clean_train01.csv')
#     .load(f'{path_main}/clean_train_data/*')

display(df_train.count())
display(df_train.printSchema())
display(df_train.head(2))

97672

root
 |-- Url: string (nullable = true)
 |-- GameType: string (nullable = true)
 |-- Location: string (nullable = true)
 |-- Date: string (nullable = true)
 |-- Time: string (nullable = true)
 |-- WinningTeam: string (nullable = true)
 |-- Quarter: integer (nullable = true)
 |-- SecLeft: integer (nullable = true)
 |-- AwayTeam: string (nullable = true)
 |-- AwayPlay: string (nullable = true)
 |-- AwayScore: integer (nullable = true)
 |-- HomeTeam: string (nullable = true)
 |-- HomePlay: string (nullable = true)
 |-- HomeScore: integer (nullable = true)
 |-- Shooter: string (nullable = true)
 |-- ShotType: string (nullable = true)
 |-- ShotOutcome: integer (nullable = true)
 |-- ShotDist: integer (nullable = true)
 |-- Assister: string (nullable = true)
 |-- Blocker: string (nullable = true)
 |-- FoulType: string (nullable = true)
 |-- Fouler: string (nullable = true)
 |-- Fouled: string (nullable = true)
 |-- Rebounder: string (nullable = true)
 |-- ReboundType: integer (nullable = tru

None

[Row(Url='/boxscores/202012220BRK.html', GameType='regular', Location='Barclays Center Brooklyn New York', Date='December 22 2020', Time='7:00 PM', WinningTeam='BRK', Quarter=1, SecLeft=710, AwayTeam='GSW', AwayPlay=None, AwayScore=0, HomeTeam='BRK', HomePlay='Turnover by D. Jordan (bad pass)', HomeScore=0, Shooter=None, ShotType=None, ShotOutcome=None, ShotDist=None, Assister=None, Blocker=None, FoulType=None, Fouler=None, Fouled=None, Rebounder=None, ReboundType=None, ViolationPlayer=None, ViolationType=None, TimeoutTeam=None, FreeThrowShooter=None, FreeThrowOutcome=None, FreeThrowNum=None, EnterGame=None, LeaveGame=None, TurnoverPlayer='D. Jordan - jordade01', TurnoverType='bad pass', TurnoverCause=None, TurnoverCauser=None, JumpballAwayPlayer=None, JumpballHomePlayer=None, JumpballPoss=None),
 Row(Url='/boxscores/202012220BRK.html', GameType='regular', Location='Barclays Center Brooklyn New York', Date='December 22 2020', Time='7:00 PM', WinningTeam='BRK', Quarter=1, SecLeft=698, A

In [13]:
df_adjusted = build_model_win_percent(df_train)

display(df_adjusted.count())
df_adjusted.printSchema()
df_adjusted.head(10)

195344

root
 |-- Date: string (nullable = true)
 |-- HomeTeam: string (nullable = true)
 |-- AwayTeam: string (nullable = true)
 |-- Team: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Won: integer (nullable = true)
 |-- ScoreDiff: integer (nullable = true)
 |-- Quarter: integer (nullable = true)
 |-- SecLeftTotal: integer (nullable = true)
 |-- LogSecLeftTotal: double (nullable = true)
 |-- SecLeftTotalInverse: double (nullable = true)
 |-- HasPossession: integer (nullable = true)
 |-- assist_team_cnt: long (nullable = true)
 |-- assist_opponent_cnt: long (nullable = true)
 |-- turnover_team_cnt: long (nullable = true)
 |-- turnover_opponent_cnt: long (nullable = true)
 |-- block_team_cnt: long (nullable = true)
 |-- block_opponent_cnt: long (nullable = true)
 |-- foul_team_cnt: long (nullable = true)
 |-- foul_opponent_cnt: long (nullable = true)
 |-- rebound_team_cnt: long (nullable = true)
 |-- rebound_opponent_cnt: long (nullable = true)
 |-- shotOnGoal_team_cnt: lon

[Row(Date='December 25 2020', HomeTeam='MIA', AwayTeam='NOP', Team='MIA', Year=2020, Won=1, ScoreDiff=0, Quarter=1, SecLeftTotal=2880, LogSecLeftTotal=7.9658927350845286, SecLeftTotalInverse=0.0003471017007983339, HasPossession=0, assist_team_cnt=0, assist_opponent_cnt=0, turnover_team_cnt=0, turnover_opponent_cnt=0, block_team_cnt=0, block_opponent_cnt=0, foul_team_cnt=0, foul_opponent_cnt=0, rebound_team_cnt=0, rebound_opponent_cnt=0, shotOnGoal_team_cnt=0, shotOnGoal_opponent_cnt=0, freeThrow_team_cnt=0, freeThrow_opponent_cnt=0, SecLeftTotalInverseTimesScoreDiff=0.0, assist_diff=0, turnover_diff=0, block_diff=0, foul_diff=0, rebound_diff=0, shotOnGoal_diff=0, freeThrow_diff=0),
 Row(Date='December 25 2020', HomeTeam='MIA', AwayTeam='NOP', Team='MIA', Year=2020, Won=1, ScoreDiff=3, Quarter=1, SecLeftTotal=2863, LogSecLeftTotal=7.9599745280805365, SecLeftTotalInverse=0.00034916201117318437, HasPossession=1, assist_team_cnt=1, assist_opponent_cnt=0, turnover_team_cnt=0, turnover_oppon

In [43]:
df_adjusted.filter(df_adjusted.SecLeftTotal < 30) \
            .show(20, truncate=False)

In [14]:
path_write = f'{path_main}/clean_no_overtime_data'
df_adjusted.write.csv(f'{path_write}/clean_train_fully_processed.csv')

In [16]:
FIELDS: Dict[str, T] = {
      'Date': StringType
    , 'HomeTeam': StringType
    , 'AwayTeam': StringType
    , 'Team': StringType
    , 'Year': IntegerType
    , 'Won': IntegerType
    
    , 'ScoreDiff': IntegerType
    , 'Quarter': IntegerType
    , 'SecLeftTotal': IntegerType
    , 'LogSecLeftTotal': DoubleType
    , 'SecLeftTotalInverse': DoubleType
    
    , 'HasPossession': IntegerType
    , 'assist_team_cnt': LongType
    , 'assist_opponent_cnt': LongType
    , 'turnover_team_cnt': LongType
    , 'turnover_opponent_cnt': LongType
    , 'block_team_cnt': LongType
    , 'block_opponent_cnt': LongType
    
    , 'foul_team_cnt': LongType
    , 'foul_opponent_cnt': LongType
    , 'rebound_team_cnt': LongType
    , 'rebound_opponent_cnt': LongType
    , 'shotOnGoal_team_cnt': LongType
    , 'shotOnGoal_opponent_cnt': LongType
    , 'freeThrow_team_cnt': LongType
    , 'freeThrow_opponent_cnt': LongType
    
    , 'SecLeftTotalInverseTimesScoreDiff': DoubleType
    , 'assist_diff': IntegerType
    , 'turnover_diff': IntegerType
    , 'block_diff': IntegerType
    , 'foul_diff': IntegerType
    , 'rebound_diff': IntegerType
    , 'shotOnGoal_diff': IntegerType
    , 'freeThrow_diff': IntegerType
}

schema = StructType([StructField(k, v()) for k, v in FIELDS.items()])
    
df_train = spark.read \
    .format('csv') \
    .option('header', True) \
    .schema(schema) \
    .load(f'{path_write}/clean_train_fully_processed.csv')
#     .load(f'{path_main}/clean_train_data/*')

display(df_train.count())
display(df_train.printSchema())
display(df_train.head(2))

195086

root
 |-- Date: string (nullable = true)
 |-- HomeTeam: string (nullable = true)
 |-- AwayTeam: string (nullable = true)
 |-- Team: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Won: integer (nullable = true)
 |-- ScoreDiff: integer (nullable = true)
 |-- Quarter: integer (nullable = true)
 |-- SecLeftTotal: integer (nullable = true)
 |-- LogSecLeftTotal: double (nullable = true)
 |-- SecLeftTotalInverse: double (nullable = true)
 |-- HasPossession: integer (nullable = true)
 |-- assist_team_cnt: long (nullable = true)
 |-- assist_opponent_cnt: long (nullable = true)
 |-- turnover_team_cnt: long (nullable = true)
 |-- turnover_opponent_cnt: long (nullable = true)
 |-- block_team_cnt: long (nullable = true)
 |-- block_opponent_cnt: long (nullable = true)
 |-- foul_team_cnt: long (nullable = true)
 |-- foul_opponent_cnt: long (nullable = true)
 |-- rebound_team_cnt: long (nullable = true)
 |-- rebound_opponent_cnt: long (nullable = true)
 |-- shotOnGoal_team_cnt: lon

None

[Row(Date='December 27 2020', HomeTeam='LAL', AwayTeam='MIN', Team='LAL', Year=2020, Won=1, ScoreDiff=0, Quarter=1, SecLeftTotal=2880, LogSecLeftTotal=7.9658927350845286, SecLeftTotalInverse=0.0003471017007983339, HasPossession=0, assist_team_cnt=0, assist_opponent_cnt=0, turnover_team_cnt=0, turnover_opponent_cnt=0, block_team_cnt=0, block_opponent_cnt=0, foul_team_cnt=0, foul_opponent_cnt=0, rebound_team_cnt=0, rebound_opponent_cnt=0, shotOnGoal_team_cnt=0, shotOnGoal_opponent_cnt=0, freeThrow_team_cnt=0, freeThrow_opponent_cnt=0, SecLeftTotalInverseTimesScoreDiff=0.0, assist_diff=0, turnover_diff=0, block_diff=0, foul_diff=0, rebound_diff=0, shotOnGoal_diff=0, freeThrow_diff=0),
 Row(Date='December 27 2020', HomeTeam='LAC', AwayTeam='DAL', Team='LAC', Year=2020, Won=0, ScoreDiff=2, Quarter=1, SecLeftTotal=2861, LogSecLeftTotal=7.959275960116396, SecLeftTotalInverse=0.00034940600978336826, HasPossession=1, assist_team_cnt=0, assist_opponent_cnt=0, turnover_team_cnt=0, turnover_oppone

In [None]:
cv_results = pd.DataFrame({}) # todo: schema

for method in ML_METHODS.keys():
    cv_results = cv_results.union(cross_validate(df_train, method, ['ScoreDiff', 'TimeElapsed', 'HasPossession']))
    
path_write = f'{path_main}/ml_models'
    
cv_results.to_csv(f'{path_write}/cv_results_{len(os.listdir(path_write))}.csv')

In [46]:
results = [ 
    (
        [
            {key.name: paramValue} 
            for key, paramValue 
            in zip(
                params.keys(),
                params.values()
            ) 
        ], metric
    )
    for params, metric in zip(
        cvModel.getEstimatorParamMaps(),
        cvModel.avgMetrics 
    )
] 

sorted(results,
        key=lambda el: el[1], 
       reverse=True)

[([{'maxIter': 2}, {'regParam': 0.05}], 0.8138699584461815),
 ([{'maxIter': 2}, {'regParam': 0.3}], 0.8138023496031435),
 ([{'maxIter': 2}, {'regParam': 0.01}], 0.8137753927038403),
 ([{'maxIter': 10}, {'regParam': 0.3}], 0.8136894482435892),
 ([{'maxIter': 50}, {'regParam': 0.3}], 0.813687641785139),
 ([{'maxIter': 50}, {'regParam': 0.05}], 0.8136648414061434),
 ([{'maxIter': 10}, {'regParam': 0.05}], 0.8136636961010186),
 ([{'maxIter': 50}, {'regParam': 0.01}], 0.8136527452998163),
 ([{'maxIter': 10}, {'regParam': 0.01}], 0.8136517440875349)]

In [17]:
def build_pipeline(listInputCols):
    from pyspark.ml import feature as ft
    from pyspark.ml import Pipeline

    # Build the Pipeline
    print('build the pipeline')

    featuresCreator = ft.VectorAssembler(
        inputCols=listInputCols,
        outputCol='vectors'
    )
    
    sScaler = ft.StandardScaler(
        withMean=True, 
        withStd=True, 
        inputCol='vectors', 
        outputCol='features'
    )

    pipeline = Pipeline(
        stages=[
            featuresCreator,
            sScaler
        ])

    return pipeline

def build_logistic_grid(logistic):
    import pyspark.ml.tuning as tune
    
    # Build the CV
    print('build the cv')

    grid = tune.ParamGridBuilder() \
                .addGrid(logistic.maxIter, [2, 10, 50]) \
                .addGrid(logistic.regParam, [0.01, 0.05, 0.3]) \
                .build()
    
    return grid

def build_cross_validator(model, evaluator, grid):
    import pyspark.ml.tuning as tune
    
    cv = tune.CrossValidator( 
        estimator=model, 
        estimatorParamMaps=grid, 
        evaluator=evaluator
    )
    
    return cv

def build_logistic_model_and_evaluator(target):
    import pyspark.ml.evaluation as ev
    from pyspark.ml.classification import GBTClassifier, LinearSVC, LogisticRegression, RandomForestClassifier
    
    logistic = LogisticRegression(labelCol=target)
    # pipeline = Pipeline(stages[featureCreater, logistic])

    evaluator = ev.BinaryClassificationEvaluator(rawPredictionCol='probability', labelCol=target)
    
    return logistic, evaluator

In [48]:
def evaluate_cv_model(pipeline, train_data, test_data, cv, listInputCols):
    
    # Fit the Model
    print('build data transformer')
    data_transformer = pipeline.fit(train_data)
    
    print('fit CV model')
    cvModel = cv.fit(data_transformer.transform(train_data))
    
    print('transform test data')
    data_train = data_transformer.transform(test_data)
    
    print('Evaluate model against test data')
    predictions = cvModel.transform(data_train)
    
    print(evaluator.evaluate(predictions, {evaluator.metricName: 'areaUnderROC'}))
    print(evaluator.evaluate(predictions, {evaluator.metricName: 'areaUnderPR'}))
    print(listInputCols)
    print(cvModel.bestModel.coefficients)
    
    return predictions
    
# predictions.filter(predictions.SecLeftTotal < 5).select(['HomeTeam','AwayTeam', 'Team', 'Won',  'ScoreDiff', 'SecLeftTotal', 'probability', 'prediction']).show(20, truncate=False)

In [50]:
# assist_diff: integer (nullable = true)
#  |-- turnover_diff: integer (nullable = true)
#  |-- block_diff: integer (nullable = true)
#  |-- foul_diff: integer (nullable = true)
#  |-- rebound_diff: integer (nullable = true)
#  |-- shotOnGoal_diff: integer (nullable = true)
#  |-- freeThrow_diff: integer (nullable = true)

train_data, test_data = df_adjusted.randomSplit([0.7, 0.3], seed=123) # LogSecLeftTotal

listInputCols = ['ScoreDiff', 'SecLeftTotalInverse', 'SecLeftTotalInverseTimesScoreDiff', 'shotOnGoal_diff', 'rebound_diff']

pipeline = build_pipeline(listInputCols=listInputCols)
model, evaluator = build_logistic_model_and_evaluator(target='Won')
cv = build_cross_validator(
    model=model, 
    evaluator=evaluator, 
    grid=build_logistic_grid(model)
)

build the pipeline
build the cv


In [None]:
predictions = evaluate_cv_model(
    pipeline=pipeline, 
    train_data=train_data, 
    test_data=test_data, 
    cv=cv,
    listInputCols=listInputCols
)

build data transformer
fit CV model


In [19]:
(predictions.filter((predictions.SecLeftTotal < 30) # 
#                      & (predictions.ScoreDiff < 5) 
#                      & (predictions.ScoreDiff > -5)
                   )
            .select(['HomeTeam','AwayTeam', 'Team', 'Won',  'ScoreDiff', 'SecLeftTotal', 'SecLeftTotalInverseTimesScoreDiff', 'SecLeftTotalInverse', 'probability', 'prediction']) \
            .show(40, truncate=False)
)

+--------+--------+----+---+---------+------------+---------------------------------+--------------------+------------------------------------------+----------+
|HomeTeam|AwayTeam|Team|Won|ScoreDiff|SecLeftTotal|SecLeftTotalInverseTimesScoreDiff|SecLeftTotalInverse |probability                               |prediction|
+--------+--------+----+---+---------+------------+---------------------------------+--------------------+------------------------------------------+----------+
|BRK     |GSW     |BRK |1  |29       |10          |2.6363636363636362               |0.09090909090909091 |[0.009942409216720627,0.9900575907832794] |1.0       |
|BOS     |MIL     |BOS |1  |-1       |8           |-0.1111111111111111              |0.1111111111111111  |[0.5410565183017384,0.4589434816982616]   |0.0       |
|BOS     |MIL     |BOS |1  |2        |0           |2.0                              |1.0                 |[0.25786382455710233,0.7421361754428977]  |1.0       |
|BOS     |MIL     |BOS |1  |2     