In [1]:
# ================================================================
# V5 (FINAL): N1-PUSH upgrade on top of your V4 (minimal changes)
# Adds ONLY:
#   (1) Soft labels near transitions (boundary soft targets)
#   (2) Confusion-cost matrix regularizer (cost-sensitive penalty)
#   (3) Duration-aware auxiliary head (predict remaining run-length bucket)
# Keeps EVERYTHING else (encoder/transformer/EMA/LA-CE/aux N1/trans loss/smoothing)
# SHHS2 stays external eval. MESA is eval-only if present in manifest.
# ================================================================

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from collections import Counter
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, average_precision_score

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
print("Visible CUDA devices:", torch.cuda.device_count())
print("Using device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


CUDA available: True
Visible CUDA devices: 1
Using device: cuda
GPU: NVIDIA RTX A6000


In [2]:
# ----------------------------
# Paths / Manifest
# ----------------------------
ROOT = Path("/data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/")
MANIFEST_PATH = ROOT / "manifest_sleepstaging_planA.csv"

manifest = pd.read_csv(MANIFEST_PATH)
print("Rows:", len(manifest))
print(manifest.groupby(["cohort","split"]).size())

df_train = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="train")].copy()
df_val   = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="val")].copy()
df_test  = manifest[(manifest.cohort=="SHHS1") & (manifest.split=="test")].copy()
df_ext   = manifest[(manifest.cohort=="SHHS2") & (manifest.split=="external_test")].copy()

# Optional: MESA eval-only if exists in manifest
df_mesa = manifest[(manifest.cohort=="MESA")].copy() if ("MESA" in manifest["cohort"].unique()) else None
if df_mesa is not None and len(df_mesa) > 0:
    # prefer explicit split name if you used it; fallback to all MESA rows
    if "external_test" in df_mesa["split"].unique():
        df_mesa = df_mesa[df_mesa.split=="external_test"].copy()
    print("MESA detected | rows:", len(df_mesa))
else:
    df_mesa = None
    print("MESA not detected in manifest (ok).")

print("DF sizes:", len(df_train), len(df_val), len(df_test), len(df_ext), ("| MESA="+str(len(df_mesa)) if df_mesa is not None else ""))


Rows: 9868
cohort  split        
MESA    external_test    1856
SHHS1   test              548
        train            4380
        val               548
SHHS2   external_test    2536
dtype: int64
MESA detected | rows: 1856
DF sizes: 4380 548 548 2536 | MESA=1856


In [3]:
# ----------------------------
# Augment + Normalize (same as V4)
# ----------------------------
class EEGAugment:
    """
    Safe EEG augmentations on epochs.
    x: np.ndarray float32, shape (E, T)
    """
    def __init__(self,
                 p_amp=0.5,
                 p_noise=0.5,
                 p_shift=0.5,
                 p_bandstop=0.3,
                 p_freqdrop=0.3,
                 amp_range=(0.8, 1.2),
                 noise_std=0.01,       # relative to per-epoch std
                 shift_max=125,        # 1 sec at 125Hz
                 bandstop_ranges=((49,51), (59,61)),
                 freqdrop_max_bins=12):
        self.p_amp = p_amp
        self.p_noise = p_noise
        self.p_shift = p_shift
        self.p_bandstop = p_bandstop
        self.p_freqdrop = p_freqdrop
        self.amp_range = amp_range
        self.noise_std = noise_std
        self.shift_max = shift_max
        self.bandstop_ranges = bandstop_ranges
        self.freqdrop_max_bins = freqdrop_max_bins

    def _amp_scale(self, x):
        s = np.random.uniform(self.amp_range[0], self.amp_range[1])
        return x * s

    def _gaussian_noise(self, x):
        std = np.std(x, axis=1, keepdims=True) + 1e-6
        noise = np.random.randn(*x.shape).astype(np.float32) * (self.noise_std * std)
        return x + noise

    def _time_shift(self, x):
        shift = np.random.randint(-self.shift_max, self.shift_max+1)
        return np.roll(x, shift=shift, axis=1)

    def _bandstop_fft(self, x, fs=125.0):
        E, T = x.shape
        X = np.fft.rfft(x, axis=1)
        freqs = np.fft.rfftfreq(T, d=1.0/fs)
        for (f1, f2) in self.bandstop_ranges:
            mask = (freqs >= f1) & (freqs <= f2)
            X[:, mask] = 0.0
        y = np.fft.irfft(X, n=T, axis=1).astype(np.float32)
        return y

    def _freq_dropout(self, x):
        E, T = x.shape
        X = np.fft.rfft(x, axis=1)
        Fbins = X.shape[1]
        drop = np.random.randint(1, self.freqdrop_max_bins+1)
        start = np.random.randint(0, max(1, Fbins - drop))
        X[:, start:start+drop] = 0.0
        y = np.fft.irfft(X, n=T, axis=1).astype(np.float32)
        return y

    def __call__(self, x, fs=125.0):
        if np.random.rand() < self.p_amp:
            x = self._amp_scale(x)
        if np.random.rand() < self.p_noise:
            x = self._gaussian_noise(x)
        if np.random.rand() < self.p_shift:
            x = self._time_shift(x)
        if np.random.rand() < self.p_bandstop:
            x = self._bandstop_fft(x, fs=fs)
        if np.random.rand() < self.p_freqdrop:
            x = self._freq_dropout(x)
        return x

def normalize_epochs_zscore(x, eps=1e-6, clip=10.0):
    """
    per-epoch z-score: (x - mean) / std
    x: (E,T) float32 -> (E,T) float32
    """
    mu = np.mean(x, axis=1, keepdims=True)
    sd = np.std(x, axis=1, keepdims=True) + eps
    x = (x - mu) / sd
    if clip is not None:
        x = np.clip(x, -clip, clip)
    return x.astype(np.float32)

augment = EEGAugment()
print("Augmenter + Normalizer ready (per-epoch z-score).")


Augmenter + Normalizer ready (per-epoch z-score).


In [4]:
# ----------------------------
# Labels / constants
# ----------------------------
LABELS = {0:"W", 1:"N1", 2:"N2", 3:"N3", 4:"REM"}
NUM_CLASSES = 5
FS = 125
T = 3750   # 30s * 125Hz


In [5]:
# ================================================================
# V5 DATASET: same as V4, but returns extra supervision:
#   - soft_targets (L,C): boundary soft labels
#   - dur_bucket (L,): remaining run-length bucket labels
# ================================================================

def _compute_runlength_remaining(y):
    """
    y: (E,) int64
    returns remaining run length at each t: rem[t] = how many epochs left (including t) in the current segment
    """
    E = len(y)
    rem = np.zeros((E,), dtype=np.int64)
    t = 0
    while t < E:
        j = t
        while j < E and y[j] == y[t]:
            j += 1
        seg_len = j - t
        for k in range(t, j):
            rem[k] = (j - k)  # including current epoch
        t = j
    return rem

def _bucketize_remaining(rem, edges=(2, 5, 10, 20, 40, 80, 160)):
    """
    rem: (E,) remaining run length
    edges define bucket upper bounds.
    returns bucket id in [0..len(edges)] (B buckets)
    """
    edges = np.array(edges, dtype=np.int64)
    b = np.zeros_like(rem, dtype=np.int64)
    for i, r in enumerate(rem):
        b[i] = int(np.searchsorted(edges, r, side="right"))
    return b

def _make_soft_targets_boundary(y, num_classes=5, radius=2, alpha_max=0.35):
    """
    Soft labels near transitions:
      - if an epoch is within `radius` of a transition, mix current label with neighbor (prev/next).
      - alpha decays with distance: alpha = alpha_max * (1 - dist/radius)
    y: (E,)
    returns soft: (E,C) float32, sums to 1
    """
    E = len(y)
    soft = np.zeros((E, num_classes), dtype=np.float32)
    soft[np.arange(E), y] = 1.0

    if E < 2 or radius <= 0:
        return soft

    trans = np.where(y[1:] != y[:-1])[0] + 1  # boundary index (start of new stage)
    if trans.size == 0:
        return soft

    # For each boundary t (new stage begins at t), affect indices near t and t-1
    for t in trans:
        for dt in range(-radius, radius+1):
            i = t + dt
            if i < 0 or i >= E:
                continue
            dist = abs(dt)
            if dist > radius:
                continue
            alpha = alpha_max * (1.0 - (dist / max(radius, 1)))
            if alpha <= 0:
                continue

            # choose neighbor label towards the boundary direction
            if i < t:
                # before boundary => mix with next label y[t]
                nb = y[t]
            else:
                # at/after boundary => mix with prev label y[t-1]
                nb = y[t-1]

            cur = y[i]
            if nb == cur:
                continue

            # mix: (1-alpha)*onehot(cur) + alpha*onehot(nb)
            soft[i, :] = 0.0
            soft[i, cur] = 1.0 - alpha
            soft[i, nb]  = alpha

    # normalize (safety)
    soft = soft / (soft.sum(axis=1, keepdims=True) + 1e-8)
    return soft.astype(np.float32)

class SleepSequenceDataset(Dataset):
    """
    One item = one subject recording.

    Returns:
      x: (L, 1, T) float32
      y: (L,) int64
      mask: (L,) bool (True valid)
      soft: (L,C) float32 (soft targets near transitions)
      dur_bucket: (L,) int64 (duration-aware label)
    """
    def __init__(self, df, mode="train",
                 max_hours=None,
                 min_hours=2.0,
                 augmentor=None,
                 exclude_unknown=True,
                 do_normalize=True,
                 # ---- V4 additions ----
                 boundary_oversample_p=0.70,
                 boundary_radius=2,
                 # ---- V5 additions ----
                 soft_radius=2,
                 soft_alpha_max=0.35,
                 dur_edges=(2,5,10,20,40,80,160)):
        self.paths = df["npz_path"].tolist()
        self.mode = mode
        self.max_hours = max_hours
        self.min_hours = min_hours
        self.augmentor = augmentor
        self.exclude_unknown = exclude_unknown
        self.do_normalize = do_normalize

        self.boundary_oversample_p = boundary_oversample_p
        self.boundary_radius = int(boundary_radius)

        self.soft_radius = int(soft_radius)
        self.soft_alpha_max = float(soft_alpha_max)
        self.dur_edges = tuple(dur_edges)

        print(f"SleepSequenceDataset[{mode}] files={len(self.paths)} max_hours={self.max_hours} normalize={self.do_normalize}")

    def __len__(self):
        return len(self.paths)

    def _pick_block_start_boundary_aware(self, y, L):
        E = len(y)
        if E <= L:
            return 0

        if np.random.rand() > self.boundary_oversample_p:
            return np.random.randint(0, E - L + 1)

        trans = np.where(y[1:] != y[:-1])[0] + 1
        if trans.size == 0:
            return np.random.randint(0, E - L + 1)

        t = int(np.random.choice(trans))
        start = t - (L // 2)
        start = max(0, min(start, E - L))
        return int(start)

    def __getitem__(self, idx):
        p = self.paths[idx]
        d = np.load(p, allow_pickle=True)

        x = d["x"].astype(np.float32)    # (E,T)
        y = d["y"].astype(np.int64)      # (E,)

        if self.exclude_unknown:
            keep = (y >= 0)
            x = x[keep]
            y = y[keep]

        if self.do_normalize:
            x = normalize_epochs_zscore(x, eps=1e-6, clip=10.0)

        E = len(y)

        # V5: compute soft targets + duration labels BEFORE block crop,
        # then crop consistently (so supervision matches x,y)
        soft = _make_soft_targets_boundary(y, num_classes=NUM_CLASSES,
                                           radius=self.soft_radius,
                                           alpha_max=self.soft_alpha_max)
        rem = _compute_runlength_remaining(y)
        dur_bucket = _bucketize_remaining(rem, edges=self.dur_edges)

        # sample long block for training
        if self.max_hours is not None:
            max_L = int((self.max_hours * 3600) / 30)
            min_L = int((self.min_hours * 3600) / 30)

            L = min(max_L, E)
            if E > L:
                start = self._pick_block_start_boundary_aware(y, L)
                x = x[start:start+L]
                y = y[start:start+L]
                soft = soft[start:start+L]
                dur_bucket = dur_bucket[start:start+L]
                E = L

        if self.mode == "train" and self.augmentor is not None:
            x = self.augmentor(x, fs=FS)

        x_t = torch.from_numpy(x).unsqueeze(1)              # (E,1,T)
        y_t = torch.from_numpy(y).long()                    # (E,)
        soft_t = torch.from_numpy(soft).float()             # (E,C)
        dur_t = torch.from_numpy(dur_bucket).long()         # (E,)
        mask = torch.ones((E,), dtype=torch.bool)

        return x_t, y_t, mask, soft_t, dur_t

def collate_pad(batch):
    lengths = [b[0].shape[0] for b in batch]
    Lmax = max(lengths)

    xs, ys, ms, ss, ds = [], [], [], [], []
    for x, y, m, s, d in batch:
        L = x.shape[0]
        padL = Lmax - L
        if padL > 0:
            x = torch.cat([x, torch.zeros((padL, 1, T), dtype=x.dtype)], dim=0)
            y = torch.cat([y, torch.zeros((padL,), dtype=y.dtype)], dim=0)
            m = torch.cat([m, torch.zeros((padL,), dtype=torch.bool)], dim=0)
            s = torch.cat([s, torch.zeros((padL, NUM_CLASSES), dtype=s.dtype)], dim=0)
            d = torch.cat([d, torch.zeros((padL,), dtype=d.dtype)], dim=0)
        xs.append(x); ys.append(y); ms.append(m); ss.append(s); ds.append(d)

    x = torch.stack(xs, dim=0)  # (B,L,1,T)
    y = torch.stack(ys, dim=0)  # (B,L)
    m = torch.stack(ms, dim=0)  # (B,L)
    s = torch.stack(ss, dim=0)  # (B,L,C)
    d = torch.stack(ds, dim=0)  # (B,L)
    return x, y, m, s, d


In [6]:
# ----------------------------
# DataLoaders (same batching)
# ----------------------------
train_seq_ds = SleepSequenceDataset(
    df_train, mode="train",
    max_hours=4.0, min_hours=2.0,
    augmentor=augment,
    exclude_unknown=True,
    do_normalize=True,
    boundary_oversample_p=0.70,
    boundary_radius=2,
    # V5
    soft_radius=2,
    soft_alpha_max=0.35,
    dur_edges=(2,5,10,20,40,80,160),
)

val_seq_ds  = SleepSequenceDataset(df_val,  mode="eval", max_hours=None, augmentor=None, do_normalize=True)
test_seq_ds = SleepSequenceDataset(df_test, mode="eval", max_hours=None, augmentor=None, do_normalize=True)
ext_seq_ds  = SleepSequenceDataset(df_ext,  mode="eval", max_hours=None, augmentor=None, do_normalize=True)
mesa_seq_ds = SleepSequenceDataset(df_mesa, mode="eval", max_hours=None, augmentor=None, do_normalize=True) if df_mesa is not None else None

BATCH_SUBJ = 2
NUM_WORKERS = 2
PIN = True

train_seq_loader = DataLoader(
    train_seq_ds, batch_size=BATCH_SUBJ, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=PIN,
    collate_fn=collate_pad, persistent_workers=False
)

val_seq_loader  = DataLoader(val_seq_ds,  batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad)
test_seq_loader = DataLoader(test_seq_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad)
ext_seq_loader  = DataLoader(ext_seq_ds,  batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad)
mesa_seq_loader = DataLoader(mesa_seq_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=PIN, collate_fn=collate_pad) if mesa_seq_ds is not None else None

xb, yb, mb, sb, db = next(iter(train_seq_loader))
print("Batch shapes:", xb.shape, yb.shape, mb.shape, sb.shape, db.shape)
print("Valid tokens:", int(mb.sum().item()), "| y min/max(valid):",
      int(yb[mb].min().item()), int(yb[mb].max().item()))


SleepSequenceDataset[train] files=4380 max_hours=4.0 normalize=True
SleepSequenceDataset[eval] files=548 max_hours=None normalize=True
SleepSequenceDataset[eval] files=548 max_hours=None normalize=True
SleepSequenceDataset[eval] files=2536 max_hours=None normalize=True
SleepSequenceDataset[eval] files=1856 max_hours=None normalize=True
Batch shapes: torch.Size([2, 480, 1, 3750]) torch.Size([2, 480]) torch.Size([2, 480]) torch.Size([2, 480, 5]) torch.Size([2, 480])
Valid tokens: 960 | y min/max(valid): 0 4


In [7]:
# ----------------------------
# Class weights (same as V4)
# ----------------------------
def class_counts_train(df):
    c = Counter()
    for p in tqdm(df["npz_path"].tolist(), desc="Counting train labels", leave=False):
        d = np.load(p, allow_pickle=True)
        y = d["y"].astype(np.int64)
        y = y[y >= 0]
        c.update(y.tolist())
    counts = np.array([c.get(i, 0) for i in range(NUM_CLASSES)], dtype=np.float64)
    return counts

counts = class_counts_train(df_train)
total = counts.sum()
weights = total / (NUM_CLASSES * np.maximum(counts, 1.0))
weights = weights / weights.mean()

print("Counts:", {LABELS[i]: int(counts[i]) for i in range(NUM_CLASSES)})
print("Weights:", {LABELS[i]: float(weights[i]) for i in range(NUM_CLASSES)})

class_weights = torch.tensor(weights, dtype=torch.float32).to(device)


                                                                                

Counts: {'W': 951298, 'N1': 159157, 'N2': 1801005, 'N3': 557286, 'REM': 620992}
Weights: {'W': 0.4653661355903559, 'N1': 2.781541962055294, 'N2': 0.24580824265053922, 'N3': 0.7943890104090797, 'REM': 0.7128946492947323}




In [8]:
# %% [code]
# ================================================================
# Model (V5.1): SAME V5 encoder + heads
# + Local-Global Attention (windowed most blocks, global every k blocks)
# Goal: better MESA generalization with small compute increase (vs full V6)
# ================================================================

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------
# DropPath
# -------------------------
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.1):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x):
        if (not self.training) or self.drop_prob == 0.0:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rand = keep + torch.rand(shape, device=x.device)
        mask = torch.floor(rand)
        return x / keep * mask

