In [107]:
import polars as pl

import dataclasses
import pandas as pd
import functools

In [95]:
submission_probs = pd.read_csv('../data/kaggle_2025/cleaned/submissions/submission_base_mens_with_kenpom.csv')
submission_probs

Unnamed: 0,ID,Pred
0,2025_1101_1102,0.632365
1,2025_1101_1103,0.325837
2,2025_1101_1104,0.079927
3,2025_1101_1105,0.793165
4,2025_1101_1106,0.609938
...,...,...
131402,2025_3477_3479,0.910224
131403,2025_3477_3480,0.782747
131404,2025_3478_3479,0.401514
131405,2025_3478_3480,0.192509


In [12]:
mens_team = pd.read_csv('../data/kaggle_2025/raw/MTeams.csv')
mens_team

Unnamed: 0,TeamID,TeamName,FirstD1Season,LastD1Season
0,1101,Abilene Chr,2014,2025
1,1102,Air Force,1985,2025
2,1103,Akron,1985,2025
3,1104,Alabama,1985,2025
4,1105,Alabama A&M,2000,2025
...,...,...,...,...
375,1476,Stonehill,2023,2025
376,1477,East Texas A&M,2023,2025
377,1478,Le Moyne,2024,2025
378,1479,Mercyhurst,2025,2025


In [5]:
mens_team_spellings = pd.read_csv('../data/kaggle_2025/raw/MTeamSpellings.csv')
mens_team_spellings

Unnamed: 0,TeamNameSpelling,TeamID
0,a&m-corpus chris,1394
1,a&m-corpus christi,1394
2,abilene chr,1101
3,abilene christian,1101
4,abilene-christian,1101
...,...,...
1173,youngstown st,1464
1174,youngstown st.,1464
1175,youngstown state,1464
1176,youngstown-st,1464


In [47]:
mens_team_spellings[mens_team_spellings['TeamNameSpelling'].str.contains('ole miss')]

Unnamed: 0,TeamNameSpelling,TeamID
725,ole miss,1279


In [48]:
matchups_raw = [
    # SOUTH Region (Left side, top to bottom)
    [1120, 1106],    # Auburn (1) vs. Alabama State (16)
    [1257, 1166],    # Louisville (8) vs. Creighton (9)
    [1276, 1471],    # Michigan (5) vs. UC San Diego (12)
    [1401, 1463],    # Texas A&M (4) vs. Yale (13)
    [1279, 1314],    # Ole Miss (6) vs. North Carolina (11)
    [1235, 1252],    # Iowa State (3) vs. Lipscomb (14)
    [1266, 1307],    # Marquette (7) vs. New Mexico (10)
    [1277, 1136],    # Michigan State (2) vs. Bryant (15)
    
    # EAST Region (Right side, top to bottom)
    [1181, 1291],    # Duke (1) vs. Mount St. Mary's (16)
    [1280, 1124],    # Mississippi State (8) vs. Baylor (9)
    [1332, 1251],    # Oregon (5) vs. Liberty (12)
    [1112, 1103],    # Arizona (4) vs. Akron (13)
    [1140, 1433],    # BYU (6) vs. VCU (11)
    [1458, 1285],    # Wisconsin (3) vs. Montana (14)
    [1388, 1435],    # Saint Mary's (7) vs. Vanderbilt (10)
    [1104, 1352],    # Alabama (2) vs. Robert Morris (15)
    
    # WEST Region (Bottom left, top to bottom)
    [1196, 1313],    # Florida (1) vs. Norfolk State (16)
    [1163, 1328],    # UConn (8) vs. Oklahoma (9)
    [1272, 1161],    # Memphis (5) vs. Colorado State (12)
    [1268, 1213],    # Maryland (4) vs. Grand Canyon (13)
    [1281, 1179],    # Missouri (6) vs. Drake (11)
    [1403, 1423],    # Texas Tech (3) vs. UNC Wilmington (14)
    [1242, 1116],    # Kansas (7) vs. Arkansas (10)
    [1385, 1303],    # St. John's (2) vs. Omaha (15)
    
    # MIDWEST Region (Bottom right, top to bottom)
    [1222, 1188],    # Houston (1) vs. SIU Edwardsville (16)
    [1211, 1208],    # Gonzaga (8) vs. Georgia (9)
    [1155, 1270],    # Clemson (5) vs. McNeese (12)
    [1345, 1219],    # Purdue (4) vs. High Point (13)
    [1228, 1462],    # Illinois (6) vs. Xavier (11)
    [1246, 1407],    # Kentucky (3) vs. Troy (14)
    [1417, 1429],    # UCLA (7) vs. Utah State (10)
    [1397, 1459]     # Tennessee (2) vs. Wofford (15)
]

