In [1]:
# %% [CELL 0] Imports + GPU + device
import os, json, math, random
from pathlib import Path
from tqdm import tqdm
from collections import Counter

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, average_precision_score

# --- choose GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "4"   # change if needed

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
print("Visible CUDA devices:", torch.cuda.device_count())
print("Using device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


CUDA available: True
Visible CUDA devices: 1
Using device: cuda
GPU: NVIDIA RTX A6000


In [2]:
# %% [CELL 1] Paths + Manifest split (same)
ROOT = Path("/data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/")
MANIFEST_PATH = ROOT / "manifest_sleepstaging_planA.csv"
assert MANIFEST_PATH.exists(), f"Missing manifest: {MANIFEST_PATH}"

manifest = pd.read_csv(MANIFEST_PATH)
print("Rows:", len(manifest))
print(manifest.groupby(["cohort","split"]).size())

df_train = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="train")].copy()
df_val   = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="val")].copy()
df_test  = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="test")].copy()
df_ext   = manifest[(manifest.cohort=="SHHS2") & (manifest.split=="external_test")].copy()

df_mesa = manifest[(manifest.cohort=="MESA")].copy() if ("MESA" in manifest["cohort"].unique()) else None
if df_mesa is not None and len(df_mesa) > 0:
    if "external_test" in df_mesa["split"].unique():
        df_mesa = df_mesa[df_mesa.split=="external_test"].copy()
    print("MESA detected | rows:", len(df_mesa))
else:
    df_mesa = None
    print("MESA not detected in manifest (ok).")

print("DF sizes:", len(df_train), len(df_val), len(df_test), len(df_ext),
      ("| MESA="+str(len(df_mesa)) if df_mesa is not None else ""))


Rows: 9868
cohort  split        
MESA    external_test    1856
SHHS1   test              548
        train            4380
        val               548
SHHS2   external_test    2536
dtype: int64
MESA detected | rows: 1856
DF sizes: 4380 548 548 2536 | MESA=1856


In [3]:
# %% [CELL 2] Augment + Normalize (identical)
class EEGAugment:
    def __init__(self,
                 p_amp=0.5, p_noise=0.5, p_shift=0.5, p_bandstop=0.3, p_freqdrop=0.3,
                 amp_range=(0.8, 1.2), noise_std=0.01, shift_max=125,
                 bandstop_ranges=((49,51), (59,61)), freqdrop_max_bins=12):
        self.p_amp = p_amp
        self.p_noise = p_noise
        self.p_shift = p_shift
        self.p_bandstop = p_bandstop
        self.p_freqdrop = p_freqdrop
        self.amp_range = amp_range
        self.noise_std = noise_std
        self.shift_max = shift_max
        self.bandstop_ranges = bandstop_ranges
        self.freqdrop_max_bins = freqdrop_max_bins

    def _amp_scale(self, x):
        s = np.random.uniform(self.amp_range[0], self.amp_range[1])
        return x * s

    def _gaussian_noise(self, x):
        std = np.std(x, axis=1, keepdims=True) + 1e-6
        noise = np.random.randn(*x.shape).astype(np.float32) * (self.noise_std * std)
        return x + noise

    def _time_shift(self, x):
        shift = np.random.randint(-self.shift_max, self.shift_max+1)
        return np.roll(x, shift=shift, axis=1)

    def _bandstop_fft(self, x, fs=125.0):
        E, T_ = x.shape
        X = np.fft.rfft(x, axis=1)
        freqs = np.fft.rfftfreq(T_, d=1.0/fs)
        for (f1, f2) in self.bandstop_ranges:
            mask = (freqs >= f1) & (freqs <= f2)
            X[:, mask] = 0.0
        y = np.fft.irfft(X, n=T_, axis=1).astype(np.float32)
        return y

    def _freq_dropout(self, x):
        E, T_ = x.shape
        X = np.fft.rfft(x, axis=1)
        Fbins = X.shape[1]
        drop = np.random.randint(1, self.freqdrop_max_bins+1)
        start = np.random.randint(0, max(1, Fbins - drop))
        X[:, start:start+drop] = 0.0
        y = np.fft.irfft(X, n=T_, axis=1).astype(np.float32)
        return y

    def __call__(self, x, fs=125.0):
        if np.random.rand() < self.p_amp:
            x = self._amp_scale(x)
        if np.random.rand() < self.p_noise:
            x = self._gaussian_noise(x)
        if np.random.rand() < self.p_shift:
            x = self._time_shift(x)
        if np.random.rand() < self.p_bandstop:
            x = self._bandstop_fft(x, fs=fs)
        if np.random.rand() < self.p_freqdrop:
            x = self._freq_dropout(x)
        return x

def normalize_epochs_zscore(x, eps=1e-6, clip=10.0):
    mu = np.mean(x, axis=1, keepdims=True)
    sd = np.std(x, axis=1, keepdims=True) + eps
    x = (x - mu) / sd
    if clip is not None:
        x = np.clip(x, -clip, clip)
    return x.astype(np.float32)

augment = EEGAugment()
print("Augmenter + Normalizer ready (per-epoch z-score).")


Augmenter + Normalizer ready (per-epoch z-score).


In [4]:
# %% [CELL 3] Labels/constants + Dataset (identical)
LABELS = {0:"W", 1:"N1", 2:"N2", 3:"N3", 4:"REM"}
NUM_CLASSES = 5
FS = 125
T = 3750

def _compute_runlength_remaining(y):
    E = len(y)
    rem = np.zeros((E,), dtype=np.int64)
    t = 0
    while t < E:
        j = t
        while j < E and y[j] == y[t]:
            j += 1
        for k in range(t, j):
            rem[k] = (j - k)
        t = j
    return rem

def _bucketize_remaining(rem, edges=(2, 5, 10, 20, 40, 80, 160)):
    edges = np.array(edges, dtype=np.int64)
    b = np.zeros_like(rem, dtype=np.int64)
    for i, r in enumerate(rem):
        b[i] = int(np.searchsorted(edges, r, side="right"))
    return b

def _make_soft_targets_boundary(y, num_classes=5, radius=2, alpha_max=0.35):
    E = len(y)
    soft = np.zeros((E, num_classes), dtype=np.float32)
    soft[np.arange(E), y] = 1.0

    if E < 2 or radius <= 0:
        return soft

    trans = np.where(y[1:] != y[:-1])[0] + 1
    if trans.size == 0:
        return soft

    for t in trans:
        for dt in range(-radius, radius+1):
            i = t + dt
            if i < 0 or i >= E:
                continue
            dist = abs(dt)
            if dist > radius:
                continue
            alpha = alpha_max * (1.0 - (dist / max(radius, 1)))
            if alpha <= 0:
                continue

            if i < t:
                nb = y[t]
            else:
                nb = y[t-1]

            cur = y[i]
            if nb == cur:
                continue

            soft[i, :] = 0.0
            soft[i, cur] = 1.0 - alpha
            soft[i, nb]  = alpha

    soft = soft / (soft.sum(axis=1, keepdims=True) + 1e-8)
    return soft.astype(np.float32)

