In [16]:
# %% [CELL 0] Standalone SHHS1 TEST night-level analysis:
# - Full-night hypnograms (GT vs Pred)
# - Night metrics (TST, SE, SOL, WASO, stage minutes/%)
# - BlandAltman plots (TST, SE, WASO, SOL)
# - Saves CSV + figures

import os, json, math, random, re
from pathlib import Path
from tqdm import tqdm
from collections import Counter

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, average_precision_score

import matplotlib.pyplot as plt

# ----------------------------
# (A) GPU + seed
# ----------------------------
# IMPORTANT: CUDA_VISIBLE_DEVICES indexes the GPUs visible to the machine.
# If you want "the 1st visible GPU", set it to "0".
# If your machine has GPUs [0..3], setting "4" is invalid.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # change if needed

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
print("Visible CUDA devices:", torch.cuda.device_count())
print("Using device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


# %% [CELL 1] Paths + load manifest (SHHS1 test only)

ROOT = Path("/data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/")
MANIFEST_PATH = ROOT / "manifest_sleepstaging_planA.csv"
assert MANIFEST_PATH.exists(), f"Missing manifest: {MANIFEST_PATH}"

manifest = pd.read_csv(MANIFEST_PATH)
print("Rows:", len(manifest))
print(manifest.groupby(["cohort","split"]).size())

df_test = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="test")].copy()
print("SHHS1 TEST subjects:", len(df_test))


# %% [CELL 2] Normalization + learned smoothing helpers (minimal)

NUM_CLASSES = 5
LABELS = {0:"W", 1:"N1", 2:"N2", 3:"N3", 4:"REM"}
EPOCH_SEC = 30.0
FS = 125
T = 3750  # 30s * 125Hz

def normalize_epochs_zscore(x, eps=1e-6, clip=10.0):
    mu = np.mean(x, axis=1, keepdims=True)
    sd = np.std(x, axis=1, keepdims=True) + eps
    x = (x - mu) / sd
    if clip is not None:
        x = np.clip(x, -clip, clip)
    return x.astype(np.float32)

def apply_learned_smoothing_probs(probs, trans_logits):
    # probs: (B,L,C), trans_logits: (C,C)
    Tm = torch.softmax(trans_logits, dim=1)
    return probs @ Tm


# %% [CELL 3] Model definition (your HierSleepTransformerV5_1)

class DropPath(nn.Module):
    def __init__(self, drop_prob=0.1):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x):
        if (not self.training) or self.drop_prob == 0.0:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rand = keep + torch.rand(shape, device=x.device)
        mask = torch.floor(rand)
        return x / keep * mask

