In [None]:
from connect import ibis_connect, snowpark_connect
from ibis.interactive import *
from snowflake.snowpark import Session, functions as F, types as T
from snowflake.snowpark.functions import when, lit

import os

con = ibis_connect()

# Load the data into tables

In [None]:
import zipfile
from pathlib import Path

import polars as pl
from ibis.interactive import *


def process_file(con, file, zf):
   table_name = convert_to_all_caps_and_prefix(file.filename)
   t = pl.read_csv(zf.read(file), encoding="utf8-lossy")
   t = ibis.memtable(t)
   t = t.rename({c.upper(): c for c in t.columns})
   t = con.create_table(table_name, t, overwrite=True)
   return t


 with zipfile.ZipFile(
     Path(__file__).parents[1] / "data" / "mm-learning-mania-2024.zip",
     "r",
 ) as zf:
     files = zf.filelist
     for file in files:
         t = process_file(con, file, zf)

In [None]:
def combine_stats(con, table_name: str, schema: str):
    t = con.table(table_name, schema=schema)
    t = t.mutate(W1Column=_.WScore)
    t = t.mutate(W2Column=_.LScore)
    winning = t.select("Season", "DayNum", s.startswith("W")).mutate(Won=1)
    winning = winning.rename(
        {
            c[1:]: c
            for c in winning.columns
            if c.startswith("W") and c not in ["Won", "WLoc", "W1Column", "W2Column"]
        }
    )

    losing = t.select(
        "Season", "DayNum", "WLoc", "W1Column", "W2Column", s.startswith("L")
    ).mutate(Won=0)
    losing = losing.rename(
        {
            c[1:]: c
            for c in losing.columns
            if c.startswith("L") and c not in ["Won", "WLoc"]
        }
    )
    return winning.union(losing)


def flatten_regions(con, table_name: str, schema: str):
    flattened_regions = (
        con.table(table_name, schema=schema)
        .pivot_longer(s.startswith("Region"))
        .rename({"Region": "name", "RegionName": "value"})
        .mutate(Region=_.Region.replace("Region", ""))
        .drop("DayZero")
    )
    return flattened_regions

In [None]:
m_reg = combine_stats(con, "MREGULARSEASONDETAILEDRESULTS", schema="MEN")

In [None]:
w_margin = (
    m_reg.filter(_.Won == 1)
    .mutate(ScoreDiff=_.W1Column - _.W2Column)
    .group_by(["Season", "TeamID"])
    .agg(WinMarginMedian=_.ScoreDiff.median(), WinMarginMean=_.ScoreDiff.mean())
)

l_margin = (
    m_reg.filter(_.Won == 0)
    .mutate(ScoreDiff=_.W1Column - _.W2Column)
    .group_by(["Season", "TeamID"])
    .agg(LoseMarginMedian=_.ScoreDiff.median(), LoseMarginMean=_.ScoreDiff.mean())
)

m_season_margin = w_margin.join(l_margin, (["Season", "TeamID"]))

In [None]:
ctw = (
    combine_stats(con, "MSECONDARYTOURNEYCOMPACTRESULTS", "MEN")
    .group_by(["Season", "TeamID"])
    .agg(CTWins=_.Won.sum(), AverageCTScore=_.Score.mean())
)

In [None]:
season_stats = (
    m_reg.drop("DayNum")
    .group_by(["Season", "TeamID"])
    .agg(s.across(s.numeric(), dict(mean=_.mean(), median=_.median(), stddev=_.std())))
    .drop(s.startswith("Won_"), s.startswith("Season_"), s.startswith("TeamID_"))
)

In [None]:
hna = (
    m_reg.group_by(["Season", "TeamID", "WLoc"])
    .agg(WinCount=_.Won.sum())
    .mutate(WLoc="WLoc" + _.WLoc)
    .pivot_wider(names_from="WLoc", values_from="WinCount")
    .mutate(s.across(s.startswith("WLoc"), ibis.coalesce(_, 0)))
)

In [None]:
flattened_regions = flatten_regions(con, "MSEASONS", schema="MEN")

m_regions = (
    con.table("MNCAATOURNEYw_team_seedS", "MEN")
    .mutate(Region=_.Seed[0])
    .join(flattened_regions, ["Season", "Region"])
)

In [None]:
season_joined = (
    season_stats.join(hna, ["Season", "TeamID"])
    .join(m_season_margin, ["Season", "TeamID"])
    .join(ctw, ["Season", "TeamID"], how="left")
    .join(m_regions, ["Season", "TeamID"])
).drop(s.endswith("_right"))

season_joined

In [None]:
#season_joined = season_joined.cache()

In [None]:
season_joined

In [None]:
# If we wanted to see how far these teams made it each year...

tournament_progress = (
    combine_stats(con, "MNCAATOURNEYCOMPACTRESULTS", "MEN")
    .select("Season", "DayNum", "TeamID")
    .mutate(
        Round=ibis.row_number().over(
            ibis.window(group_by=["Season", "TeamID"], order_by=_.DayNum)
        )
    )
)

