# Stability & Optimization lab

Making training smooth and reliable:

- Norm choices: LayerNorm vs RMSNorm

- Residual wiring: plain vs scaled (α) vs ReZero gating

- Regularization: weight decay (AdamW), label smoothing

- Training knobs: LRs, (optional) warmup, grad clipping

- Masking sanity (padding/causal) to avoid silent bugs

In [30]:
import math, torch, torch.nn as nn, torch.nn.functional as F
import random
import numpy as np

## Drop-in RMSNorm + gated residuals

In [2]:
# ----- Norms -----
class RMSNorm(nn.Module):
    # LN: learnable scale + bias (per feature), with mean-centering and variance normalization.
    # RMSNorm: learnable scale only (per feature), with RMS normalization (no centering)
    def __init__(self, d, eps=1e-8):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(d))
        self.eps = eps
    def forward(self, x):
        # normalize by root-mean-square over features
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        x_hat = x / rms
        return self.scale * x_hat

def make_norm(kind: str, d_model: int):
    kind = kind.lower()
    if kind in ("ln", "layernorm"):
        return nn.LayerNorm(d_model)
    elif kind in ("rms", "rmsnorm"):
        return RMSNorm(d_model)
    else:
        raise ValueError(f"unknown norm: {kind}")

In [3]:
# ----- Residual wrappers -----
class ResidualAdd(nn.Module):
    """x + F(LN(x)) with optional scaling or gating (residual-zero)."""
    def __init__(self, d_model, fn, norm="ln", drop_p=0.1, mode="plain"):
        super().__init__()
        self.norm = make_norm(norm, d_model)
        self.fn = fn
        self.drop = nn.Dropout(drop_p)
        self.mode = mode
        if mode == "scaled":         # constant residual scale α
            self.alpha = 0.5
        elif mode == "rezero":       # learnable gate g, init 0
            # let training “turn on” the residuals, stabilizes very deep stacks and can reduce the need for aggressive warmup or heavy normalization
            self.g = nn.Parameter(torch.zeros(1))
    def forward(self, x, **kwargs):
        h = self.fn(self.norm(x), **kwargs)
        h = self.drop(h)
        if self.mode == "plain":
            return x + h
        elif self.mode == "scaled":
            return x + self.alpha * h
        elif self.mode == "rezero":
            return x + self.g * h
        else:
            raise ValueError(self.mode)

In [68]:
# Refactor my preLN encoder
class PreLNEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, p_drop=0.1, ff_multi=4,
                 norm="ln", resid_mode="plain"):
        super().__init__()
        # use nn.MultiheadAttention wrapper
        self.mha_ctor = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=p_drop, batch_first=True)

        self.attn = ResidualAdd(
            d_model,
            # # out, attn = mha(x, pad_mask=..., causal=...), we only need the out, hence [0].
            fn=lambda x, pad_mask=None, causal=False, attn_mask=None: 
            self.mha_ctor(x, x, x, 
                     key_padding_mask=(~pad_mask) if pad_mask is not None else None, 
                     is_causal=causal, 
                     attn_mask=attn_mask, 
                     need_weights=False)[0], 
            norm=norm, drop_p=p_drop, mode=resid_mode
        )
        self.ff = ResidualAdd(
            d_model,
            fn=nn.Sequential(
                nn.Linear(d_model, ff_multi*d_model),
                nn.GELU(),
                nn.Linear(ff_multi*d_model, d_model),
            ),
            norm=norm, drop_p=p_drop, mode=resid_mode
        )
    def forward(self, x, pad_mask=None, causal=False, attn_mask=None):
        assert pad_mask.dtype == torch.bool
        x = self.attn(x, pad_mask=pad_mask, causal=causal, attn_mask=attn_mask) # [B,T,D]
        x = self.ff(x) # [B,T,D]
        return x

In [16]:
# PE helpers
class SinusoidalPE(nn.Module):
    def __init__(self, d_model: int, max_len: int = 256):
        super().__init__()
        pe = torch.zeros(max_len, d_model)               # [Tmax, D]
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [Tmax, 1]
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # [D/2]
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)                   # not trainable

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # x: [B,T,D]
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0) 
    