# -------------------------
# Residual Conv Block
# -------------------------
class ResConv1D(nn.Module):
    def __init__(self, c_in, c_out, k, s=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(c_in, c_out, k, stride=s, padding=k//2),
            nn.BatchNorm1d(c_out),
            nn.GELU(),
            nn.Conv1d(c_out, c_out, k, padding=k//2),
            nn.BatchNorm1d(c_out),
        )
        self.skip = nn.Conv1d(c_in, c_out, 1, stride=s) if (c_in != c_out or s != 1) else nn.Identity()
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(self.conv(x) + self.skip(x))

# -------------------------
# Epoch Encoder (FFT SAFE)
# -------------------------
class EpochEncoder(nn.Module):
    """
    Input : (B,L,1,T)
    Output: (B,L,384)
    """
    def __init__(self, d_model=384):
        super().__init__()
        self.branch_short = ResConv1D(1, 128, k=7,  s=4)
        self.branch_mid   = ResConv1D(1, 128, k=15, s=4)
        self.branch_long  = ResConv1D(1, 128, k=31, s=4)

        self.freq_proj = nn.Sequential(
            nn.Linear(1876, 256),
            nn.LayerNorm(256),
            nn.GELU(),
        )

        self.fuse = nn.Sequential(
            nn.Linear(128*3 + 256, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
        )

    def forward(self, x):
        B, L, _, T_ = x.shape
        x = x.view(B*L, 1, T_)

        zs = self.branch_short(x).mean(-1)
        zm = self.branch_mid(x).mean(-1)
        zl = self.branch_long(x).mean(-1)

        # FFT in fp32 (AMP safe)
        with torch.cuda.amp.autocast(enabled=False):
            xf32 = x.squeeze(1).float()
            Xf = torch.fft.rfft(xf32, dim=-1)
            mag = torch.abs(Xf)
            mag = mag[:, :1876]
            mag = torch.log1p(mag)
            mag = mag / (mag.mean(dim=1, keepdim=True) + 1e-6)

        zf = self.freq_proj(mag)
        z = torch.cat([zs, zm, zl, zf.to(zs.dtype)], dim=-1)
        z = self.fuse(z)
        return z.view(B, L, -1)

# -------------------------
# RoPE helpers
# -------------------------
def rotate_half(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)

class RoPE(nn.Module):
    def __init__(self, head_dim, base=10000):
        super().__init__()
        assert head_dim % 2 == 0
        self.head_dim = head_dim
        self.base = base

    def forward(self, x):
        # x: (B, L, H, Dh)
        B, L, H, Dh = x.shape
        half = Dh // 2
        freqs = 1.0 / (self.base ** (torch.arange(half, device=x.device) / half))
        t = torch.arange(L, device=x.device)
        angles = torch.einsum("l,d->ld", t, freqs)
        cos = torch.cos(angles)[None, :, None, :]
        sin = torch.sin(angles)[None, :, None, :]
        cos = cos.repeat_interleave(2, dim=-1)
        sin = sin.repeat_interleave(2, dim=-1)
        return (x * cos) + (rotate_half(x) * sin)

# ================================================================
# V5.1 Local-Global Attention (Windowed most blocks + periodic global)
# ================================================================
def _windows(L, w):
    out = []
    s = 0
    while s < L:
        e = min(L, s + w)
        out.append((s, e))
        s = e
    return out

class MultiHeadSelfAttentionRoPE_LocalGlobal(nn.Module):
    def __init__(self, d_model=384, n_heads=8, dropout=0.1, window_size=64):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.window_size = int(window_size)

        self.qkv = nn.Linear(d_model, 3*d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
        self.rope = RoPE(self.d_head)

    def forward(self, x, key_padding_mask=None, global_attn=False):
        """
        x: (B,L,D)
        key_padding_mask: (B,L) bool, True=valid
        global_attn: True => full attention, else windowed attention
        """
        B, L, D = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, L, self.n_heads, self.d_head)
        k = k.view(B, L, self.n_heads, self.d_head)
        v = v.view(B, L, self.n_heads, self.d_head)

        q = self.rope(q)
        k = self.rope(k)

        q = q.transpose(1, 2)  # (B,H,L,Dh)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # ---- Global attention ----
        if global_attn or self.window_size >= L:
            scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B,H,L,L)
            scores = scores.float()
            if key_padding_mask is not None:
                scores = scores.masked_fill(~key_padding_mask[:, None, None, :], -1e9)
            attn = torch.softmax(scores, dim=-1)
            attn = self.drop(attn).to(v.dtype)
            out = attn @ v
            out = out.transpose(1, 2).contiguous().view(B, L, D)
            return self.proj(out)

        # ---- Windowed attention ----
        w = self.window_size
        out = torch.zeros((B, self.n_heads, L, self.d_head), device=x.device, dtype=v.dtype)
        for (s, e) in _windows(L, w):
            qs = q[:, :, s:e, :]  # (B,H,w,Dh)
            ks = k[:, :, s:e, :]
            vs = v[:, :, s:e, :]

            scores = (qs @ ks.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B,H,w,w)
            scores = scores.float()

            if key_padding_mask is not None:
                m = key_padding_mask[:, s:e]  # (B,w)
                scores = scores.masked_fill(~m[:, None, None, :], -1e9)

            attn = torch.softmax(scores, dim=-1)
            attn = self.drop(attn).to(vs.dtype)
            out[:, :, s:e, :] = attn @ vs

        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.proj(out)

class TransformerBlockLG(nn.Module):
    def __init__(self, d_model=384, n_heads=8, drop=0.1, drop_path=0.1, window_size=64):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttentionRoPE_LocalGlobal(d_model, n_heads, drop, window_size=window_size)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(4*d_model, d_model),
        )
        self.dp = DropPath(drop_path)

    def forward(self, x, mask, global_attn=False):
        x = x + self.dp(self.attn(self.ln1(x), key_padding_mask=mask, global_attn=global_attn))
        x = x + self.dp(self.mlp(self.ln2(x)))
        return x

# -------------------------
# FINAL MODEL (V5.1)
# - same heads as V5
# - Local-Global attention in transformer
# -------------------------
class HierSleepTransformerV5_1(nn.Module):
    def __init__(
        self,
        num_classes=5,
        d_model=384,
        depth=12,          # slightly deeper than V5=10 (still cheap)
        n_heads=8,
        dur_bins=8,
        window_size=64,    # local window length
        global_every=3     # every k blocks => global attention
    ):
        super().__init__()
        self.num_classes = num_classes
        self.dur_bins = dur_bins
        self.depth = int(depth)
        self.global_every = int(global_every)

        self.encoder = EpochEncoder(d_model)
        self.blocks = nn.ModuleList([
            TransformerBlockLG(
                d_model=d_model,
                n_heads=n_heads,
                drop=0.1,
                drop_path=0.1*(i+1)/depth,
                window_size=window_size
            )
            for i in range(depth)
        ])
        self.head = nn.Linear(d_model, num_classes)

        # same as V5
        self.aux_n1 = nn.Linear(d_model, 2)
        self.aux_dur = nn.Linear(d_model, dur_bins)
        self.trans_logits = nn.Parameter(torch.zeros(num_classes, num_classes))

    def forward(self, x, mask):
        z = self.encoder(x)
        for i, blk in enumerate(self.blocks):
            use_global = (self.global_every > 0) and ((i % self.global_every) == 0)
            z = blk(z, mask, global_attn=use_global)

        main_logits = self.head(z)      # (B,L,C)
        aux_logits  = self.aux_n1(z)    # (B,L,2)
        dur_logits  = self.aux_dur(z)   # (B,L,dur_bins)
        return main_logits, aux_logits, dur_logits

# derive dur_bins from dataset bucket edges
DUR_EDGES = (2,5,10,20,40,80,160)
DUR_BINS = len(DUR_EDGES) + 1

# Recommended defaults for MESA generalization:
# - depth=12 (small boost)
# - window_size=64 (try 64 or 96)
# - global_every=3 (global attention 1/3 blocks)
model = HierSleepTransformerV5_1(
    num_classes=NUM_CLASSES,
    d_model=384,
    depth=12,
    n_heads=8,
    dur_bins=DUR_BINS,
    window_size=64,
    global_every=3
).to(device)

print("V5.1 model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)
print("V5.1 Local-Global settings:", dict(depth=12, window_size=64, global_every=3))


V5.1 model params (M): 22.905512
V5.1 Local-Global settings: {'depth': 12, 'window_size': 64, 'global_every': 3}


In [9]:
# ================================================================
# V5 switches: keep V4 defaults, add V5 knobs
# ================================================================
USE_EMA = True
EMA_DECAY = 0.999

USE_AUX_N1 = True
AUX_N1_WEIGHT = 0.30

USE_LA_CE = True
LA_TAU = 1.0

USE_HARD_NEG_N1 = True
HARD_NEG_MULT = 2.0

USE_TRANS_LOSS = True
TRANS_LOSS_WEIGHT = 0.10

# V5: soft-label loss near transitions (main logits vs soft targets)
USE_SOFT_BOUNDARY_LOSS = True
SOFT_BOUNDARY_WEIGHT = 0.25   # << adjust 0.15~0.35 if needed

# V5: confusion-cost regularizer (expected cost under predicted distribution)
USE_COST_MATRIX = True
COST_WEIGHT = 0.20            # << adjust 0.10~0.30 if needed

# V5: duration-aware head (run-length bucket prediction)
USE_AUX_DUR = True
AUX_DUR_WEIGHT = 0.15          # << adjust 0.10~0.25 if needed
# extra boost on N1 epochs for duration head (N1 is transition-heavy)
AUX_DUR_N1_MULT = 1.50         # << 1.2~2.0

print("V5 switches ready.")


V5 switches ready.


In [10]:
# -------------------------
# Logit-Adjusted Cross Entropy (LA-CE)
# -------------------------
class LogitAdjustedCE(nn.Module):
    def __init__(self, class_freq, tau=1.0):
        super().__init__()
        freq = torch.tensor(class_freq, dtype=torch.float32)
        self.register_buffer("log_prior", torch.log(freq / freq.sum()))
        self.tau = float(tau)

    def forward(self, logits, targets, reduction="none"):
        logits = logits + self.tau * self.log_prior
        return F.cross_entropy(logits, targets, reduction=reduction)

# -------------------------
# Class-dependent label smoothing (vector) -- keep as V4
# -------------------------
def label_smoothing_nll(logits, targets, smooth_per_class):
    """
    logits: (N,C)
    targets: (N,)
    smooth_per_class: tensor (C,) in [0,1)
    returns per-sample loss (N,)
    """
    logp = F.log_softmax(logits, dim=-1)  # (N,C)
    nll = -logp.gather(dim=-1, index=targets.view(-1,1)).squeeze(1)  # (N,)
    smooth = -logp.mean(dim=-1)  # (N,)
    s = smooth_per_class[targets]  # (N,)
    return (1 - s) * nll + s * smooth

# V4 idea: smaller smoothing for N1
smooth_vec = torch.tensor([0.02, 0.00, 0.05, 0.05, 0.02], dtype=torch.float32).to(device)
la_ce = LogitAdjustedCE(class_freq=counts, tau=LA_TAU).to(device)

# -------------------------
# V5: boundary soft-target cross-entropy
#   loss_soft = - sum_c soft_target[c] * log_softmax(logits)[c]
# -------------------------
def soft_target_ce(logits, soft_targets):
    """
    logits: (N,C)
    soft_targets: (N,C) sum to 1
    returns per-sample loss (N,)
    """
    logp = F.log_softmax(logits, dim=-1)
    return -(soft_targets * logp).sum(dim=-1)

# -------------------------
# V5: confusion-cost matrix regularizer
#   expected_cost = sum_j p(j) * cost[y, j]
# -------------------------
def build_cost_matrix(device):
    C = NUM_CLASSES
    cost = torch.zeros((C, C), dtype=torch.float32, device=device)

    # Base: small penalty for any wrong class
    cost += 0.05
    cost.fill_diagonal_(0.0)

    # N1 critical confusions:
    # true N1 predicted W or N2 => big penalty
    cost[1, 0] = 1.00
    cost[1, 2] = 1.00
    # also penalize predicting N1 when true W or N2 (symmetry-ish)
    cost[0, 1] = 0.60
    cost[2, 1] = 0.60

    # Optional: slightly penalize N2<->N3 confusion
    cost[2, 3] = 0.20
    cost[3, 2] = 0.20

    # Optional: W<->REM confusion mild
    cost[0, 4] = 0.15
    cost[4, 0] = 0.15

    return cost

COST_MAT = build_cost_matrix(device)
print("Cost matrix ready.")


Cost matrix ready.


In [11]:
# -------------------------
# EMA (same as V4)
# -------------------------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = float(decay)
        self.shadow = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.detach().clone()

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            if name in self.shadow:
                self.shadow[name].mul_(self.decay)
                self.shadow[name].add_((1.0 - self.decay) * p.data)

    def apply(self, model):
        self.backup = {}
        for name, p in model.named_parameters():
            if name in self.shadow:
                self.backup[name] = p.data.detach().clone()
                p.data.copy_(self.shadow[name])

    def restore(self, model):
        for name, p in model.named_parameters():
            if name in self.backup:
                p.data.copy_(self.backup[name])

ema = EMA(model, EMA_DECAY) if USE_EMA else None


In [12]:
# ================================================================
# V5 Losses: keep V4 + add:
#   - soft boundary loss
#   - cost matrix reg
#   - duration aux head loss (N1-weighted)
# ================================================================

def compute_transition_loss(model, y, mask):
    """
    Supervised transition matrix training:
    minimize CE( trans_logits[y[t-1]], y[t] ) over valid adjacent pairs.
    """
    B, L = y.shape
    y_prev = y[:, :-1]
    y_next = y[:, 1:]
    m_pair = mask[:, :-1] & mask[:, 1:]

    if m_pair.sum().item() == 0:
        return torch.zeros((), device=y.device)

    y_prev_v = y_prev[m_pair]
    y_next_v = y_next[m_pair]

    trans_logits = model.trans_logits  # (C,C)
    logits_pair = trans_logits[y_prev_v]  # (N,C)
    loss = F.cross_entropy(logits_pair, y_next_v)
    return loss

def masked_loss_v5(model, main_logits, aux_logits, dur_logits, y, mask, soft_targets, dur_bucket):
    """
    main_logits: (B,L,C)
    aux_logits : (B,L,2)
    dur_logits : (B,L,DUR_BINS)
    y         : (B,L)
    mask      : (B,L) bool
    soft_targets: (B,L,C)
    dur_bucket  : (B,L)
    """
    B, L, C = main_logits.shape

    # flatten valid
    logits2 = main_logits.view(B*L, C)
    y2 = y.view(B*L)
    m2 = mask.view(B*L)

    logits_valid = logits2[m2]
    y_valid = y2[m2]

    # ---- Main loss: LA-CE + class-dependent label smoothing (same as V4) ----
    if USE_LA_CE:
        adj_logits = logits_valid + la_ce.tau * la_ce.log_prior
        loss_main_vec = label_smoothing_nll(adj_logits, y_valid, smooth_vec)
    else:
        loss_main_vec = label_smoothing_nll(logits_valid, y_valid, smooth_vec)

    # ---- V4: Hard-negative mining for N1 (when predicted W or N2) ----
    if USE_HARD_NEG_N1:
        with torch.no_grad():
            pred = torch.argmax(logits_valid, dim=-1)
            hard = (y_valid == 1) & ((pred == 0) | (pred == 2))
        loss_main_vec = loss_main_vec * torch.where(
            hard,
            torch.tensor(HARD_NEG_MULT, device=logits_valid.device),
            torch.tensor(1.0, device=logits_valid.device)
        )

    loss = loss_main_vec.mean()

    # ---- V5: Soft-label boundary loss ----
    if USE_SOFT_BOUNDARY_LOSS:
        soft2 = soft_targets.view(B*L, C)[m2]  # (N,C)
        # use same LA adjustment for soft targets for consistency
        if USE_LA_CE:
            adj_logits_soft = logits_valid + la_ce.tau * la_ce.log_prior
            loss_soft_vec = soft_target_ce(adj_logits_soft, soft2)
        else:
            loss_soft_vec = soft_target_ce(logits_valid, soft2)
        loss = loss + SOFT_BOUNDARY_WEIGHT * loss_soft_vec.mean()

    # ---- V5: Confusion-cost regularizer (expected cost) ----
    if USE_COST_MATRIX:
        probs = torch.softmax(logits_valid.float(), dim=-1)  # (N,C)
        # cost_row for each target: (N,C)
        cost_row = COST_MAT[y_valid]  # (N,C)
        expected_cost = (probs * cost_row).sum(dim=-1)  # (N,)
        loss = loss + COST_WEIGHT * expected_cost.mean()

    # ---- Aux head: N1 vs Others (same) ----
    if USE_AUX_N1:
        aux2 = aux_logits.view(B*L, 2)[m2]
        y_aux = (y_valid == 1).long()
        loss_aux = F.cross_entropy(aux2, y_aux)
        loss = loss + AUX_N1_WEIGHT * loss_aux

    # ---- V5: Duration-aware aux head ----
    if USE_AUX_DUR:
        dur2 = dur_logits.view(B*L, DUR_BINS)[m2]  # (N,DUR_BINS)
        dur_t = dur_bucket.view(B*L)[m2]           # (N,)
        # boost duration loss on N1 epochs (transition-focused)
        w = torch.ones_like(dur_t, dtype=torch.float32, device=dur2.device)
        w = w * torch.where(y_valid == 1, torch.tensor(AUX_DUR_N1_MULT, device=dur2.device), torch.tensor(1.0, device=dur2.device))
        loss_dur = F.cross_entropy(dur2, dur_t, reduction="none")
        loss_dur = (loss_dur * w).mean()
        loss = loss + AUX_DUR_WEIGHT * loss_dur

    # ---- Transition loss (same) ----
    if USE_TRANS_LOSS:
        loss_trans = compute_transition_loss(model, y, mask)
        loss = loss + TRANS_LOSS_WEIGHT * loss_trans

    return loss


In [13]:
# -------------------------
# Optimizer + OneCycle (same values as V4)
# -------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)

EPOCHS = 60
steps_per_epoch = len(train_seq_loader)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-4,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.15,
    div_factor=20.0,
    final_div_factor=100.0
)
scheduler_is_step_per_batch = True
print("Using OneCycleLR (max_lr=5e-4)")

scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))