class SleepSequenceDataset(Dataset):
    def __init__(self, df, mode="train",
                 max_hours=None, min_hours=2.0,
                 augmentor=None, exclude_unknown=True, do_normalize=True,
                 boundary_oversample_p=0.70, boundary_radius=2,
                 soft_radius=2, soft_alpha_max=0.35,
                 dur_edges=(2,5,10,20,40,80,160)):
        self.paths = df["npz_path"].tolist()
        self.mode = mode
        self.max_hours = max_hours
        self.min_hours = min_hours
        self.augmentor = augmentor
        self.exclude_unknown = exclude_unknown
        self.do_normalize = do_normalize

        self.boundary_oversample_p = boundary_oversample_p
        self.boundary_radius = int(boundary_radius)

        self.soft_radius = int(soft_radius)
        self.soft_alpha_max = float(soft_alpha_max)
        self.dur_edges = tuple(dur_edges)

        print(f"SleepSequenceDataset[{mode}] files={len(self.paths)} max_hours={self.max_hours} normalize={self.do_normalize}")

    def __len__(self):
        return len(self.paths)

    def _pick_block_start_boundary_aware(self, y, L):
        E = len(y)
        if E <= L:
            return 0
        if np.random.rand() > self.boundary_oversample_p:
            return np.random.randint(0, E - L + 1)
        trans = np.where(y[1:] != y[:-1])[0] + 1
        if trans.size == 0:
            return np.random.randint(0, E - L + 1)
        t = int(np.random.choice(trans))
        start = t - (L // 2)
        start = max(0, min(start, E - L))
        return int(start)

    def __getitem__(self, idx):
        p = self.paths[idx]
        d = np.load(p, allow_pickle=True)

        x = d["x"].astype(np.float32)   # (E,T)
        y = d["y"].astype(np.int64)     # (E,)

        if self.exclude_unknown:
            keep = (y >= 0)
            x = x[keep]
            y = y[keep]

        if self.do_normalize:
            x = normalize_epochs_zscore(x, eps=1e-6, clip=10.0)

        E = len(y)

        soft = _make_soft_targets_boundary(
            y, num_classes=NUM_CLASSES, radius=self.soft_radius, alpha_max=self.soft_alpha_max
        )
        rem = _compute_runlength_remaining(y)
        dur_bucket = _bucketize_remaining(rem, edges=self.dur_edges)

        if self.max_hours is not None:
            max_L = int((self.max_hours * 3600) / 30)
            min_L = int((self.min_hours * 3600) / 30)
            L = min(max_L, E)
            if E > L:
                start = self._pick_block_start_boundary_aware(y, L)
                x = x[start:start+L]
                y = y[start:start+L]
                soft = soft[start:start+L]
                dur_bucket = dur_bucket[start:start+L]

        if self.mode == "train" and self.augmentor is not None:
            x = self.augmentor(x, fs=FS)

        x_t = torch.from_numpy(x).unsqueeze(1)     # (E,1,T)
        y_t = torch.from_numpy(y).long()           # (E,)
        soft_t = torch.from_numpy(soft).float()    # (E,C)
        dur_t = torch.from_numpy(dur_bucket).long()# (E,)
        mask = torch.ones((x_t.shape[0],), dtype=torch.bool)

        return x_t, y_t, mask, soft_t, dur_t

def collate_pad(batch):
    lengths = [b[0].shape[0] for b in batch]
    Lmax = max(lengths)
    xs, ys, ms, ss, ds = [], [], [], [], []
    for x, y, m, s, d in batch:
        L = x.shape[0]
        padL = Lmax - L
        if padL > 0:
            x = torch.cat([x, torch.zeros((padL, 1, T), dtype=x.dtype)], dim=0)
            y = torch.cat([y, torch.zeros((padL,), dtype=y.dtype)], dim=0)
            m = torch.cat([m, torch.zeros((padL,), dtype=torch.bool)], dim=0)
            s = torch.cat([s, torch.zeros((padL, NUM_CLASSES), dtype=s.dtype)], dim=0)
            d = torch.cat([d, torch.zeros((padL,), dtype=d.dtype)], dim=0)
        xs.append(x); ys.append(y); ms.append(m); ss.append(s); ds.append(d)

    x = torch.stack(xs, dim=0)
    y = torch.stack(ys, dim=0)
    m = torch.stack(ms, dim=0)
    s = torch.stack(ss, dim=0)
    d = torch.stack(ds, dim=0)
    return x, y, m, s, d


In [5]:
# %% [CELL 4] Eval loaders (identical)
val_seq_ds  = SleepSequenceDataset(df_val,  mode="eval", max_hours=None, augmentor=None, do_normalize=True)
test_seq_ds = SleepSequenceDataset(df_test, mode="eval", max_hours=None, augmentor=None, do_normalize=True)
ext_seq_ds  = SleepSequenceDataset(df_ext,  mode="eval", max_hours=None, augmentor=None, do_normalize=True)
mesa_seq_ds = SleepSequenceDataset(df_mesa, mode="eval", max_hours=None, augmentor=None, do_normalize=True) if df_mesa is not None else None

PIN = True
val_seq_loader  = DataLoader(val_seq_ds,  batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad)
test_seq_loader = DataLoader(test_seq_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad)
ext_seq_loader  = DataLoader(ext_seq_ds,  batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad)
mesa_seq_loader = DataLoader(mesa_seq_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad) if mesa_seq_ds is not None else None

print("Loaders built:",
      "VAL", len(val_seq_loader),
      "| TEST", len(test_seq_loader),
      "| SHHS2", len(ext_seq_loader),
      "| MESA", (len(mesa_seq_loader) if mesa_seq_loader is not None else None))


SleepSequenceDataset[eval] files=548 max_hours=None normalize=True
SleepSequenceDataset[eval] files=548 max_hours=None normalize=True
SleepSequenceDataset[eval] files=2536 max_hours=None normalize=True
SleepSequenceDataset[eval] files=1856 max_hours=None normalize=True
Loaders built: VAL 548 | TEST 548 | SHHS2 2536 | MESA 1856


In [6]:
# %% [CELL 5] Model (identical) + loss helpers needed for eval

class DropPath(nn.Module):
    def __init__(self, drop_prob=0.1):
        super().__init__()
        self.drop_prob = float(drop_prob)
    def forward(self, x):
        if (not self.training) or self.drop_prob == 0.0:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rand = keep + torch.rand(shape, device=x.device)
        mask = torch.floor(rand)
        return x / keep * mask

class ResConv1D(nn.Module):
    def __init__(self, c_in, c_out, k, s=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(c_in, c_out, k, stride=s, padding=k//2),
            nn.BatchNorm1d(c_out),
            nn.GELU(),
            nn.Conv1d(c_out, c_out, k, padding=k//2),
            nn.BatchNorm1d(c_out),
        )
        self.skip = nn.Conv1d(c_in, c_out, 1, stride=s) if (c_in != c_out or s != 1) else nn.Identity()
        self.act = nn.GELU()
    def forward(self, x):
        return self.act(self.conv(x) + self.skip(x))

class EpochEncoder(nn.Module):
    def __init__(self, d_model=384):
        super().__init__()
        self.branch_short = ResConv1D(1, 128, k=7,  s=4)
        self.branch_mid   = ResConv1D(1, 128, k=15, s=4)
        self.branch_long  = ResConv1D(1, 128, k=31, s=4)

        self.freq_proj = nn.Sequential(
            nn.Linear(1876, 256),
            nn.LayerNorm(256),
            nn.GELU(),
        )

        self.fuse = nn.Sequential(
            nn.Linear(128*3 + 256, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
        )

    def forward(self, x):
        B, L, _, T_ = x.shape
        x = x.view(B*L, 1, T_)

        zs = self.branch_short(x).mean(-1)
        zm = self.branch_mid(x).mean(-1)
        zl = self.branch_long(x).mean(-1)

        with torch.cuda.amp.autocast(enabled=False):
            xf32 = x.squeeze(1).float()
            Xf = torch.fft.rfft(xf32, dim=-1)
            mag = torch.abs(Xf)
            mag = mag[:, :1876]
            mag = torch.log1p(mag)
            mag = mag / (mag.mean(dim=1, keepdim=True) + 1e-6)

        zf = self.freq_proj(mag)
        z = torch.cat([zs, zm, zl, zf.to(zs.dtype)], dim=-1)
        z = self.fuse(z)
        return z.view(B, L, -1)

def rotate_half(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)

class RoPE(nn.Module):
    def __init__(self, head_dim, base=10000):
        super().__init__()
        assert head_dim % 2 == 0
        self.head_dim = head_dim
        self.base = base
    def forward(self, x):
        B, L, H, Dh = x.shape
        half = Dh // 2
        freqs = 1.0 / (self.base ** (torch.arange(half, device=x.device) / half))
        t = torch.arange(L, device=x.device)
        angles = torch.einsum("l,d->ld", t, freqs)
        cos = torch.cos(angles)[None, :, None, :]
        sin = torch.sin(angles)[None, :, None, :]
        cos = cos.repeat_interleave(2, dim=-1)
        sin = sin.repeat_interleave(2, dim=-1)
        return (x * cos) + (rotate_half(x) * sin)

def _windows(L, w):
    out = []
    s = 0
    while s < L:
        e = min(L, s + w)
        out.append((s, e))
        s = e
    return out

class MultiHeadSelfAttentionRoPE_LocalGlobal(nn.Module):
    def __init__(self, d_model=384, n_heads=8, dropout=0.1, window_size=64):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.window_size = int(window_size)

        self.qkv = nn.Linear(d_model, 3*d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
        self.rope = RoPE(self.d_head)

    def forward(self, x, key_padding_mask=None, global_attn=False):
        B, L, D = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, L, self.n_heads, self.d_head)
        k = k.view(B, L, self.n_heads, self.d_head)
        v = v.view(B, L, self.n_heads, self.d_head)

        q = self.rope(q)
        k = self.rope(k)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if global_attn or self.window_size >= L:
            scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
            scores = scores.float()
            if key_padding_mask is not None:
                scores = scores.masked_fill(~key_padding_mask[:, None, None, :], -1e9)
            attn = torch.softmax(scores, dim=-1)
            attn = self.drop(attn).to(v.dtype)
            out = attn @ v
            out = out.transpose(1, 2).contiguous().view(B, L, D)
            return self.proj(out)

        w = self.window_size
        out = torch.zeros((B, self.n_heads, L, self.d_head), device=x.device, dtype=v.dtype)
        for (s, e) in _windows(L, w):
            qs = q[:, :, s:e, :]
            ks = k[:, :, s:e, :]
            vs = v[:, :, s:e, :]

            scores = (qs @ ks.transpose(-2, -1)) / math.sqrt(self.d_head)
            scores = scores.float()
            if key_padding_mask is not None:
                m = key_padding_mask[:, s:e]
                scores = scores.masked_fill(~m[:, None, None, :], -1e9)

            attn = torch.softmax(scores, dim=-1)
            attn = self.drop(attn).to(vs.dtype)
            out[:, :, s:e, :] = attn @ vs

        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.proj(out)

class TransformerBlockLG(nn.Module):
    def __init__(self, d_model=384, n_heads=8, drop=0.1, drop_path=0.1, window_size=64):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttentionRoPE_LocalGlobal(d_model, n_heads, drop, window_size=window_size)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(4*d_model, d_model),
        )
        self.dp = DropPath(drop_path)

    def forward(self, x, mask, global_attn=False):
        x = x + self.dp(self.attn(self.ln1(x), key_padding_mask=mask, global_attn=global_attn))
        x = x + self.dp(self.mlp(self.ln2(x)))
        return x

class HierSleepTransformerV5_1(nn.Module):
    def __init__(self, num_classes=5, d_model=384, depth=12, n_heads=8,
                 dur_bins=8, window_size=64, global_every=3):
        super().__init__()
        self.num_classes = num_classes
        self.dur_bins = dur_bins
        self.depth = int(depth)
        self.global_every = int(global_every)

        self.encoder = EpochEncoder(d_model)
        self.blocks = nn.ModuleList([
            TransformerBlockLG(
                d_model=d_model,
                n_heads=n_heads,
                drop=0.1,
                drop_path=0.1*(i+1)/depth,
                window_size=window_size
            )
            for i in range(depth)
        ])
        self.head = nn.Linear(d_model, num_classes)

        self.aux_n1 = nn.Linear(d_model, 2)
        self.aux_dur = nn.Linear(d_model, dur_bins)
        self.trans_logits = nn.Parameter(torch.zeros(num_classes, num_classes))

    def forward(self, x, mask):
        z = self.encoder(x)
        for i, blk in enumerate(self.blocks):
            use_global = (self.global_every > 0) and ((i % self.global_every) == 0)
            z = blk(z, mask, global_attn=use_global)
        main_logits = self.head(z)
        aux_logits  = self.aux_n1(z)
        dur_logits  = self.aux_dur(z)
        return main_logits, aux_logits, dur_logits

# dur bins
DUR_EDGES = (2,5,10,20,40,80,160)
DUR_BINS = len(DUR_EDGES) + 1

model = HierSleepTransformerV5_1(
    num_classes=NUM_CLASSES,
    d_model=384,
    depth=12,
    n_heads=8,
    dur_bins=DUR_BINS,
    window_size=64,
    global_every=3
).to(device)

print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)


Model params (M): 22.905512


In [7]:
# %% [CELL 6] Loss switches + helpers (identical defaults)
USE_LA_CE = True
LA_TAU = 1.0
USE_HARD_NEG_N1 = True
HARD_NEG_MULT = 2.0

USE_SOFT_BOUNDARY_LOSS = True
SOFT_BOUNDARY_WEIGHT = 0.25

USE_COST_MATRIX = True
COST_WEIGHT = 0.20

USE_AUX_N1 = True
AUX_N1_WEIGHT = 0.30

USE_AUX_DUR = True
AUX_DUR_WEIGHT = 0.15
AUX_DUR_N1_MULT = 1.50

USE_TRANS_LOSS = True
TRANS_LOSS_WEIGHT = 0.10

USE_LEARNED_SMOOTHING = True
USE_VITERBI = False  # keep same as your eval default

# class weights from your train set (needed for LA-CE prior)
def class_counts_train(df):
    c = Counter()
    for p in tqdm(df["npz_path"].tolist(), desc="Counting train labels", leave=False):
        d = np.load(p, allow_pickle=True)
        y = d["y"].astype(np.int64)
        y = y[y >= 0]
        c.update(y.tolist())
    counts = np.array([c.get(i, 0) for i in range(NUM_CLASSES)], dtype=np.float64)
    return counts

counts = class_counts_train(df_train)

class LogitAdjustedCE(nn.Module):
    def __init__(self, class_freq, tau=1.0):
        super().__init__()
        freq = torch.tensor(class_freq, dtype=torch.float32)
        self.register_buffer("log_prior", torch.log(freq / freq.sum()))
        self.tau = float(tau)
    def forward(self, logits, targets, reduction="none"):
        logits = logits + self.tau * self.log_prior
        return F.cross_entropy(logits, targets, reduction=reduction)

def label_smoothing_nll(logits, targets, smooth_per_class):
    logp = F.log_softmax(logits, dim=-1)
    nll = -logp.gather(dim=-1, index=targets.view(-1,1)).squeeze(1)
    smooth = -logp.mean(dim=-1)
    s = smooth_per_class[targets]
    return (1 - s) * nll + s * smooth

smooth_vec = torch.tensor([0.02, 0.00, 0.05, 0.05, 0.02], dtype=torch.float32).to(device)
la_ce = LogitAdjustedCE(class_freq=counts, tau=LA_TAU).to(device)

def soft_target_ce(logits, soft_targets):
    logp = F.log_softmax(logits, dim=-1)
    return -(soft_targets * logp).sum(dim=-1)

def build_cost_matrix(device):
    C = NUM_CLASSES
    cost = torch.zeros((C, C), dtype=torch.float32, device=device)
    cost += 0.05
    cost.fill_diagonal_(0.0)
    cost[1, 0] = 1.00
    cost[1, 2] = 1.00
    cost[0, 1] = 0.60
    cost[2, 1] = 0.60
    cost[2, 3] = 0.20
    cost[3, 2] = 0.20
    cost[0, 4] = 0.15
    cost[4, 0] = 0.15
    return cost

COST_MAT = build_cost_matrix(device)

def compute_transition_loss(model, y, mask):
    B, L = y.shape
    y_prev = y[:, :-1]
    y_next = y[:, 1:]
    m_pair = mask[:, :-1] & mask[:, 1:]
    if m_pair.sum().item() == 0:
        return torch.zeros((), device=y.device)
    y_prev_v = y_prev[m_pair]
    y_next_v = y_next[m_pair]
    trans_logits = model.trans_logits
    logits_pair = trans_logits[y_prev_v]
    return F.cross_entropy(logits_pair, y_next_v)

def masked_loss_v5(model, main_logits, aux_logits, dur_logits, y, mask, soft_targets, dur_bucket):
    B, L, C = main_logits.shape
    logits2 = main_logits.view(B*L, C)
    y2 = y.view(B*L)
    m2 = mask.view(B*L)

    logits_valid = logits2[m2]
    y_valid = y2[m2]

    if USE_LA_CE:
        adj_logits = logits_valid + la_ce.tau * la_ce.log_prior
        loss_main_vec = label_smoothing_nll(adj_logits, y_valid, smooth_vec)
    else:
        loss_main_vec = label_smoothing_nll(logits_valid, y_valid, smooth_vec)

    if USE_HARD_NEG_N1:
        with torch.no_grad():
            pred = torch.argmax(logits_valid, dim=-1)
            hard = (y_valid == 1) & ((pred == 0) | (pred == 2))
        loss_main_vec = loss_main_vec * torch.where(
            hard,
            torch.tensor(HARD_NEG_MULT, device=logits_valid.device),
            torch.tensor(1.0, device=logits_valid.device)
        )

    loss = loss_main_vec.mean()

    if USE_SOFT_BOUNDARY_LOSS:
        soft2 = soft_targets.view(B*L, C)[m2]
        if USE_LA_CE:
            adj_logits_soft = logits_valid + la_ce.tau * la_ce.log_prior
            loss_soft_vec = soft_target_ce(adj_logits_soft, soft2)
        else:
            loss_soft_vec = soft_target_ce(logits_valid, soft2)
        loss = loss + SOFT_BOUNDARY_WEIGHT * loss_soft_vec.mean()

    if USE_COST_MATRIX:
        probs = torch.softmax(logits_valid.float(), dim=-1)
        cost_row = COST_MAT[y_valid]
        expected_cost = (probs * cost_row).sum(dim=-1)
        loss = loss + COST_WEIGHT * expected_cost.mean()

    if USE_AUX_N1:
        aux2 = aux_logits.view(B*L, 2)[m2]
        y_aux = (y_valid == 1).long()
        loss_aux = F.cross_entropy(aux2, y_aux)
        loss = loss + AUX_N1_WEIGHT * loss_aux

    if USE_AUX_DUR:
        dur2 = dur_logits.view(B*L, DUR_BINS)[m2]
        dur_t = dur_bucket.view(B*L)[m2]
        w = torch.ones_like(dur_t, dtype=torch.float32, device=dur2.device)
        w = w * torch.where(y_valid == 1, torch.tensor(AUX_DUR_N1_MULT, device=dur2.device), torch.tensor(1.0, device=dur2.device))
        loss_dur = F.cross_entropy(dur2, dur_t, reduction="none")
        loss_dur = (loss_dur * w).mean()
        loss = loss + AUX_DUR_WEIGHT * loss_dur

    if USE_TRANS_LOSS:
        loss_trans = compute_transition_loss(model, y, mask)
        loss = loss + TRANS_LOSS_WEIGHT * loss_trans

    return loss

def apply_learned_smoothing_probs(probs, model):
    Tm = torch.softmax(model.trans_logits, dim=1)
    return probs @ Tm


                                                                                

In [8]:
# %% [CELL 6] Loss switches + helpers (identical defaults)
USE_LA_CE = True
LA_TAU = 1.0
USE_HARD_NEG_N1 = True
HARD_NEG_MULT = 2.0

USE_SOFT_BOUNDARY_LOSS = True
SOFT_BOUNDARY_WEIGHT = 0.25

USE_COST_MATRIX = True
COST_WEIGHT = 0.20

USE_AUX_N1 = True
AUX_N1_WEIGHT = 0.30

USE_AUX_DUR = True
AUX_DUR_WEIGHT = 0.15
AUX_DUR_N1_MULT = 1.50

USE_TRANS_LOSS = True
TRANS_LOSS_WEIGHT = 0.10

USE_LEARNED_SMOOTHING = True
USE_VITERBI = False  # keep same as your eval default

# class weights from your train set (needed for LA-CE prior)
def class_counts_train(df):
    c = Counter()
    for p in tqdm(df["npz_path"].tolist(), desc="Counting train labels", leave=False):
        d = np.load(p, allow_pickle=True)
        y = d["y"].astype(np.int64)
        y = y[y >= 0]
        c.update(y.tolist())
    counts = np.array([c.get(i, 0) for i in range(NUM_CLASSES)], dtype=np.float64)
    return counts

counts = class_counts_train(df_train)

class LogitAdjustedCE(nn.Module):
    def __init__(self, class_freq, tau=1.0):
        super().__init__()
        freq = torch.tensor(class_freq, dtype=torch.float32)
        self.register_buffer("log_prior", torch.log(freq / freq.sum()))
        self.tau = float(tau)
    def forward(self, logits, targets, reduction="none"):
        logits = logits + self.tau * self.log_prior
        return F.cross_entropy(logits, targets, reduction=reduction)

def label_smoothing_nll(logits, targets, smooth_per_class):
    logp = F.log_softmax(logits, dim=-1)
    nll = -logp.gather(dim=-1, index=targets.view(-1,1)).squeeze(1)
    smooth = -logp.mean(dim=-1)
    s = smooth_per_class[targets]
    return (1 - s) * nll + s * smooth

smooth_vec = torch.tensor([0.02, 0.00, 0.05, 0.05, 0.02], dtype=torch.float32).to(device)
la_ce = LogitAdjustedCE(class_freq=counts, tau=LA_TAU).to(device)

def soft_target_ce(logits, soft_targets):
    logp = F.log_softmax(logits, dim=-1)
    return -(soft_targets * logp).sum(dim=-1)

def build_cost_matrix(device):
    C = NUM_CLASSES
    cost = torch.zeros((C, C), dtype=torch.float32, device=device)
    cost += 0.05
    cost.fill_diagonal_(0.0)
    cost[1, 0] = 1.00
    cost[1, 2] = 1.00
    cost[0, 1] = 0.60
    cost[2, 1] = 0.60
    cost[2, 3] = 0.20
    cost[3, 2] = 0.20
    cost[0, 4] = 0.15
    cost[4, 0] = 0.15
    return cost

COST_MAT = build_cost_matrix(device)

def compute_transition_loss(model, y, mask):
    B, L = y.shape
    y_prev = y[:, :-1]
    y_next = y[:, 1:]
    m_pair = mask[:, :-1] & mask[:, 1:]
    if m_pair.sum().item() == 0:
        return torch.zeros((), device=y.device)
    y_prev_v = y_prev[m_pair]
    y_next_v = y_next[m_pair]
    trans_logits = model.trans_logits
    logits_pair = trans_logits[y_prev_v]
    return F.cross_entropy(logits_pair, y_next_v)

def masked_loss_v5(model, main_logits, aux_logits, dur_logits, y, mask, soft_targets, dur_bucket):
    B, L, C = main_logits.shape
    logits2 = main_logits.view(B*L, C)
    y2 = y.view(B*L)
    m2 = mask.view(B*L)

    logits_valid = logits2[m2]
    y_valid = y2[m2]

    if USE_LA_CE:
        adj_logits = logits_valid + la_ce.tau * la_ce.log_prior
        loss_main_vec = label_smoothing_nll(adj_logits, y_valid, smooth_vec)
    else:
        loss_main_vec = label_smoothing_nll(logits_valid, y_valid, smooth_vec)

    if USE_HARD_NEG_N1:
        with torch.no_grad():
            pred = torch.argmax(logits_valid, dim=-1)
            hard = (y_valid == 1) & ((pred == 0) | (pred == 2))
        loss_main_vec = loss_main_vec * torch.where(
            hard,
            torch.tensor(HARD_NEG_MULT, device=logits_valid.device),
            torch.tensor(1.0, device=logits_valid.device)
        )

    loss = loss_main_vec.mean()

    if USE_SOFT_BOUNDARY_LOSS:
        soft2 = soft_targets.view(B*L, C)[m2]
        if USE_LA_CE:
            adj_logits_soft = logits_valid + la_ce.tau * la_ce.log_prior
            loss_soft_vec = soft_target_ce(adj_logits_soft, soft2)
        else:
            loss_soft_vec = soft_target_ce(logits_valid, soft2)
        loss = loss + SOFT_BOUNDARY_WEIGHT * loss_soft_vec.mean()

    if USE_COST_MATRIX:
        probs = torch.softmax(logits_valid.float(), dim=-1)
        cost_row = COST_MAT[y_valid]
        expected_cost = (probs * cost_row).sum(dim=-1)
        loss = loss + COST_WEIGHT * expected_cost.mean()

    if USE_AUX_N1:
        aux2 = aux_logits.view(B*L, 2)[m2]
        y_aux = (y_valid == 1).long()
        loss_aux = F.cross_entropy(aux2, y_aux)
        loss = loss + AUX_N1_WEIGHT * loss_aux

    if USE_AUX_DUR:
        dur2 = dur_logits.view(B*L, DUR_BINS)[m2]
        dur_t = dur_bucket.view(B*L)[m2]
        w = torch.ones_like(dur_t, dtype=torch.float32, device=dur2.device)
        w = w * torch.where(y_valid == 1, torch.tensor(AUX_DUR_N1_MULT, device=dur2.device), torch.tensor(1.0, device=dur2.device))
        loss_dur = F.cross_entropy(dur2, dur_t, reduction="none")
        loss_dur = (loss_dur * w).mean()
        loss = loss + AUX_DUR_WEIGHT * loss_dur

    if USE_TRANS_LOSS:
        loss_trans = compute_transition_loss(model, y, mask)
        loss = loss + TRANS_LOSS_WEIGHT * loss_trans

    return loss

def apply_learned_smoothing_probs(probs, model):
    Tm = torch.softmax(model.trans_logits, dim=1)
    return probs @ Tm


                                                                                

In [9]:
# %% [CELL 7] Load checkpoint + apply EMA shadow (identical)
CKPT_DIR = ROOT / "checkpoints_hier_rope_seq_v5_1"
BEST_CKPT = CKPT_DIR / "BEST_VAL_macroF1.pt"   # change if needed
assert BEST_CKPT.exists(), f"Missing ckpt: {BEST_CKPT}"
print("Loading:", BEST_CKPT)

ckpt = torch.load(BEST_CKPT, map_location="cpu")
sd = ckpt["model_state"]
missing, unexpected = model.load_state_dict(sd, strict=False)
print("Loaded model_state. missing:", len(missing), "| unexpected:", len(unexpected))

use_ema = bool(ckpt.get("use_ema", False))
ema_shadow = ckpt.get("ema_shadow", None)

if use_ema and isinstance(ema_shadow, dict) and len(ema_shadow) > 0:
    with torch.no_grad():
        curr = model.state_dict()
        applied = 0
        for k, v in ema_shadow.items():
            if k in curr:
                curr[k].copy_(v.to(curr[k].device).to(curr[k].dtype))
                applied += 1
        model.load_state_dict(curr, strict=False)
    print("Applied EMA shadow tensors:", applied)
else:
    print("EMA shadow not applied (missing/disabled).")

model.eval()


Loading: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/checkpoints_hier_rope_seq_v5_1/BEST_VAL_macroF1.pt
Loaded model_state. missing: 0 | unexpected: 0
Applied EMA shadow tensors: 189


HierSleepTransformerV5_1(
  (encoder): EpochEncoder(
    (branch_short): ResConv1D(
      (conv): Sequential(
        (0): Conv1d(1, 128, kernel_size=(7,), stride=(4,), padding=(3,))
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU(approximate=none)
        (3): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
        (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip): Conv1d(1, 128, kernel_size=(1,), stride=(4,))
      (act): GELU(approximate=none)
    )
    (branch_mid): ResConv1D(
      (conv): Sequential(
        (0): Conv1d(1, 128, kernel_size=(15,), stride=(4,), padding=(7,))
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GELU(approximate=none)
        (3): Conv1d(128, 128, kernel_size=(15,), stride=(1,), padding=(7,))
        (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track

In [10]:
# %% [CELL 8] Eval (identical behavior) + return raw arrays for saving
def _ece_from_probs(y_true, probs, n_bins=15):
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == y_true).astype(np.float32)
    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        m = (conf >= lo) & (conf < hi)
        if m.sum() == 0:
            continue
        bin_acc = acc[m].mean()
        bin_conf = conf[m].mean()
        ece += (m.mean()) * abs(bin_acc - bin_conf)
    return float(ece)

def _auroc_auprc_multiclass(y_true, probs, num_classes=5):
    Y = label_binarize(y_true, classes=list(range(num_classes)))
    aurocs, auprcs = [], []
    for c in range(num_classes):
        if Y[:, c].sum() == 0:
            continue
        try:
            aurocs.append(roc_auc_score(Y[:, c], probs[:, c]))
            auprcs.append(average_precision_score(Y[:, c], probs[:, c]))
        except Exception:
            pass
    if len(aurocs) == 0:
        return float("nan"), float("nan")
    return float(np.mean(aurocs)), float(np.mean(auprcs))

@torch.no_grad()
def eval_sequence_collect(model, loader, desc="Eval"):
    model.eval()

    all_true, all_pred, all_probs = [], [], []
    total_loss = 0.0
    total_n = 0
    bad_batches = 0

    for bidx, (xb, yb, mb, sb, db) in enumerate(tqdm(loader, desc=desc, leave=False)):
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        mb = mb.to(device, non_blocking=True)
        sb = sb.to(device, non_blocking=True)
        db = db.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            main_logits, aux_logits, dur_logits = model(xb, mb)
            loss = masked_loss_v5(model, main_logits, aux_logits, dur_logits, yb, mb, sb, db)

        if (not torch.isfinite(main_logits).all()) or (not torch.isfinite(loss)):
            bad_batches += 1
            continue

        probs = torch.softmax(main_logits.float(), dim=-1)
        if USE_LEARNED_SMOOTHING:
            probs = apply_learned_smoothing_probs(probs, model)

        if not torch.isfinite(probs).all():
            bad_batches += 1
            continue

        pred = torch.argmax(probs, dim=-1)

        yv = yb[mb].detach().cpu().numpy()
        pv = pred[mb].detach().cpu().numpy()
        pr = probs[mb].detach().cpu().numpy()

        if yv.size == 0:
            continue

        all_true.append(yv)
        all_pred.append(pv)
        all_probs.append(pr)

        n = int(mb.sum().item())
        total_loss += float(loss.item()) * n
        total_n += n

    if len(all_true) == 0:
        metrics = {
            "loss": float("nan"), "acc": float("nan"), "macro_f1": float("nan"), "kappa": float("nan"),
            "AUROC": float("nan"), "AUPRC": float("nan"), "meanConf": float("nan"), "ECE": float("nan"),
            "f1_per_class": {LABELS[i]: float("nan") for i in range(NUM_CLASSES)},
            "cm": np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64),
            "bad_batches": int(bad_batches),
        }
        return metrics, None

    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_pred)
    probs_all = np.concatenate(all_probs)

    acc = accuracy_score(y_true, y_pred)
    mf1 = f1_score(y_true, y_pred, average="macro")
    kappa = cohen_kappa_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES)))

    auroc, auprc = _auroc_auprc_multiclass(y_true, probs_all, num_classes=NUM_CLASSES)
    mean_conf = float(probs_all.max(axis=1).mean())
    ece = _ece_from_probs(y_true, probs_all, n_bins=15)

    f1_per = {LABELS[i]: float(f1_score((y_true==i).astype(int), (y_pred==i).astype(int)))
              for i in range(NUM_CLASSES)}

    metrics = {
        "loss": total_loss / max(total_n, 1),
        "acc": float(acc),
        "macro_f1": float(mf1),
        "kappa": float(kappa),
        "AUROC": auroc,
        "AUPRC": auprc,
        "meanConf": mean_conf,
        "ECE": ece,
        "f1_per_class": f1_per,
        "cm": cm,
        "bad_batches": int(bad_batches),
    }

    raw = {"y_true": y_true, "y_pred": y_pred, "probs": probs_all}
    return metrics, raw

def print_metrics(tag, m):
    print(f"\n===== {tag} =====")
    for k in ["loss","acc","macro_f1","kappa","AUROC","AUPRC","meanConf","ECE"]:
        if k in m:
            print(f"{k:9s}: {m[k]:.4f}" if isinstance(m[k], (float,int)) else f"{k:9s}: {m[k]}")
    if "bad_batches" in m:
        print("bad_batches:", m["bad_batches"])
    if "f1_per_class" in m:
        print("F1/class :", m["f1_per_class"])
    if "cm" in m:
        labs = [LABELS[i] for i in range(NUM_CLASSES)]
        print("Confusion Matrix labels:", labs)
        print(m["cm"])


In [13]:
# %% [CELL 9] Run evals + save SHHS2+MESA together
OUT_DIR = ROOT / "eval_outputs_fresh_notebook"
OUT_DIR.mkdir(parents=True, exist_ok=True)

val_m,  _      = eval_sequence_collect(model, val_seq_loader,  desc="VAL")
test_m, _      = eval_sequence_collect(model, test_seq_loader, desc="SHHS1 TEST")
shhs2_m, shhs2_raw = eval_sequence_collect(model, ext_seq_loader,  desc="SHHS2 EXT")

print_metrics("VAL", val_m)
print_metrics("SHHS1 TEST", test_m)
print_metrics("SHHS2 EXT", shhs2_m)

mesa_m, mesa_raw = (None, None)
if mesa_seq_loader is not None:
    mesa_m, mesa_raw = eval_sequence_collect(model, mesa_seq_loader, desc="MESA EXT")
    print_metrics("MESA EXT", mesa_m)
else:
    print("\n[MESA] mesa_seq_loader not found. (Skip)")

# --- save SHHS2 + MESA together (metrics + raw arrays)
bundle = {
    "checkpoint": str(BEST_CKPT),
    "shhs2_metrics": shhs2_m,
    "mesa_metrics": mesa_m,
}



# raw arrays saved in NPZ (compact)
npz_path = OUT_DIR / "EVAL_SHHS2_MESA_raw.npz"
save_dict = {
    "shhs2_y_true": shhs2_raw["y_true"] if shhs2_raw is not None else np.array([], dtype=np.int64),
    "shhs2_y_pred": shhs2_raw["y_pred"] if shhs2_raw is not None else np.array([], dtype=np.int64),
    "shhs2_probs":  shhs2_raw["probs"]  if shhs2_raw is not None else np.zeros((0, NUM_CLASSES), dtype=np.float32),
}
if mesa_raw is not None:
    save_dict.update({
        "mesa_y_true": mesa_raw["y_true"],
        "mesa_y_pred": mesa_raw["y_pred"],
        "mesa_probs":  mesa_raw["probs"],
    })
else:
    save_dict.update({
        "mesa_y_true": np.array([], dtype=np.int64),
        "mesa_y_pred": np.array([], dtype=np.int64),
        "mesa_probs":  np.zeros((0, NUM_CLASSES), dtype=np.float32),
    })

np.savez_compressed(npz_path, **save_dict)

# also save confusion matrices as .npy
np.save(OUT_DIR / "cm_shhs2.npy", shhs2_m["cm"])
if mesa_m is not None:
    np.save(OUT_DIR / "cm_mesa.npy", mesa_m["cm"])

print("\nSaved:")
print(" -", json_path)
print(" -", npz_path)
print(" -", OUT_DIR / "cm_shhs2.npy")
print(" -", OUT_DIR / "cm_mesa.npy" if mesa_m is not None else "(no mesa cm)")


                                                                                


===== VAL =====
loss     : 0.8562
acc      : 0.8568
macro_f1 : 0.8019
kappa    : 0.8025
AUROC    : 0.9702
AUPRC    : 0.8488
meanConf : 0.6985
ECE      : 0.1583
bad_batches: 0
F1/class : {'W': 0.9147962818323049, 'N1': 0.5222609682912983, 'N2': 0.8653339068314851, 'N3': 0.8254334653621173, 'REM': 0.8819111537232341}
Confusion Matrix labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[106582   5961   4623    503   1969]
 [  2237  12880   3175     10   1733]
 [  3628   8370 185088  16458   9607]
 [   228     22   9001  62936    195]
 [   705   2056   2746    203  71747]]

===== SHHS1 TEST =====
loss     : 0.8165
acc      : 0.8655
macro_f1 : 0.8108
kappa    : 0.8144
AUROC    : 0.9737
AUPRC    : 0.8622
meanConf : 0.6987
ECE      : 0.1668
bad_batches: 0
F1/class : {'W': 0.9216682597977562, 'N1': 0.536456381428274, 'N2': 0.8737501224798317, 'N3': 0.8319843540426431, 'REM': 0.8902020233270428}
Confusion Matrix labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[109920   5925   3963    418   2187]
 [  1898  12883  

                                                                                


===== MESA EXT =====
loss     : 1.4269
acc      : 0.7713
macro_f1 : 0.6292
kappa    : 0.6698
AUROC    : 0.9210
AUPRC    : 0.7162
meanConf : 0.6540
ECE      : 0.1173
bad_batches: 0
F1/class : {'W': 0.8768155814387883, 'N1': 0.4071434383372159, 'N2': 0.7838374053621394, 'N3': 0.2764576584914391, 'REM': 0.801973383872619}
Confusion Matrix labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[692384  25860  69702   2217  39800]
 [ 23939  67551  75639     50  20011]
 [ 24904  47000 689414   5824  30756]
 [  1618    125 114166  23897    971]
 [  6507   4103  12255    115 231889]]

Saved:
 - /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/eval_outputs_fresh_notebook/EVAL_SHHS2_MESA_bundle.json
 - /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/eval_outputs_fresh_notebook/EVAL_SHHS2_MESA_raw.npz
 - /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/eval_outputs_fresh_notebook/cm_shhs2.npy
 - /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/eval_outputs_fresh_noteboo

In [17]:
# %% [COMBINE ONLY] Use existing shhs2_m + mesa_m (no ext_m)

# shhs2_m and mesa_m already exist in your globals (per your print)
combo = combine_eval_metrics(shhs2_m, mesa_m)

print("\n===== COMBINED (SHHS2 + MESA) =====")
print(f"epochs: SHHS2={combo['n_epochs_shhs2']} | MESA={combo['n_epochs_mesa']} | total={combo['n_epochs_total']}")
for k in ["loss","acc","macro_f1","weighted_f1","kappa","AUROC","AUPRC","meanConf","ECE"]:
    print(f"{k:11s}: {combo[k]:.4f}")

print("F1/class:", combo["f1_per_class"])
print("Confusion Matrix labels:", [LABELS[i] for i in range(NUM_CLASSES)])
print(combo["cm"])



===== COMBINED (SHHS2 + MESA) =====
epochs: SHHS2=2858985 | MESA=2210697 | total=5069682
loss       : 1.0677
acc        : 0.8267
macro_f1   : 0.7496
weighted_f1: 0.8274
kappa      : 0.7562
AUROC      : 0.9508
AUPRC      : 0.7982
meanConf   : 0.6759
ECE        : 0.1508
F1/class: {'W': 0.9030743596455058, 'N1': 0.4503597389424367, 'N2': 0.833533267853821, 'N3': 0.7017157627449399, 'REM': 0.8590734722414846}
Confusion Matrix labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[1523253   71311  107940    9361   55115]
 [  31242  133422   98515     204   29425]
 [  39945   79484 1633218   99097   55242]
 [   1980     154  142330  300254    1290]
 [  10083   15334   29794     847  600842]]