class ClippedRelPosBias(nn.Module):
    def __init__(self, num_heads, max_rel=128):
        super().__init__()
        self.max_rel = max_rel
        self.table = nn.Parameter(torch.zeros(num_heads, 2*max_rel - 1))  # [-R+1..R-1]

    def forward(self, T, *, device=None):
        device = device or self.table.device
        q = torch.arange(T, device=device)[:, None]
        k = torch.arange(T, device=device)[None, :]
        rel = (k - q).clamp(-self.max_rel+1, self.max_rel-1)  # [T,T]
        idx = rel + (self.max_rel - 1)          # shift to [0..2R-2]
        bias = self.table[:, idx]               # [H,T,T]
        return bias
        # nn.MultiheadAttention's attn_mask accepts [B*H, T, T] (per (batch, head) bias).

In [124]:
# Encoder→Pool→Head wrapper
class MHAEncoderClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, pad_id, cls_id, num_heads, pe= "none", # {"none","sin","relbias"}
                 p_drop=0.1, ff_multi=4, norm="ln", resid_mode="plain", num_layers=2, pool="mean"):
        super().__init__()
        self.num_heads = num_heads
        self.cls_id = cls_id   # int or None
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pe = pe
        if pe=="sin":
            self.posenc = SinusoidalPE(d_model) # returns x + pe(x)
        elif pe=="relbias":
            self.posenc=ClippedRelPosBias(num_heads) # returns [H,T,T]
        self.block = nn.ModuleList([PreLNEncoderBlock(d_model, num_heads, p_drop, ff_multi, norm, resid_mode) for _ in range(num_layers)])
        self.final_ln = make_norm(norm, d_model)
        self.head = nn.Linear(d_model, 1)
        self.pool = pool

    def masked_mean(self, x, mask):               # x:[B,T,D], mask:[B,T]
        m = mask.float().unsqueeze(-1)
        return (x*m).sum(1) / m.sum(1).clamp(min=1.0)

    def forward(self, ids: torch.Tensor, mask: torch.Tensor):
        attn_mask = None # default
        if self.pool == "cls" and self.cls_id is not None:
            # prepend CLS id and mask=True (non-pad)
            cls_col = torch.full((B, 1), self.cls_id, dtype=ids.dtype, device=ids.device)
            ids = torch.cat([cls_col, ids], dim=1)                             # [B,T+1]
            mask = torch.cat([torch.ones(B,1, dtype=torch.bool, device=mask.device), mask], dim=1)  # [B,T+1]

        x = self.embed(ids)                       # [B,T,D]
        B,T,_ = x.shape
        # pe
        if self.pe=="sin":
            x = self.posenc(x)
        elif self.pe=="relbias":
            bias = self.posenc(T, device=x.device)
            # [H, T, T] -> [1, H, T, T] -> [B, H, T, T] -> [B*H, T, T]
            attn_mask = bias.unsqueeze(0).expand(B, -1, -1, -1).reshape(B*self.num_heads, T, T)
        for b in self.block:
            x = b(x, pad_mask=mask, attn_mask=attn_mask, causal=False)             # [B,T,D]
        x = self.final_ln(x)

        if self.pool == "mean":
            sent = self.masked_mean(x, mask)      # [B,D]
        else:
            sent = x[:, 0, :]                     # first-token (CLS-style)
        logits = self.head(sent).view(-1)      # [B]
        return logits

In [None]:
## Binary label smoothing helper (for BCE)
# Use it once per batch before computing BCEWithLogitsLoss.
def smooth_labels(y, eps=0.05):
    # y in {0,1} -> {eps, 1-eps}
    return y*(1-eps) + (1-y)*eps

In [6]:
# Optimizer & (optional) warmup
# Param groups that exclude biases/norms from weight decay
def adamw_groups(model, lr_enc, lr_head, wd_enc=1e-4, wd_head=0.0):
    enc_decay, enc_no_decay, head_decay, head_no_decay = [], [], [], []

    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_head = n.startswith("head.")
        is_bias = n.endswith(".bias")
        is_norm = ("norm" in n.lower())  # catches LayerNorm/RMSNorm params
        # treat other scale-like knobs as no_decay too:
        is_scale_like = n.endswith(".scale") or n.endswith(".g") # scale or gating
        is_pos_emb = ("pos" in n.lower() and "emb" in n.lower()) # positional embeddings

        no_decay = is_bias or is_norm or is_scale_like or is_pos_emb

        if is_head:
            (head_no_decay if no_decay else head_decay).append(p)
        else:
            (enc_no_decay  if no_decay else enc_decay ).append(p)

    groups = []
    if enc_decay:
        groups.append({"params": enc_decay,     "lr": lr_enc,  "weight_decay": wd_enc})
    if enc_no_decay:
        groups.append({"params": enc_no_decay,  "lr": lr_enc,  "weight_decay": 0.0})
    if head_decay:
        groups.append({"params": head_decay,    "lr": lr_head, "weight_decay": wd_head})
    if head_no_decay:
        groups.append({"params": head_no_decay, "lr": lr_head, "weight_decay": 0.0})
    return groups


