# Single-Head Self-Attention Encoder (Pre-LN)

- Build a single-head self-attention *encoder block* (Pre-LN) <br>
- Add sinusoidal positional encoding (absolute) <br>
- Add a position-wise FFN <br>
- Pool (CLS or masked-mean) -> linear head for tiny classification

In [1]:
import math, random
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
# setup
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
D_MODEL = 32
P_DROP = 0.1

In [3]:
# vocab helpers
def build_vocab(sentences: List[str], min_freq: int = 1) -> Dict[str, int]:
    freq = {}
    for s in sentences:
        for w in s.split():
            freq[w] = freq.get(w, 0) + 1
    words = [w for w, c in sorted(freq.items(), key=lambda x: (-x[1], x[0])) if c >= min_freq]
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1}
    for w in words:
        vocab[w] = len(vocab)
    return vocab

def tokenize(s: str) -> List[str]: return s.strip().split()
def numericalize(tokens: List[str], vocab: Dict[str,int]) -> List[int]:
    return [vocab.get(t, vocab[UNK_TOKEN]) for t in tokens]

def pad_batch(batch_ids: List[List[int]], pad_id: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
    T = max(len(x) for x in batch_ids)
    padded, mask = [], []
    for ids in batch_ids:
        pad_len = T - len(ids)
        padded.append(ids + [pad_id]*pad_len)
        mask.append([1]*len(ids) + [0]*pad_len)
    return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)


In [4]:
# Positional encoding (sinusoidal)
class SinusoidalPE(nn.Module):
    def __init__(self, d_model: int, max_len: int = 256):
        super().__init__()
        pe = torch.zeros(max_len, d_model)               # [Tmax, D]
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [Tmax, 1]
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # [D/2]
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)                   # not trainable

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x: [B,T,D]
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0)              # broadcast add, a simple elementwise add encodes positions.

Q = what I seek, K = how findable I am, V = what I give. <br>

Attention scores use Q·K; the update is a softmax-weighted mix of V.

In [5]:
# Single-head self-attention
class SelfAttention1H(nn.Module): # attention (context mixing)
    def __init__(self, d_model: int):
        super().__init__()
        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.scale = math.sqrt(d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        """
        x:    [B,T,D], mask: [B,T] (True=real token)
        returns:
          out:  [B,T,D]
          attn: [B,T,T] attention matrix (row: query token, col: key token)
        """
        Q, K, V = self.q(x), self.k(x), self.v(x)             # [B,T,D]
        scores = torch.matmul(Q, K.transpose(1, 2)) / self.scale  # [B,T,T]
        key_mask = mask.unsqueeze(1)                           # [B,1,T]
        scores = scores.masked_fill(~key_mask, float("-inf"))  # forbid attending to PADs
        attn = torch.softmax(scores, dim=-1)                   # [B,T,T]
        out = torch.matmul(attn, V)                            # [B,T,D]
        return out, attn


In [6]:
# Position-wise FFN: get nonlinear feature interactions after attention has aggregated context
class FFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int = 128, p_drop: float = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff), # Typical setting is d_ff ≈ 4 * d_model in Transformers
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(p_drop),
        )
    def forward(self, x): return self.net(x)

In [7]:
# Pre-LN Encoder Block (single head)
class EncoderBlock(nn.Module):
    def __init__(self, d_model: int, p_drop: float = 0.1, ffn_mult: int = 4):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = SelfAttention1H(d_model)
        self.drop1 = nn.Dropout(p_drop)

        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff=ffn_mult*d_model, p_drop=0.0)
        self.drop2 = nn.Dropout(p_drop)

    def forward(self, x, mask):
        # 1) pre-LN + residual around attention
        y, attn = self.attn(self.ln1(x), mask)
        x = x + self.drop1(y)
        # 2) pre-LN + residual around FFN
        y = self.ffn(self.ln2(x))
        x = x + self.drop2(y)
        return x, attn

In [8]:
# Encoder + pooling + head
@dataclass
class ModelCfg:
    vocab_size: int
    d_model: int
    pad_id: int
    pool: str = "cls"   # "cls" or "mean"

