In [None]:
import math
from pathlib import Path

import numpy as np
import pandas as pd
import torch

pd.set_option("display.max_columns", 50)

torch.manual_seed(7)
np.random.seed(7)

DATA_PATH = Path("2026_MCM_Problem_C_Data.csv")
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True)

CONFIG = {
    "sigma": 1.0,
    "sigma_a": 1.0,
    "beta0": 0.5,
    "sigma_beta": 1.0,
    "delta": 1e-3,
    "tau": 0.1,
    "lambda": 100.0,
    "lambdaF": 100.0,
    "eps": 1e-3,
    "steps": 50,
    "burnin": 10,
    "thinning": 5,
    "chains": 1,
    "allow_bottom2": True,
}


In [None]:
import re

raw = pd.read_csv(DATA_PATH)

week_cols = [c for c in raw.columns if re.match(r"week\d+_judge\d+_score", c)]
week_nums = sorted({int(re.search(r"week(\d+)_", c).group(1)) for c in week_cols})
max_week = max(week_nums)

week_to_cols = {
    week: [c for c in week_cols if c.startswith(f"week{week}_")]
    for week in week_nums
}


def compute_judge_totals(df):
    totals = {}
    for week, cols in week_to_cols.items():
        totals[week] = df[cols].sum(axis=1, skipna=True)
    return pd.DataFrame(totals)


def compute_rank_desc(series):
    return series.rank(ascending=False, method="average")


def season_preprocess(df_season):
    judge_totals = compute_judge_totals(df_season)
    judge_totals = judge_totals.loc[:, week_nums]
    J = judge_totals.to_numpy(dtype=float)
    active = J > 0
    last_nonzero = np.where(active, np.arange(1, len(week_nums) + 1), 0).max(axis=1)
    final_week = int(last_nonzero.max())
    elim_weeks = {week: np.where(last_nonzero == week)[0].tolist() for week in range(1, final_week)}
    placement = df_season["placement"].to_numpy()
    return {
        "J": J,
        "active": active,
        "last_nonzero": last_nonzero,
        "final_week": final_week,
        "elim_weeks": elim_weeks,
        "placement": placement,
        "names": df_season["celebrity_name"].to_numpy(),
    }


In [None]:


def compute_z(J, active):
    z = np.zeros_like(J, dtype=float)
    for t in range(J.shape[1]):
        mask = active[:, t]
        if not mask.any():
            continue
        vals = J[mask, t]
        std = vals.std()
        if std == 0:
            z[mask, t] = 0.0
        else:
            z[mask, t] = (vals - vals.mean()) / std
    return z


def softmax_masked(y, mask):
    y_masked = y.clone()
    y_masked[~mask] = -1e9
    probs = torch.softmax(y_masked, dim=0)
    return probs * mask


def smooth_rank(y_t, mask_t, tau):
    idx = torch.where(mask_t)[0]
    if idx.numel() == 0:
        return torch.zeros_like(y_t)
    y_active = y_t[idx]
    diff = (y_active[None, :] - y_active[:, None]) / tau
    sig = torch.sigmoid(diff)
    r_active = 1.0 + sig.sum(dim=1) - 0.5
    r = torch.zeros_like(y_t)
    r[idx] = r_active
    return r


def compute_ranks(J, active):
    rJ = np.zeros_like(J, dtype=float)
    for t in range(J.shape[1]):
        mask = active[:, t]
        if not mask.any():
            continue
        vals = pd.Series(J[mask, t])
        ranks = compute_rank_desc(vals).to_numpy()
        rJ[mask, t] = ranks
    return rJ