In [8]:
# Warmup (linear) → constant:
def linear_warmup(step, warmup_steps, base_lr):
    if step < warmup_steps:
        return base_lr * (step+1)/warmup_steps
    return base_lr

# inside training loop:
# step += 1
# for g in opt.param_groups: g["lr"] = linear_warmup(step, warmup_steps=200, base_lr=g["lr"])

In [60]:
# Define a ctor that returns a *fresh* model each fold
def make_model(vocab_size, d_model, pad_id, cls_id, num_heads, pe, p_drop, ff_multi, norm, resid_mode, num_layers, pool):
    return MHAEncoderClassifier(vocab_size=vocab_size, d_model=d_model, pad_id=pad_id, cls_id=cls_id,
                                num_heads=num_heads, pe=pe, p_drop=p_drop, ff_multi=ff_multi,
                                norm=norm, resid_mode=resid_mode, num_layers=num_layers, pool=pool)

In [66]:
def _call_model(model, Xb, Mb=None):
    """Call model with (X, mask) if mask given; unwrap (logits, extra) if returned."""
    if Mb is None:
        out = model(Xb)
    else:
        out = model(Xb, Mb)
    if isinstance(out, (tuple, list)): # out, attn
        logits = out[0]
    else:
        logits = out
    return logits

In [115]:
def _eval(X, y, M, model, criterion):
    model.eval()
    with torch.no_grad():
        logits = _call_model(model, X, M).squeeze(-1)
        y_flt  = y.float()
        loss = criterion(logits, y_flt)
        prob = torch.sigmoid(loss)
        preds = (prob >= 0.5).float()
        # acc
        acc = (preds==y).float().mean().item()
        # Brier (binary): mean (p - y)^2
        brier = torch.mean(torch.pow(preds-y_flt,2)).item()
        # avg margin around 0.5 (simple): mean |p - 0.5|
        margin = torch.mean(torch.abs(preds-0.5)).item()

    return {"loss": loss.item(), "acc": acc, "brier": brier, "margin": margin}

In [12]:
from sklearn.model_selection import StratifiedKFold