class SingleHeadEncoderClassifier(nn.Module):
    def __init__(self, cfg: ModelCfg):
        super().__init__()
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_id)
        self.pe = SinusoidalPE(cfg.d_model)
        self.block = EncoderBlock(cfg.d_model, p_drop=P_DROP)
        self.final_ln = nn.LayerNorm(cfg.d_model)  # final LN is common in Pre-LN stacks
        self.head = nn.Linear(cfg.d_model, 1)
        self.pool = cfg.pool
        self.pad_id = cfg.pad_id

    def masked_mean(self, x, mask):               # x:[B,T,D], mask:[B,T]
        m = mask.float().unsqueeze(-1)
        return (x*m).sum(1) / m.sum(1).clamp(min=1.0)

    def forward(self, ids: torch.Tensor, mask: torch.Tensor):
        x = self.embed(ids)                       # [B,T,D]
        x = self.pe(x)                            # add sinusoidal positions
        x, attn = self.block(x, mask)             # [B,T,D], [B,T,T]
        x = self.final_ln(x)

        if self.pool == "mean":
            sent = self.masked_mean(x, mask)      # [B,D]
        else:
            sent = x[:, 0, :]                     # first-token (CLS-style): take the hidden state of that token as the whole-sentence representation

        logits = self.head(sent).squeeze(-1)      # [B]
        return logits, attn


In [9]:
greeting_hard = [
    "good morning everyone",
    "hello there my friend",
    "hey buddy how are you",
    "good evening folks",
    "salutations from the sushi bar",          # greeting + food word
    "pizza party greetings to all",            # greeting + food word
    "hi from the ramen shop",                  # greeting + food word
    "hello and welcome to brunch",             # greeting + food word
]
food_hard = [
    "i love pizza",
    "pasta is tasty tonight",
    "fresh salad with apple",
    "i like sushi a lot",
    "good sandwich this morning",              # food + greeting words
    "ramen is great hello world",              # food + greeting word
    "eating an apple for breakfast",
    "not a fan of pizza anymore",              # negation
]
HARD_SUP = greeting_hard + food_hard
HARD_LABELS = [0]*len(greeting_hard) + [1]*len(food_hard)

In [11]:
# --- Add a real <cls> token and batching with CLS ---
CLS_TOKEN = "<cls>"

def build_vocab_with_specials(sentences):
    base = build_vocab(sentences)
    # reserve 0/1/2 = <pad>/<unk>/<cls>
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1, CLS_TOKEN: 2}
    next_id = 3
    for w in base:
        if w in (PAD_TOKEN, UNK_TOKEN):  # skip if already special
            continue
        vocab[w] = next_id
        next_id += 1
    return vocab

# Build vocab with <cls>
VOCAB = build_vocab_with_specials(HARD_SUP)
pad_id, cls_id = VOCAB[PAD_TOKEN], VOCAB[CLS_TOKEN]

In [12]:
def pad_batch_with_cls(batch_ids, pad_id, cls_id):
    with_cls = [[cls_id] + ids for ids in batch_ids]
    T = max(len(x) for x in with_cls)
    padded, mask = [], []
    for ids in with_cls:
        pad_len = T - len(ids)
        padded.append(ids + [pad_id]*pad_len)
        mask.append([1]*len(ids) + [0]*pad_len)
    X = torch.tensor(padded, dtype=torch.long)
    M = torch.tensor(mask, dtype=torch.bool)
    return X, M

In [16]:
# --- Simple split (80/20) ---
def split_train_val(sentences, labels, val_frac=0.2, seed=42):
    rng = random.Random(seed)
    idx = list(range(len(sentences)))
    rng.shuffle(idx)
    cut = max(1, int(len(idx) * (1 - val_frac)))
    tr_idx, va_idx = idx[:cut], idx[cut:]
    tr_s = [sentences[i] for i in tr_idx]
    tr_y = [labels[i] for i in tr_idx]
    va_s = [sentences[i] for i in va_idx]
    va_y = [labels[i] for i in va_idx]
    return tr_s, tr_y, va_s, va_y

tr_s, tr_y, va_s, va_y = split_train_val(HARD_SUP, HARD_LABELS, val_frac=0.25, seed=42)