class ResConv1D(nn.Module):
    def __init__(self, c_in, c_out, k, s=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(c_in, c_out, k, stride=s, padding=k//2),
            nn.BatchNorm1d(c_out),
            nn.GELU(),
            nn.Conv1d(c_out, c_out, k, padding=k//2),
            nn.BatchNorm1d(c_out),
        )
        self.skip = nn.Conv1d(c_in, c_out, 1, stride=s) if (c_in != c_out or s != 1) else nn.Identity()
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.conv(x) + self.skip(x))

class EpochEncoder(nn.Module):
    def __init__(self, d_model=384):
        super().__init__()
        self.branch_short = ResConv1D(1, 128, k=7,  s=4)
        self.branch_mid   = ResConv1D(1, 128, k=15, s=4)
        self.branch_long  = ResConv1D(1, 128, k=31, s=4)

        self.freq_proj = nn.Sequential(
            nn.Linear(1876, 256),
            nn.LayerNorm(256),
            nn.GELU(),
        )

        self.fuse = nn.Sequential(
            nn.Linear(128*3 + 256, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
        )

    def forward(self, x):
        # x: (B,L,1,T)
        B, L, _, T_ = x.shape
        x = x.view(B*L, 1, T_)

        zs = self.branch_short(x).mean(-1)
        zm = self.branch_mid(x).mean(-1)
        zl = self.branch_long(x).mean(-1)

        with torch.cuda.amp.autocast(enabled=False):
            xf32 = x.squeeze(1).float()
            Xf = torch.fft.rfft(xf32, dim=-1)
            mag = torch.abs(Xf)
            mag = mag[:, :1876]
            mag = torch.log1p(mag)
            mag = mag / (mag.mean(dim=1, keepdim=True) + 1e-6)

        zf = self.freq_proj(mag)
        z = torch.cat([zs, zm, zl, zf.to(zs.dtype)], dim=-1)
        z = self.fuse(z)
        return z.view(B, L, -1)

def rotate_half(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)

class RoPE(nn.Module):
    def __init__(self, head_dim, base=10000):
        super().__init__()
        assert head_dim % 2 == 0
        self.head_dim = head_dim
        self.base = base

    def forward(self, x):
        # x: (B,L,H,Dh)
        B, L, H, Dh = x.shape
        half = Dh // 2
        freqs = 1.0 / (self.base ** (torch.arange(half, device=x.device) / half))
        t = torch.arange(L, device=x.device)
        angles = torch.einsum("l,d->ld", t, freqs)
        cos = torch.cos(angles)[None, :, None, :]
        sin = torch.sin(angles)[None, :, None, :]
        cos = cos.repeat_interleave(2, dim=-1)
        sin = sin.repeat_interleave(2, dim=-1)
        return (x * cos) + (rotate_half(x) * sin)

def _windows(L, w):
    out = []
    s = 0
    while s < L:
        e = min(L, s + w)
        out.append((s, e))
        s = e
    return out

class MultiHeadSelfAttentionRoPE_LocalGlobal(nn.Module):
    def __init__(self, d_model=384, n_heads=8, dropout=0.1, window_size=64):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.window_size = int(window_size)

        self.qkv = nn.Linear(d_model, 3*d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
        self.rope = RoPE(self.d_head)

    def forward(self, x, key_padding_mask=None, global_attn=False):
        # x: (B,L,D), key_padding_mask: (B,L)
        B, L, D = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, L, self.n_heads, self.d_head)
        k = k.view(B, L, self.n_heads, self.d_head)
        v = v.view(B, L, self.n_heads, self.d_head)

        q = self.rope(q)
        k = self.rope(k)

        q = q.transpose(1, 2)  # (B,H,L,Dh)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if global_attn or self.window_size >= L:
            scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
            scores = scores.float()
            if key_padding_mask is not None:
                scores = scores.masked_fill(~key_padding_mask[:, None, None, :], -1e9)
            attn = torch.softmax(scores, dim=-1)
            attn = self.drop(attn).to(v.dtype)
            out = attn @ v
            out = out.transpose(1, 2).contiguous().view(B, L, D)
            return self.proj(out)

        w = self.window_size
        out = torch.zeros((B, self.n_heads, L, self.d_head), device=x.device, dtype=v.dtype)

        for (s, e) in _windows(L, w):
            qs = q[:, :, s:e, :]
            ks = k[:, :, s:e, :]
            vs = v[:, :, s:e, :]

            scores = (qs @ ks.transpose(-2, -1)) / math.sqrt(self.d_head)
            scores = scores.float()
            if key_padding_mask is not None:
                m = key_padding_mask[:, s:e]
                scores = scores.masked_fill(~m[:, None, None, :], -1e9)

            attn = torch.softmax(scores, dim=-1)
            attn = self.drop(attn).to(vs.dtype)
            out[:, :, s:e, :] = attn @ vs

        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.proj(out)

class TransformerBlockLG(nn.Module):
    def __init__(self, d_model=384, n_heads=8, drop=0.1, drop_path=0.1, window_size=64):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttentionRoPE_LocalGlobal(d_model, n_heads, drop, window_size=window_size)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(4*d_model, d_model),
        )
        self.dp = DropPath(drop_path)

    def forward(self, x, mask, global_attn=False):
        x = x + self.dp(self.attn(self.ln1(x), key_padding_mask=mask, global_attn=global_attn))
        x = x + self.dp(self.mlp(self.ln2(x)))
        return x

class HierSleepTransformerV5_1(nn.Module):
    def __init__(self, num_classes=5, d_model=384, depth=12, n_heads=8,
                 dur_bins=8, window_size=64, global_every=3):
        super().__init__()
        self.num_classes = num_classes
        self.dur_bins = dur_bins
        self.depth = int(depth)
        self.global_every = int(global_every)

        self.encoder = EpochEncoder(d_model)
        self.blocks = nn.ModuleList([
            TransformerBlockLG(
                d_model=d_model,
                n_heads=n_heads,
                drop=0.1,
                drop_path=0.1*(i+1)/depth,
                window_size=window_size
            )
            for i in range(depth)
        ])
        self.head = nn.Linear(d_model, num_classes)

        # aux heads (not needed for inference but exist in checkpoint)
        self.aux_n1 = nn.Linear(d_model, 2)
        self.aux_dur = nn.Linear(d_model, dur_bins)

        # learned transition smoothing
        self.trans_logits = nn.Parameter(torch.zeros(num_classes, num_classes))

    def forward(self, x, mask):
        z = self.encoder(x)
        for i, blk in enumerate(self.blocks):
            use_global = (self.global_every > 0) and ((i % self.global_every) == 0)
            z = blk(z, mask, global_attn=use_global)
        main_logits = self.head(z)
        aux_logits  = self.aux_n1(z)
        dur_logits  = self.aux_dur(z)
        return main_logits, aux_logits, dur_logits


# %% [CELL 4] Load checkpoint (EMA applied if available)

DUR_EDGES = (2,5,10,20,40,80,160)
DUR_BINS = len(DUR_EDGES) + 1

model = HierSleepTransformerV5_1(
    num_classes=NUM_CLASSES,
    d_model=384,
    depth=12,
    n_heads=8,
    dur_bins=DUR_BINS,
    window_size=64,
    global_every=3
).to(device)

print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

CKPT_DIR = ROOT / "checkpoints_hier_rope_seq_v5_1"
BEST_CKPT = CKPT_DIR / "BEST_VAL_macroF1.pt"
assert BEST_CKPT.exists(), f"Missing ckpt: {BEST_CKPT}"
print("Loading:", BEST_CKPT)

ckpt = torch.load(BEST_CKPT, map_location="cpu")
sd = ckpt["model_state"]
missing, unexpected = model.load_state_dict(sd, strict=False)
print("Loaded model_state. missing:", len(missing), "| unexpected:", len(unexpected))

use_ema = bool(ckpt.get("use_ema", False))
ema_shadow = ckpt.get("ema_shadow", None)

if use_ema and isinstance(ema_shadow, dict) and len(ema_shadow) > 0:
    with torch.no_grad():
        curr = model.state_dict()
        applied = 0
        for k, v in ema_shadow.items():
            if k in curr:
                curr[k].copy_(v.to(curr[k].device).to(curr[k].dtype))
                applied += 1
        model.load_state_dict(curr, strict=False)
    print("Applied EMA shadow tensors:", applied)
else:
    print("EMA shadow not applied (missing/disabled).")

model.eval()


# %% [CELL 5] Full-night prediction per file (SHHS1 test)

USE_LEARNED_SMOOTHING = True  # set False if you want raw softmax only

def infer_subject_id(row):
    if "subject_id" in row.index and pd.notna(row["subject_id"]):
        return str(row["subject_id"])
    return Path(row["npz_path"]).stem

@torch.no_grad()
def predict_fullnight(npz_path: str):
    d = np.load(npz_path, allow_pickle=True)
    x = d["x"].astype(np.float32)  # (E,T)
    y = d["y"].astype(np.int64)    # (E,)

    keep = (y >= 0)
    x = x[keep]
    y = y[keep]

    x = normalize_epochs_zscore(x, eps=1e-6, clip=10.0)

    xb = torch.from_numpy(x).unsqueeze(0).unsqueeze(2).to(device)  # (1,E,1,T)
    mb = torch.ones((1, xb.shape[1]), dtype=torch.bool, device=device)

    main_logits, _, _ = model(xb, mb)
    probs = torch.softmax(main_logits.float(), dim=-1)

    if USE_LEARNED_SMOOTHING:
        probs = apply_learned_smoothing_probs(probs, model.trans_logits)

    pred = torch.argmax(probs, dim=-1).squeeze(0).cpu().numpy()
    probs = probs.squeeze(0).cpu().numpy()

    return y, pred, probs

records = []
for _, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Predict SHHS1 TEST full nights"):
    npz_path = str(row["npz_path"])
    sid = infer_subject_id(row)
    y_true, y_pred, probs = predict_fullnight(npz_path)
    records.append({
        "subject_id": sid,
        "npz_path": npz_path,
        "y_true": y_true,
        "y_pred": y_pred,
        "probs": probs,
    })

print("Collected:", len(records), "subjects")


# %% [CELL 6] Hypnogram plotting + Night metrics + BlandAltman

STAGE_NAMES = {0:"W", 1:"N1", 2:"N2", 3:"N3", 4:"REM"}
HYPNO_YMAP = {0:4, 1:3, 2:2, 3:1, 4:0}  # W on top
HYPNO_YTICKS = [4,3,2,1,0]
HYPNO_YLABELS = ["W","N1","N2","N3","REM"]

def compute_night_metrics(stage_seq):
    stage_seq = np.asarray(stage_seq, dtype=np.int64)
    E = len(stage_seq)
    total_min = E * EPOCH_SEC / 60.0

    is_sleep = (stage_seq != 0)
    sleep_epochs = int(is_sleep.sum())
    tst_min = sleep_epochs * EPOCH_SEC / 60.0
    se_pct = 100.0 * (tst_min / total_min) if total_min > 0 else np.nan

    if sleep_epochs == 0:
        sol_min = total_min
        waso_min = 0.0
    else:
        first_sleep = int(np.argmax(is_sleep))
        sol_min = first_sleep * EPOCH_SEC / 60.0
        waso_min = float(np.sum(stage_seq[first_sleep:] == 0) * EPOCH_SEC / 60.0)

    def mins(sid): return float(np.sum(stage_seq == sid) * EPOCH_SEC / 60.0)

    w_min, n1_min, n2_min, n3_min, rem_min = mins(0), mins(1), mins(2), mins(3), mins(4)

    def pct(x): return 100.0 * (x / tst_min) if tst_min > 0 else np.nan

    return {
        "TotalTime_min": total_min,
        "TST_min": tst_min,
        "SE_pct": se_pct,
        "SOL_min": sol_min,
        "WASO_min": waso_min,
        "W_min": w_min,
        "N1_min": n1_min,
        "N2_min": n2_min,
        "N3_min": n3_min,
        "REM_min": rem_min,
        "N1_pct": pct(n1_min),
        "N2_pct": pct(n2_min),
        "N3_pct": pct(n3_min),
        "REM_pct": pct(rem_min),
    }

def plot_hypnogram_pair(
    y_true, y_pred, subject_id,
    save_path=None,
    max_hours=None,
):
    import numpy as np
    import matplotlib.pyplot as plt

    E = len(y_true)
    t_hours = (np.arange(E) * EPOCH_SEC) / 3600.0

    if max_hours is not None:
        keep = t_hours <= max_hours
        y_true = y_true[keep]
        y_pred = y_pred[keep]
        t_hours = t_hours[keep]

    yt = np.vectorize(HYPNO_YMAP.get)(y_true)
    yp = np.vectorize(HYPNO_YMAP.get)(y_pred)

    # Slightly shorter height for paper
    fig, ax = plt.subplots(figsize=(14, 3.4))

    ax.step(t_hours, yt, where="post", linewidth=1.6, label="GT")
    ax.step(t_hours, yp, where="post", linewidth=1.3, alpha=0.85, label="Pred")


    ax.set_yticks(HYPNO_YTICKS)
    ax.set_yticklabels(HYPNO_YLABELS)
    ax.set_ylim(-0.5, 4.5)

    ax.set_xlabel("Time (hours)")
    ax.set_ylabel("Stage")
    ax.set_title(f"SHHS1 TEST Hypnogram | {subject_id}")

    ax.grid(True, alpha=0.25)

    # ---- KEY CHANGE: legend OUTSIDE ----
    ax.legend(
        loc="center left",
        bbox_to_anchor=(1.01, 0.5),
        frameon=False
    )

    # Leave space on the right for legend
    fig.tight_layout(rect=[0, 0, 0.88, 1])

    if save_path is not None:
        fig.savefig(save_path, dpi=400, bbox_inches="tight")
        plt.close(fig)
    else:
        plt.show()

def bland_altman(gt, pred, title, xlab, ylab, save_path=None):
    gt = np.asarray(gt, dtype=float)
    pred = np.asarray(pred, dtype=float)
    m = (gt + pred) / 2.0
    d = (pred - gt)

    bias = np.mean(d)
    sd = np.std(d, ddof=1) if len(d) > 1 else 0.0
    loa_low = bias - 1.96 * sd
    loa_high = bias + 1.96 * sd

    plt.figure(figsize=(6.5, 5.2))
    plt.scatter(m, d, s=18, alpha=0.7)
    plt.axhline(bias, linewidth=1.5)
    plt.axhline(loa_low, linestyle="--", linewidth=1.2)
    plt.axhline(loa_high, linestyle="--", linewidth=1.2)
    plt.title(f"{title}\nBias={bias:.2f}, LoA=[{loa_low:.2f}, {loa_high:.2f}]")
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    plt.grid(True, alpha=0.25)
    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path, dpi=400)
        plt.close()
    else:
        plt.show()