Using OneCycleLR (max_lr=5e-4)


In [14]:
# -------------------------
# Metrics helpers (same as V4)
# -------------------------
def _ece_from_probs(y_true, probs, n_bins=15):
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == y_true).astype(np.float32)

    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        m = (conf >= lo) & (conf < hi)
        if m.sum() == 0:
            continue
        bin_acc = acc[m].mean()
        bin_conf = conf[m].mean()
        ece += (m.mean()) * abs(bin_acc - bin_conf)
    return float(ece)

def _auroc_auprc_multiclass(y_true, probs, num_classes=5):
    Y = label_binarize(y_true, classes=list(range(num_classes)))
    aurocs, auprcs = [], []
    for c in range(num_classes):
        if Y[:, c].sum() == 0:
            continue
        try:
            aurocs.append(roc_auc_score(Y[:, c], probs[:, c]))
            auprcs.append(average_precision_score(Y[:, c], probs[:, c]))
        except Exception:
            pass
    if len(aurocs) == 0:
        return float("nan"), float("nan")
    return float(np.mean(aurocs)), float(np.mean(auprcs))

def estimate_transition_from_train(df, num_classes=5, eps=1.0):
    Tm = np.zeros((num_classes, num_classes), dtype=np.float64)
    for p in tqdm(df["npz_path"].tolist(), desc="Estimating transitions(train)", leave=False):
        d = np.load(p, allow_pickle=True)
        y = d["y"].astype(np.int64)
        y = y[y >= 0]
        if len(y) < 2:
            continue
        a = y[:-1]; b = y[1:]
        for i, j in zip(a, b):
            if 0 <= i < num_classes and 0 <= j < num_classes:
                Tm[i, j] += 1.0
    Tm = Tm + eps
    Tm = Tm / Tm.sum(axis=1, keepdims=True)
    return Tm.astype(np.float32)

def viterbi_decode(log_emis, log_trans):
    L, C = log_emis.shape
    dp = np.zeros((L, C), dtype=np.float32)
    back = np.zeros((L, C), dtype=np.int32)
    dp[0] = log_emis[0]
    for t in range(1, L):
        scores = dp[t-1][:, None] + log_trans
        back[t] = np.argmax(scores, axis=0)
        dp[t] = log_emis[t] + scores[back[t], np.arange(C)]
    path = np.zeros((L,), dtype=np.int32)
    path[-1] = int(np.argmax(dp[-1]))
    for t in range(L-2, -1, -1):
        path[t] = back[t+1, path[t+1]]
    return path

# V4 learned smoothing at inference (keep)
USE_LEARNED_SMOOTHING = True
def apply_learned_smoothing_probs(probs, model):
    Tm = torch.softmax(model.trans_logits, dim=1)  # (C,C)
    return probs @ Tm

# Viterbi option (keep off)
USE_VITERBI = False
Tmat = estimate_transition_from_train(df_train, num_classes=NUM_CLASSES, eps=1.0) if USE_VITERBI else None
logT = np.log(Tmat + 1e-12).astype(np.float32) if USE_VITERBI else None

def tensor_stats(name, t):
    t = t.detach()
    finite = torch.isfinite(t).all().item()
    absmax = t.abs().max().item() if t.numel() else 0.0
    return f"{name}: finite={finite} absmax={absmax:.6g} dtype={t.dtype} shape={tuple(t.shape)}"

@torch.no_grad()
def check_model_params_finite(model):
    for n, p in model.named_parameters():
        if p is None:
            continue
        if not torch.isfinite(p).all():
            return False, n
    return True, None


In [15]:
# -------------------------
# Eval (updated to accept new batch outputs, still reports same metrics)
# -------------------------
@torch.no_grad()
def eval_sequence(model, loader, desc="Eval"):
    model.eval()

    all_true, all_pred = [], []
    all_probs = []

    all_true_v, all_pred_v = [], []

    total_loss = 0.0
    total_n = 0
    bad_batches = 0
    first_bad_printed = False

    for bidx, (xb, yb, mb, sb, db) in enumerate(tqdm(loader, desc=desc, leave=False)):
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        mb = mb.to(device, non_blocking=True)
        sb = sb.to(device, non_blocking=True)
        db = db.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            main_logits, aux_logits, dur_logits = model(xb, mb)
            loss = masked_loss_v5(model, main_logits, aux_logits, dur_logits, yb, mb, sb, db)

        if (not torch.isfinite(main_logits).all()) or (not torch.isfinite(loss)):
            bad_batches += 1
            if not first_bad_printed:
                first_bad_printed = True
                print(f"\n[WARN] {desc}: non-finite logits/loss at batch={bidx}")
                print(tensor_stats("logits", main_logits))
                print("loss:", float(loss.item()) if torch.isfinite(loss) else loss)
            continue

        probs = torch.softmax(main_logits.float(), dim=-1)

        if USE_LEARNED_SMOOTHING:
            probs = apply_learned_smoothing_probs(probs, model)

        if not torch.isfinite(probs).all():
            bad_batches += 1
            if not first_bad_printed:
                first_bad_printed = True
                print(f"\n[WARN] {desc}: non-finite probs at batch={bidx}")
                print(tensor_stats("logits", main_logits))
                print(tensor_stats("probs", probs))
            continue

        pred = torch.argmax(probs, dim=-1)

        yv = yb[mb].detach().cpu().numpy()
        pv = pred[mb].detach().cpu().numpy()
        pr = probs[mb].detach().cpu().numpy()

        if yv.size == 0:
            continue

        all_true.append(yv)
        all_pred.append(pv)
        all_probs.append(pr)

        if USE_VITERBI:
            probs_np = probs.detach().cpu().numpy()
            y_np = yb.detach().cpu().numpy()
            m_np = mb.detach().cpu().numpy()
            for i in range(probs_np.shape[0]):
                Li = int(m_np[i].sum())
                if Li <= 0:
                    continue
                emis = probs_np[i, :Li]
                yseq = y_np[i, :Li].astype(np.int64)
                if not np.isfinite(emis).all():
                    continue
                path = viterbi_decode(np.log(emis + 1e-12).astype(np.float32), logT)
                all_true_v.append(yseq)
                all_pred_v.append(path.astype(np.int64))

        n = int(mb.sum().item())
        total_loss += float(loss.item()) * n
        total_n += n

    if len(all_true) == 0:
        print(f"\n[STOP] {desc}: 0 valid batches collected (bad_batches={bad_batches}).")
        return {
            "loss": float("nan"),
            "acc": float("nan"),
            "macro_f1": float("nan"),
            "kappa": float("nan"),
            "AUROC": float("nan"),
            "AUPRC": float("nan"),
            "meanConf": float("nan"),
            "ECE": float("nan"),
            "f1_per_class": {LABELS[i]: float("nan") for i in range(NUM_CLASSES)},
            "cm": np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64),
            "bad_batches": int(bad_batches),
        }

    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_pred)
    probs_all = np.concatenate(all_probs)

    acc = accuracy_score(y_true, y_pred)
    mf1 = f1_score(y_true, y_pred, average="macro")
    kappa = cohen_kappa_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES)))

    auroc, auprc = _auroc_auprc_multiclass(y_true, probs_all, num_classes=NUM_CLASSES)
    mean_conf = float(probs_all.max(axis=1).mean())
    ece = _ece_from_probs(y_true, probs_all, n_bins=15)

    f1_per = {}
    for i in range(NUM_CLASSES):
        f1_per[LABELS[i]] = float(f1_score((y_true==i).astype(int), (y_pred==i).astype(int)))

    out = {
        "loss": total_loss / max(total_n, 1),
        "acc": float(acc),
        "macro_f1": float(mf1),
        "kappa": float(kappa),
        "AUROC": auroc,
        "AUPRC": auprc,
        "meanConf": mean_conf,
        "ECE": ece,
        "f1_per_class": f1_per,
        "cm": cm,
        "bad_batches": int(bad_batches),
    }

    if USE_VITERBI and len(all_true_v) > 0:
        yt = np.concatenate(all_true_v)
        yp = np.concatenate(all_pred_v)
        out["viterbi_acc"] = float(accuracy_score(yt, yp))
        out["viterbi_macro_f1"] = float(f1_score(yt, yp, average="macro"))
        out["viterbi_kappa"] = float(cohen_kappa_score(yt, yp))
        out["viterbi_cm"] = confusion_matrix(yt, yp, labels=list(range(NUM_CLASSES)))

    return out


In [16]:
# -------------------------
# Train one epoch (updated batch tuple + V5 loss)
# -------------------------
def train_one_epoch(model, loader, epoch, grad_clip=1.0, droppath_target=0.10, droppath_warm_epochs=10):
    model.train()
    running_loss = 0.0
    n_seen = 0

    # DropPath warmup (keep)
    warm = min(1.0, epoch / max(1, droppath_warm_epochs))
    for blk in model.blocks:
        blk.dp.drop_prob = droppath_target * warm

    pbar = tqdm(loader, desc=f"Train epoch {epoch}", leave=False)
    for step, (xb, yb, mb, sb, db) in enumerate(pbar):
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        mb = mb.to(device, non_blocking=True)
        sb = sb.to(device, non_blocking=True)
        db = db.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            main_logits, aux_logits, dur_logits = model(xb, mb)
            loss = masked_loss_v5(model, main_logits, aux_logits, dur_logits, yb, mb, sb, db)

        if not torch.isfinite(loss):
            print("\n[STOP] Non-finite loss detected:", float(loss.item()))
            print(tensor_stats("logits", main_logits))
            ok, bad = check_model_params_finite(model)
            print("params finite:", ok, "| first bad:", bad)
            return float("nan")

        scaler.scale(loss).backward()

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        scaler.step(optimizer)
        scaler.update()

        if scheduler_is_step_per_batch:
            scheduler.step()

        if USE_EMA:
            ema.update(model)

        n = int(mb.sum().item())
        running_loss += float(loss.item()) * n
        n_seen += n
        pbar.set_postfix(loss=running_loss/max(n_seen,1), lr=optimizer.param_groups[0]["lr"])

    return running_loss / max(n_seen, 1)


In [17]:
# %% [code]
# -------------------------
# Training loop + checkpoints
# - Save BEST by VAL macro-F1 (main)
# - Save BEST by MESA macro-F1 (secondary; if MESA loader exists)
# - Keep Top-K by VAL macro-F1 for ensemble
# -------------------------
CKPT_DIR = ROOT / "checkpoints_hier_rope_seq_v5_1"   # <= new folder to avoid overwriting old runs
CKPT_DIR.mkdir(parents=True, exist_ok=True)

best_val_path  = CKPT_DIR / "BEST_VAL_macroF1.pt"
best_mesa_path = CKPT_DIR / "BEST_MESA_macroF1.pt"

best_val  = -1.0
best_mesa = -1.0

topk_paths = []
ENSEMBLE_K = 5

def _make_payload(epoch, tr_loss, val_m, test_m, ext_m, mesa_m):
    payload = {
        "epoch": epoch,
        "train_loss": float(tr_loss),

        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),

        "val_metrics": val_m,
        "test_metrics": test_m,
        "ext_metrics": ext_m,
        "mesa_metrics": mesa_m,

        "class_weights": class_weights.detach().cpu().numpy(),

        "use_ema": USE_EMA,
        "ema_decay": EMA_DECAY,
        "ema_shadow": {k: v.detach().cpu() for k, v in ema.shadow.items()} if USE_EMA else None,

        "use_learned_smoothing": USE_LEARNED_SMOOTHING,
        "use_viterbi": USE_VITERBI,
        "Tmat": Tmat if USE_VITERBI else None,

        # V5 config snapshot (works for V5.1 too)
        "v5_soft_boundary": dict(enabled=USE_SOFT_BOUNDARY_LOSS, weight=SOFT_BOUNDARY_WEIGHT),
        "v5_cost_matrix": dict(enabled=USE_COST_MATRIX, weight=COST_WEIGHT),
        "v5_aux_dur": dict(enabled=USE_AUX_DUR, weight=AUX_DUR_WEIGHT, n1_mult=AUX_DUR_N1_MULT, dur_edges=DUR_EDGES),
    }
    return payload


for epoch in range(1, EPOCHS + 1):
    tr_loss = train_one_epoch(model, train_seq_loader, epoch)

    # Evaluate with EMA weights
    if USE_EMA:
        ema.apply(model)

    val_m  = eval_sequence(model, val_seq_loader,  desc="VAL")
    test_m = eval_sequence(model, test_seq_loader, desc="SHHS1 TEST")
    ext_m  = eval_sequence(model, ext_seq_loader,  desc="SHHS2 EXT")
    mesa_m = eval_sequence(model, mesa_seq_loader, desc="MESA EXT") if mesa_seq_loader is not None else None

    if USE_EMA:
        ema.restore(model)

    print(f"\nEpoch {epoch:02d} | train_loss={tr_loss:.4f}")

    print(f"  VAL   : loss={val_m['loss']:.4f} acc={val_m['acc']:.4f} macroF1={val_m['macro_f1']:.4f} "
          f"kappa={val_m['kappa']:.4f} AUROC={val_m['AUROC']:.4f} AUPRC={val_m['AUPRC']:.4f} "
          f"meanConf={val_m['meanConf']:.4f} ECE={val_m['ECE']:.4f}")
    print(f"  F1/class: {val_m['f1_per_class']}")

    print(f"  TEST1 : loss={test_m['loss']:.4f} acc={test_m['acc']:.4f} macroF1={test_m['macro_f1']:.4f} "
          f"kappa={test_m['kappa']:.4f} AUROC={test_m['AUROC']:.4f} AUPRC={test_m['AUPRC']:.4f} "
          f"meanConf={test_m['meanConf']:.4f} ECE={test_m['ECE']:.4f}")

    print(f"  SHHS2 : loss={ext_m['loss']:.4f} acc={ext_m['acc']:.4f} macroF1={ext_m['macro_f1']:.4f} "
          f"kappa={ext_m['kappa']:.4f} AUROC={ext_m['AUROC']:.4f} AUPRC={ext_m['AUPRC']:.4f} "
          f"meanConf={ext_m['meanConf']:.4f} ECE={ext_m['ECE']:.4f}")

    if mesa_m is not None:
        print(f"  MESA  : loss={mesa_m['loss']:.4f} acc={mesa_m['acc']:.4f} macroF1={mesa_m['macro_f1']:.4f} "
              f"kappa={mesa_m['kappa']:.4f} AUROC={mesa_m['AUROC']:.4f} AUPRC={mesa_m['AUPRC']:.4f} "
              f"meanConf={mesa_m['meanConf']:.4f} ECE={mesa_m['ECE']:.4f}")
        print(f"  MESA F1/class: {mesa_m['f1_per_class']}")

    if USE_VITERBI:
        print(f"  VAL(viterbi): acc={val_m['viterbi_acc']:.4f} macroF1={val_m['viterbi_macro_f1']:.4f} kappa={val_m['viterbi_kappa']:.4f}")

    payload = _make_payload(epoch, tr_loss, val_m, test_m, ext_m, mesa_m)

    # ============================================================
    # (A) Save BEST by VAL macro-F1 (main)
    # ============================================================
    if (val_m is not None) and (val_m.get("macro_f1", float("nan")) == val_m.get("macro_f1", float("nan"))):
        if val_m["macro_f1"] > best_val:
            best_val = float(val_m["macro_f1"])
            payload["best_val_macroF1"] = best_val
            torch.save(payload, best_val_path)
            print("  ? Saved BEST_VAL:", best_val_path.name, f"(val_macroF1={best_val:.4f})")

            ck = CKPT_DIR / f"VALBEST_ep{epoch:03d}_valF1_{best_val:.4f}.pt"
            torch.save(payload, ck)
            topk_paths.append(ck)
            topk_paths = topk_paths[-ENSEMBLE_K:]
            print("  ? Added to VAL Top-K:", ck.name)

    # ============================================================
    # (B) Save BEST by MESA macro-F1 (secondary)
    # ============================================================
    if mesa_m is not None:
        mf1 = mesa_m.get("macro_f1", float("nan"))
        if mf1 == mf1:  # not NaN
            if mf1 > best_mesa:
                best_mesa = float(mf1)
                payload["best_mesa_macroF1"] = best_mesa
                torch.save(payload, best_mesa_path)
                print("  ? Saved BEST_MESA:", best_mesa_path.name, f"(mesa_macroF1={best_mesa:.4f})")

                ck2 = CKPT_DIR / f"MESABEST_ep{epoch:03d}_mesaF1_{best_mesa:.4f}.pt"
                torch.save(payload, ck2)
                print("  ? Snapshot MESA best:", ck2.name)

print("\n==============================")
print("Training finished")
print("==============================")
print("BEST VAL macroF1 :", f"{best_val:.4f}", "|", best_val_path)
if mesa_seq_loader is not None:
    print("BEST MESA macroF1:", f"{best_mesa:.4f}", "|", best_mesa_path)

print("\nTop-K VAL checkpoints for ensemble:")
for p in topk_paths:
    print(" -", p)


                                                                                