In [27]:
# %% [CELL 1] Minimal eval utilities (self-contained)

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, average_precision_score

def _ece_from_probs(y_true, probs, n_bins=15):
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == y_true).astype(np.float32)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        lo, hi = bins[i], bins[i + 1]
        m = (conf >= lo) & (conf < hi)
        if m.sum() == 0:
            continue
        ece += (m.mean()) * abs(acc[m].mean() - conf[m].mean())
    return float(ece)

def _auroc_auprc_multiclass(y_true, probs, num_classes):
    Y = label_binarize(y_true, classes=list(range(num_classes)))
    aurocs, auprcs = [], []
    for c in range(num_classes):
        if Y[:, c].sum() == 0:
            continue
        try:
            aurocs.append(roc_auc_score(Y[:, c], probs[:, c]))
            auprcs.append(average_precision_score(Y[:, c], probs[:, c]))
        except Exception:
            pass
    if len(aurocs) == 0:
        return float("nan"), float("nan")
    return float(np.mean(aurocs)), float(np.mean(auprcs))

def _print_metrics(tag, m):
    print(f"\n===== {tag} =====")
    for k in ["loss","acc","macro_f1","kappa","AUROC","AUPRC","meanConf","ECE"]:
        if k in m:
            print(f"{k:9s}: {m[k]:.4f}")
    if "bad_batches" in m:
        print("bad_batches:", m["bad_batches"])
    if "f1_per_class" in m:
        print("F1/class :", m["f1_per_class"])
    if "cm" in m:
        labs = [LABELS[i] for i in range(NUM_CLASSES)] if "LABELS" in globals() else list(range(NUM_CLASSES))
        print("Confusion Matrix labels:", labs)
        print(m["cm"])


