# Step 1: Imports

In [1]:
from typing import List, Tuple
import numpy as np
import pandas as pd

pd.set_option("display.max_columns", None)

def load_raw_data(filename: str) -> pd.DataFrame:
    mens_filepath = f"/kaggle/input/warmup-round-march-machine-learning-mania-2023/M{filename}.csv"
    weomens_filepath = f"/kaggle/input/warmup-round-march-machine-learning-mania-2023/W{filename}.csv"
    df_mens = pd.read_csv(mens_filepath)
    df_mens["Gender"] = 0
    df_weomens = pd.read_csv(weomens_filepath)
    df_weomens["Gender"] = 1
    return pd.concat([df_mens, df_weomens])

def process_detailed_results(detailed_results: pd.DataFrame) -> pd.DataFrame:
    df = detailed_results.copy()
    df = clean_detailed_results(df)
    df = game_to_team_conversion(df)
    df = enrich_team_results(df)
    df = transform_team_results(df)
    return df

def clean_detailed_results(df: pd.DataFrame) -> pd.DataFrame:
    return df.drop(["WLoc", "DayNum", "Gender"], axis=1)

def game_to_team_conversion(game_results: pd.DataFrame) -> pd.DataFrame:
    winners = rename_columns(game_results, "W")
    loosers = rename_columns(game_results, "L")
    team_results = pd.concat((winners, loosers))
    team_results.drop(["TeamIDOpp"], axis=1, inplace=True)
    return team_results

def enrich_team_results(df: pd.DataFrame) -> pd.DataFrame:
    df["Dur"] = 40 + 5 * df["NumOT"]
    df.drop(["NumOT"], axis=1, inplace=True)
    return df

def transform_team_results(df: pd.DataFrame) -> pd.DataFrame:
    df = df.groupby(["Season", "TeamID"]).mean()
    df["FGP"] =  df["FGM"] / df["FGA"]
    df["FGP3"] =  df["FGM3"] / df["FGA3"]
    df["FTP"] =  df["FTM"] / df["FTA"]
    return df.reset_index()
    
def rename_columns(df: pd.DataFrame, team_prefix: str) -> pd.DataFrame:
    df = df.copy()
    df.columns =  (rename_column(column_name, team_prefix) for column_name in df.columns)
    return df

def rename_column(column_name: str, team_prefix: str) -> pd.DataFrame:
    if team_prefix == "W":
        opponent_prefix = "L"
    elif team_prefix == "L":
        opponent_prefix = "W"
    else:
        raise ValueError
    if column_name.startswith(team_prefix):
        column_name = column_name.lstrip(team_prefix)
    elif column_name.startswith(opponent_prefix):
        column_name = f"{column_name.lstrip(opponent_prefix)}Opp"
    return column_name

def split_winner_and_looser_columns(df: pd.DataFrame) -> Tuple[List[str], List[str]]:
    winner_columns = [name for name in df.columns if not name.startswith("L")]
    looser_columns = [name for name in df.columns if not name.startswith("W")]
    return winner_columns, looser_columns

def clean_column_names(df: pd.DataFrame) -> List[str]:
    column_names = [
        name[1:] if 
        name.startswith("L") or name.startswith("W")
        else name 
        for name in df.columns
    ]
    return column_names

def process_seeds(df_in: pd.DataFrame) -> pd.DataFrame:
    df = df_in.copy()
    mask = df["Season"] > 2002
    df = df[mask]
    df["Seed"] = df["Seed"].str.replace(r"\D+","", regex=True)
    df["Seed"] = df["Seed"].astype(int)
    return df



def process_rankings(df_in:pd.DataFrame) -> pd.DataFrame:
    df = df_in.copy()
    mask = df["RankingDayNum"] == df["RankingDayNum"].max()
    df = df[mask]
    df.drop(["SystemName", "RankingDayNum"], axis=1, inplace=True)
    df = df.groupby(["Season", "TeamID"]).agg("median")
    return df.reset_index()


def merge_features(
    season_features: pd.DataFrame, 
    tournament_features: pd.DataFrame, 
    seed_features: pd.DataFrame, 
    ranking_features: pd.DataFrame
) -> pd.DataFrame:
    features = pd.merge(
        season_features,
        tournament_features,
        how="inner",
        on=["Season", "TeamID"],
        suffixes=("Reg", "Tou")
    )
    features = features.merge(
        seed_features,
        how="inner",
        on=["Season", "TeamID"]
    )
    features = features.merge(
        ranking_features,
        how="left",
        on=["Season", "TeamID"],
    )
    return features

def get_outcomes(df):
    input_rows = df.to_records()
    output_rows = []
    for input_row in input_rows:
        output_rows.extend(parse_row(input_row))
    out_df = pd.DataFrame(output_rows)
    return out_df

def parse_row(row):
    season = row['Season']
    winning_team_id = row['WTeamID']
    losing_team_id = row['LTeamID']
    if winning_team_id < losing_team_id:
        small_id = winning_team_id
        big_id = losing_team_id
        outcome = True
    elif losing_team_id < winning_team_id:
        small_id = losing_team_id
        big_id = winning_team_id
        outcome = False
    records = [
        {
            "ID": f"{season}_{small_id}_{big_id}",
            'Season': season,
            'LowID': small_id,
            'HighID': big_id,
            'Win': outcome
        },
        {
            "ID": f"{season}_{big_id}_{small_id}",
            'Season': season,
            'LowID': big_id,
            'HighID': small_id,
            'Win': not outcome
        },
    ]
    return records