Epoch 01 | train_loss=1.3762
  VAL   : loss=1.2399 acc=0.7872 macroF1=0.6513 kappa=0.6938 AUROC=0.9473 AUPRC=0.7838 meanConf=0.2081 ECE=0.5791
  F1/class: {'W': 0.8225161576933548, 'N1': 0.05058552087414566, 'N2': 0.8033134343332999, 'N3': 0.7542364701826956, 'REM': 0.8258525709836858}
  TEST1 : loss=1.2185 acc=0.7935 macroF1=0.6554 kappa=0.7028 AUROC=0.9509 AUPRC=0.7917 meanConf=0.2081 ECE=0.5855
  SHHS2 : loss=1.2955 acc=0.7675 macroF1=0.6282 kappa=0.6688 AUROC=0.9456 AUPRC=0.7761 meanConf=0.2077 ECE=0.5598
  MESA  : loss=1.7274 acc=0.6964 macroF1=0.5758 kappa=0.5577 AUROC=0.8980 AUPRC=0.6962 meanConf=0.2069 ECE=0.4895
  MESA F1/class: {'W': 0.7600170124661046, 'N1': 0.0027893560708116077, 'N2': 0.7053987029187828, 'N3': 0.6437907217558705, 'REM': 0.7672267376707401}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.6513)
  ? Added to VAL Top-K: VALBEST_ep001_valF1_0.6513.pt
  ? Saved BEST_MESA: BEST_MESA_macroF1.pt (mesa_macroF1=0.5758)
  ? Snapshot MESA best: MESABEST_ep001_m

                                                                                


Epoch 02 | train_loss=1.1840
  VAL   : loss=1.0878 acc=0.8294 macroF1=0.7371 kappa=0.7574 AUROC=0.9583 AUPRC=0.8177 meanConf=0.2318 ECE=0.5976
  F1/class: {'W': 0.8929749536464091, 'N1': 0.32192874692874696, 'N2': 0.8415611380865352, 'N3': 0.7665224441816495, 'REM': 0.8625898171277022}
  TEST1 : loss=1.0563 acc=0.8376 macroF1=0.7423 kappa=0.7687 AUROC=0.9620 AUPRC=0.8281 meanConf=0.2318 ECE=0.6058
  SHHS2 : loss=1.0430 acc=0.8366 macroF1=0.7236 kappa=0.7678 AUROC=0.9621 AUPRC=0.8265 meanConf=0.2308 ECE=0.6058
  MESA  : loss=1.5583 acc=0.7224 macroF1=0.6057 kappa=0.5956 AUROC=0.9181 AUPRC=0.7332 meanConf=0.2270 ECE=0.4954
  MESA F1/class: {'W': 0.7874290416156475, 'N1': 0.1026176631099569, 'N2': 0.7370577437978761, 'N3': 0.610963880057253, 'REM': 0.7902263250543472}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7371)
  ? Added to VAL Top-K: VALBEST_ep002_valF1_0.7371.pt
  ? Saved BEST_MESA: BEST_MESA_macroF1.pt (mesa_macroF1=0.6057)
  ? Snapshot MESA best: MESABEST_ep002_mesaF

                                                                                


Epoch 03 | train_loss=1.1380
  VAL   : loss=1.0352 acc=0.8370 macroF1=0.7680 kappa=0.7721 AUROC=0.9611 AUPRC=0.8254 meanConf=0.2797 ECE=0.5573
  F1/class: {'W': 0.8984920431490275, 'N1': 0.44173441734417346, 'N2': 0.8473815833169547, 'N3': 0.7813822284908323, 'REM': 0.8711978877327458}
  TEST1 : loss=1.0014 acc=0.8450 macroF1=0.7741 kappa=0.7831 AUROC=0.9648 AUPRC=0.8363 meanConf=0.2798 ECE=0.5652
  SHHS2 : loss=0.9763 acc=0.8460 macroF1=0.7592 kappa=0.7843 AUROC=0.9653 AUPRC=0.8356 meanConf=0.2782 ECE=0.5678
  MESA  : loss=1.5273 acc=0.7462 macroF1=0.6558 kappa=0.6351 AUROC=0.9203 AUPRC=0.7437 meanConf=0.2696 ECE=0.4766
  MESA F1/class: {'W': 0.8132645434160393, 'N1': 0.2525295360238428, 'N2': 0.7573898988588511, 'N3': 0.6503484813853286, 'REM': 0.805698821154993}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7680)
  ? Added to VAL Top-K: VALBEST_ep003_valF1_0.7680.pt
  ? Saved BEST_MESA: BEST_MESA_macroF1.pt (mesa_macroF1=0.6558)
  ? Snapshot MESA best: MESABEST_ep003_mesaF

                                                                                


Epoch 04 | train_loss=1.1066
  VAL   : loss=0.9896 acc=0.8369 macroF1=0.7786 kappa=0.7746 AUROC=0.9630 AUPRC=0.8283 meanConf=0.3612 ECE=0.4756
  F1/class: {'W': 0.9020108660360613, 'N1': 0.4918646636651362, 'N2': 0.8445191068778235, 'N3': 0.7820647696305216, 'REM': 0.8724920607567515}
  TEST1 : loss=0.9558 acc=0.8449 macroF1=0.7852 kappa=0.7856 AUROC=0.9665 AUPRC=0.8388 meanConf=0.3614 ECE=0.4836
  SHHS2 : loss=0.9286 acc=0.8479 macroF1=0.7742 kappa=0.7891 AUROC=0.9673 AUPRC=0.8408 meanConf=0.3577 ECE=0.4902
  MESA  : loss=1.5120 acc=0.7448 macroF1=0.6647 kappa=0.6369 AUROC=0.9198 AUPRC=0.7425 meanConf=0.3396 ECE=0.4053
  MESA F1/class: {'W': 0.8105911841460651, 'N1': 0.30588381231051004, 'N2': 0.7568977574893391, 'N3': 0.6500011293929352, 'REM': 0.8000120698132196}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7786)
  ? Added to VAL Top-K: VALBEST_ep004_valF1_0.7786.pt
  ? Saved BEST_MESA: BEST_MESA_macroF1.pt (mesa_macroF1=0.6647)
  ? Snapshot MESA best: MESABEST_ep004_mesa

                                                                                


Epoch 05 | train_loss=1.0780
  VAL   : loss=0.9641 acc=0.8345 macroF1=0.7800 kappa=0.7731 AUROC=0.9644 AUPRC=0.8279 meanConf=0.4726 ECE=0.3619
  F1/class: {'W': 0.9037463144389026, 'N1': 0.5000596350335944, 'N2': 0.8402708386863467, 'N3': 0.7835339352448146, 'REM': 0.8723571353480635}
  TEST1 : loss=0.9276 acc=0.8437 macroF1=0.7883 kappa=0.7854 AUROC=0.9677 AUPRC=0.8389 meanConf=0.4724 ECE=0.3713
  SHHS2 : loss=0.8907 acc=0.8503 macroF1=0.7822 kappa=0.7936 AUROC=0.9694 AUPRC=0.8435 meanConf=0.4649 ECE=0.3853
  MESA  : loss=1.4128 acc=0.7669 macroF1=0.6881 kappa=0.6691 AUROC=0.9258 AUPRC=0.7479 meanConf=0.4349 ECE=0.3320
  MESA F1/class: {'W': 0.8460318673463424, 'N1': 0.3657237301305098, 'N2': 0.7753024636655753, 'N3': 0.6455588243749582, 'REM': 0.8078234968693556}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7800)
  ? Added to VAL Top-K: VALBEST_ep005_valF1_0.7800.pt
  ? Saved BEST_MESA: BEST_MESA_macroF1.pt (mesa_macroF1=0.6881)
  ? Snapshot MESA best: MESABEST_ep005_mesaF

                                                                                


Epoch 06 | train_loss=1.0460
  VAL   : loss=0.9357 acc=0.8375 macroF1=0.7820 kappa=0.7766 AUROC=0.9651 AUPRC=0.8313 meanConf=0.5777 ECE=0.2598
  F1/class: {'W': 0.9034265213096001, 'N1': 0.4985085684206422, 'N2': 0.846469164720574, 'N3': 0.7884219039419792, 'REM': 0.8729646044563172}
  TEST1 : loss=0.9036 acc=0.8455 macroF1=0.7890 kappa=0.7874 AUROC=0.9678 AUPRC=0.8409 meanConf=0.5776 ECE=0.2679
  SHHS2 : loss=0.8625 acc=0.8537 macroF1=0.7846 kappa=0.7983 AUROC=0.9697 AUPRC=0.8452 meanConf=0.5666 ECE=0.2871
  MESA  : loss=1.4167 acc=0.7601 macroF1=0.6862 kappa=0.6615 AUROC=0.9244 AUPRC=0.7512 meanConf=0.5305 ECE=0.2296
  MESA F1/class: {'W': 0.8325163511777528, 'N1': 0.3766883237306729, 'N2': 0.7745656877173027, 'N3': 0.6530745283771919, 'REM': 0.7941502112615331}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7820)
  ? Added to VAL Top-K: VALBEST_ep006_valF1_0.7820.pt


                                                                                


Epoch 07 | train_loss=1.0325
  VAL   : loss=0.9204 acc=0.8346 macroF1=0.7814 kappa=0.7738 AUROC=0.9661 AUPRC=0.8339 meanConf=0.6424 ECE=0.1921
  F1/class: {'W': 0.9073813002503439, 'N1': 0.5047862156987875, 'N2': 0.8376214308153102, 'N3': 0.7824870524306614, 'REM': 0.8748836183662468}
  TEST1 : loss=0.8867 acc=0.8431 macroF1=0.7892 kappa=0.7853 AUROC=0.9691 AUPRC=0.8442 meanConf=0.6420 ECE=0.2011
  SHHS2 : loss=0.8611 acc=0.8492 macroF1=0.7829 kappa=0.7926 AUROC=0.9701 AUPRC=0.8470 meanConf=0.6287 ECE=0.2204
  MESA  : loss=1.4491 acc=0.7500 macroF1=0.6831 kappa=0.6452 AUROC=0.9202 AUPRC=0.7455 meanConf=0.5743 ECE=0.1757
  MESA F1/class: {'W': 0.817154839480285, 'N1': 0.3823357842037449, 'N2': 0.7545487619038902, 'N3': 0.6475781936094035, 'REM': 0.8139909851161258}


                                                                                


Epoch 08 | train_loss=1.0274
  VAL   : loss=0.9186 acc=0.8393 macroF1=0.7845 kappa=0.7785 AUROC=0.9653 AUPRC=0.8343 meanConf=0.6701 ECE=0.1692
  F1/class: {'W': 0.9028583783783785, 'N1': 0.5075208354252435, 'N2': 0.8475247617355194, 'N3': 0.7909779976839667, 'REM': 0.8734851568147443}
  TEST1 : loss=0.8838 acc=0.8496 macroF1=0.7933 kappa=0.7925 AUROC=0.9684 AUPRC=0.8444 meanConf=0.6700 ECE=0.1795
  SHHS2 : loss=0.8770 acc=0.8500 macroF1=0.7836 kappa=0.7927 AUROC=0.9684 AUPRC=0.8449 meanConf=0.6547 ECE=0.1953
  MESA  : loss=1.4269 acc=0.7627 macroF1=0.6879 kappa=0.6616 AUROC=0.9234 AUPRC=0.7529 meanConf=0.6108 ECE=0.1519
  MESA F1/class: {'W': 0.8318389690173232, 'N1': 0.37510108243455215, 'N2': 0.7725722977736569, 'N3': 0.6469995106689007, 'REM': 0.8129401002701763}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7845)
  ? Added to VAL Top-K: VALBEST_ep008_valF1_0.7845.pt


                                                                                


Epoch 09 | train_loss=1.0120
  VAL   : loss=0.9047 acc=0.8388 macroF1=0.7850 kappa=0.7791 AUROC=0.9664 AUPRC=0.8374 meanConf=0.6874 ECE=0.1515
  F1/class: {'W': 0.9079161317928969, 'N1': 0.5097196729517771, 'N2': 0.842656425180697, 'N3': 0.7879976588087322, 'REM': 0.876646121925917}
  TEST1 : loss=0.8696 acc=0.8475 macroF1=0.7930 kappa=0.7909 AUROC=0.9694 AUPRC=0.8475 meanConf=0.6871 ECE=0.1605
  SHHS2 : loss=0.8370 acc=0.8545 macroF1=0.7870 kappa=0.7995 AUROC=0.9702 AUPRC=0.8488 meanConf=0.6785 ECE=0.1759
  MESA  : loss=1.3487 acc=0.7877 macroF1=0.7011 kappa=0.6974 AUROC=0.9280 AUPRC=0.7594 meanConf=0.6289 ECE=0.1588
  MESA F1/class: {'W': 0.8748322708648105, 'N1': 0.3726567854119354, 'N2': 0.7930000597405839, 'N3': 0.6460641801596745, 'REM': 0.8188675796140673}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7850)
  ? Added to VAL Top-K: VALBEST_ep009_valF1_0.7850.pt
  ? Saved BEST_MESA: BEST_MESA_macroF1.pt (mesa_macroF1=0.7011)
  ? Snapshot MESA best: MESABEST_ep009_mesaF1_

                                                                                


Epoch 10 | train_loss=1.0009
  VAL   : loss=0.9000 acc=0.8421 macroF1=0.7871 kappa=0.7828 AUROC=0.9667 AUPRC=0.8385 meanConf=0.6893 ECE=0.1527
  F1/class: {'W': 0.9069515218289491, 'N1': 0.5078441347034287, 'N2': 0.8491428759885136, 'N3': 0.795769158920646, 'REM': 0.8757183689140516}
  TEST1 : loss=0.8642 acc=0.8513 macroF1=0.7954 kappa=0.7954 AUROC=0.9697 AUPRC=0.8491 meanConf=0.6896 ECE=0.1617
  SHHS2 : loss=0.8365 acc=0.8581 macroF1=0.7901 kappa=0.8041 AUROC=0.9706 AUPRC=0.8503 meanConf=0.6778 ECE=0.1803
  MESA  : loss=1.3696 acc=0.7821 macroF1=0.6901 kappa=0.6897 AUROC=0.9265 AUPRC=0.7579 meanConf=0.6220 ECE=0.1601
  MESA F1/class: {'W': 0.8656622886306022, 'N1': 0.3784378581102198, 'N2': 0.7989929437436135, 'N3': 0.600760994206913, 'REM': 0.8065736140792611}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7871)
  ? Added to VAL Top-K: VALBEST_ep010_valF1_0.7871.pt


                                                                                


Epoch 11 | train_loss=0.9887
  VAL   : loss=0.8902 acc=0.8461 macroF1=0.7913 kappa=0.7879 AUROC=0.9677 AUPRC=0.8409 meanConf=0.6915 ECE=0.1546
  F1/class: {'W': 0.9097645902117959, 'N1': 0.5142961430831001, 'N2': 0.8537726789369884, 'N3': 0.7999269639841148, 'REM': 0.8787435815541852}
  TEST1 : loss=0.8545 acc=0.8554 macroF1=0.7998 kappa=0.8007 AUROC=0.9710 AUPRC=0.8522 meanConf=0.6916 ECE=0.1638
  SHHS2 : loss=0.8246 acc=0.8628 macroF1=0.7952 kappa=0.8101 AUROC=0.9718 AUPRC=0.8531 meanConf=0.6817 ECE=0.1811
  MESA  : loss=1.4022 acc=0.7698 macroF1=0.6713 kappa=0.6714 AUROC=0.9252 AUPRC=0.7506 meanConf=0.6238 ECE=0.1463
  MESA F1/class: {'W': 0.8481995254822331, 'N1': 0.38374548385938695, 'N2': 0.7936095048748197, 'N3': 0.5308490249155322, 'REM': 0.799968129311853}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7913)
  ? Added to VAL Top-K: VALBEST_ep011_valF1_0.7913.pt


                                                                                


Epoch 12 | train_loss=0.9710
  VAL   : loss=0.8915 acc=0.8456 macroF1=0.7911 kappa=0.7875 AUROC=0.9680 AUPRC=0.8404 meanConf=0.6934 ECE=0.1522
  F1/class: {'W': 0.9095300666351988, 'N1': 0.514005115089514, 'N2': 0.8517251527387021, 'N3': 0.8009426815924585, 'REM': 0.8793401278286406}
  TEST1 : loss=0.8516 acc=0.8558 macroF1=0.8007 kappa=0.8015 AUROC=0.9713 AUPRC=0.8523 meanConf=0.6933 ECE=0.1626
  SHHS2 : loss=0.8234 acc=0.8610 macroF1=0.7941 kappa=0.8079 AUROC=0.9717 AUPRC=0.8533 meanConf=0.6832 ECE=0.1778
  MESA  : loss=1.3752 acc=0.7748 macroF1=0.6803 kappa=0.6785 AUROC=0.9264 AUPRC=0.7514 meanConf=0.6239 ECE=0.1511
  MESA F1/class: {'W': 0.8568034782743202, 'N1': 0.3852569240654748, 'N2': 0.7910409705942075, 'N3': 0.5577179299091304, 'REM': 0.810677574624892}


                                                                                


Epoch 13 | train_loss=0.9684
  VAL   : loss=0.8850 acc=0.8449 macroF1=0.7908 kappa=0.7869 AUROC=0.9679 AUPRC=0.8418 meanConf=0.6967 ECE=0.1482
  F1/class: {'W': 0.9082074738931563, 'N1': 0.5136625119846597, 'N2': 0.8526589750722562, 'N3': 0.8028680294526543, 'REM': 0.876402712724434}
  TEST1 : loss=0.8445 acc=0.8545 macroF1=0.7995 kappa=0.7999 AUROC=0.9713 AUPRC=0.8536 meanConf=0.6967 ECE=0.1577
  SHHS2 : loss=0.8195 acc=0.8585 macroF1=0.7916 kappa=0.8048 AUROC=0.9715 AUPRC=0.8539 meanConf=0.6872 ECE=0.1713
  MESA  : loss=1.4077 acc=0.7650 macroF1=0.6747 kappa=0.6655 AUROC=0.9226 AUPRC=0.7460 meanConf=0.6374 ECE=0.1279
  MESA F1/class: {'W': 0.8497563893677142, 'N1': 0.39343182429971657, 'N2': 0.7789371372374208, 'N3': 0.5603705243156581, 'REM': 0.7911653693915682}


                                                                                