def compute_energy(params, data, config, season_rule, allow_bottom2=True):
    y, a, beta = params
    J = data["J_torch"]
    active = data["active_torch"]
    z = data["z_torch"]
    rJ = data.get("rJ_torch")
    placement = data["placement"]
    final_week = data["final_week"]
    elim_weeks = data["elim_weeks"]
    sigma = config["sigma"]
    sigma_a = config["sigma_a"]
    beta0 = config["beta0"]
    sigma_beta = config["sigma_beta"]
    delta = config["delta"]
    tau = config["tau"]
    lam = config["lambda"]
    lamF = config["lambdaF"]

    resid = (y - a[:, None] - beta * z)
    prior = ((resid[active]) ** 2).sum() / (2 * sigma**2)
    prior = prior + (a**2).sum() / (2 * sigma_a**2) + (beta - beta0) ** 2 / (2 * sigma_beta**2)

    penalty = torch.tensor(0.0, device=y.device)
    penalty_final = torch.tensor(0.0, device=y.device)

    for t in range(final_week):
        week_idx = t
        mask_t = active[:, week_idx]
        if not mask_t.any():
            continue
        p_t = softmax_masked(y[:, week_idx], mask_t)
        if season_rule == "percent":
            J_t = J[:, week_idx]
            total_J = J_t[mask_t].sum()
            q_t = torch.zeros_like(J_t)
            q_t[mask_t] = J_t[mask_t] / (total_J + 1e-12)
            c_t = q_t + p_t
            if week_idx + 1 < final_week:
                elim = elim_weeks.get(week_idx + 1, [])
                if elim:
                    elim_idx = torch.tensor(elim, device=y.device)
                    keep_idx = torch.tensor([i for i in torch.where(mask_t)[0].tolist() if i not in elim], device=y.device)
                    max_elim = c_t[elim_idx].max()
                    min_keep = c_t[keep_idx].min()
                    penalty = penalty + torch.relu(max_elim - min_keep + delta) ** 2
            else:
                active_idx = torch.where(mask_t)[0]
                for i in range(active_idx.numel()):
                    for j in range(active_idx.numel()):
                        if placement[active_idx[i]] < placement[active_idx[j]]:
                            penalty_final = penalty_final + torch.relu(c_t[active_idx[j]] - c_t[active_idx[i]] + delta) ** 2
        else:
            rJ_t = rJ[:, week_idx]
            rF_t = smooth_rank(y[:, week_idx], mask_t, tau)
            s_t = rJ_t + rF_t
            if week_idx + 1 < final_week:
                elim = elim_weeks.get(week_idx + 1, [])
                if elim:
                    elim_idx = torch.tensor(elim, device=y.device)
                    keep_idx = torch.tensor([i for i in torch.where(mask_t)[0].tolist() if i not in elim], device=y.device)
                    if allow_bottom2 and len(elim) == 1:
                        s_active = s_t[mask_t]
                        s2 = torch.topk(s_active, k=2).values[-1]
                        penalty = penalty + torch.relu(s2 - s_t[elim_idx][0] + delta) ** 2
                    else:
                        max_keep = s_t[keep_idx].max()
                        min_elim = s_t[elim_idx].min()
                        penalty = penalty + torch.relu(max_keep - min_elim + delta) ** 2
            else:
                active_idx = torch.where(mask_t)[0]
                for i in range(active_idx.numel()):
                    for j in range(active_idx.numel()):
                        if placement[active_idx[i]] < placement[active_idx[j]]:
                            penalty_final = penalty_final + torch.relu(s_t[active_idx[i]] - s_t[active_idx[j]] + delta) ** 2

    return prior + lam * penalty + lamF * penalty_final


def langevin_sampler(data, config, season_rule, allow_bottom2=True):
    n, w = data["J"].shape
    samples = []
    for chain in range(config["chains"]):
        y = torch.zeros((n, w), requires_grad=True)
        a = torch.zeros(n, requires_grad=True)
        beta = torch.tensor(config["beta0"], requires_grad=True)
        for step in range(config["steps"]):
            energy = compute_energy((y, a, beta), data, config, season_rule, allow_bottom2)
            energy.backward()
            with torch.no_grad():
                y -= config["eps"] * y.grad + math.sqrt(2 * config["eps"]) * torch.randn_like(y)
                a -= config["eps"] * a.grad + math.sqrt(2 * config["eps"]) * torch.randn_like(a)
                beta -= config["eps"] * beta.grad + math.sqrt(2 * config["eps"]) * torch.randn_like(beta)
            y.grad.zero_()
            a.grad.zero_()
            beta.grad.zero_()
            if step >= config["burnin"] and (step - config["burnin"]) % config["thinning"] == 0:
                samples.append(y.detach().clone())
    return samples


In [None]:
all_rows = []
summary_rows = []