In [152]:
@dataclasses.dataclass(slots=True)
class Team:
    _id: int
    name: str

    @classmethod
    def from_id(cls, id: int):
        id_map = mens_team[mens_team['TeamID'] == id]
        assert len(id_map) == 1, f"id_map: {id_map}"
        return cls(id, id_map.TeamName.item())

    def __hash__(self):
        return self._id

def get_win_prob(team1: Team, team2: Team, probs: pd.DataFrame = submission_probs):
    if team1._id < team2._id:
        prob = submission_probs[submission_probs['ID'] == f"2025_{team1._id}_{team2._id}"]
        assert len(prob) == 1, f"prob: {prob}"
        return prob.Pred.item()
    else:
        prob = submission_probs[submission_probs['ID'] == f"2025_{team2._id}_{team1._id}"]
        assert len(prob) == 1, f"prob: {prob}"
        return 1 - prob.Pred.item()

Team.from_id(1345)

Team(_id=1345, name='Purdue')

In [153]:
matchups = [
    [Team.from_id(matchup[0]), Team.from_id(matchup[1])] 
    if matchup[0] < matchup[1]
    else [Team.from_id(matchup[1]), Team.from_id(matchup[0])] 
    for matchup in matchups_raw
]

In [161]:
@dataclasses.dataclass
class Game:
    team1: Team
    team2: Team
    
    @functools.cached_property
    def win_prob(self) -> float:
        return get_win_prob(self.team1, self.team2)

    @functools.cached_property
    def winner(self) -> Team:
        return self.team1 if self.win_prob > .5 else self.team2

    @property
    def teams(self) -> list[Team]:
        return [self.team1, self.team2]

    @property 
    def all_win_probs(self)-> dict[Team, float]:
        return {
            self.team1: self.win_prob,
            self.team2: 1 - self.win_prob
        }

@dataclasses.dataclass
class HyperGame:
    prev_game1: Game | HyperGame
    prev_game2: Game | HyperGame

    @property
    def deterministic_win_prob(self):
        game1_winner = self.prev_game1.winner
        game2_winner = self.prev_game2.winner
        return Game(game1_winner, game2_winner).win_prob

    @functools.cached_property
    def winner(self):
        game1_winner = self.prev_game1.winner
        game2_winner = self.prev_game2.winner
        game = Game(game1_winner, game2_winner)
        if game.win_prob > .5:
            return game.team1
        return game.team2

    @functools.cached_property
    def teams(self):
        return self.prev_game1.teams + self.prev_game2.teams
    
    @functools.cached_property
    def all_win_probs(self) -> dict[Team, float]:
        win_probs = {}
        for team in self.prev_game1.all_win_probs:
            for team2 in self.prev_game2.all_win_probs:
                prob_game_happens = self.prev_game1.all_win_probs[team] * self.prev_game2.all_win_probs[team2]
                win_probs[team] = win_probs.get(team, 0) + get_win_prob(team, team2) * prob_game_happens
                win_probs[team2] = win_probs.get(team2, 0) + get_win_prob(team2, team) * prob_game_happens
        return win_probs
            
            
    

In [162]:
first_round = [
    Game(matchup[0], matchup[1]) for matchup in matchups
]
for i, game in enumerate(first_round):
    print(game.team1, game.team2, game.win_prob)
    if i % 8 == 7:
        print()