In [28]:
# %% [CELL 2] eval_sequence_simple (no dependency on masked_loss_v5)

@torch.no_grad()
def eval_sequence_simple(model, loader, desc="Eval", num_classes=5, use_ce_loss=True):
    model.eval()

    all_true, all_pred, all_probs = [], [], []
    total_loss = 0.0
    total_n = 0
    bad_batches = 0

    for batch in tqdm(loader, desc=desc, leave=False):
        # Support both dataset outputs:
        # (xb,yb,mb) OR (xb,yb,mb,sb,db)
        if len(batch) == 3:
            xb, yb, mb = batch
        else:
            xb, yb, mb = batch[0], batch[1], batch[2]

        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        mb = mb.to(device, non_blocking=True)

        # Model forward: your V5.1 returns (main, aux, dur)
        out = model(xb, mb)
        if isinstance(out, (tuple, list)) and len(out) >= 1:
            main_logits = out[0]
        else:
            main_logits = out

        if not torch.isfinite(main_logits).all():
            bad_batches += 1
            continue

        probs = torch.softmax(main_logits.float(), dim=-1)
        pred = torch.argmax(probs, dim=-1)

        # valid positions only
        yv = yb[mb].detach().cpu().numpy()
        pv = pred[mb].detach().cpu().numpy()
        pr = probs[mb].detach().cpu().numpy()

        if yv.size == 0:
            continue

        all_true.append(yv)
        all_pred.append(pv)
        all_probs.append(pr)

        if use_ce_loss:
            # standard CE for logging only
            logits_valid = main_logits[mb]              # (N,C)
            y_valid = yb[mb]                            # (N,)
            loss = F.cross_entropy(logits_valid, y_valid)
            n = int(mb.sum().item())
            total_loss += float(loss.item()) * n
            total_n += n

    if len(all_true) == 0:
        return {
            "loss": float("nan"),
            "acc": float("nan"),
            "macro_f1": float("nan"),
            "kappa": float("nan"),
            "AUROC": float("nan"),
            "AUPRC": float("nan"),
            "meanConf": float("nan"),
            "ECE": float("nan"),
            "f1_per_class": {LABELS[i]: float("nan") for i in range(num_classes)},
            "cm": np.zeros((num_classes, num_classes), dtype=np.int64),
            "bad_batches": int(bad_batches),
        }

    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_pred)
    probs_all = np.concatenate(all_probs)

    acc = accuracy_score(y_true, y_pred)
    mf1 = f1_score(y_true, y_pred, average="macro")
    kappa = cohen_kappa_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))

    auroc, auprc = _auroc_auprc_multiclass(y_true, probs_all, num_classes=num_classes)
    mean_conf = float(probs_all.max(axis=1).mean())
    ece = _ece_from_probs(y_true, probs_all, n_bins=15)

    f1_per = {LABELS[i]: float(f1_score((y_true==i).astype(int), (y_pred==i).astype(int)))
              for i in range(num_classes)}

    return {
        "loss": (total_loss / max(total_n, 1)) if use_ce_loss else float("nan"),
        "acc": float(acc),
        "macro_f1": float(mf1),
        "kappa": float(kappa),
        "AUROC": auroc,
        "AUPRC": auprc,
        "meanConf": mean_conf,
        "ECE": ece,
        "f1_per_class": f1_per,
        "cm": cm,
        "bad_batches": int(bad_batches),
        "n_epochs": int(len(y_true)),
    }