Epoch 14 | train_loss=0.9548
  VAL   : loss=0.8771 acc=0.8468 macroF1=0.7922 kappa=0.7893 AUROC=0.9683 AUPRC=0.8433 meanConf=0.6931 ECE=0.1537
  F1/class: {'W': 0.9068939710387078, 'N1': 0.5125799235900785, 'N2': 0.8563243874800285, 'N3': 0.8069317208626372, 'REM': 0.8782787233784063}
  TEST1 : loss=0.8427 acc=0.8546 macroF1=0.7991 kappa=0.8000 AUROC=0.9712 AUPRC=0.8537 meanConf=0.6931 ECE=0.1616
  SHHS2 : loss=0.8160 acc=0.8595 macroF1=0.7920 kappa=0.8061 AUROC=0.9717 AUPRC=0.8539 meanConf=0.6838 ECE=0.1757
  MESA  : loss=1.4321 acc=0.7553 macroF1=0.6714 kappa=0.6523 AUROC=0.9206 AUPRC=0.7402 meanConf=0.6246 ECE=0.1307
  MESA F1/class: {'W': 0.8319930991389327, 'N1': 0.3885050534587185, 'N2': 0.77042741581311, 'N3': 0.565725846623087, 'REM': 0.8001431690961247}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7922)
  ? Added to VAL Top-K: VALBEST_ep014_valF1_0.7922.pt


                                                                                


Epoch 15 | train_loss=0.9511
  VAL   : loss=0.8823 acc=0.8448 macroF1=0.7903 kappa=0.7871 AUROC=0.9685 AUPRC=0.8448 meanConf=0.7010 ECE=0.1439
  F1/class: {'W': 0.9075792846993379, 'N1': 0.511520261194801, 'N2': 0.851831381488589, 'N3': 0.8003401635019459, 'REM': 0.8803279091490095}
  TEST1 : loss=0.8479 acc=0.8512 macroF1=0.7961 kappa=0.7960 AUROC=0.9712 AUPRC=0.8548 meanConf=0.7001 ECE=0.1511
  SHHS2 : loss=0.8132 acc=0.8578 macroF1=0.7892 kappa=0.8042 AUROC=0.9718 AUPRC=0.8538 meanConf=0.6947 ECE=0.1632
  MESA  : loss=1.3294 acc=0.7929 macroF1=0.6934 kappa=0.7043 AUROC=0.9269 AUPRC=0.7491 meanConf=0.6514 ECE=0.1416
  MESA F1/class: {'W': 0.8935188394934496, 'N1': 0.38956907876948726, 'N2': 0.800517561656673, 'N3': 0.5755478073977041, 'REM': 0.8078792467147848}


                                                                                


Epoch 16 | train_loss=0.9436
  VAL   : loss=0.8772 acc=0.8511 macroF1=0.7954 kappa=0.7943 AUROC=0.9687 AUPRC=0.8441 meanConf=0.6948 ECE=0.1563
  F1/class: {'W': 0.9099404790206919, 'N1': 0.515363967789752, 'N2': 0.8618366546531354, 'N3': 0.8096582377319959, 'REM': 0.880199371758525}
  TEST1 : loss=0.8415 acc=0.8597 macroF1=0.8032 kappa=0.8059 AUROC=0.9718 AUPRC=0.8558 meanConf=0.6946 ECE=0.1650
  SHHS2 : loss=0.8094 acc=0.8668 macroF1=0.7987 kappa=0.8154 AUROC=0.9726 AUPRC=0.8566 meanConf=0.6862 ECE=0.1806
  MESA  : loss=1.4031 acc=0.7742 macroF1=0.6756 kappa=0.6769 AUROC=0.9241 AUPRC=0.7497 meanConf=0.6396 ECE=0.1358
  MESA F1/class: {'W': 0.8579381094346009, 'N1': 0.3980831595837401, 'N2': 0.7919238107359045, 'N3': 0.5197893455791903, 'REM': 0.8101628773838911}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7954)
  ? Added to VAL Top-K: VALBEST_ep016_valF1_0.7954.pt


                                                                                


Epoch 17 | train_loss=0.9326
  VAL   : loss=0.8700 acc=0.8492 macroF1=0.7945 kappa=0.7925 AUROC=0.9691 AUPRC=0.8454 meanConf=0.6969 ECE=0.1522
  F1/class: {'W': 0.9094795627685544, 'N1': 0.5147981707679211, 'N2': 0.8586646876607258, 'N3': 0.8101974215645094, 'REM': 0.8795101202698737}
  TEST1 : loss=0.8332 acc=0.8582 macroF1=0.8031 kappa=0.8048 AUROC=0.9722 AUPRC=0.8570 meanConf=0.6969 ECE=0.1613
  SHHS2 : loss=0.8121 acc=0.8622 macroF1=0.7950 kappa=0.8098 AUROC=0.9723 AUPRC=0.8564 meanConf=0.6884 ECE=0.1738
  MESA  : loss=1.4625 acc=0.7584 macroF1=0.6555 kappa=0.6561 AUROC=0.9187 AUPRC=0.7312 meanConf=0.6391 ECE=0.1194
  MESA F1/class: {'W': 0.8388488625653173, 'N1': 0.39340952873329765, 'N2': 0.7856442477574462, 'N3': 0.4610390655167976, 'REM': 0.7983643314290623}


                                                                                


Epoch 18 | train_loss=0.9305
  VAL   : loss=0.8665 acc=0.8508 macroF1=0.7960 kappa=0.7945 AUROC=0.9691 AUPRC=0.8456 meanConf=0.6914 ECE=0.1594
  F1/class: {'W': 0.9104029332605219, 'N1': 0.5158613094184514, 'N2': 0.8605584405246326, 'N3': 0.8146361658923184, 'REM': 0.8786612914013997}
  TEST1 : loss=0.8318 acc=0.8598 macroF1=0.8047 kappa=0.8068 AUROC=0.9722 AUPRC=0.8573 meanConf=0.6913 ECE=0.1685
  SHHS2 : loss=0.8080 acc=0.8651 macroF1=0.7974 kappa=0.8136 AUROC=0.9726 AUPRC=0.8568 meanConf=0.6815 ECE=0.1836
  MESA  : loss=1.4099 acc=0.7656 macroF1=0.6651 kappa=0.6651 AUROC=0.9228 AUPRC=0.7400 meanConf=0.6287 ECE=0.1378
  MESA F1/class: {'W': 0.853673146732744, 'N1': 0.3964539117204117, 'N2': 0.781464934521073, 'N3': 0.49709213877756814, 'REM': 0.7967603796900727}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7960)
  ? Added to VAL Top-K: VALBEST_ep018_valF1_0.7960.pt


                                                                                


Epoch 19 | train_loss=0.9251
  VAL   : loss=0.8666 acc=0.8528 macroF1=0.7973 kappa=0.7968 AUROC=0.9693 AUPRC=0.8470 meanConf=0.6982 ECE=0.1546
  F1/class: {'W': 0.9107019891641872, 'N1': 0.5168215575666572, 'N2': 0.8629085397229719, 'N3': 0.8159862404894604, 'REM': 0.8799265268636155}
  TEST1 : loss=0.8285 acc=0.8625 macroF1=0.8066 kappa=0.8101 AUROC=0.9726 AUPRC=0.8594 meanConf=0.6979 ECE=0.1647
  SHHS2 : loss=0.7974 acc=0.8682 macroF1=0.8005 kappa=0.8175 AUROC=0.9729 AUPRC=0.8586 meanConf=0.6897 ECE=0.1785
  MESA  : loss=1.3447 acc=0.7842 macroF1=0.6767 kappa=0.6903 AUROC=0.9272 AUPRC=0.7482 meanConf=0.6375 ECE=0.1467
  MESA F1/class: {'W': 0.8833401487675528, 'N1': 0.4026656453466105, 'N2': 0.7930738729742183, 'N3': 0.49032049057040383, 'REM': 0.8139613547736931}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7973)
  ? Added to VAL Top-K: VALBEST_ep019_valF1_0.7973.pt


                                                                                


Epoch 20 | train_loss=0.9178
  VAL   : loss=0.8624 acc=0.8497 macroF1=0.7955 kappa=0.7935 AUROC=0.9694 AUPRC=0.8472 meanConf=0.6961 ECE=0.1536
  F1/class: {'W': 0.910900980336798, 'N1': 0.5176584421867442, 'N2': 0.8575299782834482, 'N3': 0.8135580252108877, 'REM': 0.8780830717393675}
  TEST1 : loss=0.8283 acc=0.8569 macroF1=0.8022 kappa=0.8034 AUROC=0.9722 AUPRC=0.8588 meanConf=0.6960 ECE=0.1609
  SHHS2 : loss=0.8036 acc=0.8624 macroF1=0.7956 kappa=0.8102 AUROC=0.9724 AUPRC=0.8576 meanConf=0.6866 ECE=0.1757
  MESA  : loss=1.3337 acc=0.7816 macroF1=0.6836 kappa=0.6886 AUROC=0.9278 AUPRC=0.7489 meanConf=0.6225 ECE=0.1591
  MESA F1/class: {'W': 0.8799828051918054, 'N1': 0.39132500784202895, 'N2': 0.7905464350860497, 'N3': 0.5545614488572406, 'REM': 0.8013662754491836}


                                                                                


Epoch 21 | train_loss=0.9127
  VAL   : loss=0.8611 acc=0.8535 macroF1=0.7981 kappa=0.7980 AUROC=0.9698 AUPRC=0.8476 meanConf=0.6954 ECE=0.1581
  F1/class: {'W': 0.9116560440747542, 'N1': 0.5175024422012373, 'N2': 0.8631839517993508, 'N3': 0.817486810928275, 'REM': 0.8806242634802199}
  TEST1 : loss=0.8270 acc=0.8614 macroF1=0.8056 kappa=0.8087 AUROC=0.9727 AUPRC=0.8601 meanConf=0.6956 ECE=0.1659
  SHHS2 : loss=0.8049 acc=0.8666 macroF1=0.7985 kappa=0.8153 AUROC=0.9727 AUPRC=0.8583 meanConf=0.6871 ECE=0.1795
  MESA  : loss=1.3670 acc=0.7763 macroF1=0.6633 kappa=0.6788 AUROC=0.9248 AUPRC=0.7428 meanConf=0.6340 ECE=0.1423
  MESA F1/class: {'W': 0.8760865181131026, 'N1': 0.4073827685680448, 'N2': 0.7897653854742318, 'N3': 0.450267798306504, 'REM': 0.792865100483342}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7981)
  ? Added to VAL Top-K: VALBEST_ep021_valF1_0.7981.pt


                                                                                


Epoch 22 | train_loss=0.9052
  VAL   : loss=0.8611 acc=0.8461 macroF1=0.7937 kappa=0.7895 AUROC=0.9698 AUPRC=0.8481 meanConf=0.6966 ECE=0.1496
  F1/class: {'W': 0.913212066900468, 'N1': 0.5214028591998354, 'N2': 0.8499642162154398, 'N3': 0.8030098776930807, 'REM': 0.8807862023444281}
  TEST1 : loss=0.8261 acc=0.8542 macroF1=0.8015 kappa=0.8003 AUROC=0.9729 AUPRC=0.8607 meanConf=0.6966 ECE=0.1576
  SHHS2 : loss=0.7994 acc=0.8597 macroF1=0.7931 kappa=0.8071 AUROC=0.9731 AUPRC=0.8590 meanConf=0.6909 ECE=0.1687
  MESA  : loss=1.3383 acc=0.7861 macroF1=0.6831 kappa=0.6938 AUROC=0.9280 AUPRC=0.7462 meanConf=0.6327 ECE=0.1534
  MESA F1/class: {'W': 0.8847800344959317, 'N1': 0.3908507768299902, 'N2': 0.7937514948993454, 'N3': 0.5324008904311389, 'REM': 0.8137028875446652}


                                                                                


Epoch 23 | train_loss=0.8997
  VAL   : loss=0.8602 acc=0.8544 macroF1=0.7995 kappa=0.7992 AUROC=0.9699 AUPRC=0.8474 meanConf=0.6988 ECE=0.1557
  F1/class: {'W': 0.9136424310232364, 'N1': 0.5192790867905117, 'N2': 0.8641220733447089, 'N3': 0.819923777260216, 'REM': 0.8807334951337303}
  TEST1 : loss=0.8235 acc=0.8634 macroF1=0.8083 kappa=0.8114 AUROC=0.9731 AUPRC=0.8598 meanConf=0.6990 ECE=0.1644
  SHHS2 : loss=0.7894 acc=0.8702 macroF1=0.8033 kappa=0.8202 AUROC=0.9737 AUPRC=0.8611 meanConf=0.6931 ECE=0.1770
  MESA  : loss=1.3369 acc=0.7911 macroF1=0.6718 kappa=0.6991 AUROC=0.9290 AUPRC=0.7519 meanConf=0.6493 ECE=0.1418
  MESA F1/class: {'W': 0.8936444929991583, 'N1': 0.40481014915644054, 'N2': 0.8008336133987937, 'N3': 0.44194906356555047, 'REM': 0.817597660300688}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.7995)
  ? Added to VAL Top-K: VALBEST_ep023_valF1_0.7995.pt


                                                                                


Epoch 24 | train_loss=0.9016
  VAL   : loss=0.8577 acc=0.8529 macroF1=0.7986 kappa=0.7976 AUROC=0.9698 AUPRC=0.8481 meanConf=0.6991 ECE=0.1538
  F1/class: {'W': 0.9120627882151084, 'N1': 0.5194659620889129, 'N2': 0.8613799548248642, 'N3': 0.819158181464427, 'REM': 0.8807184956497335}
  TEST1 : loss=0.8203 acc=0.8619 macroF1=0.8071 kappa=0.8098 AUROC=0.9730 AUPRC=0.8602 meanConf=0.6993 ECE=0.1626
  SHHS2 : loss=0.7864 acc=0.8684 macroF1=0.8012 kappa=0.8180 AUROC=0.9734 AUPRC=0.8599 meanConf=0.6931 ECE=0.1752
  MESA  : loss=1.3721 acc=0.7792 macroF1=0.6615 kappa=0.6829 AUROC=0.9257 AUPRC=0.7402 meanConf=0.6401 ECE=0.1391
  MESA F1/class: {'W': 0.8763674009317307, 'N1': 0.39752121868185314, 'N2': 0.7944247125792737, 'N3': 0.4285353488111051, 'REM': 0.810701112786348}


                                                                                


Epoch 25 | train_loss=0.8899
  VAL   : loss=0.8599 acc=0.8515 macroF1=0.7976 kappa=0.7960 AUROC=0.9698 AUPRC=0.8461 meanConf=0.6973 ECE=0.1542
  F1/class: {'W': 0.9130998813679386, 'N1': 0.5167317291071879, 'N2': 0.8586733487778356, 'N3': 0.8206647568967546, 'REM': 0.8787206338144407}
  TEST1 : loss=0.8208 acc=0.8617 macroF1=0.8078 kappa=0.8099 AUROC=0.9732 AUPRC=0.8591 meanConf=0.6976 ECE=0.1641
  SHHS2 : loss=0.7843 acc=0.8681 macroF1=0.8023 kappa=0.8179 AUROC=0.9738 AUPRC=0.8608 meanConf=0.6911 ECE=0.1770
  MESA  : loss=1.3308 acc=0.7812 macroF1=0.6605 kappa=0.6871 AUROC=0.9267 AUPRC=0.7378 meanConf=0.6382 ECE=0.1430
  MESA F1/class: {'W': 0.8883402062900219, 'N1': 0.4115315488317526, 'N2': 0.7947059671223056, 'N3': 0.41309862255788804, 'REM': 0.7947837143690692}


                                                                                


Epoch 26 | train_loss=0.8836
  VAL   : loss=0.8583 acc=0.8547 macroF1=0.7992 kappa=0.7995 AUROC=0.9697 AUPRC=0.8480 meanConf=0.6995 ECE=0.1551
  F1/class: {'W': 0.912434036939314, 'N1': 0.5187906188188753, 'N2': 0.865178615101169, 'N3': 0.8210889205896338, 'REM': 0.8783786246742122}
  TEST1 : loss=0.8197 acc=0.8648 macroF1=0.8091 kappa=0.8131 AUROC=0.9731 AUPRC=0.8603 meanConf=0.6998 ECE=0.1650
  SHHS2 : loss=0.7970 acc=0.8688 macroF1=0.8014 kappa=0.8184 AUROC=0.9728 AUPRC=0.8596 meanConf=0.6927 ECE=0.1762
  MESA  : loss=1.3723 acc=0.7843 macroF1=0.6643 kappa=0.6900 AUROC=0.9251 AUPRC=0.7373 meanConf=0.6505 ECE=0.1339
  MESA F1/class: {'W': 0.8867353658885776, 'N1': 0.3971951792142745, 'N2': 0.7966779071892183, 'N3': 0.4265058654198926, 'REM': 0.8146156401335688}


                                                                                


Epoch 27 | train_loss=0.8793
  VAL   : loss=0.8578 acc=0.8531 macroF1=0.7983 kappa=0.7978 AUROC=0.9696 AUPRC=0.8481 meanConf=0.6995 ECE=0.1536
  F1/class: {'W': 0.911131573918029, 'N1': 0.5176517812343201, 'N2': 0.8625515715997926, 'N3': 0.8204332764593352, 'REM': 0.8796560208531267}
  TEST1 : loss=0.8183 acc=0.8624 macroF1=0.8074 kappa=0.8104 AUROC=0.9730 AUPRC=0.8601 meanConf=0.6995 ECE=0.1629
  SHHS2 : loss=0.7918 acc=0.8663 macroF1=0.7988 kappa=0.8155 AUROC=0.9726 AUPRC=0.8590 meanConf=0.6929 ECE=0.1734
  MESA  : loss=1.3769 acc=0.7812 macroF1=0.6484 kappa=0.6851 AUROC=0.9226 AUPRC=0.7285 meanConf=0.6505 ECE=0.1307
  MESA F1/class: {'W': 0.8895659762743169, 'N1': 0.398053187506521, 'N2': 0.7942514651436017, 'N3': 0.3561363718931487, 'REM': 0.8039592258827005}


                                                                                