Team(_id=1106, name='Alabama St') Team(_id=1120, name='Auburn') 0.0362119842689986
Team(_id=1166, name='Creighton') Team(_id=1257, name='Louisville') 0.4204182155068385
Team(_id=1276, name='Michigan') Team(_id=1471, name='UC San Diego') 0.6203590595649396
Team(_id=1401, name='Texas A&M') Team(_id=1463, name='Yale') 0.7135302169085129
Team(_id=1279, name='Mississippi') Team(_id=1314, name='North Carolina') 0.520322990091003
Team(_id=1235, name='Iowa St') Team(_id=1252, name='Lipscomb') 0.8042768248663442
Team(_id=1266, name='Marquette') Team(_id=1307, name='New Mexico') 0.5784697017777151
Team(_id=1136, name='Bryant') Team(_id=1277, name='Michigan St') 0.1501657075342823

Team(_id=1181, name='Duke') Team(_id=1291, name="Mt St Mary's") 0.960376676385308
Team(_id=1124, name='Baylor') Team(_id=1280, name='Mississippi St') 0.5240621770168928
Team(_id=1251, name='Liberty') Team(_id=1332, name='Oregon') 0.435892646295553
Team(_id=1103, name='Akron') Team(_id=1112, name='Arizona') 0.1904230721

In [163]:
second_round = [
    HyperGame(first_round[i], first_round[i+1])
    for i in range(0, len(first_round), 2)
]
for i, game in enumerate(second_round):
    if game.prev_game1.winner._id < game.prev_game2.winner._id:
        print(game.prev_game1.winner, game.prev_game2.winner, game.deterministic_win_prob)
    else:
        print(game.prev_game2.winner, game.prev_game1.winner, game.deterministic_win_prob)
        
    if i % 4 == 3:
        print()

Team(_id=1120, name='Auburn') Team(_id=1257, name='Louisville') 0.7203726843016799
Team(_id=1276, name='Michigan') Team(_id=1401, name='Texas A&M') 0.4810873451006361
Team(_id=1235, name='Iowa St') Team(_id=1279, name='Mississippi') 0.3873700440146389
Team(_id=1266, name='Marquette') Team(_id=1277, name='Michigan St') 0.4354246620419457

Team(_id=1124, name='Baylor') Team(_id=1181, name='Duke') 0.7548755008970636
Team(_id=1112, name='Arizona') Team(_id=1332, name='Oregon') 0.3487589640900486
Team(_id=1140, name='BYU') Team(_id=1458, name='Wisconsin') 0.4964876567807687
Team(_id=1104, name='Alabama') Team(_id=1388, name="St Mary's CA") 0.34658671784548967

Team(_id=1163, name='Connecticut') Team(_id=1196, name='Florida') 0.7218734895685577
Team(_id=1161, name='Colorado St') Team(_id=1268, name='Maryland') 0.3316639448711342
Team(_id=1281, name='Missouri') Team(_id=1403, name='Texas Tech') 0.4317886933778651
Team(_id=1242, name='Kansas') Team(_id=1385, name="St John's") 0.506510818687336

In [164]:
third_round = [
    HyperGame(second_round[i], second_round[i+1])
    for i in range(0, len(second_round), 2)
]
for i, game in enumerate(third_round):
    if game.prev_game1.winner._id < game.prev_game2.winner._id:
        print(game.prev_game1.winner, game.prev_game2.winner, game.deterministic_win_prob)
    else:
        print(game.prev_game2.winner, game.prev_game1.winner, game.deterministic_win_prob)
        
    if i % 2 == 1:
        print()

Team(_id=1120, name='Auburn') Team(_id=1401, name='Texas A&M') 0.7231637673379046
Team(_id=1235, name='Iowa St') Team(_id=1277, name='Michigan St') 0.5421543559329141

Team(_id=1112, name='Arizona') Team(_id=1181, name='Duke') 0.6707293414696457
Team(_id=1104, name='Alabama') Team(_id=1458, name='Wisconsin') 0.40513502549691016

Team(_id=1196, name='Florida') Team(_id=1268, name='Maryland') 0.6249945850181572
Team(_id=1242, name='Kansas') Team(_id=1403, name='Texas Tech') 0.5872700538518312