print("Plotting + metrics helpers ready.")


# %% [CELL 7] Build night-level dataframe (GT vs Pred) + errors

rows = []
for r in records:
    gt = compute_night_metrics(r["y_true"])
    pr = compute_night_metrics(r["y_pred"])

    rows.append({
        "subject_id": r["subject_id"],
        "E": len(r["y_true"]),
        "TST_gt": gt["TST_min"],
        "TST_pred": pr["TST_min"],
        "SE_gt": gt["SE_pct"],
        "SE_pred": pr["SE_pct"],
        "SOL_gt": gt["SOL_min"],
        "SOL_pred": pr["SOL_min"],
        "WASO_gt": gt["WASO_min"],
        "WASO_pred": pr["WASO_min"],
        "N1_gt": gt["N1_min"],
        "N1_pred": pr["N1_min"],
        "N2_gt": gt["N2_min"],
        "N2_pred": pr["N2_min"],
        "N3_gt": gt["N3_min"],
        "N3_pred": pr["N3_min"],
        "REM_gt": gt["REM_min"],
        "REM_pred": pr["REM_min"],
        "REMpct_gt": gt["REM_pct"],
        "REMpct_pred": pr["REM_pct"],
        "N3pct_gt": gt["N3_pct"],
        "N3pct_pred": pr["N3_pct"],
    })