Epoch 28 | train_loss=0.8785
  VAL   : loss=0.8535 acc=0.8533 macroF1=0.7990 kappa=0.7984 AUROC=0.9701 AUPRC=0.8486 meanConf=0.6991 ECE=0.1542
  F1/class: {'W': 0.9132907659121927, 'N1': 0.5193561897923772, 'N2': 0.8616577685064055, 'N3': 0.820744516527649, 'REM': 0.8797908234184695}
  TEST1 : loss=0.8153 acc=0.8628 macroF1=0.8086 kappa=0.8111 AUROC=0.9734 AUPRC=0.8616 meanConf=0.6991 ECE=0.1637
  SHHS2 : loss=0.7843 acc=0.8679 macroF1=0.8008 kappa=0.8178 AUROC=0.9734 AUPRC=0.8610 meanConf=0.6937 ECE=0.1742
  MESA  : loss=1.3423 acc=0.7885 macroF1=0.6596 kappa=0.6954 AUROC=0.9265 AUPRC=0.7372 meanConf=0.6494 ECE=0.1391
  MESA F1/class: {'W': 0.8970256294823458, 'N1': 0.4101891605439488, 'N2': 0.7978533902597873, 'N3': 0.3749288574769499, 'REM': 0.8180813622528917}


                                                                                


Epoch 29 | train_loss=0.8712
  VAL   : loss=0.8562 acc=0.8568 macroF1=0.8019 kappa=0.8025 AUROC=0.9702 AUPRC=0.8488 meanConf=0.6985 ECE=0.1583
  F1/class: {'W': 0.9147962818323049, 'N1': 0.5222609682912983, 'N2': 0.8653339068314851, 'N3': 0.8254334653621173, 'REM': 0.8819111537232341}
  TEST1 : loss=0.8165 acc=0.8655 macroF1=0.8108 kappa=0.8144 AUROC=0.9737 AUPRC=0.8622 meanConf=0.6987 ECE=0.1668
  SHHS2 : loss=0.7900 acc=0.8695 macroF1=0.8029 kappa=0.8194 AUROC=0.9739 AUPRC=0.8616 meanConf=0.6927 ECE=0.1767
  MESA  : loss=1.4269 acc=0.7713 macroF1=0.6292 kappa=0.6698 AUROC=0.9210 AUPRC=0.7162 meanConf=0.6540 ECE=0.1173
  MESA F1/class: {'W': 0.8768155814387883, 'N1': 0.4071434383372159, 'N2': 0.7838374053621394, 'N3': 0.2764576584914391, 'REM': 0.801973383872619}
  ? Saved BEST_VAL: BEST_VAL_macroF1.pt (val_macroF1=0.8019)
  ? Added to VAL Top-K: VALBEST_ep029_valF1_0.8019.pt


                                                                                


Epoch 30 | train_loss=0.8682
  VAL   : loss=0.8568 acc=0.8536 macroF1=0.7993 kappa=0.7985 AUROC=0.9702 AUPRC=0.8491 meanConf=0.7001 ECE=0.1535
  F1/class: {'W': 0.9123386305585208, 'N1': 0.5201574044330228, 'N2': 0.8621761463410054, 'N3': 0.8211059605629455, 'REM': 0.880825893406134}
  TEST1 : loss=0.8173 acc=0.8632 macroF1=0.8087 kappa=0.8115 AUROC=0.9736 AUPRC=0.8622 meanConf=0.7004 ECE=0.1628
  SHHS2 : loss=0.8010 acc=0.8645 macroF1=0.7981 kappa=0.8130 AUROC=0.9730 AUPRC=0.8602 meanConf=0.6942 ECE=0.1704
  MESA  : loss=1.3991 acc=0.7746 macroF1=0.6423 kappa=0.6755 AUROC=0.9220 AUPRC=0.7256 meanConf=0.6501 ECE=0.1245
  MESA F1/class: {'W': 0.8796484356341056, 'N1': 0.4026582591862362, 'N2': 0.787439690668735, 'N3': 0.33545113515623304, 'REM': 0.8061734972772122}


                                                                                


Epoch 31 | train_loss=0.8636
  VAL   : loss=0.8580 acc=0.8544 macroF1=0.7997 kappa=0.7995 AUROC=0.9700 AUPRC=0.8481 meanConf=0.6984 ECE=0.1561
  F1/class: {'W': 0.9136555325563218, 'N1': 0.5190575747828885, 'N2': 0.8627768127764611, 'N3': 0.8236321797119966, 'REM': 0.8792940279979843}
  TEST1 : loss=0.8156 acc=0.8654 macroF1=0.8107 kappa=0.8143 AUROC=0.9738 AUPRC=0.8621 meanConf=0.6988 ECE=0.1666
  SHHS2 : loss=0.8014 acc=0.8675 macroF1=0.8012 kappa=0.8167 AUROC=0.9731 AUPRC=0.8605 meanConf=0.6916 ECE=0.1759
  MESA  : loss=1.3953 acc=0.7733 macroF1=0.6382 kappa=0.6742 AUROC=0.9214 AUPRC=0.7154 meanConf=0.6461 ECE=0.1273
  MESA F1/class: {'W': 0.882502743958533, 'N1': 0.4056757427148053, 'N2': 0.7867357071516771, 'N3': 0.3155043796948509, 'REM': 0.800520412957319}


                                                                                


Epoch 32 | train_loss=0.8562
  VAL   : loss=0.8550 acc=0.8563 macroF1=0.8016 kappa=0.8019 AUROC=0.9701 AUPRC=0.8491 meanConf=0.7016 ECE=0.1547
  F1/class: {'W': 0.9131450743694762, 'N1': 0.5227945172392972, 'N2': 0.8657903281947537, 'N3': 0.8255928270871838, 'REM': 0.8808463661453543}
  TEST1 : loss=0.8154 acc=0.8656 macroF1=0.8107 kappa=0.8144 AUROC=0.9736 AUPRC=0.8628 meanConf=0.7015 ECE=0.1640
  SHHS2 : loss=0.7958 acc=0.8685 macroF1=0.8014 kappa=0.8180 AUROC=0.9729 AUPRC=0.8605 meanConf=0.6958 ECE=0.1726
  MESA  : loss=1.4175 acc=0.7754 macroF1=0.6429 kappa=0.6768 AUROC=0.9214 AUPRC=0.7214 meanConf=0.6581 ECE=0.1173
  MESA F1/class: {'W': 0.8786775860722529, 'N1': 0.40466203904288695, 'N2': 0.7912759492232982, 'N3': 0.33109626480412996, 'REM': 0.8087639641142462}


                                                                                


Epoch 33 | train_loss=0.8496
  VAL   : loss=0.8557 acc=0.8544 macroF1=0.8000 kappa=0.7997 AUROC=0.9698 AUPRC=0.8491 meanConf=0.7022 ECE=0.1522
  F1/class: {'W': 0.9123773425257187, 'N1': 0.5201784828953891, 'N2': 0.8639929372594783, 'N3': 0.82425782693542, 'REM': 0.8790583619421285}
  TEST1 : loss=0.8137 acc=0.8639 macroF1=0.8095 kappa=0.8124 AUROC=0.9736 AUPRC=0.8629 meanConf=0.7022 ECE=0.1617
  SHHS2 : loss=0.7865 acc=0.8690 macroF1=0.8015 kappa=0.8191 AUROC=0.9731 AUPRC=0.8600 meanConf=0.6978 ECE=0.1713
  MESA  : loss=1.4026 acc=0.7752 macroF1=0.6430 kappa=0.6767 AUROC=0.9197 AUPRC=0.7134 meanConf=0.6538 ECE=0.1214
  MESA F1/class: {'W': 0.8846283901444393, 'N1': 0.39455586338784415, 'N2': 0.7870548719153613, 'N3': 0.3331700052841428, 'REM': 0.8153486608284607}


                                                                                


Epoch 34 | train_loss=0.8420
  VAL   : loss=0.8575 acc=0.8539 macroF1=0.7999 kappa=0.7989 AUROC=0.9698 AUPRC=0.8486 meanConf=0.7024 ECE=0.1515
  F1/class: {'W': 0.9125381894021986, 'N1': 0.5219168568310348, 'N2': 0.8627161015457201, 'N3': 0.8223946971511428, 'REM': 0.8800364608353964}
  TEST1 : loss=0.8154 acc=0.8640 macroF1=0.8095 kappa=0.8125 AUROC=0.9736 AUPRC=0.8622 meanConf=0.7025 ECE=0.1615
  SHHS2 : loss=0.7938 acc=0.8669 macroF1=0.8002 kappa=0.8161 AUROC=0.9728 AUPRC=0.8596 meanConf=0.6978 ECE=0.1690
  MESA  : loss=1.4791 acc=0.7571 macroF1=0.6337 kappa=0.6521 AUROC=0.9150 AUPRC=0.7043 meanConf=0.6440 ECE=0.1131
  MESA F1/class: {'W': 0.8546955824309515, 'N1': 0.3873440350683858, 'N2': 0.7778039337014593, 'N3': 0.3454018234540182, 'REM': 0.8033275698725117}


                                                                                


Epoch 35 | train_loss=0.8402
  VAL   : loss=0.8563 acc=0.8520 macroF1=0.7982 kappa=0.7968 AUROC=0.9698 AUPRC=0.8491 meanConf=0.7016 ECE=0.1504
  F1/class: {'W': 0.9125238222536499, 'N1': 0.5195382235093426, 'N2': 0.8597564726871767, 'N3': 0.819664838551828, 'REM': 0.8797045525652274}
  TEST1 : loss=0.8150 acc=0.8616 macroF1=0.8078 kappa=0.8098 AUROC=0.9733 AUPRC=0.8626 meanConf=0.7015 ECE=0.1602
  SHHS2 : loss=0.7888 acc=0.8653 macroF1=0.7978 kappa=0.8144 AUROC=0.9729 AUPRC=0.8585 meanConf=0.6971 ECE=0.1682
  MESA  : loss=1.4271 acc=0.7646 macroF1=0.6339 kappa=0.6625 AUROC=0.9171 AUPRC=0.7062 meanConf=0.6450 ECE=0.1195
  MESA F1/class: {'W': 0.8721942870746592, 'N1': 0.39020873160472336, 'N2': 0.7805933173059852, 'N3': 0.32897586482742475, 'REM': 0.7974655692842949}


                                                                                


Epoch 36 | train_loss=0.8331
  VAL   : loss=0.8586 acc=0.8553 macroF1=0.8006 kappa=0.8008 AUROC=0.9698 AUPRC=0.8475 meanConf=0.7031 ECE=0.1522
  F1/class: {'W': 0.9138710740124331, 'N1': 0.5203177391026664, 'N2': 0.864138593717329, 'N3': 0.8243046150640693, 'REM': 0.8803278688524591}
  TEST1 : loss=0.8158 acc=0.8658 macroF1=0.8111 kappa=0.8149 AUROC=0.9734 AUPRC=0.8621 meanConf=0.7032 ECE=0.1626
  SHHS2 : loss=0.7810 acc=0.8708 macroF1=0.8035 kappa=0.8214 AUROC=0.9734 AUPRC=0.8606 meanConf=0.6991 ECE=0.1717
  MESA  : loss=1.4345 acc=0.7691 macroF1=0.6297 kappa=0.6683 AUROC=0.9153 AUPRC=0.6965 meanConf=0.6538 ECE=0.1152
  MESA F1/class: {'W': 0.8819118698339179, 'N1': 0.3942627736347604, 'N2': 0.7826534440303646, 'N3': 0.2889515555398609, 'REM': 0.8007771148316682}


                                                                                


Epoch 37 | train_loss=0.8296
  VAL   : loss=0.8585 acc=0.8546 macroF1=0.8002 kappa=0.8001 AUROC=0.9696 AUPRC=0.8469 meanConf=0.7052 ECE=0.1494
  F1/class: {'W': 0.9148158005443134, 'N1': 0.5211702707614029, 'N2': 0.8626023229306982, 'N3': 0.8221763643259853, 'REM': 0.8802503442848711}
  TEST1 : loss=0.8151 acc=0.8648 macroF1=0.8104 kappa=0.8137 AUROC=0.9733 AUPRC=0.8621 meanConf=0.7057 ECE=0.1591
  SHHS2 : loss=0.7859 acc=0.8692 macroF1=0.8019 kappa=0.8191 AUROC=0.9728 AUPRC=0.8603 meanConf=0.7025 ECE=0.1666
  MESA  : loss=1.4441 acc=0.7680 macroF1=0.6341 kappa=0.6669 AUROC=0.9146 AUPRC=0.6994 meanConf=0.6576 ECE=0.1104
  MESA F1/class: {'W': 0.8767190993155212, 'N1': 0.38940561279742564, 'N2': 0.7843447502916584, 'N3': 0.3153964541612227, 'REM': 0.8046285643454594}


                                                                                


Epoch 38 | train_loss=0.8243
  VAL   : loss=0.8594 acc=0.8538 macroF1=0.7996 kappa=0.7990 AUROC=0.9695 AUPRC=0.8471 meanConf=0.7056 ECE=0.1482
  F1/class: {'W': 0.9144481223408594, 'N1': 0.5204850340410103, 'N2': 0.8611620506776664, 'N3': 0.8220400459649101, 'REM': 0.8797438796578507}
  TEST1 : loss=0.8165 acc=0.8641 macroF1=0.8098 kappa=0.8129 AUROC=0.9731 AUPRC=0.8620 meanConf=0.7061 ECE=0.1580
  SHHS2 : loss=0.7890 acc=0.8676 macroF1=0.8005 kappa=0.8171 AUROC=0.9724 AUPRC=0.8590 meanConf=0.7025 ECE=0.1650
  MESA  : loss=1.4246 acc=0.7775 macroF1=0.6401 kappa=0.6800 AUROC=0.9176 AUPRC=0.7021 meanConf=0.6690 ECE=0.1085
  MESA F1/class: {'W': 0.8920700515305772, 'N1': 0.3974041609592196, 'N2': 0.7890439689396019, 'N3': 0.31111807911214995, 'REM': 0.8110217718155429}


                                                                                


Epoch 39 | train_loss=0.8199
  VAL   : loss=0.8600 acc=0.8551 macroF1=0.8006 kappa=0.8007 AUROC=0.9696 AUPRC=0.8471 meanConf=0.7053 ECE=0.1498
  F1/class: {'W': 0.9143165587823279, 'N1': 0.5193387771360287, 'N2': 0.8632714960852084, 'N3': 0.8257741713870921, 'REM': 0.880437800386222}
  TEST1 : loss=0.8167 acc=0.8655 macroF1=0.8107 kappa=0.8146 AUROC=0.9732 AUPRC=0.8616 meanConf=0.7057 ECE=0.1598
  SHHS2 : loss=0.7893 acc=0.8689 macroF1=0.8017 kappa=0.8187 AUROC=0.9723 AUPRC=0.8588 meanConf=0.7024 ECE=0.1665
  MESA  : loss=1.4236 acc=0.7685 macroF1=0.6329 kappa=0.6690 AUROC=0.9152 AUPRC=0.6949 meanConf=0.6596 ECE=0.1089
  MESA F1/class: {'W': 0.8854503464203234, 'N1': 0.39397409596129707, 'N2': 0.7840195306143909, 'N3': 0.30697378197643566, 'REM': 0.7942644613604278}


                                                                                


Epoch 40 | train_loss=0.8144
  VAL   : loss=0.8603 acc=0.8543 macroF1=0.7997 kappa=0.7995 AUROC=0.9692 AUPRC=0.8470 meanConf=0.7068 ECE=0.1475
  F1/class: {'W': 0.9134511173424624, 'N1': 0.517702097292154, 'N2': 0.8627991833032985, 'N3': 0.8252403285569921, 'REM': 0.8794703618167822}
  TEST1 : loss=0.8168 acc=0.8657 macroF1=0.8111 kappa=0.8149 AUROC=0.9728 AUPRC=0.8615 meanConf=0.7070 ECE=0.1587
  SHHS2 : loss=0.7930 acc=0.8678 macroF1=0.8002 kappa=0.8172 AUROC=0.9717 AUPRC=0.8569 meanConf=0.7040 ECE=0.1638
  MESA  : loss=1.3988 acc=0.7791 macroF1=0.6486 kappa=0.6836 AUROC=0.9179 AUPRC=0.7104 meanConf=0.6678 ECE=0.1114
  MESA F1/class: {'W': 0.8958440489661558, 'N1': 0.3891990598222052, 'N2': 0.7904344689934365, 'N3': 0.36991260467327314, 'REM': 0.7975294550430264}


                                                                                


Epoch 41 | train_loss=0.8055
  VAL   : loss=0.8637 acc=0.8570 macroF1=0.8015 kappa=0.8026 AUROC=0.9691 AUPRC=0.8465 meanConf=0.7099 ECE=0.1471
  F1/class: {'W': 0.9142886665403758, 'N1': 0.5184127431065344, 'N2': 0.8667551148662803, 'N3': 0.8280248039876774, 'REM': 0.8799133223755086}
  TEST1 : loss=0.8211 acc=0.8676 macroF1=0.8116 kappa=0.8169 AUROC=0.9727 AUPRC=0.8606 meanConf=0.7104 ECE=0.1573
  SHHS2 : loss=0.7949 acc=0.8711 macroF1=0.8030 kappa=0.8213 AUROC=0.9716 AUPRC=0.8565 meanConf=0.7066 ECE=0.1645
  MESA  : loss=1.4639 acc=0.7633 macroF1=0.6287 kappa=0.6613 AUROC=0.9113 AUPRC=0.6885 meanConf=0.6626 ECE=0.1007
  MESA F1/class: {'W': 0.8770511102371925, 'N1': 0.38348326780040665, 'N2': 0.7795285938410402, 'N3': 0.3106751657691322, 'REM': 0.7925926433928038}


                                                                                