In [126]:
def kfold_train(X, y, vocab_size, pad_id, mask=None, cls_id = None, *,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="ln", 
    resid_mode="plain", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0,
    warmup_steps=0,
    seed=42,):

    assert X.shape[0] == y.shape[0], "X and y must align"
    if mask is not None:
        assert mask.shape[0] == y.shape[0], "mask and y must align"
    device = device or (torch.device("mps") if torch.backends.mps.is_available()
                        else torch.device("cpu"))
    
    # move once to device
    X = X.to(device)
    y_long = y.to(device).long()   # keep an integer copy for stratify + metrics
    y_flt  = y_long.float()        # float copy for BCE
    M = mask.to(device) if mask is not None else None

    # Global seeding (once)
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    all_metrics = []

    # scikit-learn needs a NumPy array on CPU, the 1 makes fold numbers start at 1 instead of 0.
    y_np = y_long.detach().cpu().numpy()
    for fold, (tr_idx, va_idx) in enumerate(skf.split(np.zeros(len(y_np)), y_np), 1):
        tr_idx = torch.tensor(tr_idx, device=device)
        va_idx = torch.tensor(va_idx, device=device)
        # model, criterion, optimizer
        model = model_ctor(vocab_size, d_model, pad_id, cls_id, num_heads, pe, p_drop, ff_multi, norm, resid_mode, num_layers, pool).to(device)
        criterion = nn.BCEWithLogitsLoss()
        param_groups = groups_fn(model, lr_enc=lr_enc, lr_head=lr_head or lr_enc, wd_enc=wd_enc, wd_head=wd_head)
        Optim = torch.optim.AdamW if use_adamw else torch.optim.Adam
        opt = Optim(param_groups, betas=(0.9,0.98), eps=1e-9)

        for g in opt.param_groups:
            g.setdefault("base_lr", g["lr"])
        global_step = 0
        
        best = math.inf
        best_sd = None
        wait = 0

        for epoch in range(1, epochs + 1):
            model.train()
            # build batch
            if batch_size is None or batch_size >= len(tr_idx):
                batches = [tr_idx]
            else:
                perm = tr_idx[torch.randperm(len(tr_idx), device=device)]
                batches = [perm[i:i+batch_size] for i in range(0, len(perm), batch_size)]
            
            total_loss = 0.0
            for bi in batches:
                opt.zero_grad(set_to_none=True)
                # compute logit
                logit = _call_model(model,X[bi], M[bi] if M is not None else None)
                targets = y_flt[bi]
                # smooth_labels(y)
                if smooth_eps>0:
                    targets = smooth_labels(targets, eps=smooth_eps)
                # compute loss
                loss = criterion(logit, targets)
                loss.backward()
                total_loss += loss.item() * len(bi)
                # prevents exploding gradients by rescaling all gradients so their total norm ≤ clip
                if clip:
                    nn.utils.clip_grad_norm_(model.parameters(), clip)

                global_step += 1
                for g in opt.param_groups:
                    g["lr"] = linear_warmup(global_step, warmup_steps=warmup_steps, base_lr=g["base_lr"])
                opt.step()
            train_loss = total_loss / len(tr_idx)

            # evaluate on val
            val_dict = _eval(X[va_idx], y_long[va_idx], M[va_idx] if M is not None else None, model, criterion)
            if epoch % 5 ==0:
                print(f"Fold {fold}, epoch {epoch}, {[k+': '+str(v) for k,v in val_dict.items()]}")
            # early stopping on val
            if best-val_dict['loss']<min_delta:
                wait +=1
            else:
                wait = 0
            if val_dict['loss']<best:
                best = val_dict['loss']
                best_sd = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            if wait>=patience:
                break

        # restore best
        if best_sd is not None:
            model.load_state_dict(best_sd)
        # log final metrics on this fold with best model
        fold_metrics = _eval(X[va_idx], y_long[va_idx], M[va_idx] if M is not None else None, model, criterion)
        all_metrics.append(fold_metrics)

    # aggregate all folds
    keys = all_metrics[0].keys()
    mean = {k: float(np.mean([m[k] for m in all_metrics])) for k in keys}
    std  = {k: float(np.std( [m[k] for m in all_metrics])) for k in keys}
    return {"folds": all_metrics, "mean": mean, "std": std}

| Variant | Norm      | Residual     | Loss                   | Optim             |
| :-----: | --------- | ------------ | ---------------------- | ----------------- |
|    A    | LayerNorm | plain        | BCE                    | Adam              |
|    B    | RMSNorm   | plain        | BCE                    | AdamW             |
|    C    | RMSNorm   | scaled α=0.5 | BCE + smoothing ε=0.05 | AdamW             |
|    D    | RMSNorm   | ReZero       | BCE                    | AdamW (no warmup) |


In [19]:
greeting_hard = [
    "good morning everyone",
    "hello there my friend",
    "hey buddy how are you",
    "good evening folks",
    "salutations from the sushi bar",          # greeting + food word
    "pizza party greetings to all",            # greeting + food word
    "hi from the ramen shop",                  # greeting + food word
    "hello and welcome to brunch",             # greeting + food word
]
food_hard = [
    "i love pizza",
    "pasta is tasty tonight",
    "fresh salad with apple",
    "i like sushi a lot",
    "good sandwich this morning",              # food + greeting words
    "ramen is great hello world",              # food + greeting word
    "eating an apple for breakfast",
    "not a fan of pizza anymore",              # negation
]
HARD_SUP = greeting_hard + food_hard
HARD_LABELS = [0]*len(greeting_hard) + [1]*len(food_hard)

In [22]:
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
CLS_TOKEN = "<cls>"

def build_vocab(sentences, min_freq=1):
    freq = {}
    for s in sentences:
        for w in s.split():
            freq[w] = freq.get(w, 0) + 1
    words = [w for w, c in sorted(freq.items(), key=lambda x: (-x[1], x[0])) if c >= min_freq]
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1}
    for w in words:
        vocab[w] = len(vocab)
    return vocab

def build_vocab_with_specials(sentences):
    base = build_vocab(sentences)
    # reserve 0/1/2 = <pad>/<unk>/<cls>
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1, CLS_TOKEN: 2}
    next_id = 3
    for w in base:
        if w in (PAD_TOKEN, UNK_TOKEN):  # skip if already special
            continue
        vocab[w] = next_id
        next_id += 1
    return vocab

# Build vocab with <cls>
VOCAB = build_vocab_with_specials(HARD_SUP)
pad_id, cls_id = VOCAB[PAD_TOKEN], VOCAB[CLS_TOKEN]