def merge_outcomes_with_features(outcomes: pd.DataFrame, features: pd.DataFrame, how: str = "inner") -> pd.DataFrame:
    data = pd.merge(
        outcomes, 
        features, 
        how=how, 
        left_on=["Season", "HighID"], 
        right_on=["Season", "TeamID"]
    )
    data = pd.merge(
        data, 
        features, 
        how=how, 
        left_on=["Season", "LowID"], 
        right_on=["Season", "TeamID"],
        suffixes=("High", "Low")
    )
    data.drop(
        ["Season", "HighID", "LowID","TeamIDHigh","TeamIDLow"], 
        axis=1, 
        inplace=True
    )
    data.set_index("ID", inplace=True)
    return data

# Step 2: Load the data

## Season Detailed Results

In [2]:
RegularSeasonDetailedResults = load_raw_data("RegularSeasonDetailedResults")
RegularSeasonDetailedResults.tail()

Unnamed: 0,Season,DayNum,WTeamID,WScore,LTeamID,LScore,WLoc,NumOT,WFGM,WFGA,WFGM3,WFGA3,WFTM,WFTA,WOR,WDR,WAst,WTO,WStl,WBlk,WPF,LFGM,LFGA,LFGM3,LFGA3,LFTM,LFTA,LOR,LDR,LAst,LTO,LStl,LBlk,LPF,Gender
70673,2023,127,3415,63,3142,54,N,0,20,56,2,10,21,30,11,23,9,14,5,2,22,17,43,4,13,16,21,1,21,5,15,7,2,24,1
70674,2023,127,3424,71,3361,68,H,0,23,55,2,12,23,33,10,23,12,14,7,6,16,28,58,5,10,7,10,7,24,12,13,4,1,26,1
70675,2023,127,3455,65,3378,53,A,0,24,51,6,13,11,13,2,28,14,11,7,2,12,19,60,4,22,11,13,8,22,11,12,2,6,12,1
70676,2023,127,3461,65,3161,56,H,0,25,57,5,17,10,16,13,35,15,14,3,2,13,21,55,9,24,5,8,1,20,12,8,8,4,17,1
70677,2023,127,3477,65,3230,62,A,0,23,50,3,13,16,19,12,20,10,9,6,0,11,22,51,8,22,10,12,8,15,12,9,5,1,13,1


In [3]:
season_features = process_detailed_results(RegularSeasonDetailedResults)
season_features.tail()

Unnamed: 0,Season,TeamID,Score,ScoreOpp,FGM,FGA,FGM3,FGA3,FTM,FTA,OR,DR,Ast,TO,Stl,Blk,PF,FGMOpp,FGAOpp,FGM3Opp,FGA3Opp,FTMOpp,FTAOpp,OROpp,DROpp,AstOpp,TOOpp,StlOpp,BlkOpp,PFOpp,Dur,FGP,FGP3,FTP
12130,2023,3473,54.64,72.04,19.76,50.32,6.4,21.0,8.72,13.16,6.36,19.24,11.88,19.28,6.32,1.92,16.04,27.68,59.68,4.84,16.84,11.84,16.0,10.64,22.04,14.12,13.72,10.8,2.92,16.4,40.0,0.392687,0.304762,0.662614
12131,2023,3474,57.115385,71.730769,20.038462,58.076923,4.423077,16.692308,12.615385,18.538462,8.461538,20.538462,6.538462,15.153846,7.230769,2.346154,21.192308,25.115385,54.961538,5.0,15.346154,16.5,23.423077,9.423077,27.0,14.115385,16.038462,6.576923,3.346154,17.769231,40.0,0.345033,0.264977,0.680498
12132,2023,3475,62.384615,65.846154,21.807692,55.230769,4.423077,15.538462,14.346154,18.692308,9.0,23.769231,13.0,19.923077,7.538462,2.807692,18.538462,22.230769,56.076923,6.461538,20.807692,14.923077,21.346154,9.269231,21.653846,12.307692,17.307692,11.153846,3.192308,19.576923,40.192308,0.394847,0.284653,0.76749
12133,2023,3476,59.178571,65.428571,21.571429,55.821429,6.178571,20.25,9.857143,13.607143,8.678571,22.571429,14.0,14.535714,3.892857,2.714286,14.964286,25.178571,58.428571,5.0,15.821429,10.071429,14.464286,8.785714,22.964286,11.785714,11.214286,7.892857,3.214286,16.142857,40.178571,0.386436,0.305115,0.724409
12134,2023,3477,64.83871,69.935484,23.064516,58.677419,3.677419,14.451613,15.032258,22.258065,11.483871,23.709677,11.83871,16.612903,8.064516,3.225806,15.580645,26.548387,61.354839,6.580645,19.83871,10.258065,14.645161,9.483871,23.451613,15.580645,15.870968,8.903226,4.16129,18.677419,40.0,0.393073,0.254464,0.675362


## Tournament Detailed Results

In [4]:
NCAATourneyDetailedResults = load_raw_data("NCAATourneyDetailedResults")
NCAATourneyDetailedResults.tail()

