# Load Data

In [6]:
from pathlib import Path
import polars as pl
import datetime as dt
from tools.database import LoadFromLocal

DATA_ROOT = './data'

def csv_to_df(filename: str, dtype: dict = None, pdates: bool = True, sep: str = ',') -> pl.LazyFrame:
    return pl.read_csv(Path(DATA_ROOT, filename), 
                       try_parse_dates=pdates, 
                       truncate_ragged_lines=True, 
                       infer_schema_length=None, 
                       schema_overrides=dtype,
                       separator=sep).lazy()

In [7]:
TEAMS = csv_to_df('teams.csv', sep=';')
SHIFTS = csv_to_df('shifts.csv')
ROSTERS = csv_to_df('rosters.csv')
PLAYS = csv_to_df('plays.csv')
PLAY_DETAILS = csv_to_df('play_details.csv')
GAMES = csv_to_df('games.csv', pdates=True)

# Data tables

In [16]:
TEAMS.collect().shape, TEAMS.head(1).collect()

((32, 9),
 shape: (1, 9)
 ┌────────┬───────┬────────┬───────────┬───┬────────────────────┬─────────┬─────────┬───────────────┐
 │ teamid ┆ name  ┆ abbrev ┆ placename ┆ … ┆ darklogo           ┆ color1  ┆ color2  ┆ fullname      │
 │ ---    ┆ ---   ┆ ---    ┆ ---       ┆   ┆ ---                ┆ ---     ┆ ---     ┆ ---           │
 │ i64    ┆ str   ┆ str    ┆ str       ┆   ┆ str                ┆ str     ┆ str     ┆ str           │
 ╞════════╪═══════╪════════╪═══════════╪═══╪════════════════════╪═════════╪═════════╪═══════════════╡
 │ 24     ┆ Ducks ┆ ANA    ┆ Anaheim   ┆ … ┆ https://assets.nhl ┆ #00004E ┆ #FFB81C ┆ Anaheim Ducks │
 │        ┆       ┆        ┆           ┆   ┆ e.com/logos/…      ┆         ┆         ┆               │
 └────────┴───────┴────────┴───────────┴───┴────────────────────┴─────────┴─────────┴───────────────┘)

In [17]:
SHIFTS.collect().shape, SHIFTS.head(1).collect()

((1851750, 13),
 shape: (1, 13)
 ┌──────────┬────────────┬──────────┬──────────┬───┬────────────┬──────────┬────────┬──────────┐
 │ id       ┆ detailCode ┆ duration ┆ playerId ┆ … ┆ gameId     ┆ hexValue ┆ teamId ┆ typeCode │
 │ ---      ┆ ---        ┆ ---      ┆ ---      ┆   ┆ ---        ┆ ---      ┆ ---    ┆ ---      │
 │ i64      ┆ i64        ┆ str      ┆ i64      ┆   ┆ i64        ┆ str      ┆ i64    ┆ i64      │
 ╞══════════╪════════════╪══════════╪══════════╪═══╪════════════╪══════════╪════════╪══════════╡
 │ 12438645 ┆ 0          ┆ 01:01    ┆ 8471699  ┆ … ┆ 2022020365 ┆ #6F263D  ┆ 21     ┆ 517      │
 └──────────┴────────────┴──────────┴──────────┴───┴────────────┴──────────┴────────┴──────────┘)

In [18]:
PLAYS.collect().shape, PLAYS.head(1).collect()