df_night = pd.DataFrame(rows)

# add diffs + abs errors
for k in ["TST","SE","SOL","WASO","N1","N2","N3","REM","REMpct","N3pct"]:
    df_night[f"{k}_diff"] = df_night[f"{k}_pred"] - df_night[f"{k}_gt"]
    df_night[f"{k}_abs"]  = np.abs(df_night[f"{k}_diff"])

display(df_night.head())

print("\nNight-level absolute error summary (mean ± std):")
for k in ["TST","SE","SOL","WASO","N3","REM","N3pct","REMpct"]:
    mu = df_night[f"{k}_abs"].mean()
    sd = df_night[f"{k}_abs"].std(ddof=1)
    print(f"{k:6s} | MAE={mu:.2f} ± {sd:.2f}")


# %% [CELL 8] Save outputs (CSV + plots + sample hypnograms)

OUT_BASE = ROOT / "night_level_analysis_shhs1_test"
OUT_BASE.mkdir(parents=True, exist_ok=True)

# save table
csv_path = OUT_BASE / "night_metrics_gt_vs_pred.csv"
df_night.to_csv(csv_path, index=False)
print("Saved CSV:", csv_path)

# BlandAltman plots
bland_altman(
    df_night["TST_gt"].values, df_night["TST_pred"].values,
    title="BlandAltman: Total Sleep Time (min)",
    xlab="Mean TST (min)",
    ylab="TST diff (Pred - GT, min)",
    save_path=OUT_BASE / "bland_altman_TST.png"
)