In [29]:
# %% [CELL 3] Build inhouse loader + evaluate + save

from pathlib import Path
import pandas as pd
import json
import re
from torch.utils.data import DataLoader

INHOUSE_DIR = Path("/data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/inhouse_npz_shhs_style/")
assert INHOUSE_DIR.exists(), f"Not found: {INHOUSE_DIR}"

npz_files = sorted(INHOUSE_DIR.glob("*.npz"))
assert len(npz_files) > 0, "No .npz files found in the inhouse folder."
print("Found NPZ:", len(npz_files))
print("Example:", npz_files[0])

def _infer_subject_id(p: Path):
    s = p.stem
    s = re.sub(r"_inhouse$", "", s)
    return s

# Save a small manifest (safe)
inhouse_manifest_path = ROOT / "manifests" / "inhouse_12subj_manifest.csv"
inhouse_manifest_path.parent.mkdir(parents=True, exist_ok=True)

df_inhouse = pd.DataFrame({
    "subject_id": [_infer_subject_id(p) for p in npz_files],
    "npz_path": [str(p) for p in npz_files],
    "split": ["inhouse"] * len(npz_files),
})
df_inhouse.to_csv(inhouse_manifest_path, index=False)
print("Saved inhouse manifest:", inhouse_manifest_path)