In [17]:
# dataset → tensors 
def make_tensors(sentences, labels, vocab, pad_id, cls_id):
    tok = [tokenize(s) for s in sentences]
    ids = [numericalize(t, vocab) for t in tok]
    X, M = pad_batch_with_cls(ids, pad_id, cls_id)
    y = torch.tensor(labels, dtype=torch.float)
    return X, M, y

Xtr, Mtr, ytr = make_tensors(tr_s, tr_y, VOCAB, pad_id, cls_id)
Xva, Mva, yva = make_tensors(va_s, va_y, VOCAB, pad_id, cls_id)

In [None]:
def evaluate_step(model, X, M, y, threshold=0.5):
    model.eval()
    with torch.no_grad():
        logits, _ = model(X, M)
        loss = nn.BCEWithLogitsLoss()(logits, y)
        probs = torch.sigmoid(logits)
        preds = (probs >= threshold).float()
        acc = (preds == y).float().mean().item()
        brier = (probs - y).pow(2).mean().item()
        margin = (probs - 0.5).abs().mean().item()
    return loss.item(), acc, brier, margin

# --- Train loop (mini-batch) ---
def train_encoder(model, Xtr, Mtr, ytr, Xva, Mva, yva,
                  epochs=80, batch_size=8, lr=3e-3, weight_decay=1e-4,
                  patience=8, min_delta=1e-4, clip=1.0, verbose=True):
    device = next(model.parameters()).device
    Xtr, Mtr, ytr = Xtr.to(device), Mtr.to(device), ytr.to(device)
    Xva, Mva, yva = Xva.to(device), Mva.to(device), yva.to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    crit = nn.BCEWithLogitsLoss()

    best_val = float("inf")
    best_state = None
    wait = 0

    N = Xtr.size(0)
    for e in range(1, epochs+1):
        model.train()
        # simple mini-batch loop
        perm = torch.randperm(N, device=device) # returns a random permutation of integers from 0 to n - 1
        total_loss = 0.0
        for i in range(0, N, batch_size):
            idx = perm[i:i+batch_size]
            xb, mb, yb = Xtr[idx], Mtr[idx], ytr[idx]
            opt.zero_grad()
            logits, _ = model(xb, mb)
            loss = crit(logits, yb)
            loss.backward()
            if clip: # scales all grads down uniformly so the norm equals max_norm.
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
            opt.step()
            total_loss += loss.item() * xb.size(0)

        train_loss = total_loss / N
        val_loss, val_acc, val_brier, val_margin = evaluate_step(model, Xva, Mva, yva)

        if verbose and (e % 5 == 0 or e == 1):
            print(f"epoch {e:03d} | train_loss={train_loss:.4f} | "
                  f"val_loss={val_loss:.4f} | val_acc={val_acc:.3f} | "
                  f"val_brier={val_brier:.4f} | val_margin={val_margin:.3f}")

        # early stopping on val loss
        if best_val - val_loss > min_delta:
            best_val = val_loss
            best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                if verbose:
                    print(f"early stop at epoch {e} | best_val_loss={best_val:.4f}")
                break
        
    if best_state is not None:
        model.load_state_dict(best_state)
    return model

In [31]:
# Model
cfg = ModelCfg(vocab_size=len(VOCAB), d_model=D_MODEL, pad_id=pad_id, pool="cls")  # try "mean" too
model = SingleHeadEncoderClassifier(cfg).to(DEVICE)

# Train
model = train_encoder(model, Xtr, Mtr, ytr, Xva, Mva, yva,
                        epochs=80, batch_size=8, lr=3e-3, weight_decay=1e-4,
                        patience=8, min_delta=1e-4, clip=1.0, verbose=True)

# Final report on train + val
tr_metrics = evaluate_step(model, Xtr.to(DEVICE), Mtr.to(DEVICE), ytr.to(DEVICE))
va_metrics = evaluate_step(model, Xva.to(DEVICE), Mva.to(DEVICE), yva.to(DEVICE))
print(f"TRAIN  loss={tr_metrics[0]:.4f} acc={tr_metrics[1]:.3f} brier={tr_metrics[2]:.4f} margin={tr_metrics[3]:.3f}")
print(f"VALID  loss={va_metrics[0]:.4f} acc={va_metrics[1]:.3f} brier={va_metrics[2]:.4f} margin={va_metrics[3]:.3f}")