bland_altman(
    df_night["SE_gt"].values, df_night["SE_pred"].values,
    title="BlandAltman: Sleep Efficiency (%)",
    xlab="Mean SE (%)",
    ylab="SE diff (Pred - GT, %)",
    save_path=OUT_BASE / "bland_altman_SE.png"
)

bland_altman(
    df_night["WASO_gt"].values, df_night["WASO_pred"].values,
    title="BlandAltman: WASO (min)",
    xlab="Mean WASO (min)",
    ylab="WASO diff (Pred - GT, min)",
    save_path=OUT_BASE / "bland_altman_WASO.png"
)

bland_altman(
    df_night["SOL_gt"].values, df_night["SOL_pred"].values,
    title="BlandAltman: SOL (min)",
    xlab="Mean SOL (min)",
    ylab="SOL diff (Pred - GT, min)",
    save_path=OUT_BASE / "bland_altman_SOL.png"
)

print("Saved BlandAltman plots to:", OUT_BASE)

# save a few hypnograms (random + best/worst)
def night_acc(y_true, y_pred):
    return float((y_true == y_pred).mean()) if len(y_true) else np.nan

acc_list = [(r["subject_id"], night_acc(r["y_true"], r["y_pred"])) for r in records]
acc_sorted = sorted(acc_list, key=lambda x: x[1])

pick_ids = []
pick_ids += [acc_sorted[0][0], acc_sorted[1][0]]               # 2 worst
pick_ids += [acc_sorted[-1][0], acc_sorted[-2][0]]             # 2 best

# plus 2 random not in list
rng = np.random.RandomState(0)
all_ids = [r["subject_id"] for r in records]
random_ids = [x for x in rng.choice(all_ids, size=min(20, len(all_ids)), replace=False) if x not in pick_ids]
pick_ids += random_ids[:2]

for sid in pick_ids:
    r = next(rr for rr in records if rr["subject_id"] == sid)
    fig_path = OUT_BASE / f"hypnogram_{sid}.png"
    plot_hypnogram_pair(r["y_true"], r["y_pred"], subject_id=sid, save_path=fig_path)

print("Saved hypnogram PNGs:", len(pick_ids), "files")
print("Output folder:", OUT_BASE)


# %% [CELL 9] (Optional) Aggregate epoch-level metrics on SHHS1 test (for completeness)

all_true = np.concatenate([r["y_true"] for r in records])
all_pred = np.concatenate([r["y_pred"] for r in records])

acc = accuracy_score(all_true, all_pred)
mf1 = f1_score(all_true, all_pred, average="macro")
kappa = cohen_kappa_score(all_true, all_pred)
cm = confusion_matrix(all_true, all_pred, labels=list(range(NUM_CLASSES)))

f1_per = {LABELS[i]: float(f1_score((all_true==i).astype(int), (all_pred==i).astype(int)))
          for i in range(NUM_CLASSES)}

print("\n===== SHHS1 TEST (epoch-level) =====")
print(f"acc      : {acc:.4f}")
print(f"macro_f1 : {mf1:.4f}")
print(f"kappa    : {kappa:.4f}")
print("F1/class :", f1_per)
print("Confusion Matrix labels:", [LABELS[i] for i in range(NUM_CLASSES)])
print(cm)

# Save epoch-level summary
epoch_summary = {
    "acc": float(acc),
    "macro_f1": float(mf1),
    "kappa": float(kappa),
    "f1_per_class": f1_per,
    "cm": cm.tolist(),
    "n_epochs": int(len(all_true)),
    "checkpoint": str(BEST_CKPT),
    "use_learned_smoothing": bool(USE_LEARNED_SMOOTHING),
}

with open(OUT_BASE / "epoch_level_summary.json", "w") as f:
    json.dump(epoch_summary, f, indent=2)

print("Saved epoch-level summary:", OUT_BASE / "epoch_level_summary.json")


CUDA available: True
Visible CUDA devices: 1
Using device: cuda
GPU: NVIDIA RTX A6000
Rows: 9868
cohort  split        
MESA    external_test    1856
SHHS1   test              548
        train            4380
        val               548