# Build dataset using your existing class
df_inhouse2 = pd.read_csv(inhouse_manifest_path)
inhouse_ds = SleepSequenceDataset(df_inhouse2, mode="eval", max_hours=None, augmentor=None, do_normalize=True)

# IMPORTANT: use the SAME collate as your main loaders
# In your code it is called collate_pad
assert "collate_pad" in globals(), "collate_pad not found. Run the cell where collate_pad is defined."
inhouse_seq_loader = DataLoader(
    inhouse_ds,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_pad,
)

print("Built inhouse_seq_loader batches:", len(inhouse_seq_loader))

# Evaluate
inhouse_m = eval_sequence_simple(model, inhouse_seq_loader, desc="INHOUSE (12 subj)", num_classes=NUM_CLASSES, use_ce_loss=True)
_print_metrics("INHOUSE (12 subj)", inhouse_m)

# ---------- save (JSON-safe) ----------
def to_jsonable(obj):
    import numpy as np
    if obj is None: return None
    if isinstance(obj, (str, int, float, bool)): return obj
    if isinstance(obj, (np.integer,)): return int(obj)
    if isinstance(obj, (np.floating,)): return float(obj)
    if isinstance(obj, (np.ndarray,)): return obj.tolist()
    if isinstance(obj, dict): return {str(k): to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)): return [to_jsonable(v) for v in obj]
    return str(obj)