In [None]:
# This is super hacky, but I need to be able to use the same session to share cached tables.
# I also want to avoid this message: SnowparkSessionException: (1409): More than one active session is detected. When you call...

@classmethod
def from_ibis(self, con) -> Session:
    return Session.builder.config("connection", con.con).getOrCreate()

Session.from_ibis = from_ibis

session = Session.from_ibis(con)

In [None]:
session

In [None]:
session.sql(ibis.to_sql(season_joined)).show()

In [None]:
seeds = session.table('MEN.MNCAATOURNEYSEEDS')
for col in seeds.columns:
    seeds = seeds.withColumnRenamed(col, col.upper())
seeds.show()

In [None]:
## Tournament data for training

tourney = session.table('MEN.MNCAATOURNEYCOMPACTRESULTS')
for col in tourney.columns:
    tourney = tourney.withColumnRenamed(col, col.upper())
tourney = tourney.select('SEASON','WTEAMID','LTEAMID','WSCORE','LSCORE','DAYNUM')

In [None]:
seed_value = (
    seeds
    .with_column("SEED_REGION", F.substring(F.col("SEED"), 1, 1))
    .with_column(
        "SEED_VALUE", F.substring(F.col("SEED"), 2, F.length(F.col("SEED")) - 1)
    )
    .select("SEASON", "TEAMID", "SEED_REGION", "SEED_VALUE")
    .with_column(
        "SEED_VALUE",
        F.cast(F.regexp_replace(F.col("SEED_VALUE"), "[a-z]", ""), T.IntegerType()),
    )
)

seed_value.show()

In [None]:
# Tournament Progress (not using just cool to show)

df_w = tourney.select("SEASON", "WTEAMID", "DAYNUM").with_column_renamed("WTEAMID", "TEAMID")
df_l = tourney.select("SEASON", "LTEAMID", "DAYNUM").with_column_renamed("LTEAMID", "TEAMID")
df_union = df_w.union(df_l)

df_union = df_union.with_column(
    "ROUND_NUMBER",
    F.row_number().over(Window.partition_by(["SEASON", "TEAMID"]).order_by("DAYNUM")),
)

df_union = df_union.with_column(
    "ROUND_NAME",
    F.when(F.col("ROUND_NUMBER") == F.lit(1), "Round of 64")
    .when(F.col("ROUND_NUMBER") == F.lit(2), "Round of 32")
    .when(F.col("ROUND_NUMBER") == F.lit(3), "Sweet 16")
    .when(F.col("ROUND_NUMBER") == F.lit(4), "Elite Eight")
    .when(F.col("ROUND_NUMBER") == F.lit(5), "Final Four")
    .when(F.col("ROUND_NUMBER") == F.lit(6), "National Championship"),
)

df_union = df_union.select("SEASON", "TEAMID", "ROUND_NUMBER", "ROUND_NAME","DAYNUM")
df_union.show()

In [None]:
# Get max round for each team and season
# This will also help us get the round for each game played
max_round = (
    df_union.with_column(
        "ROW_NUM",
        F.row_number().over(
            Window.partition_by(["SEASON", "TEAMID"]).order_by(
                F.col("ROUND_NUMBER").desc()
            )
        ),
    )
    .filter(F.col("ROW_NUM") == 1)
    .select("SEASON", "TEAMID", "ROUND_NUMBER", "ROUND_NAME","DAYNUM")
)

# Show a team and how they did the last 10 years
max_round.filter(F.col('TEAMID')==1276).sort(F.col('season').desc()).show(10)

In [None]:
## ROUNDS & SEEDS
tourney_rounds = tourney.join(
    df_union,
    (tourney.col("wteamid") == df_union.col("teamid"))
    & (tourney.col("season") == df_union.col("season"))
    & (tourney.col("daynum") == df_union.col("daynum")),
).select(tourney.col("SEASON").alias("SEASON"), "WTEAMID", "LTEAMID","WSCORE","LSCORE","ROUND_NUMBER","DAY_NUM").drop(
    "ROUND_NAME"
)

tourney_rounds.show()

In [None]:
tourney_seeds = tourney_rounds.join(
    seed_value,
    (tourney_rounds.col("wteamid") == seed_value.col("teamid"))
    & (tourney_rounds.col("season") == seed_value.col("season")),
).select(tourney_rounds.col("SEASON").alias("SEASON"), "WTEAMID", "LTEAMID","WSCORE","LSCORE","ROUND_NUMBER","DAY_NUM",seed_value.col("SEED_VALUE").alias("w_team_seed")).drop(
    "ROUND_NAME"
)
tourney_seeds.show()

In [None]:
tourney_seeds = tourney_seeds.cache_result()
tourney_seeds = tourney_seeds.join(
    seed_value,
    (tourney_seeds.col("lteamid") == seed_value.col("teamid"))
    & (tourney_seeds.col("season") == seed_value.col("season")),
).select(tourney_seeds.col("SEASON").alias("SEASON"), "WTEAMID", "LTEAMID","WSCORE","LSCORE","ROUND_NUMBER","DAY_NUM","w_team_seed",seed_value.col("SEED_VALUE").alias("l_SEED"))