epoch 001 | train_loss=0.6810 | val_loss=0.8285 | val_acc=0.250 | val_brier=0.3164 | val_margin=0.119
epoch 005 | train_loss=0.5749 | val_loss=0.8381 | val_acc=0.250 | val_brier=0.3211 | val_margin=0.149
epoch 010 | train_loss=0.2093 | val_loss=1.1494 | val_acc=0.500 | val_brier=0.4224 | val_margin=0.286
early stop at epoch 11 | best_val_loss=0.7269
TRAIN  loss=0.6137 acc=0.833 brier=0.2108 margin=0.054
VALID  loss=0.7269 acc=0.250 brier=0.2669 margin=0.060


In [30]:
# Try pool="mean" vs pool="cls" and see which is better without pretraining.
cfg2 = ModelCfg(vocab_size=len(VOCAB), d_model=D_MODEL, pad_id=pad_id, pool="mean")  # try "mean" too
model2 = SingleHeadEncoderClassifier(cfg2).to(DEVICE)

# Train
model2 = train_encoder(model2, Xtr, Mtr, ytr, Xva, Mva, yva,
                        epochs=80, batch_size=8, lr=3e-3, weight_decay=1e-4,
                        patience=8, min_delta=1e-4, clip=1.0, verbose=True)

# Final report on train + val
tr_metrics = evaluate_step(model2, Xtr.to(DEVICE), Mtr.to(DEVICE), ytr.to(DEVICE))
va_metrics = evaluate_step(model2, Xva.to(DEVICE), Mva.to(DEVICE), yva.to(DEVICE))
print(f"TRAIN  loss={tr_metrics[0]:.4f} acc={tr_metrics[1]:.3f} brier={tr_metrics[2]:.4f} margin={tr_metrics[3]:.3f}")
print(f"VALID  loss={va_metrics[0]:.4f} acc={va_metrics[1]:.3f} brier={va_metrics[2]:.4f} margin={va_metrics[3]:.3f}")


epoch 001 | train_loss=0.7326 | val_loss=0.7517 | val_acc=0.000 | val_brier=0.2792 | val_margin=0.028
epoch 005 | train_loss=0.4753 | val_loss=0.9579 | val_acc=0.250 | val_brier=0.3739 | val_margin=0.125
early stop at epoch 9 | best_val_loss=0.7517
TRAIN  loss=0.6538 acc=0.833 brier=0.2304 margin=0.065
VALID  loss=0.7517 acc=0.000 brier=0.2792 margin=0.028


In [81]:
# Head-only warm start (few epochs) → then unfreeze encoder
def freeze(module, flag: bool):
    for p in module.parameters(): p.requires_grad = flag