OUT_DIR = ROOT / "eval_bundles"
OUT_DIR.mkdir(parents=True, exist_ok=True)

out_path = OUT_DIR / "EVAL_INHOUSE_12.json"
with open(out_path, "w") as f:
    json.dump(to_jsonable({"inhouse_metrics": inhouse_m, "inhouse_manifest": str(inhouse_manifest_path)}), f, indent=2)

print("Saved:", out_path)


Found NPZ: 12
Example: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/inhouse_npz_shhs_style/sub1.npz
Saved inhouse manifest: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/manifests/inhouse_12subj_manifest.csv
SleepSequenceDataset[eval] files=12 max_hours=None normalize=True
Built inhouse_seq_loader batches: 12


                                                                                


===== INHOUSE (12 subj) =====
loss     : 1.0821
acc      : 0.6096
macro_f1 : 0.5443
kappa    : 0.4674
AUROC    : 0.8609
AUPRC    : 0.6095
meanConf : 0.7280
ECE      : 0.1186
bad_batches: 0
F1/class : {'W': 0.7451757864128998, 'N1': 0.33427533306419055, 'N2': 0.6250714204090961, 'N3': 0.2944120100083403, 'REM': 0.7226753670473084}
Confusion Matrix labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[2819  429  681   83  258]
 [ 117  414  200   20   69]
 [ 202  633 2735  408  165]
 [  57  132  973  353   18]
 [ 101   49   19    1  886]]
Saved: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/eval_bundles/EVAL_INHOUSE_12.json




In [30]:
#### real world deployemnt analysis 

In [31]:
# %% [CELL] Model size / parameter counts

import os
from pathlib import Path
import torch

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

total_params, trainable_params = count_params(model)

# Rough FP32 size on disk if stored as raw parameters (not checkpoint overhead)
fp32_param_bytes = total_params * 4
fp16_param_bytes = total_params * 2

print("Model params:")
print(f"  total params     : {total_params:,} ({total_params/1e6:.3f} M)")
print(f"  trainable params : {trainable_params:,} ({trainable_params/1e6:.3f} M)")
print(f"  FP32 param size  : {fp32_param_bytes/1024**2:.2f} MB")
print(f"  FP16 param size  : {fp16_param_bytes/1024**2:.2f} MB")

# If you have a best checkpoint path (optional)
if "BEST_CKPT" in globals():
    ckpt_path = Path(BEST_CKPT)
    if ckpt_path.exists():
        print(f"Checkpoint file size: {ckpt_path.stat().st_size/1024**2:.2f} MB  ({ckpt_path.name})")


Model params:
  total params     : 22,905,512 (22.906 M)
  trainable params : 22,905,512 (22.906 M)
  FP32 param size  : 87.38 MB
  FP16 param size  : 43.69 MB
Checkpoint file size: 349.79 MB  (BEST_VAL_macroF1.pt)


In [32]:
# %% [CELL] Inference latency + GPU memory benchmark (sequence-level)

import time
import numpy as np
import torch
from tqdm import tqdm

def _forward_main_logits(model, xb, mb):
    out = model(xb, mb)
    if isinstance(out, (tuple, list)):
        return out[0]  # main logits
    return out

@torch.no_grad()
def benchmark_inference_sequence_loader(
    model,
    loader,
    device,
    n_warmup=5,
    n_runs=20,
    max_batches=None,
    use_amp=True,
    desc="BENCH"
):
    """
    Measures end-to-end inference over 'max_batches' subjects in loader.
    Returns ms/subject and ms/epoch (epoch=one pass over chosen batches).
    Also reports peak GPU memory footprint during inference.
    """
    assert device.type in ["cuda", "cpu"]

    model.eval()

    # pick batches to test (keep it deterministic)
    batches = []
    for i, batch in enumerate(loader):
        batches.append(batch)
        if max_batches is not None and len(batches) >= max_batches:
            break
    assert len(batches) > 0, "Loader produced 0 batches."

    # ---- warmup (GPU only matters)
    if device.type == "cuda":
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()

    for _ in range(n_warmup):
        for batch in batches:
            xb, yb, mb = batch[0], batch[1], batch[2]
            xb = xb.to(device, non_blocking=True)
            mb = mb.to(device, non_blocking=True)

            if device.type == "cuda":
                with torch.cuda.amp.autocast(enabled=use_amp):
                    _ = _forward_main_logits(model, xb, mb)
            else:
                _ = _forward_main_logits(model, xb, mb)

    if device.type == "cuda":
        torch.cuda.synchronize()

    # ---- timed runs
    epoch_times = []
    subj_times = []

    for r in range(n_runs):
        if device.type == "cuda":
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()

        t0 = time.perf_counter()
        n_subj = 0

        for batch in batches:
            xb, yb, mb = batch[0], batch[1], batch[2]
            xb = xb.to(device, non_blocking=True)
            mb = mb.to(device, non_blocking=True)

            ts = time.perf_counter()
            if device.type == "cuda":
                with torch.cuda.amp.autocast(enabled=use_amp):
                    logits = _forward_main_logits(model, xb, mb)
            else:
                logits = _forward_main_logits(model, xb, mb)

            if device.type == "cuda":
                torch.cuda.synchronize()
            te = time.perf_counter()

            subj_times.append((te - ts) * 1000.0)  # ms per subject
            n_subj += 1

        if device.type == "cuda":
            torch.cuda.synchronize()
        t1 = time.perf_counter()

        epoch_times.append((t1 - t0) * 1000.0)  # ms per pass over chosen subjects

    # memory stats (GPU)
    if device.type == "cuda":
        peak_alloc = torch.cuda.max_memory_allocated() / 1024**2
        peak_reserved = torch.cuda.max_memory_reserved() / 1024**2
    else:
        peak_alloc = float("nan")
        peak_reserved = float("nan")

    epoch_times = np.array(epoch_times, dtype=np.float64)
    subj_times = np.array(subj_times, dtype=np.float64)

    out = {
        "device": str(device),
        "use_amp": bool(use_amp),
        "n_subjects_tested": int(len(batches)),
        "n_runs": int(n_runs),
        "n_warmup": int(n_warmup),

        "ms_per_epoch_mean": float(epoch_times.mean()),
        "ms_per_epoch_std": float(epoch_times.std(ddof=1) if len(epoch_times) > 1 else 0.0),

        "ms_per_subject_mean": float(subj_times.mean()),
        "ms_per_subject_std": float(subj_times.std(ddof=1) if len(subj_times) > 1 else 0.0),

        "subjects_per_sec": float(1000.0 / max(subj_times.mean(), 1e-9)),
        "peak_mem_allocated_MB": float(peak_alloc),
        "peak_mem_reserved_MB": float(peak_reserved),
    }
    return out