Unnamed: 0,Season,DayNum,WTeamID,WScore,LTeamID,LScore,WLoc,NumOT,WFGM,WFGA,WFGM3,WFGA3,WFTM,WFTA,WOR,WDR,WAst,WTO,WStl,WBlk,WPF,LFGM,LFGA,LFGM3,LFGA3,LFTM,LFTA,LOR,LDR,LAst,LTO,LStl,LBlk,LPF,Gender
755,2022,147,3163,91,3301,87,N,2,37,77,5,21,12,20,12,23,10,7,5,2,16,32,66,7,23,16,19,6,30,20,13,4,7,16,1
756,2022,147,3257,62,3276,50,N,0,25,58,5,15,7,9,6,20,12,11,15,4,17,16,46,3,14,15,20,10,24,9,21,6,2,12,1
757,2022,151,3163,63,3390,58,N,0,21,57,5,14,16,20,12,30,14,19,5,2,16,23,66,4,23,8,13,11,23,10,11,11,3,16,1
758,2022,151,3376,72,3257,59,N,0,27,57,6,17,12,17,8,24,19,14,11,4,11,27,63,1,8,4,7,11,18,5,15,13,2,17,1
759,2022,153,3376,64,3163,49,N,0,22,60,3,16,17,26,18,23,9,14,6,4,11,22,54,4,16,1,4,3,16,14,14,4,5,21,1


In [5]:
tournament_features = process_detailed_results(NCAATourneyDetailedResults)
tournament_features["Season"] += 1
tournament_features.tail()

Unnamed: 0,Season,TeamID,Score,ScoreOpp,FGM,FGA,FGM3,FGA3,FTM,FTA,OR,DR,Ast,TO,Stl,Blk,PF,FGMOpp,FGAOpp,FGM3Opp,FGA3Opp,FTMOpp,FTAOpp,OROpp,DROpp,AstOpp,TOOpp,StlOpp,BlkOpp,PFOpp,Dur,FGP,FGP3,FTP
2034,2023,3426,71.0,78.0,27.0,62.0,7.0,16.0,10.0,16.0,7.0,26.0,11.0,9.0,6.0,5.0,20.0,28.0,65.0,7.0,20.0,15.0,26.0,11.0,27.0,6.0,7.0,2.0,6.0,14.0,40.0,0.435484,0.4375,0.625
2035,2023,3428,74.0,73.5,27.5,53.5,10.5,25.5,8.5,12.0,6.0,25.5,16.5,17.5,3.0,1.5,14.5,31.5,67.0,3.5,14.5,7.0,12.5,9.5,18.5,10.5,8.5,9.0,4.5,12.0,40.0,0.514019,0.411765,0.708333
2036,2023,3437,55.0,60.5,20.5,58.5,6.0,23.5,8.0,12.0,7.5,19.5,11.0,10.0,8.5,5.0,18.0,22.5,55.5,4.0,15.0,11.5,16.5,11.5,29.5,12.5,15.0,7.0,1.0,12.5,40.0,0.350427,0.255319,0.666667
2037,2023,3439,81.0,84.0,30.0,60.0,6.0,17.0,15.0,20.0,7.0,28.0,13.0,10.0,1.0,1.0,13.0,30.0,64.0,15.0,38.0,9.0,11.0,2.0,25.0,19.0,3.0,7.0,1.0,18.0,40.0,0.5,0.352941,0.75
2038,2023,3450,40.0,50.0,14.0,56.0,3.0,22.0,9.0,14.0,8.0,31.0,5.0,17.0,2.0,2.0,23.0,14.0,53.0,3.0,14.0,19.0,22.0,6.0,30.0,7.0,11.0,6.0,3.0,11.0,40.0,0.25,0.136364,0.642857


## Tournement Seeds

In [6]:
NCAATourneySeeds = load_raw_data("NCAATourneySeeds")
NCAATourneySeeds.tail()

Unnamed: 0,Season,Seed,TeamID,Gender
1535,2022,Z12,3125,1
1536,2022,Z13,3138,1
1537,2022,Z14,3110,1
1538,2022,Z15,3218,1
1539,2022,Z16,3107,1


In [7]:
seed_features = process_seeds(NCAATourneySeeds)
seed_features.tail()

Unnamed: 0,Season,Seed,TeamID,Gender
1535,2022,12,3125,1
1536,2022,13,3138,1
1537,2022,14,3110,1
1538,2022,15,3218,1
1539,2022,16,3107,1


## Team Rankings

In [8]:
MMasseyOrdinals = pd.read_csv("/kaggle/input/warmup-round-march-machine-learning-mania-2023/MMasseyOrdinals.csv")
MMasseyOrdinals.tail()

Unnamed: 0,Season,RankingDayNum,SystemName,TeamID,OrdinalRank
4922144,2023,128,WOL,1473,332
4922145,2023,128,WOL,1474,166
4922146,2023,128,WOL,1475,260
4922147,2023,128,WOL,1476,301
4922148,2023,128,WOL,1477,303


In [9]:
ranking_features = process_rankings(MMasseyOrdinals)
ranking_features.tail()

Unnamed: 0,Season,TeamID,OrdinalRank
6533,2022,1468,188.0
6534,2022,1469,272.0
6535,2022,1470,225.0
6536,2022,1471,263.0
6537,2022,1472,312.5


## Merge features

In [10]:
features = merge_features(season_features, tournament_features, seed_features, ranking_features)
features.head()