def train_encoder_warm_start(model, Xtr, Mtr, ytr, Xva, Mva, yva,
                  epochs=60, warm_epochs=5, batch_size=8, weight_decay=1e-4,
                  patience_warm=2, patience_full=4, min_delta=1e-4, clip=1.0, verbose=True):
    device = next(model.parameters()).device
    Xtr, Mtr, ytr = Xtr.to(device), Mtr.to(device), ytr.to(device)
    Xva, Mva, yva = Xva.to(device), Mva.to(device), yva.to(device)

    crit = nn.BCEWithLogitsLoss()

    best_val = float("inf")
    best_state = None
    wait = 0
    # Warm start head-only (encoder frozen)
    freeze(model.embed, False)
    freeze(model.block, False)
    opt = torch.optim.Adam(list(model.final_ln.parameters()) + list(model.head.parameters()), lr=1e-3, weight_decay=0)
    # (optionally) also freeze model.head.scale during warm start if using cosine head

    N = Xtr.size(0)
    for e in range(1, warm_epochs+1):
        model.train()
        # simple mini-batch loop
        perm = torch.randperm(N, device=device) # returns a random permutation of integers from 0 to n - 1
        total_loss = 0.0
        for i in range(0, N, batch_size):
            idx = perm[i:i+batch_size]
            xb, mb, yb = Xtr[idx], Mtr[idx], ytr[idx]
            opt.zero_grad()
            logits, _ = model(xb, mb)
            loss = crit(logits, yb)
            loss.backward()
            if clip: # scales all grads down uniformly so the norm equals max_norm.
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
            opt.step()
            total_loss += loss.item() * xb.size(0)

        train_loss = total_loss / N
        val_loss, val_acc, val_brier, val_margin = evaluate_step(model, Xva, Mva, yva)

        if verbose:
            print(f"Warm start epoch {e:03d} | train_loss={train_loss:.4f} | "
                  f"val_loss={val_loss:.4f} | val_acc={val_acc:.3f} | "
                  f"val_brier={val_brier:.4f} | val_margin={val_margin:.3f}")
        # early stopping on val loss
        if best_val - val_loss > min_delta:
            best_val = val_loss
            best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= patience_warm:
                if verbose:
                    print(f"early stop at epoch {e} | best_val_loss={best_val:.4f}")
                break


    # Restore best-so-far before unfreezing to start Stage-2 from best point
    if best_state is not None:
        model.load_state_dict(best_state)
    # After warm start → unfreeze encoder if training from scratch
    freeze(model.embed, True)
    freeze(model.block, True)
    # rebuild optimizer with different LRs for emb vs head
    enc_params  = [p for n,p in model.named_parameters()
               if (n.startswith("embed.") or n.startswith("block.")) and p.requires_grad]
    head_params = [p for n,p in model.named_parameters()
                if n.startswith("head.") and p.requires_grad]
    other_params= [p for n,p in model.named_parameters()
                if n.startswith("final_ln") and p.requires_grad]

    opt = torch.optim.Adam([
        {"params": enc_params,   "lr": 1e-4, "weight_decay": 1e-4},   # smaller, steadier
        {"params": head_params,  "lr": 3e-3, "weight_decay": 0.0},    # freer head to grow margin, higher LR, low/zero wd
        {"params": other_params, "lr": 3e-3, "weight_decay": 1e-4},
    ])
    # # ReduceLROnPlateau on the encoder
    # sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     opt, mode="min", factor=0.2, patience=1, threshold=1e-4
    # )

    wait = 0
    for e in range(warm_epochs+1, epochs+1):
        model.train()
        # simple mini-batch loop
        perm = torch.randperm(N, device=device) # returns a random permutation of integers from 0 to n - 1
        total_loss = 0.0
        for i in range(0, N, batch_size):
            idx = perm[i:i+batch_size]
            xb, mb, yb = Xtr[idx], Mtr[idx], ytr[idx]
            opt.zero_grad()
            logits, _ = model(xb, mb)
            loss = crit(logits, yb)
            loss.backward()
            if clip: # scales all grads down uniformly so the norm equals max_norm.
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
            opt.step()
            total_loss += loss.item() * xb.size(0)

        train_loss = total_loss / N
        val_loss, val_acc, val_brier, val_margin = evaluate_step(model, Xva, Mva, yva)

        if verbose:
            print(f"epoch {e:03d} | train_loss={train_loss:.4f} | "
                  f"val_loss={val_loss:.4f} | val_acc={val_acc:.3f} | "
                  f"val_brier={val_brier:.4f} | val_margin={val_margin:.3f}")

        # early stopping on val loss
        if best_val - val_loss > min_delta:
            best_val = val_loss
            best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= patience_full:
                if verbose:
                    print(f"early stop at epoch {e} | best_val_loss={best_val:.4f}")
                break
        
        # # at the end of each epoch (using val_loss):
        # sched.step(val_loss)
        
    if best_state is not None:
        model.load_state_dict(best_state)
    return model
    

In [82]:
model_w = SingleHeadEncoderClassifier(cfg).to(DEVICE)
model_w = train_encoder_warm_start(model_w, Xtr, Mtr, ytr, Xva, Mva, yva)