SHHS2   external_test    2536
dtype: int64
SHHS1 TEST subjects: 548
Model params (M): 22.905512
Loading: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/checkpoints_hier_rope_seq_v5_1/BEST_VAL_macroF1.pt
Loaded model_state. missing: 0 | unexpected: 0
Applied EMA shadow tensors: 189


Predict SHHS1 TEST full nights: 100%|█████████| 548/548 [02:24<00:00,  3.78it/s]

Collected: 548 subjects
Plotting + metrics helpers ready.





Unnamed: 0,subject_id,E,TST_gt,TST_pred,SE_gt,SE_pred,SOL_gt,SOL_pred,WASO_gt,WASO_pred,...,N2_diff,N2_abs,N3_diff,N3_abs,REM_diff,REM_abs,REMpct_diff,REMpct_abs,N3pct_diff,N3pct_abs
0,shhs1-200018_v1,1004,305.5,308.5,60.856574,61.454183,112.5,65.0,84.0,128.5,...,-29.0,29.0,13.0,13.0,14.5,14.5,4.529864,4.529864,4.13436,4.13436
1,shhs1-200021_v1,992,319.0,331.5,64.314516,66.834677,0.0,0.0,177.0,164.5,...,34.0,34.0,-2.0,2.0,-12.0,12.0,-4.264127,4.264127,-1.087959,1.087959
2,shhs1-200035_v1,1004,373.0,392.0,74.302789,78.087649,31.0,0.0,98.0,110.0,...,0.5,0.5,-6.0,6.0,24.5,24.5,5.68474,5.68474,-2.57017,2.57017
3,shhs1-200039_v1,879,367.0,377.0,83.503982,85.779295,23.5,23.5,49.0,39.0,...,-18.5,18.5,23.5,23.5,1.0,1.0,-0.500871,0.500871,5.546802,5.546802
4,shhs1-200043_v1,759,365.5,363.0,96.310935,95.652174,0.0,0.0,14.0,16.5,...,-65.5,65.5,13.5,13.5,30.5,30.5,8.499244,8.499244,3.862214,3.862214


  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0, flags=flags)



Night-level absolute error summary (mean ± std):
TST    | MAE=10.12 ± 12.80
SE     | MAE=2.17 ± 2.79
SOL    | MAE=6.00 ± 16.12
WASO   | MAE=10.85 ± 14.81
N3     | MAE=15.52 ± 15.89
REM    | MAE=10.36 ± 12.99
N3pct  | MAE=4.20 ± 4.29
REMpct | MAE=2.74 ± 3.49
Saved CSV: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/night_level_analysis_shhs1_test/night_metrics_gt_vs_pred.csv
Saved BlandAltman plots to: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/night_level_analysis_shhs1_test
Saved hypnogram PNGs: 6 files
Output folder: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/night_level_analysis_shhs1_test

===== SHHS1 TEST (epoch-level) =====
acc      : 0.8656
macro_f1 : 0.8109
kappa    : 0.8144
F1/class : {'W': 0.9216522245838207, 'N1': 0.5367519943346316, 'N2': 0.8737516650298247, 'N3': 0.8319557044407884, 'REM': 0.8902049606416158}
Confusion Matrix labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[109925   5916   3969    418   2185]
 [  1901  12885   3317    

In [17]:
# %% [ADD-ON CELL 1] Transition matrices (GT vs Pred) + diff + normalized versions

import numpy as np
import matplotlib.pyplot as plt

C = NUM_CLASSES  # 5

def transition_counts(seq, num_classes=5):
    """
    seq: (E,) int labels in [0..C-1]
    returns: (C,C) counts where M[i,j] = #transitions i->j
    """
    seq = np.asarray(seq, dtype=np.int64)
    if len(seq) < 2:
        return np.zeros((num_classes, num_classes), dtype=np.int64)
    a = seq[:-1]
    b = seq[1:]
    M = np.zeros((num_classes, num_classes), dtype=np.int64)
    np.add.at(M, (a, b), 1)
    return M

def row_normalize(M, eps=1e-12):
    M = M.astype(np.float64)
    rs = M.sum(axis=1, keepdims=True)
    return M / (rs + eps)

# --- aggregate over all SHHS1 test nights (records from the standalone script)
M_gt = np.zeros((C, C), dtype=np.int64)
M_pr = np.zeros((C, C), dtype=np.int64)

for r in records:  # records = list of dicts with y_true/y_pred
    M_gt += transition_counts(r["y_true"], num_classes=C)
    M_pr += transition_counts(r["y_pred"], num_classes=C)

M_diff = M_pr.astype(np.int64) - M_gt.astype(np.int64)

M_gt_norm = row_normalize(M_gt)
M_pr_norm = row_normalize(M_pr)
M_diff_norm = M_pr_norm - M_gt_norm

print("GT transition counts:\n", M_gt)
print("Pred transition counts:\n", M_pr)
print("Diff (Pred - GT) counts:\n", M_diff)

def plot_matrix(M, title, labels=None, fmt=".3f", save_path=None):
    labels = labels if labels is not None else [LABELS[i] for i in range(C)]
    plt.figure(figsize=(6.2, 5.4))
    plt.imshow(M, aspect="auto")
    plt.title(title)
    plt.xticks(range(C), labels, rotation=45, ha="right")
    plt.yticks(range(C), labels)
    plt.xlabel("To")
    plt.ylabel("From")
    plt.colorbar()

    # annotate
    for i in range(C):
        for j in range(C):
            val = M[i, j]
            if isinstance(val, (float, np.floating)):
                s = format(val, fmt)
            else:
                s = str(int(val))
            plt.text(j, i, s, ha="center", va="center", fontsize=8)

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=400)
        plt.close()
    else:
        plt.show()