tourney_seeds.show()

In [None]:
tourney_seeds.with_column("total_score",F.col("WSCORE")+F.col("LSCORE")).show()

In [None]:
## Add in conference names, uppercase column headers and values and one hot encode
conf = session.table('MEN.MTEAMCONFERENCES')
for col in conf.columns:
    conf = conf.withColumnRenamed(col, col.upper())

def fix_values(column):
    return F.upper(F.regexp_replace(F.col(column), "[^a-zA-Z0-9]+", "_"))

conf = conf.with_column("CONFABBREV", fix_values("CONFABBREV"))
conf = conf.with_column_renamed("SEASON", "C_SEASON")
conf = conf.with_column_renamed("TEAMID", "C_TEAMID")

conf.show()

In [None]:
tourney_conf = tourney_seeds.join(
    conf,
    (tourney_seeds.col("wteamid") == conf.col("C_teamid"))
    & (tourney_seeds.col("season") == conf.col("C_season"))).drop("C_SEASON","C_TEAMID").with_column_renamed("CONFABBREV", "W_CONF")
tourney_conf.show()

In [None]:
tourney_final = tourney_conf.join(
    conf,
    (tourney_seeds.col("lteamid") == conf.col("C_teamid"))
    & (tourney_seeds.col("season") == conf.col("C_season"))).drop("C_SEASON","C_TEAMID").with_column_renamed("CONFABBREV", "l_CONF")
tourney_final.show()

In [None]:
# We now have our final Tournament data (tourney_final) and season/conf tourny data () lets put it all together
season.to_pandas().head()

In [None]:
for col in season.columns:
    season = season.withColumnRenamed(col, col.upper())
    
season.to_pandas().head()

In [None]:
season = season.drop('SEED','REGIONNAME')

In [None]:
tourney_final.to_pandas().head()

In [None]:
season_w = season.select(
    *[F.col(col).alias(f"W_{col}") for col in season.columns]
)

season_l = season.select(
    *[F.col(col).alias(f"L_{col}") for col in season.columns]
)
season_w.to_pandas().head()

In [None]:
season_l.to_pandas().head()

In [None]:
season_w.count()

In [None]:
season_l.count()

In [None]:
tourney_final.count()

In [None]:
final = (
    tourney_final.join(
        season_w,
        on=(
            (tourney_final.WTEAMID == season_w.W_TEAMID)
            & (tourney_final.SEASON == season_w.W_SEASON)
        ),
    )
    .drop("W_TEAMID", "W_SEASON")
    .join(
        season_l,
        on=(
            (tourney_final.LTEAMID == season_l.L_TEAMID)
            & (tourney_final.SEASON == season_l.L_SEASON)
        ),
    )
    .drop("L_TEAMID", "L_SEASON")
)

final.count()

In [None]:
final.to_pandas().head()

In [None]:
final.write.save_as_table(
    "MEN.FINAL_FEATURES", mode="overwrite"
)

In [None]:
df = final #session.table("MEN.FINAL_FEATURES")

In [None]:
new_cols = {}

for c in df.columns:
    if c.startswith("W"):
        new_cols[c] = "L" + c[1:]
    elif c.startswith("L"):
        new_cols[c] = "W" + c[1:]
    else:
        new_cols[c] = c

df_flipped = df.select([F.col(c).alias(new_cols.get(c, c)) for c in df.columns]).select(
    *[col for col in df.columns]
)

In [None]:
df.limit(1).union_all(df_flipped.limit(1)).to_pandas().head()

In [None]:
df = df.with_column("WIN_INDICATOR", F.lit(1))
df_flipped = df_flipped.with_column("WIN_INDICATOR", F.lit(0))

df = df.union_all(df_flipped)

In [None]:
OHE = OneHotEncoder(
    input_cols=["W_CONF","L_CONF"],
    output_cols=["W_CONF","L_CONF"],
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

final_train = OHE.fit(df).transform(df)
final_train.show()

In [None]:
final_train.count()

In [None]:
final_train.write.save_as_table(
    "MEN.FINAL_TRAIN", mode="overwrite"
)

In [None]:
from snowflake.snowpark.functions import when, lit

final = final.with_column(
    "ROUND",
    when((final.daynum >= 134) & (final.daynum <= 135), lit(0))
    .when((final.daynum >= 136) & (final.daynum <= 137), lit(1))
    .when((final.daynum >= 138) & (final.daynum <= 139), lit(2))
    .when((final.daynum >= 143) & (final.daynum <= 144), lit(3))
    .when(final.daynum == 145, lit(4))
    .when(final.daynum == 152, lit(5))
    .otherwise(lit(6)),
)

final.filter(F.col("ROUND") == 0).show()

# START TRAINING

In [1]:
import ast
import json
import warnings

import pandas as pd
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col
from snowflake.snowpark import functions as F

warnings.simplefilter(action="ignore", category=UserWarning)

In [2]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


In [3]:
final = session.table('MEN.FINAL_TRAIN')

In [5]:
# Columns with null values and their respective counts
{
    k: v
    for k, v in {
        col_name: final.where(F.col(col_name).is_null()).count()
        for col_name in final.columns
    }.items()
    if v > 0
}

{'W_CTWINS': 2604,
 'W_AVERAGECTSCORE': 2604,
 'L_CTWINS': 2604,
 'L_AVERAGECTSCORE': 2604}

In [6]:
final = final.drop(['W_CTWINS','W_AVERAGECTSCORE','L_CTWINS','L_AVERAGECTSCORE'])

In [7]:
final = final.drop(['W_WLOCN','W_WLOCH','W_WLOCA','L_WLOCN','L_WLOCH','L_WLOCA']) #variants

In [8]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    # "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    # "max_depth": list(range(3, 6, 1)),
    # "min_child_weight": list(range(1, 6, 1)),
}