Unnamed: 0,Season,TeamID,ScoreReg,ScoreOppReg,FGMReg,FGAReg,FGM3Reg,FGA3Reg,FTMReg,FTAReg,ORReg,DRReg,AstReg,TOReg,StlReg,BlkReg,PFReg,FGMOppReg,FGAOppReg,FGM3OppReg,FGA3OppReg,FTMOppReg,FTAOppReg,OROppReg,DROppReg,AstOppReg,TOOppReg,StlOppReg,BlkOppReg,PFOppReg,DurReg,FGPReg,FGP3Reg,FTPReg,ScoreTou,ScoreOppTou,FGMTou,FGATou,FGM3Tou,FGA3Tou,FTMTou,FTATou,ORTou,DRTou,AstTou,TOTou,StlTou,BlkTou,PFTou,FGMOppTou,FGAOppTou,FGM3OppTou,FGA3OppTou,FTMOppTou,FTAOppTou,OROppTou,DROppTou,AstOppTou,TOOppTou,StlOppTou,BlkOppTou,PFOppTou,DurTou,FGPTou,FGP3Tou,FTPTou,Seed,Gender,OrdinalRank
0,2004,1104,72.206897,67.448276,24.896552,55.0,7.137931,18.62069,15.275862,21.862069,11.310345,24.0,11.862069,13.275862,6.448276,3.103448,18.344828,23.275862,55.931034,7.275862,20.862069,13.62069,19.482759,11.517241,22.103448,11.344828,13.448276,5.0,3.413793,20.172414,40.689655,0.452665,0.383333,0.698738,62.0,67.0,22.0,52.0,5.0,12.0,13.0,16.0,9.0,20.0,13.0,8.0,2.0,6.0,21.0,19.0,49.0,7.0,18.0,22.0,26.0,13.0,22.0,15.0,8.0,1.0,2.0,17.0,40.0,0.423077,0.416667,0.8125,8,0,33.0
1,2004,1112,87.517241,78.413793,31.517241,65.0,7.37931,19.206897,17.103448,21.758621,13.896552,25.689655,18.896552,14.965517,6.965517,4.551724,16.068966,28.896552,65.62069,8.103448,22.37931,12.517241,17.758621,13.586207,21.655172,17.448276,15.62069,7.241379,2.137931,19.448276,40.0,0.484881,0.384201,0.786054,84.75,73.75,31.0,67.75,7.75,20.75,15.0,19.25,13.5,30.25,18.75,13.5,9.25,4.5,15.5,27.0,66.75,8.25,23.75,11.5,16.0,12.0,25.5,15.75,15.5,8.5,4.25,19.0,42.5,0.457565,0.373494,0.779221,9,0,35.0
2,2004,1140,72.185185,65.37037,25.185185,51.740741,5.37037,15.555556,16.444444,22.555556,9.962963,23.444444,14.666667,14.0,7.185185,3.259259,19.111111,23.555556,52.222222,5.740741,16.777778,12.518519,18.259259,8.740741,19.962963,12.518519,14.666667,6.740741,2.111111,19.740741,40.185185,0.486757,0.345238,0.729064,53.0,58.0,20.0,64.0,2.0,17.0,11.0,13.0,15.0,26.0,11.0,11.0,8.0,4.0,22.0,17.0,52.0,4.0,14.0,20.0,27.0,12.0,29.0,8.0,14.0,3.0,8.0,16.0,40.0,0.3125,0.117647,0.846154,12,0,34.0
3,2004,1153,76.9,62.9,27.033333,59.1,7.233333,18.8,15.6,23.366667,13.133333,24.133333,16.033333,12.1,7.9,5.066667,19.433333,20.333333,52.7,6.133333,18.833333,16.1,22.666667,11.1,22.5,10.433333,17.066667,5.433333,1.966667,20.333333,40.166667,0.457417,0.384752,0.667618,69.0,74.0,26.0,66.0,10.0,27.0,7.0,10.0,13.0,22.0,13.0,10.0,7.0,6.0,24.0,20.0,47.0,6.0,14.0,28.0,37.0,8.0,28.0,12.0,12.0,2.0,2.0,15.0,40.0,0.393939,0.37037,0.7,4,0,12.0
4,2004,1163,79.090909,63.939394,30.121212,62.636364,6.393939,15.969697,12.454545,20.212121,15.30303,29.484848,18.242424,13.787879,6.151515,8.727273,16.060606,23.757576,64.363636,5.575758,17.242424,10.848485,16.121212,13.727273,21.575758,12.121212,12.424242,7.030303,3.575758,18.363636,40.151515,0.48089,0.40038,0.616192,73.666667,69.666667,26.0,63.666667,4.666667,12.333333,17.0,24.333333,14.333333,27.333333,11.0,11.333333,4.666667,5.666667,16.666667,23.666667,65.0,5.0,19.333333,17.333333,22.0,15.666667,26.333333,12.0,10.666667,5.0,5.0,20.333333,40.0,0.408377,0.378378,0.69863,2,0,8.0


## Build Dataset

In [11]:
from sklearn.model_selection import train_test_split

data = load_raw_data("NCAATourneyCompactResults")
data_train, data_valid = train_test_split(data, random_state=0)

outcomes_train = get_outcomes(data_train)
outcomes_valid = get_outcomes(data_valid)

In [12]:
features_train = merge_outcomes_with_features(outcomes_train, features)
features_valid = merge_outcomes_with_features(outcomes_valid, features)
print(features_train.shape)
features_train.tail()

(1100, 135)