# Final report on train + val
tr_metrics = evaluate_step(model_w, Xtr.to(DEVICE), Mtr.to(DEVICE), ytr.to(DEVICE))
va_metrics = evaluate_step(model_w, Xva.to(DEVICE), Mva.to(DEVICE), yva.to(DEVICE))
print(f"TRAIN  loss={tr_metrics[0]:.4f} acc={tr_metrics[1]:.3f} brier={tr_metrics[2]:.4f} margin={tr_metrics[3]:.3f}")
print(f"VALID  loss={va_metrics[0]:.4f} acc={va_metrics[1]:.3f} brier={va_metrics[2]:.4f} margin={va_metrics[3]:.3f}")

Warm start epoch 001 | train_loss=0.7538 | val_loss=0.6124 | val_acc=0.750 | val_brier=0.2101 | val_margin=0.125
Warm start epoch 002 | train_loss=0.7188 | val_loss=0.6187 | val_acc=0.750 | val_brier=0.2131 | val_margin=0.114
Warm start epoch 003 | train_loss=0.7372 | val_loss=0.6240 | val_acc=0.750 | val_brier=0.2157 | val_margin=0.105
early stop at epoch 3 | best_val_loss=0.6124
epoch 006 | train_loss=0.7234 | val_loss=0.6286 | val_acc=0.750 | val_brier=0.2179 | val_margin=0.098
epoch 007 | train_loss=0.7146 | val_loss=0.6432 | val_acc=0.750 | val_brier=0.2251 | val_margin=0.077
epoch 008 | train_loss=0.7073 | val_loss=0.6669 | val_acc=0.750 | val_brier=0.2369 | val_margin=0.047
epoch 009 | train_loss=0.6933 | val_loss=0.6954 | val_acc=0.500 | val_brier=0.2511 | val_margin=0.021
early stop at epoch 9 | best_val_loss=0.6124
TRAIN  loss=0.7419 acc=0.417 brier=0.2742 margin=0.119
VALID  loss=0.6124 acc=0.750 brier=0.2101 margin=0.125


In [75]:
model2_w = SingleHeadEncoderClassifier(cfg2).to(DEVICE)
model2_w = train_encoder_warm_start(model2_w, Xtr, Mtr, ytr, Xva, Mva, yva)

# Final report on train + val
tr_metrics = evaluate_step(model2_w, Xtr.to(DEVICE), Mtr.to(DEVICE), ytr.to(DEVICE))
va_metrics = evaluate_step(model2_w, Xva.to(DEVICE), Mva.to(DEVICE), yva.to(DEVICE))
print(f"TRAIN  loss={tr_metrics[0]:.4f} acc={tr_metrics[1]:.3f} brier={tr_metrics[2]:.4f} margin={tr_metrics[3]:.3f}")
print(f"VALID  loss={va_metrics[0]:.4f} acc={va_metrics[1]:.3f} brier={va_metrics[2]:.4f} margin={va_metrics[3]:.3f}")


Warm start epoch 001 | train_loss=0.6965 | val_loss=0.5255 | val_acc=0.750 | val_brier=0.1693 | val_margin=0.126
Warm start epoch 002 | train_loss=0.6891 | val_loss=0.5287 | val_acc=0.750 | val_brier=0.1706 | val_margin=0.122
Warm start epoch 003 | train_loss=0.6913 | val_loss=0.5315 | val_acc=0.750 | val_brier=0.1719 | val_margin=0.118
early stop at epoch 3 | best_val_loss=0.5255
epoch 006 | train_loss=0.6909 | val_loss=0.5421 | val_acc=0.750 | val_brier=0.1765 | val_margin=0.103
epoch 007 | train_loss=0.6722 | val_loss=0.5571 | val_acc=0.750 | val_brier=0.1834 | val_margin=0.085
epoch 008 | train_loss=0.6532 | val_loss=0.5729 | val_acc=0.750 | val_brier=0.1908 | val_margin=0.066
epoch 009 | train_loss=0.6514 | val_loss=0.5886 | val_acc=1.000 | val_brier=0.1983 | val_margin=0.057
early stop at epoch 9 | best_val_loss=0.5255
TRAIN  loss=0.6900 acc=0.500 brier=0.2485 margin=0.057
VALID  loss=0.5255 acc=0.750 brier=0.1693 margin=0.126


Monitor: Brier + margin; if margin ↑ but Brier ↑ too, you’re becoming confidently wrong → back off head LR a touch.

Plateau schedulers are brittle on tiny, noisy validation.