Epoch 42 | train_loss=0.7986
  VAL   : loss=0.8635 acc=0.8566 macroF1=0.8016 kappa=0.8022 AUROC=0.9691 AUPRC=0.8469 meanConf=0.7089 ECE=0.1477
  F1/class: {'W': 0.9129112595542977, 'N1': 0.5193354683746997, 'N2': 0.8661584875020124, 'N3': 0.8288816603289503, 'REM': 0.8809109541117393}
  TEST1 : loss=0.8195 acc=0.8668 macroF1=0.8113 kappa=0.8160 AUROC=0.9728 AUPRC=0.8611 meanConf=0.7093 ECE=0.1576
  SHHS2 : loss=0.7959 acc=0.8703 macroF1=0.8019 kappa=0.8202 AUROC=0.9715 AUPRC=0.8560 meanConf=0.7066 ECE=0.1637
  MESA  : loss=1.4604 acc=0.7687 macroF1=0.6334 kappa=0.6685 AUROC=0.9128 AUPRC=0.6906 meanConf=0.6729 ECE=0.0959
  MESA F1/class: {'W': 0.8860039281656523, 'N1': 0.39464912051663664, 'N2': 0.7798451714874246, 'N3': 0.3024585449782757, 'REM': 0.8038458161383791}


                                                                                


Epoch 43 | train_loss=0.7894
  VAL   : loss=0.8669 acc=0.8564 macroF1=0.8008 kappa=0.8019 AUROC=0.9687 AUPRC=0.8452 meanConf=0.7097 ECE=0.1468
  F1/class: {'W': 0.913268019834257, 'N1': 0.5157785411839898, 'N2': 0.86598798903854, 'N3': 0.8283748028794448, 'REM': 0.8807001495729138}
  TEST1 : loss=0.8230 acc=0.8676 macroF1=0.8116 kappa=0.8168 AUROC=0.9724 AUPRC=0.8596 meanConf=0.7102 ECE=0.1573
  SHHS2 : loss=0.7994 acc=0.8705 macroF1=0.8020 kappa=0.8204 AUROC=0.9709 AUPRC=0.8545 meanConf=0.7076 ECE=0.1629
  MESA  : loss=1.4873 acc=0.7678 macroF1=0.6274 kappa=0.6668 AUROC=0.9100 AUPRC=0.6778 meanConf=0.6785 ECE=0.0893
  MESA F1/class: {'W': 0.8856587859232609, 'N1': 0.38847153177237437, 'N2': 0.7809815219272196, 'N3': 0.278640667689659, 'REM': 0.8032762984783094}


                                                                                


Epoch 44 | train_loss=0.7864
  VAL   : loss=0.8679 acc=0.8570 macroF1=0.8014 kappa=0.8026 AUROC=0.9686 AUPRC=0.8456 meanConf=0.7135 ECE=0.1435
  F1/class: {'W': 0.9135736644766261, 'N1': 0.5180412893292005, 'N2': 0.8666085724670773, 'N3': 0.8270657005020651, 'REM': 0.881616927441434}
  TEST1 : loss=0.8241 acc=0.8673 macroF1=0.8112 kappa=0.8164 AUROC=0.9723 AUPRC=0.8598 meanConf=0.7137 ECE=0.1536
  SHHS2 : loss=0.7995 acc=0.8706 macroF1=0.8015 kappa=0.8205 AUROC=0.9707 AUPRC=0.8542 meanConf=0.7117 ECE=0.1589
  MESA  : loss=1.5138 acc=0.7642 macroF1=0.6212 kappa=0.6619 AUROC=0.9064 AUPRC=0.6716 meanConf=0.6807 ECE=0.0835
  MESA F1/class: {'W': 0.8840860200126082, 'N1': 0.38231866858110913, 'N2': 0.7775108618861565, 'N3': 0.2676878566713638, 'REM': 0.7944912804493954}


                                                                                


Epoch 45 | train_loss=0.7820
  VAL   : loss=0.8705 acc=0.8549 macroF1=0.7997 kappa=0.8001 AUROC=0.9685 AUPRC=0.8443 meanConf=0.7130 ECE=0.1419
  F1/class: {'W': 0.9131344260037187, 'N1': 0.5151141516188052, 'N2': 0.8643058807568423, 'N3': 0.8262639663896152, 'REM': 0.8794717247616627}
  TEST1 : loss=0.8264 acc=0.8655 macroF1=0.8098 kappa=0.8143 AUROC=0.9723 AUPRC=0.8589 meanConf=0.7131 ECE=0.1524
  SHHS2 : loss=0.8046 acc=0.8681 macroF1=0.7992 kappa=0.8173 AUROC=0.9707 AUPRC=0.8536 meanConf=0.7104 ECE=0.1577
  MESA  : loss=1.5320 acc=0.7609 macroF1=0.6249 kappa=0.6584 AUROC=0.9046 AUPRC=0.6712 meanConf=0.6804 ECE=0.0806
  MESA F1/class: {'W': 0.8780594124271174, 'N1': 0.3827318972105464, 'N2': 0.7773572144872781, 'N3': 0.28759407644693, 'REM': 0.7986190185241758}


                                                                                


Epoch 46 | train_loss=0.7754
  VAL   : loss=0.8751 acc=0.8562 macroF1=0.8001 kappa=0.8015 AUROC=0.9682 AUPRC=0.8436 meanConf=0.7141 ECE=0.1421
  F1/class: {'W': 0.9136829136829137, 'N1': 0.5126742585907741, 'N2': 0.8661305427555712, 'N3': 0.8282278824455871, 'REM': 0.8795420112758613}
  TEST1 : loss=0.8316 acc=0.8666 macroF1=0.8100 kappa=0.8154 AUROC=0.9720 AUPRC=0.8578 meanConf=0.7142 ECE=0.1524
  SHHS2 : loss=0.8024 acc=0.8711 macroF1=0.8017 kappa=0.8211 AUROC=0.9707 AUPRC=0.8528 meanConf=0.7124 ECE=0.1587
  MESA  : loss=1.5601 acc=0.7619 macroF1=0.6185 kappa=0.6588 AUROC=0.9023 AUPRC=0.6641 meanConf=0.6874 ECE=0.0745
  MESA F1/class: {'W': 0.8793450110775021, 'N1': 0.37740074228861925, 'N2': 0.7805537498299259, 'N3': 0.2547466145469775, 'REM': 0.8003178244809042}


                                                                                


Epoch 47 | train_loss=0.7701
  VAL   : loss=0.8736 acc=0.8548 macroF1=0.7990 kappa=0.8000 AUROC=0.9681 AUPRC=0.8435 meanConf=0.7133 ECE=0.1417
  F1/class: {'W': 0.913308347178754, 'N1': 0.5114737073355835, 'N2': 0.8640145325836068, 'N3': 0.8266421573420418, 'REM': 0.8794895416633176}
  TEST1 : loss=0.8311 acc=0.8652 macroF1=0.8089 kappa=0.8137 AUROC=0.9717 AUPRC=0.8574 meanConf=0.7135 ECE=0.1517
  SHHS2 : loss=0.8047 acc=0.8686 macroF1=0.7992 kappa=0.8180 AUROC=0.9700 AUPRC=0.8508 meanConf=0.7117 ECE=0.1569
  MESA  : loss=1.5459 acc=0.7620 macroF1=0.6196 kappa=0.6594 AUROC=0.9040 AUPRC=0.6667 meanConf=0.6848 ECE=0.0773
  MESA F1/class: {'W': 0.8796865036049296, 'N1': 0.37768913962980016, 'N2': 0.7814192143936353, 'N3': 0.26129314510336543, 'REM': 0.7978030359758083}


                                                                                


Epoch 48 | train_loss=0.7643
  VAL   : loss=0.8801 acc=0.8556 macroF1=0.7995 kappa=0.8008 AUROC=0.9676 AUPRC=0.8413 meanConf=0.7149 ECE=0.1408
  F1/class: {'W': 0.9135645068132714, 'N1': 0.5114268066707843, 'N2': 0.8651511333837266, 'N3': 0.8271698511141722, 'REM': 0.8802275112234317}
  TEST1 : loss=0.8347 acc=0.8660 macroF1=0.8093 kappa=0.8146 AUROC=0.9716 AUPRC=0.8565 meanConf=0.7148 ECE=0.1512
  SHHS2 : loss=0.8080 acc=0.8698 macroF1=0.8000 kappa=0.8194 AUROC=0.9696 AUPRC=0.8501 meanConf=0.7134 ECE=0.1564
  MESA  : loss=1.5930 acc=0.7565 macroF1=0.6115 kappa=0.6517 AUROC=0.8984 AUPRC=0.6487 meanConf=0.6870 ECE=0.0694
  MESA F1/class: {'W': 0.8780706045773585, 'N1': 0.37516643044281234, 'N2': 0.7753677110505479, 'N3': 0.2272719611676667, 'REM': 0.8014887376364002}


                                                                                


Epoch 49 | train_loss=0.7588
  VAL   : loss=0.8820 acc=0.8553 macroF1=0.7990 kappa=0.8003 AUROC=0.9675 AUPRC=0.8412 meanConf=0.7156 ECE=0.1397
  F1/class: {'W': 0.9126463619819387, 'N1': 0.5102580940441367, 'N2': 0.865589884045358, 'N3': 0.8266779220607924, 'REM': 0.8796983592685461}
  TEST1 : loss=0.8358 acc=0.8657 macroF1=0.8088 kappa=0.8142 AUROC=0.9717 AUPRC=0.8570 meanConf=0.7154 ECE=0.1504
  SHHS2 : loss=0.8178 acc=0.8682 macroF1=0.7982 kappa=0.8172 AUROC=0.9690 AUPRC=0.8487 meanConf=0.7139 ECE=0.1543
  MESA  : loss=1.6371 acc=0.7503 macroF1=0.6108 kappa=0.6436 AUROC=0.8946 AUPRC=0.6465 meanConf=0.6876 ECE=0.0628
  MESA F1/class: {'W': 0.8675291938885745, 'N1': 0.37192055106659544, 'N2': 0.7737991610101216, 'N3': 0.23557165791983797, 'REM': 0.8053557968141011}


                                                                                


Epoch 50 | train_loss=0.7544
  VAL   : loss=0.8844 acc=0.8553 macroF1=0.7987 kappa=0.8003 AUROC=0.9674 AUPRC=0.8412 meanConf=0.7174 ECE=0.1379
  F1/class: {'W': 0.9123589710528012, 'N1': 0.508675799086758, 'N2': 0.8658425334648967, 'N3': 0.827181905107323, 'REM': 0.879621230008326}
  TEST1 : loss=0.8380 acc=0.8655 macroF1=0.8083 kappa=0.8139 AUROC=0.9714 AUPRC=0.8556 meanConf=0.7171 ECE=0.1484
  SHHS2 : loss=0.8192 acc=0.8681 macroF1=0.7981 kappa=0.8170 AUROC=0.9688 AUPRC=0.8481 meanConf=0.7156 ECE=0.1525
  MESA  : loss=1.6205 acc=0.7546 macroF1=0.6127 kappa=0.6490 AUROC=0.8984 AUPRC=0.6551 meanConf=0.6884 ECE=0.0662
  MESA F1/class: {'W': 0.8722530968810555, 'N1': 0.37191213104853926, 'N2': 0.7742811661509993, 'N3': 0.24740533696666483, 'REM': 0.7977499118592695}


                                                                                


Epoch 51 | train_loss=0.7479
  VAL   : loss=0.8871 acc=0.8554 macroF1=0.7987 kappa=0.8003 AUROC=0.9672 AUPRC=0.8409 meanConf=0.7184 ECE=0.1370
  F1/class: {'W': 0.9119350255201694, 'N1': 0.5086719030341693, 'N2': 0.8659974135444675, 'N3': 0.8270344135568156, 'REM': 0.8799736315128297}
  TEST1 : loss=0.8401 acc=0.8657 macroF1=0.8081 kappa=0.8140 AUROC=0.9712 AUPRC=0.8553 meanConf=0.7182 ECE=0.1475
  SHHS2 : loss=0.8194 acc=0.8684 macroF1=0.7984 kappa=0.8173 AUROC=0.9688 AUPRC=0.8481 meanConf=0.7163 ECE=0.1521
  MESA  : loss=1.6016 acc=0.7637 macroF1=0.6182 kappa=0.6603 AUROC=0.9019 AUPRC=0.6626 meanConf=0.6935 ECE=0.0702
  MESA F1/class: {'W': 0.8829846681282832, 'N1': 0.37244233416163985, 'N2': 0.7784511636362373, 'N3': 0.2588374306626842, 'REM': 0.7982962962962963}


                                                                                


Epoch 52 | train_loss=0.7480
  VAL   : loss=0.8877 acc=0.8542 macroF1=0.7976 kappa=0.7989 AUROC=0.9668 AUPRC=0.8397 meanConf=0.7179 ECE=0.1363
  F1/class: {'W': 0.9115122156516662, 'N1': 0.5061936378951541, 'N2': 0.864346377085565, 'N3': 0.8263915742009058, 'REM': 0.8794810939771612}
  TEST1 : loss=0.8402 acc=0.8648 macroF1=0.8075 kappa=0.8131 AUROC=0.9710 AUPRC=0.8543 meanConf=0.7178 ECE=0.1471
  SHHS2 : loss=0.8222 acc=0.8671 macroF1=0.7972 kappa=0.8157 AUROC=0.9685 AUPRC=0.8469 meanConf=0.7158 ECE=0.1513
  MESA  : loss=1.6115 acc=0.7615 macroF1=0.6183 kappa=0.6576 AUROC=0.9009 AUPRC=0.6618 meanConf=0.6932 ECE=0.0683
  MESA F1/class: {'W': 0.8791134254792102, 'N1': 0.3718202500392514, 'N2': 0.777598164166584, 'N3': 0.2620407342769163, 'REM': 0.8008094049178764}


                                                                                


Epoch 53 | train_loss=0.7362
  VAL   : loss=0.8912 acc=0.8546 macroF1=0.7980 kappa=0.7993 AUROC=0.9665 AUPRC=0.8390 meanConf=0.7194 ECE=0.1352
  F1/class: {'W': 0.91127425725571, 'N1': 0.5083118286022925, 'N2': 0.8648459791533221, 'N3': 0.8258118660625889, 'REM': 0.8798365852291482}
  TEST1 : loss=0.8434 acc=0.8648 macroF1=0.8074 kappa=0.8129 AUROC=0.9706 AUPRC=0.8542 meanConf=0.7191 ECE=0.1457
  SHHS2 : loss=0.8214 acc=0.8682 macroF1=0.7978 kappa=0.8171 AUROC=0.9683 AUPRC=0.8470 meanConf=0.7181 ECE=0.1500
  MESA  : loss=1.6407 acc=0.7587 macroF1=0.6156 kappa=0.6535 AUROC=0.8978 AUPRC=0.6591 meanConf=0.6953 ECE=0.0633
  MESA F1/class: {'W': 0.8756451208678944, 'N1': 0.3650170991752163, 'N2': 0.7762250552575676, 'N3': 0.26519932138272756, 'REM': 0.795991707466448}


                                                                                


Epoch 54 | train_loss=0.7404
  VAL   : loss=0.8944 acc=0.8542 macroF1=0.7976 kappa=0.7988 AUROC=0.9663 AUPRC=0.8383 meanConf=0.7196 ECE=0.1346
  F1/class: {'W': 0.9114692268129607, 'N1': 0.5070608737455682, 'N2': 0.8643049561432127, 'N3': 0.8261595059781893, 'REM': 0.878778641151183}
  TEST1 : loss=0.8464 acc=0.8644 macroF1=0.8066 kappa=0.8124 AUROC=0.9704 AUPRC=0.8532 meanConf=0.7194 ECE=0.1450
  SHHS2 : loss=0.8267 acc=0.8674 macroF1=0.7969 kappa=0.8161 AUROC=0.9678 AUPRC=0.8452 meanConf=0.7182 ECE=0.1492
  MESA  : loss=1.6397 acc=0.7573 macroF1=0.6154 kappa=0.6522 AUROC=0.8965 AUPRC=0.6551 meanConf=0.6954 ECE=0.0619
  MESA F1/class: {'W': 0.8763210967017258, 'N1': 0.36596910347911965, 'N2': 0.7752956075482312, 'N3': 0.2646414399031236, 'REM': 0.794972436759436}


                                                                                


Epoch 55 | train_loss=0.7388
  VAL   : loss=0.8970 acc=0.8537 macroF1=0.7969 kappa=0.7981 AUROC=0.9661 AUPRC=0.8373 meanConf=0.7195 ECE=0.1342
  F1/class: {'W': 0.9112729765628348, 'N1': 0.5052577073209591, 'N2': 0.8637195235783893, 'N3': 0.8252146372107633, 'REM': 0.8790746975932162}
  TEST1 : loss=0.8477 acc=0.8639 macroF1=0.8060 kappa=0.8117 AUROC=0.9705 AUPRC=0.8523 meanConf=0.7191 ECE=0.1448
  SHHS2 : loss=0.8283 acc=0.8669 macroF1=0.7964 kappa=0.8154 AUROC=0.9677 AUPRC=0.8446 meanConf=0.7183 ECE=0.1487
  MESA  : loss=1.6373 acc=0.7585 macroF1=0.6180 kappa=0.6539 AUROC=0.8966 AUPRC=0.6533 meanConf=0.6965 ECE=0.0619
  MESA F1/class: {'W': 0.8796449374175523, 'N1': 0.36907885147937614, 'N2': 0.7740414209235598, 'N3': 0.26462967849094465, 'REM': 0.8025423289549176}


                                                                                


Epoch 56 | train_loss=0.7345
  VAL   : loss=0.8967 acc=0.8544 macroF1=0.7974 kappa=0.7990 AUROC=0.9661 AUPRC=0.8373 meanConf=0.7199 ECE=0.1344
  F1/class: {'W': 0.9113338737973603, 'N1': 0.5052182068525134, 'N2': 0.8647356835889731, 'N3': 0.8263716918792161, 'REM': 0.8794089430995601}
  TEST1 : loss=0.8473 acc=0.8645 macroF1=0.8066 kappa=0.8125 AUROC=0.9704 AUPRC=0.8524 meanConf=0.7196 ECE=0.1449
  SHHS2 : loss=0.8261 acc=0.8676 macroF1=0.7971 kappa=0.8163 AUROC=0.9678 AUPRC=0.8450 meanConf=0.7184 ECE=0.1492
  MESA  : loss=1.6370 acc=0.7604 macroF1=0.6179 kappa=0.6561 AUROC=0.8973 AUPRC=0.6543 meanConf=0.6980 ECE=0.0624
  MESA F1/class: {'W': 0.8801862803956113, 'N1': 0.36964501276258516, 'N2': 0.775996612285542, 'N3': 0.2578585894162191, 'REM': 0.8059122093898531}


                                                                                