Unnamed: 0_level_0,Win,ScoreRegHigh,ScoreOppRegHigh,FGMRegHigh,FGARegHigh,FGM3RegHigh,FGA3RegHigh,FTMRegHigh,FTARegHigh,ORRegHigh,DRRegHigh,AstRegHigh,TORegHigh,StlRegHigh,BlkRegHigh,PFRegHigh,FGMOppRegHigh,FGAOppRegHigh,FGM3OppRegHigh,FGA3OppRegHigh,FTMOppRegHigh,FTAOppRegHigh,OROppRegHigh,DROppRegHigh,AstOppRegHigh,TOOppRegHigh,StlOppRegHigh,BlkOppRegHigh,PFOppRegHigh,DurRegHigh,FGPRegHigh,FGP3RegHigh,FTPRegHigh,ScoreTouHigh,ScoreOppTouHigh,FGMTouHigh,FGATouHigh,FGM3TouHigh,FGA3TouHigh,FTMTouHigh,FTATouHigh,ORTouHigh,DRTouHigh,AstTouHigh,TOTouHigh,StlTouHigh,BlkTouHigh,PFTouHigh,FGMOppTouHigh,FGAOppTouHigh,FGM3OppTouHigh,FGA3OppTouHigh,FTMOppTouHigh,FTAOppTouHigh,OROppTouHigh,DROppTouHigh,AstOppTouHigh,TOOppTouHigh,StlOppTouHigh,BlkOppTouHigh,PFOppTouHigh,DurTouHigh,FGPTouHigh,FGP3TouHigh,FTPTouHigh,SeedHigh,GenderHigh,OrdinalRankHigh,ScoreRegLow,ScoreOppRegLow,FGMRegLow,FGARegLow,FGM3RegLow,FGA3RegLow,FTMRegLow,FTARegLow,ORRegLow,DRRegLow,AstRegLow,TORegLow,StlRegLow,BlkRegLow,PFRegLow,FGMOppRegLow,FGAOppRegLow,FGM3OppRegLow,FGA3OppRegLow,FTMOppRegLow,FTAOppRegLow,OROppRegLow,DROppRegLow,AstOppRegLow,TOOppRegLow,StlOppRegLow,BlkOppRegLow,PFOppRegLow,DurRegLow,FGPRegLow,FGP3RegLow,FTPRegLow,ScoreTouLow,ScoreOppTouLow,FGMTouLow,FGATouLow,FGM3TouLow,FGA3TouLow,FTMTouLow,FTATouLow,ORTouLow,DRTouLow,AstTouLow,TOTouLow,StlTouLow,BlkTouLow,PFTouLow,FGMOppTouLow,FGAOppTouLow,FGM3OppTouLow,FGA3OppTouLow,FTMOppTouLow,FTAOppTouLow,OROppTouLow,DROppTouLow,AstOppTouLow,TOOppTouLow,StlOppTouLow,BlkOppTouLow,PFOppTouLow,DurTouLow,FGPTouLow,FGP3TouLow,FTPTouLow,SeedLow,GenderLow,OrdinalRankLow
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1,Unnamed: 113_level_1,Unnamed: 114_level_1,Unnamed: 115_level_1,Unnamed: 116_level_1,Unnamed: 117_level_1,Unnamed: 118_level_1,Unnamed: 119_level_1,Unnamed: 120_level_1,Unnamed: 121_level_1,Unnamed: 122_level_1,Unnamed: 123_level_1,Unnamed: 124_level_1,Unnamed: 125_level_1,Unnamed: 126_level_1,Unnamed: 127_level_1,Unnamed: 128_level_1,Unnamed: 129_level_1,Unnamed: 130_level_1,Unnamed: 131_level_1,Unnamed: 132_level_1,Unnamed: 133_level_1,Unnamed: 134_level_1,Unnamed: 135_level_1
2012_1462_1323,True,66.484848,61.515152,22.909091,53.060606,6.424242,19.515152,14.242424,20.121212,10.272727,23.757576,14.484848,10.030303,5.181818,3.424242,13.757576,23.333333,56.636364,4.939394,15.909091,9.909091,14.0,11.242424,23.212121,12.424242,10.818182,5.454545,2.939394,17.969697,40.606061,0.431753,0.329193,0.707831,63.0,63.5,20.0,54.5,7.0,23.5,16.0,22.5,10.5,29.0,14.0,11.0,4.5,5.0,18.0,22.5,56.0,8.0,20.0,10.5,15.5,8.5,27.5,12.5,10.0,4.0,6.5,19.5,40.0,0.366972,0.297872,0.711111,7,0,37.0,70.787879,67.454545,25.121212,55.484848,4.878788,13.969697,15.666667,22.636364,10.848485,25.575758,13.393939,12.727273,6.090909,3.757576,18.848485,22.818182,57.363636,6.272727,20.484848,15.545455,20.878788,11.090909,23.333333,13.272727,12.727273,6.454545,3.545455,19.363636,40.454545,0.452758,0.349241,0.692102,55.0,66.0,21.0,51.0,2.0,13.0,11.0,15.0,11.0,19.0,11.0,15.0,3.0,2.0,19.0,24.0,45.0,5.0,12.0,13.0,19.0,5.0,21.0,13.0,15.0,6.0,2.0,18.0,40.0,0.411765,0.153846,0.733333,10,0,57.0
2015_1455_1242,True,71.205882,64.676471,24.176471,54.970588,5.823529,15.529412,17.029412,23.647059,11.823529,26.117647,13.264706,12.794118,6.529412,5.029412,17.617647,22.941176,58.147059,5.735294,18.558824,13.058824,19.529412,12.205882,22.088235,11.205882,11.705882,6.352941,5.117647,19.411765,40.147059,0.439807,0.375,0.720149,68.5,64.5,25.5,55.5,2.5,11.5,15.0,21.0,13.5,25.5,12.0,14.0,5.5,4.5,18.0,21.5,52.0,6.0,20.0,15.5,21.0,8.0,20.0,10.0,11.5,9.0,2.5,18.0,40.0,0.459459,0.217391,0.714286,2,0,9.0,68.580645,55.774194,24.16129,54.483871,6.903226,19.225806,13.354839,19.322581,11.612903,23.258065,13.645161,9.354839,7.064516,3.774194,16.612903,19.354839,48.322581,4.935484,14.354839,12.129032,17.806452,8.612903,21.806452,8.645161,13.290323,3.967742,3.483871,17.677419,40.322581,0.443458,0.35906,0.691152,70.0,57.5,25.0,49.5,8.0,20.0,12.0,19.5,7.5,27.0,15.5,10.5,5.0,4.5,19.0,19.5,54.0,6.5,23.0,12.0,20.0,11.5,21.5,8.0,8.5,6.5,2.0,16.0,40.0,0.505051,0.4,0.615385,7,0,17.0
2015_1308_1242,False,71.205882,64.676471,24.176471,54.970588,5.823529,15.529412,17.029412,23.647059,11.823529,26.117647,13.264706,12.794118,6.529412,5.029412,17.617647,22.941176,58.147059,5.735294,18.558824,13.058824,19.529412,12.205882,22.088235,11.205882,11.705882,6.352941,5.117647,19.411765,40.147059,0.439807,0.375,0.720149,68.5,64.5,25.5,55.5,2.5,11.5,15.0,21.0,13.5,25.5,12.0,14.0,5.5,4.5,18.0,21.5,52.0,6.0,20.0,15.5,21.0,8.0,20.0,10.0,11.5,9.0,2.5,18.0,40.0,0.459459,0.217391,0.714286,2,0,9.0,67.967742,59.741935,23.580645,50.83871,4.677419,12.83871,16.129032,22.806452,12.322581,22.709677,11.774194,13.870968,6.322581,3.645161,16.129032,22.16129,52.290323,3.612903,12.290323,11.806452,17.0,10.064516,18.741935,11.032258,13.0,5.387097,2.193548,19.193548,40.322581,0.463832,0.364322,0.707214,69.0,73.0,26.0,65.0,5.0,16.0,12.0,20.0,14.0,27.0,13.0,13.0,4.0,4.0,22.0,23.0,59.0,6.0,17.0,21.0,27.0,10.0,30.0,9.0,11.0,8.0,9.0,18.0,45.0,0.4,0.3125,0.6,15,0,107.5
2015_1326_1433,True,72.485714,65.514286,24.914286,59.314286,8.085714,23.657143,14.571429,22.228571,12.314286,23.114286,12.628571,10.657143,9.657143,4.314286,19.228571,22.685714,52.428571,6.371429,18.542857,13.771429,19.828571,10.257143,26.514286,12.885714,16.171429,5.857143,3.342857,18.714286,40.857143,0.420039,0.341787,0.655527,75.0,77.0,30.0,65.0,8.0,21.0,7.0,11.0,11.0,18.0,13.0,11.0,11.0,5.0,21.0,27.0,51.0,6.0,16.0,17.0,23.0,10.0,25.0,16.0,17.0,6.0,2.0,16.0,45.0,0.461538,0.380952,0.636364,7,0,29.0,75.818182,62.393939,27.757576,57.151515,6.818182,18.333333,13.484848,19.878788,11.272727,24.575758,15.424242,11.30303,7.787879,5.090909,16.030303,22.545455,55.575758,6.69697,21.030303,10.606061,15.212121,11.212121,21.757576,12.909091,14.666667,5.030303,2.848485,17.969697,40.151515,0.485684,0.371901,0.678354,59.0,60.0,24.0,50.0,3.0,12.0,8.0,12.0,3.0,25.0,12.0,14.0,10.0,4.0,16.0,22.0,49.0,3.0,13.0,13.0,17.0,4.0,24.0,12.0,13.0,10.0,0.0,14.0,40.0,0.48,0.25,0.666667,10,0,25.0
2015_1433_1326,False,75.818182,62.393939,27.757576,57.151515,6.818182,18.333333,13.484848,19.878788,11.272727,24.575758,15.424242,11.30303,7.787879,5.090909,16.030303,22.545455,55.575758,6.69697,21.030303,10.606061,15.212121,11.212121,21.757576,12.909091,14.666667,5.030303,2.848485,17.969697,40.151515,0.485684,0.371901,0.678354,59.0,60.0,24.0,50.0,3.0,12.0,8.0,12.0,3.0,25.0,12.0,14.0,10.0,4.0,16.0,22.0,49.0,3.0,13.0,13.0,17.0,4.0,24.0,12.0,13.0,10.0,0.0,14.0,40.0,0.48,0.25,0.666667,10,0,25.0,72.485714,65.514286,24.914286,59.314286,8.085714,23.657143,14.571429,22.228571,12.314286,23.114286,12.628571,10.657143,9.657143,4.314286,19.228571,22.685714,52.428571,6.371429,18.542857,13.771429,19.828571,10.257143,26.514286,12.885714,16.171429,5.857143,3.342857,18.714286,40.857143,0.420039,0.341787,0.655527,75.0,77.0,30.0,65.0,8.0,21.0,7.0,11.0,11.0,18.0,13.0,11.0,11.0,5.0,21.0,27.0,51.0,6.0,16.0,17.0,23.0,10.0,25.0,16.0,17.0,6.0,2.0,16.0,45.0,0.461538,0.380952,0.636364,7,0,29.0