In [25]:
def tokenize(s): return s.strip().split()

def numericalize(tokens, vocab):
    return [vocab.get(t, vocab[UNK_TOKEN]) for t in tokens]

def pad_batch(batch_ids, pad_id=0):
    T = max(len(x) for x in batch_ids)
    padded, mask = [], []
    for ids in batch_ids:
        pad_len = T - len(ids)
        padded.append(ids + [pad_id]*pad_len)
        mask.append([1]*len(ids) + [0]*pad_len)
    return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)

# dataset → tensors 
def make_tensors(sentences, labels, vocab, pad_id):
    tok = [tokenize(s) for s in sentences]
    ids = [numericalize(t, vocab) for t in tok]
    X, M = pad_batch(ids, pad_id)
    y = torch.tensor(labels, dtype=torch.float)
    return X, M, y

In [26]:
X, M, y = make_tensors(HARD_SUP, HARD_LABELS, VOCAB, pad_id)

In [45]:
len(VOCAB)

58

In [42]:
X.shape, M.shape, y.shape

(torch.Size([16, 6]), torch.Size([16, 6]), torch.Size([16]))

In [49]:
embed = nn.Embedding(len(VOCAB), 32, padding_idx=pad_id)
embed(X).shape

torch.Size([16, 6, 32])

In [117]:
v1 = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="ln", 
    resid_mode="plain", 
    num_layers=2,
    pool="mean",
    use_adamw=False,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0,
    warmup_steps=1,
    seed=42,)
v1