OUT_BASE = Path(OUT_BASE)  # ensure exists from your standalone script

# Save matrices
np.save(OUT_BASE / "transitions_gt_counts.npy", M_gt)
np.save(OUT_BASE / "transitions_pred_counts.npy", M_pr)
np.save(OUT_BASE / "transitions_diff_counts.npy", M_diff)

np.save(OUT_BASE / "transitions_gt_row_norm.npy", M_gt_norm)
np.save(OUT_BASE / "transitions_pred_row_norm.npy", M_pr_norm)
np.save(OUT_BASE / "transitions_diff_row_norm.npy", M_diff_norm)

# Plot (counts)
plot_matrix(M_gt, "Transition matrix (GT)  counts", save_path=OUT_BASE / "transition_GT_counts.png")
plot_matrix(M_pr, "Transition matrix (Pred)  counts", save_path=OUT_BASE / "transition_Pred_counts.png")
plot_matrix(M_diff, "Transition matrix (Pred - GT)  counts", save_path=OUT_BASE / "transition_Diff_counts.png")

# Plot (row-normalized)
plot_matrix(M_gt_norm, "Transition matrix (GT)  row-normalized", fmt=".3f",
            save_path=OUT_BASE / "transition_GT_rowNorm.png")
plot_matrix(M_pr_norm, "Transition matrix (Pred)  row-normalized", fmt=".3f",
            save_path=OUT_BASE / "transition_Pred_rowNorm.png")
plot_matrix(M_diff_norm, "Transition matrix (Pred - GT)  row-normalized", fmt=".3f",
            save_path=OUT_BASE / "transition_Diff_rowNorm.png")

print("Saved transition matrices + plots into:", OUT_BASE)


GT transition counts:
 [[107106   8562   5044     88   1322]
 [  2473  10576   6038     12    727]
 [  8345     91 199085  13256   2599]
 [  1014      9  12238  58165    136]
 [  3038    599   1078     11  71920]]
Pred transition counts:
 [[100906  10937   3347     33    623]
 [  3269  16301   7830      2    726]
 [  7692    427 182738  11614   2521]
 [  1153      4  10354  68073    175]
 [  2685    484    800      4  80834]]
Diff (Pred - GT) counts:
 [[ -6200   2375  -1697    -55   -699]
 [   796   5725   1792    -10     -1]
 [  -653    336 -16347  -1642    -78]
 [   139     -5  -1884   9908     39]
 [  -353   -115   -278     -7   8914]]


  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0, flags=flags)


Saved transition matrices + plots into: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/night_level_analysis_shhs1_test


In [18]:
# %% [ADD-ON CELL 2] REM latency (GT vs Pred) + BlandAltman + summary + save

import numpy as np
import pandas as pd

REM_ID = 4
W_ID = 0

def rem_latency_minutes(stage_seq):
    """
    REM latency = (first REM epoch index - first sleep epoch index) * 0.5 minutes
    If no sleep: return NaN
    If no REM after sleep onset: return NaN (or you can set to total sleep time; keep NaN is cleaner)
    """
    s = np.asarray(stage_seq, dtype=np.int64)
    if len(s) == 0:
        return np.nan

    is_sleep = (s != W_ID)
    if not np.any(is_sleep):
        return np.nan

    first_sleep = int(np.argmax(is_sleep))

    rem_pos = np.where(s[first_sleep:] == REM_ID)[0]
    if rem_pos.size == 0:
        return np.nan

    first_rem = first_sleep + int(rem_pos[0])
    return (first_rem - first_sleep) * (EPOCH_SEC / 60.0)  # 0.5 min per epoch

# compute per-night REM latency
rem_rows = []
for r in records:
    gt_rl = rem_latency_minutes(r["y_true"])
    pr_rl = rem_latency_minutes(r["y_pred"])
    rem_rows.append({
        "subject_id": r["subject_id"],
        "REMlat_gt_min": gt_rl,
        "REMlat_pred_min": pr_rl,
        "REMlat_diff_min": pr_rl - gt_rl if (np.isfinite(gt_rl) and np.isfinite(pr_rl)) else np.nan,
        "REMlat_abs_min": abs(pr_rl - gt_rl) if (np.isfinite(gt_rl) and np.isfinite(pr_rl)) else np.nan,
    })