Team(_id=1222, name='Houston') Team(_id=1345, name='Purdue') 0.7099970192881464
Team(_id=1246, name='Kentucky') Team(_id=1397, name='Tennessee') 0.4367175145643852



In [165]:
fourth_round = [
    HyperGame(third_round[i], third_round[i+1])
    for i in range(0, len(third_round), 2)
]
for i, game in enumerate(fourth_round):
    if game.prev_game1.winner._id < game.prev_game2.winner._id:
        print(game.prev_game1.winner, game.prev_game2.winner, game.deterministic_win_prob)
    else:
        print(game.prev_game2.winner, game.prev_game1.winner, game.deterministic_win_prob)
        
    if i % 2 == 1:
        print()

Team(_id=1120, name='Auburn') Team(_id=1235, name='Iowa St') 0.6596947680819889
Team(_id=1104, name='Alabama') Team(_id=1181, name='Duke') 0.6088516966830781

Team(_id=1196, name='Florida') Team(_id=1403, name='Texas Tech') 0.5761324445579165
Team(_id=1222, name='Houston') Team(_id=1397, name='Tennessee') 0.6463115773180642



In [166]:
fifth_round = [
    HyperGame(fourth_round[0], fourth_round[2]),
    HyperGame(fourth_round[1], fourth_round[3])
]
for i, game in enumerate(fifth_round):
    if game.prev_game1.winner._id < game.prev_game2.winner._id:
        print(game.prev_game1.winner, game.prev_game2.winner, game.deterministic_win_prob)
    else:
        print(game.prev_game2.winner, game.prev_game1.winner, game.deterministic_win_prob)
        

Team(_id=1120, name='Auburn') Team(_id=1196, name='Florida') 0.5585572418090435
Team(_id=1181, name='Duke') Team(_id=1222, name='Houston') 0.4956568196552914


In [167]:
sixth_round = [
    HyperGame(fifth_round[0], fifth_round[1]),
]
for i, game in enumerate(sixth_round):
    if game.prev_game1.winner._id < game.prev_game2.winner._id:
        print(game.prev_game1.winner, game.prev_game2.winner, game.deterministic_win_prob)
    else:
        print(game.prev_game2.winner, game.prev_game1.winner, game.deterministic_win_prob)
        

Team(_id=1120, name='Auburn') Team(_id=1222, name='Houston') 0.482814873303558


In [168]:
for team, prob in sorted(sixth_round[0].all_win_probs.items(), key=lambda team_prob_tuple: -team_prob_tuple[1]):
    print(team, prob)

Team(_id=1120, name='Auburn') 0.1509411608831387
Team(_id=1181, name='Duke') 0.14696783129013896
Team(_id=1222, name='Houston') 0.14372633912922572
Team(_id=1196, name='Florida') 0.0963725771068435
Team(_id=1104, name='Alabama') 0.054247083932949414
Team(_id=1403, name='Texas Tech') 0.0407470391149578
Team(_id=1397, name='Tennessee') 0.03908034519325122
Team(_id=1235, name='Iowa St') 0.0296718524507985
Team(_id=1112, name='Arizona') 0.024828204027356988
Team(_id=1268, name='Maryland') 0.024175978633879568
Team(_id=1277, name='Michigan St') 0.020551690365418658
Team(_id=1458, name='Wisconsin') 0.017750191797743304
Team(_id=1246, name='Kentucky') 0.017239072957106916
Team(_id=1281, name='Missouri') 0.015272873223619868
Team(_id=1385, name="St John's") 0.01456541851338203
Team(_id=1345, name='Purdue') 0.01432953715242205
Team(_id=1211, name='Gonzaga') 0.014180877446131755
Team(_id=1140, name='BYU') 0.01243036424121394
Team(_id=1242, name='Kansas') 0.012343582942767451
Team(_id=1228, name=

In [190]:
for game in sixth_round:
    for team, prob in sorted(game.all_win_probs.items(), key=lambda team_prob_tuple: -team_prob_tuple[1])[:1]:
        print(team, prob)
    print()

Team(_id=1120, name='Auburn') 0.1509411608831387