In [13]:
y_train = features_train["Win"]
X_train = features_train.drop("Win", axis=1)
y_valid = features_valid["Win"]
X_valid = features_valid.drop("Win", axis=1)

# Step 4: Train a model


### Setup Hyperparameter Tuning
See https://www.kaggle.com/prashant111/a-guide-on-xgboost-hyperparameters-tuning

In [14]:
from sklearn.metrics import brier_score_loss, roc_auc_score, confusion_matrix
import lightgbm

model = lightgbm.LGBMClassifier(objective="binary", min_data_in_leaf=400)
model.fit(X_train, y_train)

preds_train = model.predict(X_train)
preds_valid = model.predict(X_valid)

print("Training Score:", model.score(X_train, y_train))
print("Validation Score:", model.score(X_valid, y_valid))
print("Brier Score Validation:", brier_score_loss(y_valid, preds_valid))
print("ROC AUC Validation:", roc_auc_score(y_valid, preds_valid))
print("Parameters")
print(*(f"- {key}: {value}" for key, value in model.get_params(deep=True).items()), sep="\n")
print("Features")
print(*(f"- {name}: {imp}" for name, imp in sorted(zip(model.feature_name_, model.feature_importances_), key=lambda x: x[1], reverse=True)), sep="\n")


Training Score: 0.7472727272727273
Validation Score: 0.7021857923497268
Brier Score Validation: 0.2978142076502732
ROC AUC Validation: 0.7021857923497269
Parameters
- boosting_type: gbdt
- class_weight: None
- colsample_bytree: 1.0
- importance_type: split
- learning_rate: 0.1
- max_depth: -1
- min_child_samples: 20
- min_child_weight: 0.001
- min_split_gain: 0.0
- n_estimators: 100
- n_jobs: -1
- num_leaves: 31
- objective: binary
- random_state: None
- reg_alpha: 0.0
- reg_lambda: 0.0
- silent: warn
- subsample: 1.0
- subsample_for_bin: 200000
- subsample_freq: 0
- min_data_in_leaf: 400
Features
- SeedHigh: 15
- SeedLow: 14
- AstTouHigh: 7
- AstTouLow: 7
- FGMRegHigh: 5
- StlOppRegHigh: 5
- FGMRegLow: 5
- StlOppRegLow: 5
- ORRegHigh: 4
- AstOppRegHigh: 3
- ORRegLow: 3
- AstOppRegLow: 3
- AstRegHigh: 2
- DurRegHigh: 2
- DROppTouHigh: 2
- ScoreRegLow: 2
- AstRegLow: 2
- DurRegLow: 2
- DROppTouLow: 2
- ScoreRegHigh: 1
- DROppRegHigh: 1
- FTPRegHigh: 1
- TOTouHigh: 1
- FTPTouHigh: 1
- FG