# Choose a loader to benchmark:
# - inhouse_seq_loader (12 subjects) is perfect for "real-world" local feasibility
loader_for_bench = inhouse_seq_loader if "inhouse_seq_loader" in globals() else val_seq_loader

bench = benchmark_inference_sequence_loader(
    model,
    loader_for_bench,
    device=device,
    n_warmup=3,
    n_runs=10,
    max_batches=12,     # use 12 if inhouse, else you can set 50 etc.
    use_amp=(device.type=="cuda"),
    desc="BENCH"
)

print("\n=== Inference Efficiency ===")
for k in ["device","use_amp","n_subjects_tested","n_runs","ms_per_subject_mean","ms_per_subject_std",
          "ms_per_epoch_mean","ms_per_epoch_std","subjects_per_sec",
          "peak_mem_allocated_MB","peak_mem_reserved_MB"]:
    print(f"{k:22s}: {bench[k]}")



=== Inference Efficiency ===
device                : cuda
use_amp               : True
n_subjects_tested     : 12
n_runs                : 10
ms_per_subject_mean   : 95.66590481748183
ms_per_subject_std    : 24.166755547636562
ms_per_epoch_mean     : 1148.9899184554815
ms_per_epoch_std      : 2.3279415758547057
subjects_per_sec      : 10.453044916136744
peak_mem_allocated_MB : 1403.85107421875
peak_mem_reserved_MB  : 11528.0


In [33]:
# %% [CELL] Paper-ready reporting text

def fmt(mean, std, unit="ms"):
    return f"{mean:.2f}±{std:.2f} {unit}"

total_params, trainable_params = sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad)

model_size_str = f"{total_params/1e6:.2f}M params"
lat_str = fmt(bench["ms_per_subject_mean"], bench["ms_per_subject_std"], "ms/recording")
epoch_str = fmt(bench["ms_per_epoch_mean"], bench["ms_per_epoch_std"], "ms per {N} recordings").replace("{N}", str(bench["n_subjects_tested"]))

mem_alloc = bench["peak_mem_allocated_MB"]
mem_res = bench["peak_mem_reserved_MB"]
mem_str = f"peak GPU mem {mem_alloc:.1f} MB allocated ({mem_res:.1f} MB reserved)"

print("Suggested paper sentence:")
print(
    f"We report computational efficiency including model size ({model_size_str}), "
    f"inference latency ({lat_str}), and memory footprint ({mem_str}) "
    f"to assess feasibility for deployment on resource-constrained platforms."
)

print("\nTable row (example):")
print(f"Model | Params | Latency | Peak Mem")
print(f"Ours  | {model_size_str} | {lat_str} | {mem_alloc:.1f} MB")


Suggested paper sentence:
We report computational efficiency including model size (22.91M params), inference latency (95.67±24.17 ms/recording), and memory footprint (peak GPU mem 1403.9 MB allocated (11528.0 MB reserved)) to assess feasibility for deployment on resource-constrained platforms.

Table row (example):
Model | Params | Latency | Peak Mem
Ours  | 22.91M params | 95.67±24.17 ms/recording | 1403.9 MB


In [34]:
## real world infrance analysis  cpu

In [35]:
# %% [CELL] CPU benchmark (same loader) + optional thread control

import os
import torch

# ---- optional: control CPU threads (choose one)
# If you want max speed on CPU, try setting these BEFORE running benchmark.
# Start with 8 or 16 depending on your server.
os.environ["OMP_NUM_THREADS"] = "16"
os.environ["MKL_NUM_THREADS"] = "16"
torch.set_num_threads(16)
torch.set_num_interop_threads(2)

print("CPU threads:", torch.get_num_threads(), "| interop:", torch.get_num_interop_threads())

# Use SAME loader you used for GPU
loader_for_bench = inhouse_seq_loader  # 12 subjects

# Create CPU copy of model (so you don't move your GPU model)
model_cpu = type(model)(**model.__dict__.get('_init_kwargs', {})) if False else None
# ^ ignore that line; we do a safe deepcopy approach below

import copy
model_cpu = copy.deepcopy(model).to("cpu")
model_cpu.eval()

device_cpu = torch.device("cpu")

bench_cpu = benchmark_inference_sequence_loader(
    model_cpu,
    loader_for_bench,
    device=device_cpu,
    n_warmup=1,
    n_runs=3,
    max_batches=3,        # start small; increase later to 12
    use_amp=False,
    desc="CPU_BENCH"
)

print("\n=== CPU Inference Efficiency ===")
for k in ["device","use_amp","n_subjects_tested","n_runs","ms_per_subject_mean","ms_per_subject_std",
          "ms_per_epoch_mean","ms_per_epoch_std","subjects_per_sec",
          "peak_mem_allocated_MB","peak_mem_reserved_MB"]:
    print(f"{k:22s}: {bench_cpu[k]}")


CPU threads: 16 | interop: 2

=== CPU Inference Efficiency ===
device                : cpu
use_amp               : False
n_subjects_tested     : 3
n_runs                : 3
ms_per_subject_mean   : 3184.3750653788447
ms_per_subject_std    : 354.2850801274565
ms_per_epoch_mean     : 9553.179225884378
ms_per_epoch_std      : 61.86108768615994
subjects_per_sec      : 0.31403335959767986
peak_mem_allocated_MB : nan
peak_mem_reserved_MB  : nan


In [36]:
# %% [CELL] Side-by-side comparison (GPU vs CPU)

import pandas as pd

def row_from_bench(name, b):
    return {
        "Device": name,
        "Subjects": b["n_subjects_tested"],
        "Runs": b["n_runs"],
        "Latency (ms/subject)": f"{b['ms_per_subject_mean']:.2f} ± {b['ms_per_subject_std']:.2f}",
        "Epoch time (ms)": f"{b['ms_per_epoch_mean']:.2f} ± {b['ms_per_epoch_std']:.2f}",
        "Throughput (subj/s)": f"{b['subjects_per_sec']:.2f}",
        "Peak mem alloc (MB)": f"{b['peak_mem_allocated_MB']:.1f}" if b["peak_mem_allocated_MB"]==b["peak_mem_allocated_MB"] else "N/A",
        "Peak mem reserved (MB)": f"{b['peak_mem_reserved_MB']:.1f}" if b["peak_mem_reserved_MB"]==b["peak_mem_reserved_MB"] else "N/A",
        "AMP": b["use_amp"],
    }

rows = []
rows.append(row_from_bench("GPU", bench))
rows.append(row_from_bench("CPU", bench_cpu))

df = pd.DataFrame(rows)
df


Unnamed: 0,Device,Subjects,Runs,Latency (ms/subject),Epoch time (ms),Throughput (subj/s),Peak mem alloc (MB),Peak mem reserved (MB),AMP
0,GPU,12,10,95.67 ± 24.17,1148.99 ± 2.33,10.45,1403.9,11528.0,True
1,CPU,3,3,3184.38 ± 354.29,9553.18 ± 61.86,0.31,,,False


In [37]:
# %% [CELL] Paper-ready text (GPU + CPU)

total_params = sum(p.numel() for p in model.parameters())

gpu_lat = f"{bench['ms_per_subject_mean']:.1f}±{bench['ms_per_subject_std']:.1f} ms/recording"
cpu_lat = f"{bench_cpu['ms_per_subject_mean']:.1f}±{bench_cpu['ms_per_subject_std']:.1f} ms/recording"

gpu_mem = f"{bench['peak_mem_allocated_MB']:.0f} MB allocated ({bench['peak_mem_reserved_MB']:.0f} MB reserved)"
# CPU mem not measured by torch; report as "N/A" unless you want psutil-based RSS
print(
    f"We report computational efficiency metrics including model size ({total_params/1e6:.2f}M parameters), "
    f"inference latency on GPU ({gpu_lat}) and CPU ({cpu_lat}), and GPU memory footprint "
    f"({gpu_mem}), providing insights into feasibility on resource-constrained platforms."
)


We report computational efficiency metrics including model size (22.91M parameters), inference latency on GPU (95.7±24.2 ms/recording) and CPU (3184.4±354.3 ms/recording), and GPU memory footprint (1404 MB allocated (11528 MB reserved)), providing insights into feasibility on resource-constrained platforms.