Epoch 57 | train_loss=0.7317
  VAL   : loss=0.8965 acc=0.8544 macroF1=0.7975 kappa=0.7990 AUROC=0.9661 AUPRC=0.8374 meanConf=0.7200 ECE=0.1344
  F1/class: {'W': 0.9113929477944708, 'N1': 0.5059710747934628, 'N2': 0.8647055259613456, 'N3': 0.8264935405600368, 'REM': 0.8788160149556005}
  TEST1 : loss=0.8474 acc=0.8643 macroF1=0.8064 kappa=0.8122 AUROC=0.9703 AUPRC=0.8523 meanConf=0.7197 ECE=0.1446
  SHHS2 : loss=0.8274 acc=0.8674 macroF1=0.7969 kappa=0.8160 AUROC=0.9676 AUPRC=0.8446 meanConf=0.7186 ECE=0.1488
  MESA  : loss=1.6385 acc=0.7594 macroF1=0.6188 kappa=0.6552 AUROC=0.8974 AUPRC=0.6547 meanConf=0.6985 ECE=0.0609
  MESA F1/class: {'W': 0.8789030321083758, 'N1': 0.37231641511385277, 'N2': 0.7756851238358365, 'N3': 0.26282667408918325, 'REM': 0.8040424204616344}


                                                                                


Epoch 58 | train_loss=0.7312
  VAL   : loss=0.8971 acc=0.8543 macroF1=0.7975 kappa=0.7990 AUROC=0.9660 AUPRC=0.8374 meanConf=0.7205 ECE=0.1339
  F1/class: {'W': 0.9114835303746428, 'N1': 0.5070025188916877, 'N2': 0.8644828124599471, 'N3': 0.8256660658692446, 'REM': 0.8788082520947721}
  TEST1 : loss=0.8477 acc=0.8644 macroF1=0.8065 kappa=0.8124 AUROC=0.9703 AUPRC=0.8525 meanConf=0.7201 ECE=0.1443
  SHHS2 : loss=0.8288 acc=0.8673 macroF1=0.7968 kappa=0.8159 AUROC=0.9675 AUPRC=0.8445 meanConf=0.7191 ECE=0.1481
  MESA  : loss=1.6617 acc=0.7556 macroF1=0.6136 kappa=0.6500 AUROC=0.8948 AUPRC=0.6494 meanConf=0.6985 ECE=0.0571
  MESA F1/class: {'W': 0.874987230825808, 'N1': 0.366399392380676, 'N2': 0.7740154561004935, 'N3': 0.2549783010890235, 'REM': 0.7976458546571136}


                                                                                


Epoch 59 | train_loss=0.7309
  VAL   : loss=0.8973 acc=0.8538 macroF1=0.7970 kappa=0.7984 AUROC=0.9660 AUPRC=0.8371 meanConf=0.7199 ECE=0.1339
  F1/class: {'W': 0.9113124927626979, 'N1': 0.5052021967049426, 'N2': 0.8638882534465614, 'N3': 0.8258085614310475, 'REM': 0.8787680902538079}
  TEST1 : loss=0.8478 acc=0.8640 macroF1=0.8063 kappa=0.8119 AUROC=0.9702 AUPRC=0.8521 meanConf=0.7195 ECE=0.1445
  SHHS2 : loss=0.8282 acc=0.8668 macroF1=0.7964 kappa=0.8153 AUROC=0.9676 AUPRC=0.8444 meanConf=0.7184 ECE=0.1484
  MESA  : loss=1.6545 acc=0.7562 macroF1=0.6152 kappa=0.6509 AUROC=0.8955 AUPRC=0.6514 meanConf=0.6982 ECE=0.0580
  MESA F1/class: {'W': 0.8754429959736342, 'N1': 0.36845183953350913, 'N2': 0.7737054114426044, 'N3': 0.25983416682903593, 'REM': 0.7986418076305717}


                                                                                


Epoch 60 | train_loss=0.7269
  VAL   : loss=0.8973 acc=0.8542 macroF1=0.7974 kappa=0.7988 AUROC=0.9661 AUPRC=0.8373 meanConf=0.7203 ECE=0.1339
  F1/class: {'W': 0.9114018435272994, 'N1': 0.50579814459373, 'N2': 0.8643699493124304, 'N3': 0.8260416051642361, 'REM': 0.879179170454121}
  TEST1 : loss=0.8480 acc=0.8642 macroF1=0.8063 kappa=0.8121 AUROC=0.9703 AUPRC=0.8523 meanConf=0.7200 ECE=0.1442
  SHHS2 : loss=0.8272 acc=0.8675 macroF1=0.7969 kappa=0.8161 AUROC=0.9676 AUPRC=0.8446 meanConf=0.7189 ECE=0.1486
  MESA  : loss=1.6603 acc=0.7564 macroF1=0.6138 kappa=0.6510 AUROC=0.8951 AUPRC=0.6499 meanConf=0.6988 ECE=0.0576
  MESA F1/class: {'W': 0.8754923565608126, 'N1': 0.36817944384698115, 'N2': 0.7744424761581208, 'N3': 0.252421341822384, 'REM': 0.7983495556495981}

Training finished
BEST VAL macroF1 : 0.8019 | /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/checkpoints_hier_rope_seq_v5_1/BEST_VAL_macroF1.pt
BEST MESA macroF1: 0.7011 | /data2/Akbar1/sleep_stages_Dibatic/shhs_s

In [1]:
# =========================
# CELL 0: Imports (safe)
# =========================
import os
import re
import json
from pathlib import Path

import numpy as np
import torch


In [2]:
# ============================================
# CELL 1: Paths (EDIT ONLY THIS CELL IF NEEDED)
# ============================================
ROOT = Path("/data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/")

# Your training log shows checkpoints in:
# /data2/Akbar1/.../checkpoints_hier_rope_seq_v5_1/
# so default to that (change if needed).
CKPT_DIR = ROOT / "checkpoints_hier_rope_seq_v5_1"

# If you want to force a specific file, set BEST_CKPT explicitly.
# Otherwise it will auto-pick BEST_VAL_macroF1.pt if exists.
BEST_CKPT = CKPT_DIR / "BEST_VAL_macroF1.pt"

assert CKPT_DIR.exists(), f"Checkpoint folder not found: {CKPT_DIR}"

if not BEST_CKPT.exists():
    # fallback: search common best names
    candidates = [
        CKPT_DIR / "BEST_VAL_macroF1.pt",
        CKPT_DIR / "V5_best_val_macroF1.pt",
        CKPT_DIR / "best.pt",
        CKPT_DIR / "checkpoint_best.pt",
    ]
    found = None
    for c in candidates:
        if c.exists():
            found = c
            break
    if found is None:
        # last fallback: pick newest .pt
        pts = sorted(CKPT_DIR.glob("*.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
        assert len(pts) > 0, f"No .pt checkpoints found in: {CKPT_DIR}"
        found = pts[0]
    BEST_CKPT = found

print("Using checkpoint:", BEST_CKPT)


Using checkpoint: /data2/Akbar1/sleep_stages_Dibatic/shhs_sleepstaging_planA/checkpoints_hier_rope_seq_v5_1/BEST_VAL_macroF1.pt


In [3]:
# ============================================
# CELL 2: Load checkpoint (robust)
# ============================================
ckpt = torch.load(BEST_CKPT, map_location="cpu")

print("Checkpoint keys:", sorted(list(ckpt.keys()))[:40], "...")
print("Has model_state:", "model_state" in ckpt)
print("Has state_dict:", "state_dict" in ckpt)
print("Has model object:", "model" in ckpt)
print("Has ema_shadow:", "ema_shadow" in ckpt)
print("use_ema flag:", ckpt.get("use_ema", False))
print("arch/model_class:", ckpt.get("arch", None), ckpt.get("model_class", None))


Checkpoint keys: ['Tmat', 'best_val_macroF1', 'class_weights', 'ema_decay', 'ema_shadow', 'epoch', 'ext_metrics', 'mesa_metrics', 'model_state', 'optimizer_state', 'test_metrics', 'train_loss', 'use_ema', 'use_learned_smoothing', 'use_viterbi', 'v5_aux_dur', 'v5_cost_matrix', 'v5_soft_boundary', 'val_metrics'] ...
Has model_state: True
Has state_dict: False
Has model object: False
Has ema_shadow: True
use_ema flag: True
arch/model_class: None None


In [4]:
# ============================================================
# CELL 3: Build / recover model WITHOUT hardcoding class name
# ============================================================

def _find_candidate_model_classes(globs: dict):
    """
    Heuristic: find classes that likely represent your model.
    Priority: names containing V5, Sleep, Transformer, Hier, etc.
    """
    candidates = []
    for name, obj in globs.items():
        if isinstance(obj, type):  # is a class
            n = name.lower()
            if any(k in n for k in ["sleep", "transformer", "hier", "stage", "v5"]):
                candidates.append(name)

    def score(nm: str):
        n = nm.lower()
        s = 0
        if "v5" in n: s += 10
        if "hier" in n: s += 6
        if "sleep" in n: s += 6
        if "transformer" in n: s += 6
        if "net" in n: s += 2
        return s

    candidates = sorted(set(candidates), key=lambda x: score(x), reverse=True)
    return candidates

def _build_model_from_notebook_or_ckpt(ckpt: dict):
    """
    Build model using (in order):
      A) ckpt["model"] object
      B) build_model()/make_model()/get_model() if exists
      C) ckpt["model_class"]/ckpt["arch"] class name
      D) autodetect from globals
    """
    # A) checkpoint contains full model object
    if "model" in ckpt and ckpt["model"] is not None:
        m = ckpt["model"]
        if isinstance(m, torch.nn.Module):
            print("[Model] Using ckpt['model'] object directly.")
            return m
        else:
            print("[Model] ckpt['model'] exists but is not a torch.nn.Module. Ignoring.")

    # B) common builder functions in notebook
    for fn_name in ["build_model", "make_model", "get_model"]:
        fn = globals().get(fn_name, None)
        if callable(fn):
            try:
                print(f"[Model] Using notebook builder: {fn_name}()")
                return fn()
            except TypeError:
                # some builders require args; fall through
                print(f"[Model] Found {fn_name} but it needs args. Will try other methods.")
            except Exception as e:
                print(f"[Model] {fn_name}() failed: {e}. Will try other methods.")

    # C) checkpoint tells class name
    for key in ["model_class", "arch"]:
        cls_name = ckpt.get(key, None)
        if isinstance(cls_name, str) and cls_name in globals() and isinstance(globals()[cls_name], type):
            cls = globals()[cls_name]
            print(f"[Model] Using class from ckpt['{key}']:", cls_name)
            # Try instantiate with common patterns
            # NOTE: adjust here if your constructor is very custom
            try:
                return cls(num_classes=NUM_CLASSES).to(device)
            except Exception:
                try:
                    return cls(NUM_CLASSES).to(device)
                except Exception:
                    # last attempt: no args
                    return cls().to(device)

    # D) auto-detect likely class names in notebook globals
    candidates = _find_candidate_model_classes(globals())
    if len(candidates) == 0:
        raise NameError(
            "No candidate model class found in notebook globals.\n"
            "Make sure you executed the cell that defines your model class (V5) before running evaluation."
        )

    print("[Model] Auto-detected candidate classes (top 10):", candidates[:10])
    picked = candidates[0]
    cls = globals()[picked]
    print("[Model] Picking:", picked)

    # Try instantiate with your earlier hyperparams (edit ONLY if needed)
    # These are safe tries; if fails, it will attempt simpler calls.
    try:
        return cls(num_classes=NUM_CLASSES, d_model=384, depth=10, n_heads=8).to(device)
    except Exception:
        try:
            return cls(num_classes=NUM_CLASSES).to(device)
        except Exception:
            try:
                return cls(NUM_CLASSES).to(device)
            except Exception:
                return cls().to(device)

# ---- build model now ----
assert "device" in globals(), "device not found. Make sure you defined: device = torch.device('cuda'...)"
assert "NUM_CLASSES" in globals(), "NUM_CLASSES not found. Make sure NUM_CLASSES and LABELS are defined."

model = _build_model_from_notebook_or_ckpt(ckpt)
model = model.to(device)

print("Model type:", type(model).__name__)


AssertionError: device not found. Make sure you defined: device = torch.device('cuda'...)

In [5]:
# ============================================
# CELL 4: Load weights + apply EMA if present
# ============================================
def _get_state_dict(ckpt: dict):
    if "model_state" in ckpt and isinstance(ckpt["model_state"], dict):
        return ckpt["model_state"]
    if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
        return ckpt["state_dict"]
    # sometimes saved as "model"->state_dict only, but we already tried model object above
    raise KeyError("No 'model_state' or 'state_dict' found in checkpoint.")

sd = _get_state_dict(ckpt)

missing, unexpected = model.load_state_dict(sd, strict=False)
print("Loaded weights.")
print("  missing keys   :", len(missing))
print("  unexpected keys:", len(unexpected))

# Apply EMA shadow weights if present
use_ema = bool(ckpt.get("use_ema", False))
ema_shadow = ckpt.get("ema_shadow", None)

if use_ema and isinstance(ema_shadow, dict) and len(ema_shadow) > 0:
    with torch.no_grad():
        curr = model.state_dict()
        applied = 0
        for k, v in ema_shadow.items():
            if k in curr:
                curr[k].copy_(v.to(curr[k].device).to(curr[k].dtype))
                applied += 1
        model.load_state_dict(curr, strict=False)
    print(f"Applied EMA shadow weights: {applied} tensors")
else:
    print("EMA weights not found / not used.")

model.eval()


NameError: name 'model' is not defined

In [24]:
# ============================================================
# CELL 5: Evaluate (uses your existing eval_sequence function)
# ============================================================
assert "eval_sequence" in globals(), (
    "eval_sequence() not found in globals.\n"
    "Run the cell where you defined eval_sequence before this evaluation cell."
)

def _print_metrics(tag, m):
    print(f"\n===== {tag} =====")
    for k in ["loss","acc","macro_f1","kappa","AUROC","AUPRC","meanConf","ECE"]:
        if k in m:
            print(f"{k:9s}: {m[k]:.4f}")
    if "bad_batches" in m:
        print("bad_batches:", m["bad_batches"])
    if "f1_per_class" in m:
        print("F1/class :", m["f1_per_class"])
    if "cm" in m:
        labs = [LABELS[i] for i in range(NUM_CLASSES)] if "LABELS" in globals() else list(range(NUM_CLASSES))
        print("Confusion Matrix (rows=true, cols=pred) labels:", labs)
        print(m["cm"])

# Required loaders must exist
for nm in ["val_seq_loader", "test_seq_loader", "ext_seq_loader"]:
    assert nm in globals(), f"{nm} not found. Please create it before running evaluation."

val_m  = eval_sequence(model, val_seq_loader,  desc="VAL")
test_m = eval_sequence(model, test_seq_loader, desc="SHHS1 TEST")
ext_m  = eval_sequence(model, ext_seq_loader,  desc="SHHS2 EXT")

_print_metrics("VAL", val_m)
_print_metrics("SHHS1 TEST", test_m)
_print_metrics("SHHS2 EXT", ext_m)

# Optional MESA
if "mesa_seq_loader" in globals():
    mesa_m = eval_sequence(model, mesa_seq_loader, desc="MESA EXT")
    _print_metrics("MESA EXT", mesa_m)
else:
    print("\n[MESA] mesa_seq_loader not found. (Skip)")


                                                                                


===== VAL =====
loss     : 0.8704
acc      : 0.8593
macro_f1 : 0.8023
kappa    : 0.8048
AUROC    : 0.9705
AUPRC    : 0.8486
meanConf : 0.6920
ECE      : 0.1673
bad_batches: 0
F1/class : {'W': 0.9133298332330255, 'N1': 0.5192242833052276, 'N2': 0.8697101058497362, 'N3': 0.8263279694311394, 'REM': 0.8830118850982974}
Confusion Matrix (rows=true, cols=pred) labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[106987   5595   4747    501   1808]
 [  2458  12316   3538      8   1715]
 [  4002   7452 188691  14192   8814]
 [   238     14  10567  61416    147]
 [   956   2028   3223    149  71101]]

===== SHHS1 TEST =====
loss     : 0.8333
acc      : 0.8686
macro_f1 : 0.8119
kappa    : 0.8175
AUROC    : 0.9738
AUPRC    : 0.8619
meanConf : 0.6915
ECE      : 0.1771
bad_batches: 0
F1/class : {'W': 0.9209574694378058, 'N1': 0.5344476681245398, 'N2': 0.8784149801573615, 'N3': 0.8341778076226524, 'REM': 0.8915967122465064}
Confusion Matrix (rows=true, cols=pred) labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[110403

                                                                                


===== MESA EXT =====
loss     : 1.4322
acc      : 0.7700
macro_f1 : 0.6181
kappa    : 0.6656
AUROC    : 0.9228
AUPRC    : 0.7178
meanConf : 0.6281
ECE      : 0.1419
bad_batches: 0
F1/class : {'W': 0.8713183480964034, 'N1': 0.38780251316123404, 'N2': 0.7856198036341492, 'N3': 0.24597317049716316, 'REM': 0.7999503681321465}
Confusion Matrix (rows=true, cols=pred) labels: ['W', 'N1', 'N2', 'N3', 'REM']
[[685027  22805  80228    885  41018]
 [ 24112  59594  84645     18  18821]
 [ 24274  33508 708261   4150  27705]
 [  1500     55 117830  20463    929]
 [  7516   4190  14201     91 228871]]