df_remlat = pd.DataFrame(rem_rows)

# merge into df_night if you already created it
if "df_night" in globals():
    df_night = df_night.merge(df_remlat[["subject_id","REMlat_gt_min","REMlat_pred_min","REMlat_diff_min","REMlat_abs_min"]],
                              on="subject_id", how="left")
    display(df_night[["subject_id","REMlat_gt_min","REMlat_pred_min","REMlat_diff_min","REMlat_abs_min"]].head())

# summary (only finite pairs)
mask = np.isfinite(df_remlat["REMlat_gt_min"].values) & np.isfinite(df_remlat["REMlat_pred_min"].values)
gt_vals = df_remlat.loc[mask, "REMlat_gt_min"].values
pr_vals = df_remlat.loc[mask, "REMlat_pred_min"].values

print("REM latency: valid nights =", int(mask.sum()), "/", len(df_remlat))
if mask.sum() > 0:
    mae = np.mean(np.abs(pr_vals - gt_vals))
    sd  = np.std(np.abs(pr_vals - gt_vals), ddof=1) if mask.sum() > 1 else 0.0
    print(f"REM latency MAE = {mae:.2f} ± {sd:.2f} minutes")

# BlandAltman using your existing function
bland_altman(
    gt_vals, pr_vals,
    title="BlandAltman: REM latency (minutes)",
    xlab="Mean REM latency (min)",
    ylab="REM latency diff (Pred - GT, min)",
    save_path=OUT_BASE / "bland_altman_REM_latency.png"
)

# save CSV
df_remlat.to_csv(OUT_BASE / "rem_latency_gt_vs_pred.csv", index=False)
print("Saved:", OUT_BASE / "rem_latency_gt_vs_pred.csv")
print("Saved:", OUT_BASE / "bland_altman_REM_latency.png")


Unnamed: 0,subject_id,REMlat_gt_min,REMlat_pred_min,REMlat_diff_min,REMlat_abs_min
0,shhs1-200018_v1,42.5,88.0,45.5,45.5
1,shhs1-200021_v1,61.5,62.5,1.0,1.0
2,shhs1-200035_v1,202.5,194.5,-8.0,8.0
3,shhs1-200039_v1,96.5,96.5,0.0,0.0
4,shhs1-200043_v1,60.0,59.5,-0.5,0.5


REM latency: valid nights = 540 / 548
REM latency MAE = 14.10 ± 34.86 minutes


  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0, flags=flags)


Saved: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/night_level_analysis_shhs1_test/rem_latency_gt_vs_pred.csv
Saved: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/night_level_analysis_shhs1_test/bland_altman_REM_latency.png


In [19]:
# %% [CELL] Paper-style hypnogram (overlay) with legend outside

import matplotlib.pyplot as plt
import numpy as np

def save_hypnogram_paper_overlay(
    y_true, y_pred, out_png,
    subject_id=None,
    show_title=True,
    legend_outside=True,
    dpi=400
):
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    E = len(y_true)

    t_hours = (np.arange(E) * EPOCH_SEC) / 3600.0

    plt.figure(figsize=(10.5, 2.4))  # compact, paper-friendly

    plt.step(t_hours, y_true, where="post", linewidth=1.4, label="GT")
    plt.step(t_hours, y_pred, where="post", linewidth=1.4, label="Pred", alpha=0.9)

    # Put W on top like your y-axis labels order: W, N1, N2, N3, REM
    # Your numeric mapping is 0=W,1=N1,2=N2,3=N3,4=REM, so invert axis:
    plt.gca().invert_yaxis()
    plt.yticks([0,1,2,3,4], ["W","N1","N2","N3","REM"])

    plt.xlabel("Time (hours)")
    plt.ylabel("Stage")
    plt.grid(True, alpha=0.25)

    if show_title and subject_id is not None:
        plt.title(f"SHHS1 Test Hypnogram | {subject_id}")

    if legend_outside:
        # Legend outside, no box inside the plot
        plt.legend(loc="center left", bbox_to_anchor=(1.01, 0.5), frameon=False)
        plt.tight_layout(rect=[0, 0, 0.86, 1])  # leave space on right
    else:
        # Or no legend at all:
        # plt.legend().remove()
        plt.tight_layout()

    plt.savefig(out_png, dpi=dpi, bbox_inches="tight")
    plt.close()

print("Ready: save_hypnogram_paper_overlay")


Ready: save_hypnogram_paper_overlay


In [20]:
# Example: regenerate for shhs1-201437_v1
sid = "shhs1-201437_v1"
r = next(rr for rr in records if rr["subject_id"] == sid)

save_hypnogram_paper_overlay(
    r["y_true"], r["y_pred"],
    out_png=OUT_BASE / f"hypnogram_{sid}_paper.png",
    subject_id=sid,
    show_title=False,          # <- remove title if you want
    legend_outside=True        # <- legend outside
)