for season, df_season in raw.groupby("season"):
    season_info = season_preprocess(df_season)
    J = season_info["J"]
    active = season_info["active"]
    z = compute_z(J, active)
    rJ = compute_ranks(J, active)

    season_rule = "percent" if 3 <= season <= 27 else "rank"
    allow_bottom2 = CONFIG["allow_bottom2"] and season >= 28

    data_torch = {
        "J": J,
        "J_torch": torch.tensor(J, dtype=torch.float32),
        "active": active,
        "active_torch": torch.tensor(active, dtype=torch.bool),
        "z_torch": torch.tensor(z, dtype=torch.float32),
        "rJ_torch": torch.tensor(rJ, dtype=torch.float32),
        "placement": season_info["placement"],
        "final_week": season_info["final_week"],
        "elim_weeks": season_info["elim_weeks"],
    }

    samples = langevin_sampler(data_torch, CONFIG, season_rule, allow_bottom2)
    if not samples:
        continue

    sample_stack = torch.stack(samples)  # (S, n, w)
    p_samples = []
    for t in range(J.shape[1]):
        mask_t = data_torch["active_torch"][:, t]
        y_t = sample_stack[:, :, t]
        y_masked = y_t.clone()
        y_masked[:, ~mask_t] = -1e9
        p_t = torch.softmax(y_masked, dim=1)
        p_samples.append(p_t)
    p_samples = torch.stack(p_samples, dim=2)  # (S, n, w)

    p_mean = p_samples.mean(dim=0).numpy()
    p_sd = p_samples.std(dim=0).numpy()
    p_ci_low = p_samples.quantile(0.025, dim=0).numpy()
    p_ci_high = p_samples.quantile(0.975, dim=0).numpy()
    rel_unc = (p_ci_high - p_ci_low) / (p_mean + 1e-12)

    # consistency metrics
    matches = []
    margins = []
    pc_probs = []

    for week in range(1, season_info["final_week"]):
        elim = season_info["elim_weeks"].get(week, [])
        if not elim:
            continue
        mask_t = active[:, week - 1]
        if not mask_t.any():
            continue
        if season_rule == "percent":
            q_t = J[:, week - 1] / (J[mask_t, week - 1].sum() + 1e-12)
            c_t = q_t + p_mean[:, week - 1]
            elim_hat = np.argsort(c_t[mask_t])[: len(elim)]
            elim_hat_idx = np.where(mask_t)[0][elim_hat]
            match = set(elim_hat_idx.tolist()) == set(elim)
            max_elim = c_t[elim].max()
            min_keep = np.min([c_t[i] for i in np.where(mask_t)[0] if i not in elim])
            margin = min_keep - max_elim
            margins.append(margin)

            sample_margins = []
            for s in range(p_samples.shape[0]):
                c_s = q_t + p_samples[s, :, week - 1].numpy()
                max_elim_s = c_s[elim].max()
                min_keep_s = np.min([c_s[i] for i in np.where(mask_t)[0] if i not in elim])
                sample_margins.append(min_keep_s - max_elim_s)
            pc_probs.append(np.mean(np.array(sample_margins) > 0))
        else:
            rJ_t = rJ[:, week - 1]
            s_t = rJ_t + p_mean[:, week - 1]  # proxy using p_mean for ordering
            if allow_bottom2 and len(elim) == 1:
                s_active = s_t[mask_t]
                s2 = np.sort(s_active)[-2]
                match = s_t[elim[0]] >= s2
                margin = s_t[elim[0]] - s2
                margins.append(margin)
                sample_margins = []
                for s in range(p_samples.shape[0]):
                    s_s = rJ_t + p_samples[s, :, week - 1].numpy()
                    s_active_s = s_s[mask_t]
                    s2_s = np.sort(s_active_s)[-2]
                    sample_margins.append(s_s[elim[0]] - s2_s)
                pc_probs.append(np.mean(np.array(sample_margins) > 0))
            else:
                s_active = s_t[mask_t]
                elim_hat = np.argsort(s_active)[-len(elim):]
                elim_hat_idx = np.where(mask_t)[0][elim_hat]
                match = set(elim_hat_idx.tolist()) == set(elim)
                max_keep = np.max([s_t[i] for i in np.where(mask_t)[0] if i not in elim])
                min_elim = s_t[elim].min()
                margin = min_elim - max_keep
                margins.append(margin)
                sample_margins = []
                for s in range(p_samples.shape[0]):
                    s_s = rJ_t + p_samples[s, :, week - 1].numpy()
                    max_keep_s = np.max([s_s[i] for i in np.where(mask_t)[0] if i not in elim])
                    min_elim_s = s_s[elim].min()
                    sample_margins.append(min_elim_s - max_keep_s)
                pc_probs.append(np.mean(np.array(sample_margins) > 0))

        matches.append(match)

    acc = float(np.mean(matches)) if matches else np.nan
    mean_margin = float(np.mean(margins)) if margins else np.nan
    mean_pc = float(np.mean(pc_probs)) if pc_probs else np.nan

    summary_rows.append({
        "season": season,
        "rule": season_rule,
        "acc": acc,
        "mean_margin": mean_margin,
        "mean_PC": mean_pc,
    })

    for idx, name in enumerate(season_info["names"]):
        for w in range(season_info["final_week"]):
            if not active[idx, w]:
                continue
            all_rows.append({
                "season": season,
                "week": w + 1,
                "celebrity_name": name,
                "p_mean": p_mean[idx, w],
                "p_sd": p_sd[idx, w],
                "p_ci_low": p_ci_low[idx, w],
                "p_ci_high": p_ci_high[idx, w],
                "rel_unc": rel_unc[idx, w],
            })


pd.DataFrame(all_rows).to_csv(OUTPUT_DIR / "fan_vote_posteriors.csv", index=False)
summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(OUTPUT_DIR / "season_consistency_summary.csv", index=False)

summary_df.head()