In [15]:
outcomes = get_outcomes(data)
features_full = merge_outcomes_with_features(outcomes, features)
y = features_full["Win"]
X = features_full.drop("Win", axis=1)
model.fit(X, y)
preds = model.predict(X)
print("Training Score:", model.score(X, y))
print("Brier Score:", brier_score_loss(y, preds))
print("ROC AUC:", roc_auc_score(y, preds))
print("Parameters")
print(*(f"- {key}: {value}" for key, value in model.get_params(deep=True).items()), sep="\n")
print("Features")
print(*(f"- {name}: {imp}" for name, imp in sorted(zip(model.feature_name_, model.feature_importances_), key=lambda x: x[1], reverse=True)), sep="\n")


Training Score: 0.7639836289222374
Brier Score: 0.23601637107776263
ROC AUC: 0.7639836289222374
Parameters
- boosting_type: gbdt
- class_weight: None
- colsample_bytree: 1.0
- importance_type: split
- learning_rate: 0.1
- max_depth: -1
- min_child_samples: 20
- min_child_weight: 0.001
- min_split_gain: 0.0
- n_estimators: 100
- n_jobs: -1
- num_leaves: 31
- objective: binary
- random_state: None
- reg_alpha: 0.0
- reg_lambda: 0.0
- silent: warn
- subsample: 1.0
- subsample_for_bin: 200000
- subsample_freq: 0
- min_data_in_leaf: 400
Features
- SeedHigh: 19
- SeedLow: 19
- FGMRegLow: 10
- FGMOppRegLow: 10
- FGMRegHigh: 9
- FGMOppRegHigh: 9
- FGPRegHigh: 6
- AstTouHigh: 6
- FGPRegLow: 6
- AstTouLow: 6
- StlOppRegHigh: 5
- TOTouHigh: 5
- TOTouLow: 5
- BlkRegHigh: 4
- DurRegHigh: 4
- DROppTouHigh: 4
- BlkRegLow: 4
- StlOppRegLow: 4
- DROppTouLow: 4
- ScoreRegHigh: 3
- FTMTouHigh: 3
- FTPTouHigh: 3
- ScoreRegLow: 3
- ORRegLow: 3
- DurRegLow: 3
- FTPRegLow: 3
- FTPTouLow: 3
- FGA3RegHigh: 2
-

# Step 5: Submit to the competition

We"ll begin by using the trained model to generate predictions, which we"ll save to a CSV file.

In [16]:
SampleSubmissionWarmup = pd.read_csv("/kaggle/input/warmup-round-march-machine-learning-mania-2023/SampleSubmissionWarmup.csv")

print(SampleSubmissionWarmup.shape)
SampleSubmissionWarmup.tail()

(614319, 2)


Unnamed: 0,ID,Pred
614314,2022_3469_3471,0.5
614315,2022_3469_3472,0.5
614316,2022_3470_3471,0.5
614317,2022_3470_3472,0.5
614318,2022_3471_3472,0.5