((767648, 10),
 shape: (1, 10)
 ┌────────────┬─────────┬────────┬────────────┬───┬──────────┬─────────────┬───────────┬────────────┐
 │ gameId     ┆ eventId ┆ period ┆ periodType ┆ … ┆ typeCode ┆ typeDescKey ┆ sortOrder ┆ homeTeamDe │
 │ ---        ┆ ---     ┆ ---    ┆ ---        ┆   ┆ ---      ┆ ---         ┆ ---       ┆ fendingSid │
 │ i64        ┆ i64     ┆ i64    ┆ str        ┆   ┆ i64      ┆ str         ┆ i64       ┆ e          │
 │            ┆         ┆        ┆            ┆   ┆          ┆             ┆           ┆ ---        │
 │            ┆         ┆        ┆            ┆   ┆          ┆             ┆           ┆ str        │
 ╞════════════╪═════════╪════════╪════════════╪═══╪══════════╪═════════════╪═══════════╪════════════╡
 │ 2022020365 ┆ 51      ┆ 1      ┆ REG        ┆ … ┆ 520      ┆ period-star ┆ 8         ┆ left       │
 │            ┆         ┆        ┆            ┆   ┆          ┆ t           ┆           ┆            │
 └────────────┴─────────┴────────┴────────────┴───┴

In [19]:
PLAY_DETAILS.collect().shape, PLAY_DETAILS.head(1).collect()

((767648, 37),
 shape: (1, 37)
 ┌────────────┬─────────┬──────────────┬─────────────┬───┬─────────────┬────────┬────────┬──────────┐
 │ gameId     ┆ eventId ┆ assist1Playe ┆ assist1Play ┆ … ┆ winningPlay ┆ xCoord ┆ yCoord ┆ zoneCode │
 │ ---        ┆ ---     ┆ rId          ┆ erTotal     ┆   ┆ erId        ┆ ---    ┆ ---    ┆ ---      │
 │ i64        ┆ i64     ┆ ---          ┆ ---         ┆   ┆ ---         ┆ i64    ┆ i64    ┆ str      │
 │            ┆         ┆ i64          ┆ i64         ┆   ┆ i64         ┆        ┆        ┆          │
 ╞════════════╪═════════╪══════════════╪═════════════╪═══╪═════════════╪════════╪════════╪══════════╡
 │ 2022020365 ┆ 51      ┆ null         ┆ null        ┆ … ┆ null        ┆ null   ┆ null   ┆ null     │
 └────────────┴─────────┴──────────────┴─────────────┴───┴─────────────┴────────┴────────┴──────────┘)

In [20]:
ROSTERS.collect().shape, ROSTERS.head(1).collect()

((97052, 6),
 shape: (1, 6)
 ┌────────┬──────────┬───────────────┬──────────────┬─────────────────────────────────┬────────────┐
 │ teamId ┆ playerId ┆ sweaterNumber ┆ positionCode ┆ headshot                        ┆ gameId     │
 │ ---    ┆ ---      ┆ ---           ┆ ---          ┆ ---                             ┆ ---        │
 │ i64    ┆ i64      ┆ i64           ┆ str          ┆ str                             ┆ i64        │
 ╞════════╪══════════╪═══════════════╪══════════════╪═════════════════════════════════╪════════════╡
 │ 7      ┆ 8467950  ┆ 41            ┆ G            ┆ https://assets.nhle.com/mugs/n… ┆ 2022020365 │
 └────────┴──────────┴───────────────┴──────────────┴─────────────────────────────────┴────────────┘)

In [21]:
GAMES.collect().shape, GAMES.head(1).collect()

((2427, 7),
 shape: (1, 7)
 ┌────────────┬──────────┬──────────┬──────────────┬──────────────────┬────────────┬────────────┐
 │ id         ┆ season   ┆ gameType ┆ startTimeUTC ┆ venueTimezone    ┆ awayTeamId ┆ homeTeamId │
 │ ---        ┆ ---      ┆ ---      ┆ ---          ┆ ---              ┆ ---        ┆ ---        │
 │ i64        ┆ i64      ┆ i64      ┆ datetime[μs, ┆ str              ┆ i64        ┆ i64        │
 │            ┆          ┆          ┆ UTC]         ┆                  ┆            ┆            │
 ╞════════════╪══════════╪══════════╪══════════════╪══════════════════╪════════════╪════════════╡
 │ 2022020365 ┆ 20222023 ┆ 2        ┆ 2022-12-02   ┆ America/New_York ┆ 21         ┆ 7          │
 │            ┆          ┆          ┆ 00:00:00 UTC ┆                  ┆            ┆            │
 └────────────┴──────────┴──────────┴──────────────┴──────────────────┴────────────┴────────────┘)

In [22]:
COMBINED_PLAYS = PLAYS.join(PLAY_DETAILS, on=['gameId', 'eventId'])
COMBINED_PLAYS.head(1).collect()

gameId,eventId,period,periodType,timeInPeriod,situationCode,typeCode,typeDescKey,sortOrder,homeTeamDefendingSide,assist1PlayerId,assist1PlayerTotal,assist2PlayerId,assist2PlayerTotal,awaySOG,awayScore,blockingPlayerId,committedByPlayerId,descKey,discreteClip,drawnByPlayerId,duration,eventOwnerTeamId,goalieInNetId,highlightClip,highlightClipFr,highlightClipSharingUrl,highlightClipSharingUrlFr,hitteePlayerId,hittingPlayerId,homeSOG,homeScore,losingPlayerId,playerId,reason,scoringPlayerId,scoringPlayerTotal,secondaryReason,shootingPlayerId,shotType,typeCode_right,winningPlayerId,xCoord,yCoord,zoneCode
i64,i64,i64,str,str,i64,i64,str,i64,str,i64,i64,i64,i64,i64,i64,i64,i64,str,i64,i64,i64,i64,i64,i64,i64,str,str,i64,i64,i64,i64,i64,i64,str,i64,i64,str,i64,str,str,i64,i64,i64,str
2022020365,51,1,"""REG""","""00:00""",1551,520,"""period-start""",8,"""left""",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


### Investigating Net location mapping

Seems like the `homeTeamDefendingSide` isn't very reliable as an indicator of location mapping, and 
possibly zodeCode too

In [None]:
COMBINED_PLAYS.collect()('homeTeamDefendingSide').value_counts()

homeTeamDefendingSide,count
str,u32
"""right""",392658
"""left""",374990


In [56]:
(
    COMBINED_PLAYS
    .filter(pl.col('xCoord').is_not_null())
    .join(GAMES, left_on="gameId", right_on="id")
    .with_columns(
        isHomeEvent=pl.col('eventOwnerTeamId') == pl.col('homeTeamId'),
        homeTeamNetX=pl.when(pl.col('homeTeamDefendingSide') == pl.lit('left')).then(pl.lit(100)).otherwise(pl.lit(-100)),
        awayTeamNetX=pl.when(pl.col('homeTeamDefendingSide') == pl.lit('left')).then(pl.lit(-100)).otherwise(pl.lit(100)),
    )
    .with_columns(
        netX=pl.when(pl.col('isHomeEvent')).then(pl.col('homeTeamNetX')).otherwise(pl.col('awayTeamNetX')),
        isHomeEventId=pl.when(pl.col('isHomeEvent')).then(pl.lit(1)).otherwise(pl.lit(-1)),
    )
    .with_columns(isConsistent=pl.col('isHomeEventId').mul(pl.col('netX')).mul('xCoord') >= 0)
    .select([
        'xCoord', 'isHomeEvent', 'netX', 'typeDescKey',
        'zoneCode', 'isConsistent',
        # 'homeTeamNetX', 'awayTeamNetX', 
    ])
    .filter(pl.col('isConsistent'))
    # .filter(pl.col('isConsistent').not_())
    # .head(10)
    .collect()
)

xCoord,isHomeEvent,netX,typeDescKey,zoneCode,isConsistent
i64,bool,i32,str,str,bool
0,true,100,"""faceoff""","""N""",true
47,true,100,"""shot-on-goal""","""O""",true
74,true,100,"""blocked-shot""","""D""",true
76,true,100,"""missed-shot""","""O""",true
98,false,-100,"""hit""","""D""",true
…,…,…,…,…,…
2,true,100,"""shot-on-goal""","""N""",true
5,false,-100,"""giveaway""","""N""",true
20,true,100,"""faceoff""","""N""",true
0,false,-100,"""faceoff""","""N""",true


# xG Attempt 1: Logistic Regression

In [14]:
# setting up some constants
SHOT_TYPE_DESC_KEY = ['shot-on-goal', 'blocked-shot', 'missed-shot', 'goal']

## Feature engineering

- doesn't do too well with too many parameters
- first attempt at a basic model

Features to use:

| feature | description | is parameter |
| ------- | ----------- | ------------ |
| `typeDescKey` | event desc | y |
| `sortOrder` | canonical order of events | n |
| `period` | periodId | y |
| `periodType` | REG / OT (filter out SO) | n |
| `eventOwnerTeamId` | event owner team id | y |
| `xCoord` | rink xcoords (N/S) | n |
| `yCoord` | rink ycoords (E/W) | n |

Features to add:

| feature | description | is parameter |
| ------- | ----------- | ------------ |
| `deltaTime` | time since last event (any) | n |
| `deltaTimeShot` | time since last shot | n |
| `deltaDist` | distance since last event (any) | n |
| `deltaDistNet` | distance from owner team's net | n |
| `shotAngle` | angle of shot from center North/South line | n |
| `deltaAngle` | change in angle since last shot | n |
| `rebound` | 0 / 1 | n |
| `rush` | 0 / 1 | n |
| `goal` | 0 / 1 | n |
| `shot` | 0 / 1 | n |
| `block` | 0 / 1 | n |
| `miss` | 0 / 1 | n |
| `isHome` | True / False | n |
| `sumHomeGoal` | cumsum | n |
| `sumHomeShot` | cumsum | n |
| `sumHomeBlock` | cumsum | n |
| `sumHomeMiss` | cumsum | n |
| `homeCf` |  shots + blocks + misses | y |
| `sumAwayGoal` | cumsum | n |
| `sumAwayShot` | cumsum | n |
| `sumAwayBlock` | cumsum | n |
| `sumAwayMiss` | cumsum | n |
| `awayCf` |  shots + blocks + misses | y |
| `gDiff` | homeGoal - awayGoal | n |
| `homeSkt` | home num of skaters | n |
| `homeGoalie` | if home goalie in net, then `1`, else `0` | n |
| `homeEn` | invert `homeGoalie` | y |
| `awaySkt` | away num of skaters | n |
| `awayGoalie` | if away goalie in net, then `1`, else `0` | n |
| `awayEn` | invert `awayGoalie` | y |
| `otLength` | if gametype is 2, then `5 * 60`, else `20 * 60` | n |
| `timeInPeriodSec` | computed time from start of period | y |
| `timeSinceStartSec` | computed time from start of game | n |

In [None]:
RELEVANT_FIELDS = [ 'id', 'eventid', 'typedesckey']
plays = PLAYS.with_columns(
    goal=pl.when(pl.col('typedesckey') == 'goal').then(1).otherwise(0),
    shot=pl.when(pl.col('typedesckey').is_in(SHOT_TYPE_DESC_KEY)).then(1).otherwise(0),
    situationcode=pl.col('situationcode').cast(str),
    timeinperiod_s=(pl.col("timeinperiod").dt.hour().cast(pl.Int64) * 60) + pl.col("timeinperiod").dt.minute().cast(pl.Int64),
    timeremaining_s=(pl.col("timeremaining").dt.hour() * 60).cast(pl.Int64) + pl.col("timeremaining").dt.minute().cast(pl.Int64)
).with_columns(
    timefromstart_s=pl.col('timeinperiod_s').add(
        pl.col('period').sub(1).mul(
            pl.when(pl.col('period') <= 3).then(1200).otherwise(600)
        )
    )
).join(games, on=['gameid'])
plays.filter(pl.col('gameid') == 2024020900).tail(10)

In [None]:
cplays = plays.join(PLAY_DETAILS, on=['eventid', 'gameid']).lazy()

In [None]:
shifts = SHIFTS.with_columns(
    start_s=(pl.col("starttime").dt.hour().cast(pl.Int64) * 60) + pl.col("starttime").dt.minute().cast(pl.Int64),
    end_s=(pl.col("endtime").dt.hour().cast(pl.Int64) * 60) + pl.col("endtime").dt.minute().cast(pl.Int64)
).with_columns(
    timefromstart_s=pl.col('start_s').add(
        pl.col('period').sub(1).mul(
            pl.when(pl.col('period') <= 3).then(1200).otherwise(600)
        )
    )
)
shifts.head()

## Feature extraction

We will attempt to make our own version of MoneyPuck's xG model.

**Parameters**:

1. `deltadnet` Shot Distance From Net
2. `deltat` Time Since Last Game Event
3. `shottype` Shot Type (Slap, Wrist, Backhand, etc)
4. `speed` = `deltad / deltat` Speed From Previous Event
5. `angle` Shot Angle
6. `lastxcoord` East-West Location on Ice of Last Event Before the Shot
7. `rebound_deltaomega` = `deltaangle / deltat` If Rebound, difference in shot angle divided by time since last shot
    - `rebound`
8. `lastevent` Last Event That Happened Before the Shot (Faceoff, Hit, etc)
9. `oppskt` Other team's # of skaters on ice
    - `oppmoi` Other team's # of skaters on ice
10. `xcoord` East-West Location on Ice of Shot
11. `ppteam` Team with Man Advantage Situation ('home'/'away'/null)
12. `timepp` Time since current Powerplay started
13. `deltad` Distance From Previous Event
14. `ycoord` North-South Location on Ice of Shot
15. `enteam` Team with Empty Net Situation ('home'/'away'/null)

And we have the columns `shot` and `goal` as outcome indicators.

In [None]:
features = cplays.with_columns(
    homescore=pl.col('homescore').replace(None, 0).cum_max().over('gameid'),
    awayscore=pl.col('awayscore').replace(None, 0).cum_max().over('gameid'),
    homesog=pl.col('homesog').replace(None, 0).cum_max().over('gameid'),
    awaysog=pl.col('awaysog').replace(None, 0).cum_max().over('gameid'),
).sort('sortorder').sort('gameid')
# features.select(['gameid', 'sortorder', 'eventid']).head(100)

In [None]:
# EVEN_STRENGTH = ["1551", "1441", "1331", "0440", "0660"]
SHOOTOUT_CODES = ["0101", "1010"]

pl.Config().set_tbl_rows(30)
cfeatures = features.with_columns(
    lasttime_s=pl.col('timefromstart_s').shift(1).over(['gameid', 'period']).fill_null(0),
    deltax_fromnet=pl.when(pl.col('xcoord') < 0).then(100 + pl.col('xcoord')).otherwise(100 - pl.col('xcoord')),
    lastxcoord=pl.col('xcoord').shift(1).over(['gameid', 'period']).fill_null(0),
    lastycoord=pl.col('ycoord').shift(1).over(['gameid', 'period']).fill_null(0),
    lastteam=pl.col('eventownerteamid').shift(1).over(['gameid', 'period']),
    lastevent=pl.col('typedesckey').shift(1).over(['gameid', 'period']),
    en=pl.col('situationcode').str.starts_with(0).and_(pl.col('situationcode').is_in(SHOOTOUT_CODES).not_()).and_(pl.col('typedesckey') != 'game-end'),
    oppen=pl.col('situationcode').str.ends_with(0).and_(pl.col('situationcode').is_in(SHOOTOUT_CODES).not_()).and_(pl.col('typedesckey') != 'game-end'),
    skt=pl.col('situationcode').str.slice(1, 1).str.to_integer(),
    oppskt=pl.col('situationcode').str.slice(2, 1).str.to_integer(),
).with_columns(
    changepossession=pl.col('eventownerteamid') != pl.col('lastteam'),
    deltadx2=(pl.col('xcoord').sub(pl.col('lastxcoord'))).pow(2),
    deltady2=(pl.col('ycoord').sub(pl.col('lastycoord'))).pow(2),
    deltadnet=(pl.col('deltax_fromnet').pow(2).add(pl.col('ycoord').pow(2))).pow(0.5),
    deltat=pl.col('timefromstart_s').sub(pl.col('lasttime_s')),
    angle=pl.arctan2(pl.col('ycoord').abs(), pl.col('deltax_fromnet').abs()),
    # moi = men on ice
    moi=pl.col('skt').add(pl.when(pl.col('en')).then(0).otherwise(1)),
    oppmoi=pl.col('oppskt').add(pl.when(pl.col('oppen')).then(0).otherwise(1)),
).with_columns(
    deltad=(pl.col('deltadx2').add(pl.col('deltady2'))).pow(0.5),
    pp=pl.col('moi') < pl.col('oppmoi'),
    en=pl.col('moi') > pl.col('oppmoi'),
).with_columns(
    speed=pl.col('deltad').truediv(pl.col('deltat')),
).lazy()

# TODO: rebound, rebound_deltaomega, timepp
cfeatures.select([
    # *CPLAYS_KEYS,
    'eventid', 'gameid', 'sortorder',
    'period', 'xcoord', 'ycoord', 'angle', 'lastxcoord',
    # 'deltad', 'deltadnet', 'deltadx2', 'deltady2', 'speed', 
    'situationcode', 'eventownerteamid', 'typedesckey',
    # 'timefromstart_s', 'timeinperiod', 'timeinperiod_s',
    # 'en', 'oppen', 'skt', 'oppskt', 'pp',
    # 'homescore', 'awayscore', 'hometeamid','awayteamid',
    # 'goal', 
]).head(10).collect()

In [None]:
shottypes = PLAY_DETAILS.select(['shottype']).unique().with_columns(
    shottypeid=pl.col('shottype').cum_count()
).lazy()
shottypes.collect()

## Exploration

In [None]:
import matplotlib.pyplot as plt
import draw

In [None]:
team_id_abbrev = TEAMS.select(['teamid', 'abbrev', 'color1']).lazy()

In [None]:
goal_loc_shot = (cfeatures
    .filter(pl.col('goal') == 1)
    .join(
        team_id_abbrev.with_columns(team=pl.col('abbrev'), color=pl.col('color1')), 
        left_on="eventownerteamid", right_on="teamid")
    .select(['gameid', 'xcoord', 'ycoord', 'shottype', 'goal', 'team', 'color', 'lastxcoord', 'lastycoord', 'lasttime_s', 'timefromstart_s', 'deltad', 'deltadnet', 'deltat', 'changepossession'])
    .lazy()
)
goal_loc_shot.collect()

In [None]:
def draw_xy_shot(df: pl.DataFrame, title: str, ax = plt, filename: str | None = None):
    ax.scatter(df['xcoord'], df['ycoord'], c=df['color'], s=2)
    if ax == plt:
        ax.xlim(draw.RINK[0])
        ax.ylim(draw.RINK[1])
        ax.gca().set_aspect('equal')
        ax.title(title)
    else:
        ax.set_xlim(draw.RINK[0])
        ax.set_ylim(draw.RINK[1])
        ax.set_aspect('equal')
        ax.set_title(title)
        
    if filename:
        ax.savefig(f'./diagrams/{filename}')
draw_xy_shot(goal_loc_shot.collect(), "All goals scored (x,y) up to game 930", plt)

In [None]:
def draw_teams_goals(df: pl.LazyFrame, label: str):
    fig, axes = plt.subplots(8,4);

    team_abbrevs = TEAMS['abbrev'].to_list()
    fig, axes = plt.subplots(8, 4)
    for idx, abbrev in enumerate(team_abbrevs):
        r = idx // 4
        c = idx % 4
        draw_xy_shot(
            df.filter(pl.col('team') == abbrev).collect(), 
            label % abbrev,
            axes[r, c],
        )
    fig.set_figheight(12)
    fig.set_figwidth(18)
    return fig, axes

In [None]:
fig, axes = draw_teams_goals(goal_loc_shot, "Goals scored until game 930 by %s")

In [None]:
rush_goals = goal_loc_shot.with_columns(
    is_rush=pl.col('xcoord').mul(pl.col('lastxcoord')) < 0
).filter(pl.col('is_rush'))
fig, axes = draw_teams_goals(rush_goals, "Rush goals by %s")

In [None]:
rush_goals_by_team = rush_goals.collect()['team'].value_counts().join(team_id_abbrev.collect(), left_on='team', right_on='abbrev').select(['team', 'count', 'color1']).sort('count', descending=True)
fig, ax = plt.subplots()
ax.bar(rush_goals_by_team['team'], rush_goals_by_team['count'], width=0.8, color=rush_goals_by_team['color1'])
fig.set_figwidth(18)

In [None]:
# NETFRONT_THRESHOLD = 20
REBOUND_S_THRESHOLD = 4
rebound_goals = goal_loc_shot.with_columns(
    is_rebound=(
        (pl.col('xcoord').mul(pl.col('lastxcoord')) > 0)
        .and_(pl.col('changepossession').not_())
        .and_(pl.col('deltat') < REBOUND_S_THRESHOLD)
    )
).filter(pl.col('is_rebound'))
fig, axes = draw_teams_goals(rebound_goals, "rebound goals by %s")

In [None]:
rebound_goals_by_team = rebound_goals.collect()['team'].value_counts().join(team_id_abbrev.collect(), left_on='team', right_on='abbrev').select(['team', 'count', 'color1']).sort('count', descending=True)
fig, ax = plt.subplots()
ax.bar(rebound_goals_by_team['team'], rebound_goals_by_team['count'], width=0.8, color=rebound_goals_by_team['color1'])
fig.set_figwidth(18)

## Linear Regression Model

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import minmax_scale, StandardScaler

X = cfeatures.join(
    shottypes, on='shottype', how="left"
).select([
    'deltadnet',
    'deltat',
    'shot',
    'shottypeid',
    'speed',
    'angle',
    'lastxcoord',
    'oppskt',
    'pp',
    'en',
    'deltad',
    'xcoord',
    'ycoord',
]).with_columns(
    speed=pl.when(pl.col('speed').is_infinite()).then(pl.lit(0)).otherwise(pl.col('speed')),
    pp=pl.col('pp').cast(int),
    en=pl.col('en').cast(int),
).fill_nan(0).fill_null(0).collect().to_pandas()

y = cfeatures.select(['goal']).collect().to_pandas()
# X.shape, y.shape
y['goal'].value_counts()

In [None]:
normalized_cols = [
    'xcoord',
    'ycoord',
    'lastxcoord',
    'oppskt',
    'speed',
    'deltadnet',
    'deltat',
]
X[normalized_cols] = minmax_scale(X[normalized_cols])
X.head()

In [None]:
# intercept for regression model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [None]:
feature_names = X_train.columns
num_ft = len(feature_names)

In [None]:
import numpy as np
np.set_printoptions(suppress=True)
scaler = StandardScaler()
X_scaled  = scaler.fit_transform(X_train)
X_scaled[:5], # X_train[:5], feature_names

In [None]:
clf1 = LogisticRegression(max_iter=100000)
model = clf1.fit(X_train, y_train['goal'])
# model = clf1.fit(scaler.fit_transform(X_test), y_train['goal'])

In [None]:
y_est = model.predict(X_test)
# y_est = model.predict(scaler.fit_transform(X_test))

In [None]:
results_X_test = pl.DataFrame(X_test)
results_y_est = pl.DataFrame(y_est).select(goal_pred=pl.col('column_0'))
results_y_test = pl.DataFrame(y_test)
# pl.concat([results_X_test, results_y_est, results_y_test], how="horizontal")
results_y_est['goal_pred'].value_counts()