Fold 1, epoch 5, ['loss: 0.6099370718002319', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.3418367803096771', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 10, ['loss: 0.1882925182580948', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 15, ['loss: 0.14169655740261078', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 20, ['loss: 0.1305415779352188', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 25, ['loss: 0.1261541098356247', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 30, ['loss: 0.12303049117326736', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 35, ['loss: 0.12025728076696396', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 40, ['loss: 0.11739969998598099', 'acc: 0.3333333432674408', 'brier: 0.

{'folds': [{'loss': 0.5880377292633057,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.11739969998598099,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6298583149909973,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6424539685249329,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.45326921343803406,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.48620378524065017,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.19622539854449506,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}

In [118]:
v2 = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="rms", 
    resid_mode="plain", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=1,
    warmup_steps=0,
    seed=42,)
v2

Fold 1, epoch 5, ['loss: 0.7572334408760071', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.8552422523498535', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 3, epoch 5, ['loss: 0.6297099590301514', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 4, epoch 5, ['loss: 0.5862617492675781', 'acc: 0.6666666865348816', 'brier: 0.3333333432674408', 'margin: 0.5']
Fold 5, epoch 5, ['loss: 1.0515602827072144', 'acc: 0.6666666865348816', 'brier: 0.3333333432674408', 'margin: 0.5']


{'folds': [{'loss': 0.6719502210617065,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.5542317032814026,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6280735731124878,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.5338841676712036,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.8048624992370605,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.6386004328727722,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.09690167861395321,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}

In [119]:
v3 = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="rms", 
    resid_mode="scaled", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0.05,
    warmup_steps=1,
    seed=42,)
v3

Fold 1, epoch 5, ['loss: 0.6311007142066956', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.4173022210597992', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 10, ['loss: 0.28249046206474304', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 15, ['loss: 0.16031070053577423', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 20, ['loss: 0.12056028842926025', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 25, ['loss: 0.13040964305400848', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 3, epoch 5, ['loss: 0.6759376525878906', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 4, epoch 5, ['loss: 0.7648405432701111', 'acc: 0.6666666865348816', 'brier: 0.3333333432674408', 'margin: 0.5']
Fold 5, epoch 5, ['loss: 0.7301142811775208', 'acc: 0.6666666865348816', 'brier: 0.333

{'folds': [{'loss': 0.6285253763198853,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.11987989395856857,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6319447159767151,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6227778792381287,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.4085395634174347,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.48233348578214646,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.20014912736074122,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}

In [120]:
v4 = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="rms", 
    resid_mode="rezero", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0,
    warmup_steps=0,
    seed=42,)
v4

Fold 1, epoch 5, ['loss: 0.6179345846176147', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 1, epoch 10, ['loss: 0.6171682476997375', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.512432336807251', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 10, ['loss: 0.482908695936203', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 15, ['loss: 0.4404577314853668', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 20, ['loss: 0.3761892318725586', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 25, ['loss: 0.2831224203109741', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 30, ['loss: 0.1605003923177719', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 35, ['loss: 0.06715098023414612', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
F

{'folds': [{'loss': 0.6169049739837646,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.030630091205239296,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6387419700622559,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6206154227256775,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.7470483183860779,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.530788155272603,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.25458421701752676,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}

In [121]:
v1b = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=0.0,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="ln", 
    resid_mode="plain", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0,
    warmup_steps=1,
    seed=42,)
v1b

Fold 1, epoch 5, ['loss: 0.6081080436706543', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.3438802659511566', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 10, ['loss: 0.1896723508834839', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 15, ['loss: 0.14223147928714752', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 20, ['loss: 0.13118593394756317', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 25, ['loss: 0.12694151699543', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 30, ['loss: 0.12393798679113388', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 35, ['loss: 0.12127849459648132', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 40, ['loss: 0.11863742023706436', 'acc: 0.3333333432674408', 'brier: 0.6

{'folds': [{'loss': 0.587480366230011,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.11863742023706436,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6298519968986511,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6422449946403503,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.451951265335083,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.48603320866823196,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.19571343128173535,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}

In [128]:
v1c = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=None,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="ln", 
    resid_mode="plain", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0,
    warmup_steps=1,
    seed=42,)
v1c

Fold 1, epoch 5, ['loss: 0.608108401298523', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.3438807427883148', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 10, ['loss: 0.18967290222644806', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 15, ['loss: 0.14223168790340424', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 20, ['loss: 0.13118627667427063', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 25, ['loss: 0.12694230675697327', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 30, ['loss: 0.12393901497125626', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 35, ['loss: 0.12127966433763504', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 40, ['loss: 0.11863859742879868', 'acc: 0.3333333432674408', 'brier: 

{'folds': [{'loss': 0.5874805450439453,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.11863859742879868,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6298520565032959,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6422448754310608,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.45195162296295166,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.48603353947401046,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.19571298512595783,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}

In [127]:
# mini batch
v1d = kfold_train(X, y, vocab_size=len(VOCAB), pad_id=pad_id, cls_id=cls_id, mask=M,
    model_ctor=make_model,                # callable: () -> fresh model
    device=None,
    n_splits=5,
    epochs=40,
    batch_size=2,           # None = full-batch
    lr_enc=3e-3,
    lr_head=None,              # if None, use lr_enc for all params
    wd_enc=1e-4,
    wd_head=0.0,
    patience=4,
    min_delta=1e-4,
    clip=1.0,
    d_model=32, 
    num_heads=4, 
    pe= "none",
    p_drop=0.1, 
    ff_multi=4,
    norm="ln", 
    resid_mode="plain", 
    num_layers=2,
    pool="mean",
    use_adamw=True,
    groups_fn=adamw_groups,            # optional: callable(model, lr_enc, lr_head, wd_enc, wd_head) -> param_groups
    smooth_eps=0,
    warmup_steps=1,
    seed=42,)
v1d

Fold 1, epoch 5, ['loss: 1.2776756286621094', 'acc: 0.5', 'brier: 0.5', 'margin: 0.5']
Fold 2, epoch 5, ['loss: 0.04318267107009888', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 10, ['loss: 0.041952747851610184', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 15, ['loss: 0.029705777764320374', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 20, ['loss: 0.020358499139547348', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 25, ['loss: 0.014227413572371006', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 30, ['loss: 0.009127969853579998', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 35, ['loss: 0.005400742869824171', 'acc: 0.3333333432674408', 'brier: 0.6666666865348816', 'margin: 0.5']
Fold 2, epoch 40, ['loss: 0.0035138383973389864', 'acc: 0.3333333432674408'

{'folds': [{'loss': 0.6329571008682251,
   'acc': 0.5,
   'brier': 0.5,
   'margin': 0.5},
  {'loss': 0.0035138383973389864,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 0.6062131524085999,
   'acc': 0.3333333432674408,
   'brier': 0.6666666865348816,
   'margin': 0.5},
  {'loss': 1.0578625202178955,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5},
  {'loss': 0.02071155607700348,
   'acc': 0.6666666865348816,
   'brier': 0.3333333432674408,
   'margin': 0.5}],
 'mean': {'loss': 0.4642516335938126,
  'acc': 0.5000000119209289,
  'brier': 0.5000000119209289,
  'margin': 0.5},
 'std': {'loss': 0.402491144875818,
  'acc': 0.14907120294265402,
  'brier': 0.149071202942654,
  'margin': 0.0}}