In [9]:
train = final.filter(F.col('SEASON') <= 2021).cache_result()
test = final.filter(F.col('SEASON') > 2021).cache_result()

In [10]:
session.use_warehouse('MM_L')

In [50]:
train.drop(['SEASON','WIN_INDICATOR','lteamid','wteamid','wscore','lscore','round','w_team_region','l_region','w_region','round_number']).columns

['W_CONF_ACC',
 'W_CONF_AEC',
 'W_CONF_A_SUN',
 'W_CONF_A_TEN',
 'W_CONF_BIG_EAST',
 'W_CONF_BIG_SKY',
 'W_CONF_BIG_SOUTH',
 'W_CONF_BIG_TEN',
 'W_CONF_BIG_TWELVE',
 'W_CONF_BIG_WEST',
 'W_CONF_CAA',
 'W_CONF_CUSA',
 'W_CONF_HORIZON',
 'W_CONF_IVY',
 'W_CONF_MAAC',
 'W_CONF_MAC',
 'W_CONF_MEAC',
 'W_CONF_MID_CONT',
 'W_CONF_MVC',
 'W_CONF_MWC',
 'W_CONF_NEC',
 'W_CONF_OVC',
 'W_CONF_PAC_TEN',
 'W_CONF_PAC_TWELVE',
 'W_CONF_PATRIOT',
 'W_CONF_SEC',
 'W_CONF_SOUTHERN',
 'W_CONF_SOUTHLAND',
 'W_CONF_SUMMIT',
 'W_CONF_SUN_BELT',
 'W_CONF_SWAC',
 'W_CONF_WAC',
 'W_CONF_WCC',
 'L_CONF_ACC',
 'L_CONF_AEC',
 'L_CONF_A_SUN',
 'L_CONF_A_TEN',
 'L_CONF_BIG_EAST',
 'L_CONF_BIG_SKY',
 'L_CONF_BIG_SOUTH',
 'L_CONF_BIG_TEN',
 'L_CONF_BIG_TWELVE',
 'L_CONF_BIG_WEST',
 'L_CONF_CAA',
 'L_CONF_CUSA',
 'L_CONF_HORIZON',
 'L_CONF_IVY',
 'L_CONF_MAAC',
 'L_CONF_MAC',
 'L_CONF_MEAC',
 'L_CONF_MID_CONT',
 'L_CONF_MVC',
 'L_CONF_MWC',
 'L_CONF_NEC',
 'L_CONF_OVC',
 'L_CONF_PAC_TEN',
 'L_CONF_PAC_TWELVE',
 'L_C

In [14]:
all_rounds = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train.drop(['SEASON','WIN_INDICATOR','lteamid','wteamid','wscore','lscore','round','w_team_region','l_region','w_region','round_number',"W_CONF","L_CONF"]).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
all_rounds.fit(train)

The version of package 'snowflake-snowpark-python' in the local environment is 1.13.0, which does not fit the criteria for the requirement 'snowflake-snowpark-python<2'. Your UDF might not work when the package version is different between the server and your local environment.
Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.
The version of package 'pyarrow' in the local environment is 15.0.1, which does not fit the criteria for the requirement 'pyarrow<14'. Your UDF might not work when the package version is different between the server and your local environment.
The version of package 'cachetools' in the local environment is 5.3.3, which does not fit the criteria for the requirement 'cachetools<6'. Your UDF might not work when the package version is different between the server and your local environment.


<snowflake.ml.modeling.model_selection.grid_search_cv.GridSearchCV at 0x7fc408f3a8f0>

In [15]:
session.use_warehouse('wh_xs')

In [16]:
result = all_rounds.predict(test).filter(F.col("PRED_WIN_INDICATOR") == 1).filter(F.col("ROUND") == 1)
result.count()

64

In [17]:
accuracy = accuracy_score(
    df=result.filter(result.season == 2022), y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)
print(f"Accuracy 2022: {accuracy}")

accuracy = accuracy_score(
    df=result.filter(result.season == 2023), y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)
print(f"Accuracy 2023: {accuracy}")

accuracy = accuracy_score(
    df=result, y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)

print(f"Accuracy total: {accuracy}")

Accuracy 2022: 0.78125
Accuracy 2023: 0.75
Accuracy total: 0.765625


# Predicting the Bracket & final four (we want as many final 4 as possible)

## Results play in games

In [47]:
result = all_rounds.predict_proba(test).filter(F.col('ROUND') == 0).filter(F.col("season") == 2023)#.filter(F.col("PRED_WIN_INDICATOR") == 1)

In [41]:
result.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [48]:
result.select('wteamid','lteamid','"predict_proba_1"','"predict_proba_0"').show()

---------------------------------------------------------------------
|"WTEAMID"  |"LTEAMID"  |"predict_proba_1"    |"predict_proba_0"    |
---------------------------------------------------------------------
|1280       |1338       |0.2209777981042862   |0.779022216796875    |
|1305       |1113       |0.715787410736084    |0.284212589263916    |
|1369       |1394       |0.25784918665885925  |0.7421507835388184   |
|1338       |1280       |0.809890627861023    |0.19010937213897705  |
|1113       |1305       |0.8797614574432373   |0.1202385425567627   |
|1394       |1369       |0.2758517563343048   |0.7241482734680176   |
|1411       |1192       |0.05455866828560829  |0.9454413056373596   |
|1192       |1411       |0.6336548328399658   |0.3663451671600342   |
---------------------------------------------------------------------



In [33]:
result = all_rounds.predict(test).filter(F.col('ROUND') == 0).filter(F.col("season") == 2023).filter(F.col("PRED_WIN_INDICATOR") == 1)

teams = session.table('mteams')
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = result.select('season','lteamID','l_seed','l_region','wteamid','w_seed','w_region','win_indicator','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("LTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "lteam_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

res_teams = (
    res_teamsl.join(teams, result.col("WTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "wteam_name")
    .select('season','lteamid','wteamid','LTEAM_NAME','l_seed','l_region','WTEAM_NAME','w_seed','w_region','WIN_INDICATOR','PRED_WIN_INDICATOR')
)

res_teams.sort(F.col("wteam_name")).show()

-------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"LTEAMID"  |"WTEAMID"  |"LTEAM_NAME"    |"L_SEED"  |"L_REGION"  |"WTEAM_NAME"  |"W_SEED"  |"W_REGION"  |"WIN_INDICATOR"  |"PRED_WIN_INDICATOR"  |
-------------------------------------------------------------------------------------------------------------------------------------------------------------
|2023      |1305       |1113       |Nevada          |11        |Z           |Arizona St    |11        |Z           |1                |1.0                   |
|2023      |1411       |1192       |TX Southern     |16        |W           |F Dickinson   |16        |W           |1                |1.0                   |
|2023      |1113       |1305       |Arizona St      |11        |Z           |Nevada        |11        |Z           |0                |1.0                   |
|2023      |1280       |1338       |Mississippi St  

## Results round 1

In [36]:
result = all_rounds.predict(test.filter(F.col('ROUND') == 1)).filter(F.col("season") == 2023).filter(F.col("PRED_WIN_INDICATOR") == 1)

teams = session.table('mteams')
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = result.select('season','lteamID','l_seed','l_region','wteamid','w_seed','w_region','win_indicator','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("LTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "lteam_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

res_teams = (
    res_teamsl.join(teams, result.col("WTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "wteam_name")
    .select('season','lteamid','wteamid','LTEAM_NAME','l_seed','l_region','WTEAM_NAME','w_seed','w_region','WIN_INDICATOR','PRED_WIN_INDICATOR')
)

print('The round of 32')
res_teams.sort("w_region").count()

The round of 32


32

In [38]:
res_teams.sort("w_region").show(32)

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"LTEAMID"  |"WTEAMID"  |"LTEAM_NAME"      |"L_SEED"  |"L_REGION"  |"WTEAM_NAME"    |"W_SEED"  |"W_REGION"  |"WIN_INDICATOR"  |"PRED_WIN_INDICATOR"  |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
|2023      |1436       |1266       |Vermont           |15        |W           |Marquette       |2         |W           |1                |1.0                   |
|2023      |1344       |1246       |Providence        |11        |W           |Kentucky        |6         |W           |1                |1.0                   |
|2023      |1192       |1345       |F Dickinson       |16        |W           |Purdue          |1         |W           |0                |1.0                   |
|2023      |1418       |1397

In [None]:
result = grid_search.predict(test.filter(F.col("ROUND") == 1))

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "season",
        "lteamID",
        "l_seed",
        "l_region",
        "wteamid",
        "w_team_seed",
        "w_team_region",
        "win_indicator",
        "pred_win_indicator",
    )
    .filter(F.col("WIN_INDICATOR") == 1)
    .filter(F.col("season") == 2023)
)

res_teamsl = (
    result.join(teams, result.col("LTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "lteam_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_winners = (
    res_teamsl.join(teams, result.col("WTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "wteam_name")
    .select(
        "season",
        "lteamid",
        "wteamid",
        "LTEAM_NAME",
        "l_seed",
        "l_region",
        "WTEAM_NAME",
        "w_team_seed",
        "w_team_region",
        "WIN_INDICATOR",
        "PRED_WIN_INDICATOR",
    )
)

In [None]:
round_1_winners.show()

In [None]:
from snowflake.snowpark.functions import when, lit

round_1_winners = round_1_winners.with_column(
    "W_TEAM_ID",
    when((round_1_winners.WIN_INDICATOR == round_1_winners.PRED_WIN_INDICATOR), round_1_winners.wteamid)
    .otherwise(round_1_winners.lteamID),
).with_column(
    "L_TEAM_ID",
    when((round_1_winners.WIN_INDICATOR != round_1_winners.PRED_WIN_INDICATOR), round_1_winners.wteamid)
    .otherwise(round_1_winners.lteamID),
).with_column(
    "W_TEAM_NAME",
    when((round_1_winners.WIN_INDICATOR == round_1_winners.PRED_WIN_INDICATOR), round_1_winners.wteam_name)
    .otherwise(round_1_winners.lteam_name),
).with_column(
    "L_TEAM_NAME",
    when((round_1_winners.WIN_INDICATOR != round_1_winners.PRED_WIN_INDICATOR), round_1_winners.wteam_name)
    .otherwise(round_1_winners.lteam_name),
).with_column(
    "W_TEAM_SEED",
    when((round_1_winners.WIN_INDICATOR == round_1_winners.PRED_WIN_INDICATOR), round_1_winners.w_team_seed)
    .otherwise(round_1_winners.l_seed),
).with_column(
    "L_TEAM_SEED",
    when((round_1_winners.WIN_INDICATOR != round_1_winners.PRED_WIN_INDICATOR), round_1_winners.w_team_seed)
    .otherwise(round_1_winners.l_seed)
).with_column(
    "W_TEAM_REGION",
    when((round_1_winners.WIN_INDICATOR == round_1_winners.PRED_WIN_INDICATOR), round_1_winners.w_team_region)
    .otherwise(round_1_winners.l_region)
).with_column(
    "L_TEAM_REGION",
    when((round_1_winners.WIN_INDICATOR != round_1_winners.PRED_WIN_INDICATOR), round_1_winners.w_team_region)
    .otherwise(round_1_winners.l_region)
).select("W_TEAM_ID","L_TEAM_ID","W_TEAM_NAME","L_TEAM_NAME","W_TEAM_SEED","L_TEAM_SEED","W_TEAM_REGION","L_TEAM_REGION")

In [None]:
df1 = round_1_winners.cache_result()
df2 = round_1_winners.cache_result()

second_round_matchups = df1.join(
    df2,
    (df1.w_team_region == df2.w_team_region)
    & (
        (df1.w_team_seed == 1) & (df2.w_team_seed == 8)
        | (df1.w_team_seed == 1) & (df2.w_team_seed == 9)
        | (df1.w_team_seed == 16) & (df2.w_team_seed == 8)
        | (df1.w_team_seed == 16) & (df2.w_team_seed == 9)
        | (df1.w_team_seed == 4) & (df2.w_team_seed == 5)
        | (df1.w_team_seed == 4) & (df2.w_team_seed == 12)
        | (df1.w_team_seed == 13) & (df2.w_team_seed == 5)
        | (df1.w_team_seed == 13) & (df2.w_team_seed == 12)
        | (df1.w_team_seed == 3) & (df2.w_team_seed == 6)
        | (df1.w_team_seed == 3) & (df2.w_team_seed == 11)
        | (df1.w_team_seed == 14) & (df2.w_team_seed == 6)
        | (df1.w_team_seed == 14) & (df2.w_team_seed == 11)
        | (df1.w_team_seed == 2) & (df2.w_team_seed == 7)
        | (df1.w_team_seed == 2) & (df2.w_team_seed == 10)
        | (df1.w_team_seed == 15) & (df2.w_team_seed == 7)
        | (df1.w_team_seed == 15) & (df2.w_team_seed == 10)
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("Wteam_name"),
    (df1.W_TEAM_REGION).alias("w_team_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df2.W_TEAM_NAME).alias("lTeam_name"),
    (df2.W_TEAM_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
)

# Display the second round matchups
second_round_matchups.sort("L_REGION").show(20)

In [None]:
second_round = second_round_matchups.select('WTEAMID2','LTEAMID2')
second_round.show(16)

In [None]:
test_2023 = test.filter(F.col("season") == 2023).filter(F.col("round") == 1)

In [None]:
l_team = [col for col in test_2023.columns if not col.startswith('W_')]
l_team_data = test_2023[l_team].drop('WTEAMID','WSCORE','WIN_INDICATOR','ROUND','ROUND_NUMBER')

w_team = [col for col in test_2023.columns if not col.startswith('L_')]
w_team_data = test_2023[w_team].drop('lTEAMID','lSCORE','SEASON','ROUND_NUMBER','ROUND','WIN_INDICATOR')

games_32 = second_round.join(l_team_data,l_team_data.lteamid == second_round.LTEAMID2).join(w_team_data,w_team_data.wteamid == second_round.WTEAMID2,"inner").drop('LTEAMID2','WTEAMID2')

In [None]:
games_32.count()

In [None]:
result = grid_search.predict(games_32)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "season",
        "lteamID",
        "l_seed",
        "l_region",
        "wteamid",
        "w_seed",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("LTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "lteam_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_winners = (
    res_teamsl.join(teams, result.col("WTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "wteam_name")
    .select(
        "season",
        "lteamid",
        "wteamid",
        "LTEAM_NAME",
        "l_seed",
        "l_region",
        "WTEAM_NAME",
        "w_seed",
        "w_region",
        "pred_win_indicator",
    )
)

round_2_winners.sort("W_REGION").show(16)

In [None]:
round_2_winners = round_2_winners.with_column(
    "W_TEAM_ID",
    when((round_2_winners.PRED_WIN_INDICATOR == 1), round_2_winners.wteamid)
    .otherwise(round_2_winners.lteamID),
).with_column(
    "L_TEAM_ID",
    when((round_2_winners.PRED_WIN_INDICATOR == 0), round_2_winners.wteamid)
    .otherwise(round_2_winners.lteamID),
).with_column(
    "W_TEAM_NAME",
    when((round_2_winners.PRED_WIN_INDICATOR == 1), round_2_winners.wteam_name)
    .otherwise(round_2_winners.lteam_name),
).with_column(
    "L_TEAM_NAME",
    when((round_2_winners.PRED_WIN_INDICATOR == 0), round_2_winners.wteam_name)
    .otherwise(round_2_winners.lteam_name),
).with_column(
    "W_TEAM_SEED",
    when((round_2_winners.PRED_WIN_INDICATOR == 1), round_2_winners.w_seed)
    .otherwise(round_2_winners.l_seed),
).with_column(
    "L_TEAM_SEED",
    when((round_2_winners.PRED_WIN_INDICATOR == 0), round_2_winners.w_seed)
    .otherwise(round_2_winners.l_seed)
).with_column(
    "W_TEAM_REGION",
    when((round_2_winners.PRED_WIN_INDICATOR == 1), round_2_winners.w_region)
    .otherwise(round_2_winners.l_region)
).with_column(
    "L_TEAM_REGION",
    when((round_2_winners.PRED_WIN_INDICATOR == 0), round_2_winners.w_region)
    .otherwise(round_2_winners.l_region)
).select("W_TEAM_ID","L_TEAM_ID","W_TEAM_NAME","L_TEAM_NAME","W_TEAM_SEED","L_TEAM_SEED","W_TEAM_REGION","L_TEAM_REGION")

In [None]:
round_2_winners.count()

In [None]:
round_2_winners.sort("W_TEAM_REGION").show(20)

In [None]:
from snowflake.snowpark.functions import col

df1_distinct = round_2_winners.cache_result()
df2_distinct = round_2_winners.cache_result()


# Now, let's revise the join operation with proper condition grouping
sweet_16_matchups = df1_distinct.join(
    df2_distinct,
    (df1_distinct.W_TEAM_REGION == df2_distinct.W_TEAM_REGION) &
    (
        (df1_distinct.w_team_seed == 1) & (df2_distinct.w_team_seed.isin([5, 12, 4, 13])) |
        (df1_distinct.w_team_seed == 1) & (df2_distinct.w_team_seed.isin([5, 12, 4, 13])) |
        (df1_distinct.w_team_seed == 16) & (df2_distinct.w_team_seed.isin([5, 12, 4, 13])) |
        (df1_distinct.w_team_seed == 8) & (df2_distinct.w_team_seed.isin([5, 12, 4, 13])) |
        (df1_distinct.w_team_seed == 9) & (df2_distinct.w_team_seed.isin([5, 12, 4, 13])) 
        # (df1_distinct.w_team_seed == 5) & (df2_distinct.w_team_seed.isin([1, 16, 8, 9])) |
        # (df1_distinct.w_team_seed == 12) & (df2_distinct.w_team_seed.isin([1, 16, 8, 9])) |
        # (df1_distinct.w_team_seed == 4) & (df2_distinct.w_team_seed.isin([1, 16, 8, 9])) |
        # (df1_distinct.w_team_seed == 13) & (df2_distinct.w_team_seed.isin([1, 16, 8, 9]))
    ),
    "inner"
).select(
    (df1_distinct.W_TEAM_NAME).alias("Wteam_name"),
    (df1_distinct.W_TEAM_REGION).alias("w_team_region"),
    (df1_distinct.W_TEAM_ID).alias("WTeamID2"),
    (df2_distinct.W_TEAM_NAME).alias("lTeam_name"),
    (df2_distinct.W_TEAM_REGION).alias("l_region"),
    (df2_distinct.W_TEAM_ID).alias("lteamID2"),
)

## Sweet 16 matchups

In [None]:
sweet_16_matchups.sort("w_team_region").show()

# Predict Elite 8

In [None]:
third_round = sweet_16_matchups.select('WTEAMID2','LTEAMID2')

In [None]:
l_team = [col for col in test_2023.columns if not col.startswith('W_')]
l_team_data = test_2023[l_team].drop('WTEAMID','WSCORE','WIN_INDICATOR','ROUND','ROUND_NUMBER')

w_team = [col for col in test_2023.columns if not col.startswith('L_')]
w_team_data = test_2023[w_team].drop('lTEAMID','lSCORE','SEASON','ROUND_NUMBER','ROUND','WIN_INDICATOR')

games_16 = third_round.join(l_team_data,l_team_data.lteamid == third_round.LTEAMID2).join(w_team_data,w_team_data.wteamid == third_round.WTEAMID2,"inner").drop('LTEAMID2','WTEAMID2')

In [None]:
result = grid_search.predict(games_16)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = result.select(
    "season",
    "lteamID",
    "l_seed",
    "l_region",
    "wteamid",
    "w_seed",
    "w_region",
    "pred_win_indicator",
)

res_teamsl = (
    result.join(teams, result.col("LTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "lteam_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_3_winners = (
    res_teamsl.join(teams, result.col("WTEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "wteam_name")
    .select(
        "season",
        "lteamid",
        "wteamid",
        "LTEAM_NAME",
        "l_seed",
        "l_region",
        "WTEAM_NAME",
        "w_seed",
        "w_region",
        "pred_win_indicator",
    )
)

round_3_winners.sort("W_REGION").show(16)

In [None]:
round_3_winners = round_3_winners.with_column(
    "W_TEAM_ID",
    when((round_3_winners.PRED_WIN_INDICATOR == 1), round_3_winners.wteamid)
    .otherwise(round_3_winners.lteamID),
).with_column(
    "L_TEAM_ID",
    when((round_3_winners.PRED_WIN_INDICATOR == 0), round_3_winners.wteamid)
    .otherwise(round_3_winners.lteamID),
).with_column(
    "W_TEAM_NAME",
    when((round_3_winners.PRED_WIN_INDICATOR == 1), round_3_winners.wteam_name)
    .otherwise(round_3_winners.lteam_name),
).with_column(
    "L_TEAM_NAME",
    when((round_3_winners.PRED_WIN_INDICATOR == 0), round_3_winners.wteam_name)
    .otherwise(round_3_winners.lteam_name),
).with_column(
    "W_TEAM_SEED",
    when((round_3_winners.PRED_WIN_INDICATOR == 1), round_3_winners.w_seed)
    .otherwise(round_3_winners.l_seed),
).with_column(
    "L_TEAM_SEED",
    when((round_3_winners.PRED_WIN_INDICATOR == 0), round_3_winners.w_seed)
    .otherwise(round_3_winners.l_seed)
).with_column(
    "W_TEAM_REGION",
    when((round_3_winners.PRED_WIN_INDICATOR == 1), round_3_winners.w_region)
    .otherwise(round_3_winners.l_region)
).with_column(
    "L_TEAM_REGION",
    when((round_3_winners.PRED_WIN_INDICATOR == 0), round_3_winners.w_region)
    .otherwise(round_3_winners.l_region)
).select("W_TEAM_ID","L_TEAM_ID","W_TEAM_NAME","L_TEAM_NAME","W_TEAM_SEED","L_TEAM_SEED","W_TEAM_REGION","L_TEAM_REGION")

In [None]:
round_3_winners.show()

In [None]:
elite_eight_matchups = df1.join(
    df2,
    (df1.w_team_region == df2.w_team_region) & (df1.W_TEAM_ID != df2.W_TEAM_ID),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("team1_name"),
    (df1.W_TEAM_REGION).alias("team1_region"),
    (df1.W_TEAM_ID).alias("team1_ID"),
    (df2.W_TEAM_NAME).alias("team2_name"),
    (df2.W_TEAM_REGION).alias("team2_region"),
    (df2.W_TEAM_ID).alias("team2_ID"),
)

## Final Four

In [None]:
final_four_matchups = df1.join(
    df2,
    (df1.w_team_region != df2.w_team_region)  # Ensuring teams are from different regions
    & (
        ((df1.w_team_region == 'Region 1') & (df2.w_team_region == 'Region 2'))
        | ((df1.w_team_region == 'Region 2') & (df2.w_team_region == 'Region 1'))  # Pairing Region 1 and 2
        | ((df1.w_team_region == 'Region 3') & (df2.w_team_region == 'Region 4'))
        | ((df1.w_team_region == 'Region 4') & (df2.w_team_region == 'Region 3'))  # Pairing Region 3 and 4
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("team1_name"),
    (df1.W_TEAM_REGION).alias("team1_region"),
    (df1.W_TEAM_ID).alias("team1_ID"),
    (df2.W_TEAM_NAME).alias("team2_name"),
    (df2.W_TEAM_REGION).alias("team2_region"),
    (df2.W_TEAM_ID).alias("team2_ID"),
)