In [None]:
import sys
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append(os.path.abspath(os.path.join('..', 'src')))
%matplotlib inline
%load_ext autoreload
%autoreload 2
from data import ApiFetcher
from utils import distribution_calculating, check_distribution

api = ApiFetcher(2015, 2025)
df = api.get_dataframe('leaguegamelog')
df['total_pts'] = df['home_pts'] + df['away_pts']

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
class DynamicTeamStateModel(nn.Module):
    """
    Dynamic Team State Embeddings for NBA total points prediction.
    - Predykcja: wyłącznie z pary stanów (s_home, s_away) + home_flag. Zero dostępu do bieżących statystyk.
    - Aktualizacja stanu: po meczu, z użyciem sygnałów z tego meczu (np. own_pts, opp_pts, total, margin).
    - Loss poza klasą: MAE/Huber na przewidywanym totalu.
    """
    def __init__(
        self,
        num_teams: int,
        state_dim: int = 64,
        predictor_hidden: int = 128,
        signals_dim: int = 4,  # np. [own_pts, opp_pts, total, margin]
        use_layernorm: bool = True,
        predictor_dropout: float = 0.15
    ):
        super().__init__()
        self.state_dim = state_dim
        self.signals_dim = signals_dim

        # Learnable initial state per team (używane do inicjalizacji i przeniesienia między sezonami)
        self.team_init = nn.Embedding(num_teams, state_dim)

        # Home/away bias może zostać wyuczony implicite, ale dodajemy home_flag do predyktora
        # Interakcje: concat[s_h, s_a, s_h*s_a, |s_h - s_a|, home_flag]
        pred_in = state_dim * 4 + 1
        self.predictor = nn.Sequential(
            nn.Linear(pred_in, predictor_hidden),
            nn.ReLU(),
            nn.Dropout(predictor_dropout),
            nn.Linear(predictor_hidden, predictor_hidden // 2),
            nn.ReLU(),
            nn.Linear(predictor_hidden // 2, 1)
        )

        # Updater: GRUCell z wejściem [signals, opponent_state, home_flag] i stanem = current_state
        self.updater = nn.GRUCell(
            input_size=signals_dim + state_dim + 1,
            hidden_size=state_dim
        )

        # Stabilizacja stanów
        self.use_layernorm = use_layernorm
        if use_layernorm:
            self.state_norm = nn.LayerNorm(state_dim)

        # Sezonowy carry-over (uczona mieszanka poprzedniego stanu z stanem inicjalnym)
        self._alpha = nn.Parameter(torch.tensor(0.0))  # sigmoid(alpha) w [0,1]

    # ========== Predykcja ==========
    @torch.no_grad()
    def init_states(self, team_ids: torch.Tensor) -> torch.Tensor:
        """
        Zwraca stan początkowy drużyn (przed pierwszym meczem / na starcie sezonu).
        team_ids: LongTensor [B]
        return: FloatTensor [B, state_dim]
        """
        s0 = self.team_init(team_ids)
        return self.state_norm(s0) if self.use_layernorm else s0

    def predict_total(self, s_home: torch.Tensor, s_away: torch.Tensor, home_flag: torch.Tensor) -> torch.Tensor:
        """
        Predykcja totalu wyłącznie na podstawie stanów drużyn i home_flag.
        s_home, s_away: [B, state_dim]
        home_flag: [B, 1] (float 0/1)
        return: [B]
        """
        interaction = s_home * s_away
        diff = torch.abs(s_home - s_away)
        x = torch.cat([s_home, s_away, interaction, diff, home_flag], dim=-1)
        y = self.predictor(x).squeeze(-1)
        return y

    # Alias forward -> predykcja (bez aktualizacji)
    def forward(self, s_home: torch.Tensor, s_away: torch.Tensor, home_flag: torch.Tensor) -> torch.Tensor:
        return self.predict_total(s_home, s_away, home_flag)

    # ========== Aktualizacja stanów po meczu ==========
    def update_team_state(
        self,
        s_team: torch.Tensor,         # [B, D] stan drużyny PRZED meczem
        s_opp: torch.Tensor,          # [B, D] stan przeciwnika PRZED meczem
        home_flag: torch.Tensor,      # [B, 1] 1 dla gospodarza, 0 dla gościa (z perspektywy s_team)
        signals: torch.Tensor         # [B, signals_dim] np. [own_pts, opp_pts, total, margin_from_team_view]
    ) -> torch.Tensor:
        """
        Aktualizuje stan pojedynczej drużyny po meczu, używając tylko danych, które zaszły (no leakage).
        Zależność od przeciwnika jest dozwolona (jego stan przedmeczowy).
        """
        upd_in = torch.cat([signals, s_opp, home_flag], dim=-1)
        s_new = self.updater(upd_in, s_team)
        if self.use_layernorm:
            s_new = self.state_norm(s_new)
        return s_new

    def update_both_teams(
        self,
        s_home: torch.Tensor, s_away: torch.Tensor,
        signals_home: torch.Tensor, signals_away: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Wygodny helper: aktualizuje oba stany na podstawie sygnałów z meczu.
        signals_home, signals_away: [B, signals_dim] (np. [home_pts, away_pts, total, margin])
        Z perspektywy home: home_flag=1; z perspektywy away: home_flag=0.
        """
        B = s_home.size(0)
        home_flag_home = torch.ones(B, 1, device=s_home.device, dtype=s_home.dtype)
        home_flag_away = torch.zeros(B, 1, device=s_home.device, dtype=s_home.dtype)

        s_home_new = self.update_team_state(s_home, s_away, home_flag_home, signals_home)
        s_away_new = self.update_team_state(s_away, s_home, home_flag_away, signals_away)
        return s_home_new, s_away_new

    # ========== Sezonowy carry-over ==========
    def carry_over(
        self,
        prev_state: torch.Tensor,     # [B, D] stan z końca poprzedniego sezonu
        team_ids: torch.Tensor,       # [B]
        reset_mask: torch.Tensor      # [B, 1] bool/float: 1 gdy start nowego sezonu, 0 w przeciwnym razie
    ) -> torch.Tensor:
        """
        Przy starcie sezonu miksuje poprzedni stan z inicjalnym: s = a*s_prev + (1-a)*s_init.
        Poza startem sezonu zwraca prev_state.
        """
        a = torch.sigmoid(self._alpha)  # w [0,1]
        s_init = self.team_init(team_ids)
        if self.use_layernorm:
            s_init = self.state_norm(s_init)

        mixed = a * prev_state + (1.0 - a) * s_init
        # zastosuj tylko tam, gdzie reset_mask = 1
        return mixed * reset_mask + prev_state * (1.0 - reset_mask)

In [49]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
df_sorted = df.sort_values('date').reset_index(drop=True)
split_idx = int(len(df_sorted) * 0.9)
df_train = df_sorted.iloc[:split_idx].reset_index(drop=True)
df_test  = df_sorted.iloc[split_idx:].reset_index(drop=True)

# --- normalizatory dla updatera: tylko [own_pts, opp_pts] ---
def _mk_signals_stats(df_):
    h = df_['home_pts'].values
    a = df_['away_pts'].values
    mu = np.array([h.mean(), a.mean()], dtype=float)
    sd = np.array([h.std()+1e-6, a.std()+1e-6], dtype=float)
    return mu, sd

sig_mu, sig_sd = _mk_signals_stats(df_train)

def norm_sig_own_opp(own_pts, opp_pts):
    x = np.array([own_pts, opp_pts], dtype=float)
    x = (x - sig_mu) / sig_sd
    return x

# --- normalizacja targetu (do loss) ---
tot_mu = float(df_train['total_pts'].mean())
tot_sd = float(df_train['total_pts'].std() + 1e-6)

num_teams = df['home_team_id'].nunique() + 1
model = DynamicTeamStateModel(
    num_teams=num_teams,
    state_dim=64,
    predictor_hidden=128,
    signals_dim=2,
    use_layernorm=True,
    predictor_dropout=0.15
).to(device)
loss_fn = nn.SmoothL1Loss(beta=10.0)  # nieco szerszy próg Huber
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
grad_clip = 1.0
tbptt_K = 12
epochs = 10
lambda_delta = 1e-4  # regularyzacja zmian stanów
states = {}

tot_mu_t = torch.tensor(tot_mu, device=device, dtype=torch.float32)
tot_sd_t = torch.tensor(tot_sd, device=device, dtype=torch.float32)

In [51]:
def get_state(team_idx: int) -> torch.Tensor:
    s = states.get(team_idx)
    if s is None:
        team_id_tensor = torch.tensor([team_idx], dtype=torch.long, device=device)
        s0 = model.init_states(team_id_tensor)[0]  # [D]
        states[team_idx] = s0
        return s0
    return s

def detach_all_states():
    for k in list(states.keys()):
        states[k] = states[k].detach()

In [None]:
for epoch in range(epochs):
    model.train()
    pending_losses = []
    states.clear()
    for row in df_train.itertuples(index=False):
        # pre-game states (używaj home_team_id/away_team_id)
        s_h = get_state(int(row.home_team_id)).unsqueeze(0).to(device)
        s_a = get_state(int(row.away_team_id)).unsqueeze(0).to(device)

        home_flag = torch.ones(1, 1, device=device, dtype=s_h.dtype)
        y_pred = model(s_h, s_a, home_flag)
        y_true = torch.tensor([row.total_pts], device=device, dtype=s_h.dtype)

        y_pred_std = (y_pred - tot_mu_t) / tot_sd_t
        y_true_std = (y_true - tot_mu_t) / tot_sd_t
        base_loss = loss_fn(y_pred_std, y_true_std)

        h_pts = float(row.home_pts); a_pts = float(row.away_pts)
        sig_home = torch.tensor(norm_sig_own_opp(h_pts, a_pts), device=device, dtype=s_h.dtype).unsqueeze(0)
        sig_away = torch.tensor(norm_sig_own_opp(a_pts, h_pts), device=device, dtype=s_h.dtype).unsqueeze(0)
        assert sig_home.shape[-1] == model.signals_dim and sig_away.shape[-1] == model.signals_dim

        s_h_new, s_a_new = model.update_both_teams(s_h, s_a, sig_home, sig_away)

        delta_reg = ((s_h_new - s_h).pow(2).mean() + (s_a_new - s_a).pow(2).mean()) * lambda_delta
        loss = base_loss + delta_reg
        pending_losses.append(loss)

        states[int(row.home_team_id)] = s_h_new.squeeze(0)
        states[int(row.away_team_id)] = s_a_new.squeeze(0)

        if len(pending_losses) >= tbptt_K:
            total_loss = torch.stack(pending_losses).mean()
            optimizer.zero_grad(set_to_none=True)
            total_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            pending_losses.clear()
            detach_all_states()
    if pending_losses:
        total_loss = torch.stack(pending_losses).mean()
        optimizer.zero_grad(set_to_none=True)
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        pending_losses.clear()
        detach_all_states()
    print(f"Epoch {epoch+1}/{epochs} done.")

Epoch 1/10 done.
Epoch 2/10 done.
Epoch 3/10 done.
Epoch 4/10 done.
Epoch 5/10 done.
Epoch 6/10 done.
Epoch 7/10 done.
Epoch 8/10 done.
Epoch 9/10 done.
Epoch 10/10 done.


In [None]:
@torch.no_grad()
def evaluate_mae_mse(df_eval, model, device):
    model.eval()
    states_eval = {}

    def get_state_eval(team_idx: int):
        s = states_eval.get(team_idx)
        if s is None:
            t = torch.tensor([team_idx], dtype=torch.long, device=device)
            s0 = model.init_states(t)[0]
            states_eval[team_idx] = s0
            return s0
        return s

    mae_sum, mse_sum, n = 0.0, 0.0, 0
    for row in df_eval.sort_values('date').itertuples(index=False):
        s_h = get_state_eval(int(row.home_team_id)).unsqueeze(0).to(device)
        s_a = get_state_eval(int(row.away_team_id)).unsqueeze(0).to(device)

        y_pred = model(s_h, s_a, torch.ones(1,1,device=device,dtype=s_h.dtype))
        y_true = torch.tensor([row.total_pts], device=device, dtype=s_h.dtype)

        err = (y_pred - y_true)
        mae_sum += torch.abs(err).item()
        mse_sum += (err.pow(2)).item()
        n += 1

        # update po meczu (te same normalizatory co w train)
        h_pts, a_pts = float(row.home_pts), float(row.away_pts)
        sig_home = torch.tensor(norm_sig_own_opp(h_pts, a_pts), device=device, dtype=s_h.dtype).unsqueeze(0)
        sig_away = torch.tensor(norm_sig_own_opp(a_pts, h_pts), device=device, dtype=s_h.dtype).unsqueeze(0)
        assert sig_home.shape[-1] == model.signals_dim and sig_away.shape[-1] == model.signals_dim

        s_h_new, s_a_new = model.update_both_teams(s_h, s_a, sig_home, sig_away)
        states_eval[int(row.home_team_id)] = s_h_new.squeeze(0).detach()
        states_eval[int(row.away_team_id)] = s_a_new.squeeze(0).detach()

    mae = mae_sum / max(n, 1)
    mse = mse_sum / max(n, 1)
    return mae, mse

# baseline: stała średnia z train
@torch.no_grad()
def baseline_mae_mse(df_train, df_eval):
    import numpy as np
    mu = float(df_train['total_pts'].mean())
    y = df_eval['total_pts'].to_numpy(dtype=float)
    mae = np.abs(y - mu).mean()
    mse = ((y - mu) ** 2).mean()
    return mae, mse

test_mae, test_mse = evaluate_mae_mse(df_test, model, device)
base_mae, base_mse = baseline_mae_mse(df_train, df_test)
print(f"Baseline MAE: {base_mae:.2f} | Baseline RMSE: {base_mse**0.5:.2f}")
print(f"Test MAE: {test_mae:.2f} | Test RMSE: {test_mse**0.5:.2f}")

Baseline MAE: 17.19 | Baseline RMSE: 21.43
Test MAE: 24.30 | Test RMSE: 39.04