In [17]:
def get_submission_outcomes(sample_submission: pd.DataFrame) -> pd.DataFrame:
    df = sample_submission.copy()
    df.drop("Pred", axis=1, inplace=True)
    df[["Season", "LowID", "HighID"]] = df["ID"].str.split("_", expand=True)
    df[["Season", "LowID", "HighID"]] = df[["Season", "LowID", "HighID"]].astype(int)
    return df

In [18]:
submission_outcomes = get_submission_outcomes(SampleSubmissionWarmup)
print(submission_outcomes.shape)
submission_outcomes.tail()

(614319, 4)


Unnamed: 0,ID,Season,LowID,HighID
614314,2022_3469_3471,2022,3469,3471
614315,2022_3469_3472,2022,3469,3472
614316,2022_3470_3471,2022,3470,3471
614317,2022_3470_3472,2022,3470,3472
614318,2022_3471_3472,2022,3471,3472


In [19]:
X_submission = merge_outcomes_with_features(submission_outcomes, features, how="left").fillna(0)
print(X_submission.shape)
X_submission.tail()

(614319, 134)


Unnamed: 0_level_0,ScoreRegHigh,ScoreOppRegHigh,FGMRegHigh,FGARegHigh,FGM3RegHigh,FGA3RegHigh,FTMRegHigh,FTARegHigh,ORRegHigh,DRRegHigh,AstRegHigh,TORegHigh,StlRegHigh,BlkRegHigh,PFRegHigh,FGMOppRegHigh,FGAOppRegHigh,FGM3OppRegHigh,FGA3OppRegHigh,FTMOppRegHigh,FTAOppRegHigh,OROppRegHigh,DROppRegHigh,AstOppRegHigh,TOOppRegHigh,StlOppRegHigh,BlkOppRegHigh,PFOppRegHigh,DurRegHigh,FGPRegHigh,FGP3RegHigh,FTPRegHigh,ScoreTouHigh,ScoreOppTouHigh,FGMTouHigh,FGATouHigh,FGM3TouHigh,FGA3TouHigh,FTMTouHigh,FTATouHigh,ORTouHigh,DRTouHigh,AstTouHigh,TOTouHigh,StlTouHigh,BlkTouHigh,PFTouHigh,FGMOppTouHigh,FGAOppTouHigh,FGM3OppTouHigh,FGA3OppTouHigh,FTMOppTouHigh,FTAOppTouHigh,OROppTouHigh,DROppTouHigh,AstOppTouHigh,TOOppTouHigh,StlOppTouHigh,BlkOppTouHigh,PFOppTouHigh,DurTouHigh,FGPTouHigh,FGP3TouHigh,FTPTouHigh,SeedHigh,GenderHigh,OrdinalRankHigh,ScoreRegLow,ScoreOppRegLow,FGMRegLow,FGARegLow,FGM3RegLow,FGA3RegLow,FTMRegLow,FTARegLow,ORRegLow,DRRegLow,AstRegLow,TORegLow,StlRegLow,BlkRegLow,PFRegLow,FGMOppRegLow,FGAOppRegLow,FGM3OppRegLow,FGA3OppRegLow,FTMOppRegLow,FTAOppRegLow,OROppRegLow,DROppRegLow,AstOppRegLow,TOOppRegLow,StlOppRegLow,BlkOppRegLow,PFOppRegLow,DurRegLow,FGPRegLow,FGP3RegLow,FTPRegLow,ScoreTouLow,ScoreOppTouLow,FGMTouLow,FGATouLow,FGM3TouLow,FGA3TouLow,FTMTouLow,FTATouLow,ORTouLow,DRTouLow,AstTouLow,TOTouLow,StlTouLow,BlkTouLow,PFTouLow,FGMOppTouLow,FGAOppTouLow,FGM3OppTouLow,FGA3OppTouLow,FTMOppTouLow,FTAOppTouLow,OROppTouLow,DROppTouLow,AstOppTouLow,TOOppTouLow,StlOppTouLow,BlkOppTouLow,PFOppTouLow,DurTouLow,FGPTouLow,FGP3TouLow,FTPTouLow,SeedLow,GenderLow,OrdinalRankLow
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1,Unnamed: 113_level_1,Unnamed: 114_level_1,Unnamed: 115_level_1,Unnamed: 116_level_1,Unnamed: 117_level_1,Unnamed: 118_level_1,Unnamed: 119_level_1,Unnamed: 120_level_1,Unnamed: 121_level_1,Unnamed: 122_level_1,Unnamed: 123_level_1,Unnamed: 124_level_1,Unnamed: 125_level_1,Unnamed: 126_level_1,Unnamed: 127_level_1,Unnamed: 128_level_1,Unnamed: 129_level_1,Unnamed: 130_level_1,Unnamed: 131_level_1,Unnamed: 132_level_1,Unnamed: 133_level_1,Unnamed: 134_level_1
2022_3469_3471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2022_3469_3472,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2022_3470_3471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2022_3470_3472,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2022_3471_3472,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [20]:
# Use the model to generate predictions

predictions = model.predict(X_submission)

# Save the predictions to a CSV file
output = pd.DataFrame({"ID": X_submission.index,
                       "Pred": predictions})
output["Pred"] = output["Pred"].astype(int)
output.to_csv("submission.csv", index=False)
print(output.shape)
output.describe()

(614319, 2)


Unnamed: 0,Pred
count,614319.0
mean,0.075843
std,0.264747
min,0.0
25%,0.0
50%,0.0
75%,0.0
max,1.0
