# Drop-in SDPA MHA with optional RoPE + RelBias

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math, torch, torch.nn as nn, torch.nn.functional as F

In [3]:
# ----- (A) tiny RoPE helper -----
def apply_rope(q, k, base=10000.0):
    # q,k: [B,H,T,Dh]; split half-dims
    Dh = q.size(-1); assert Dh % 2 == 0, "RoPE needs even head dim"
    half = Dh // 2
    q1, q2 = q[..., :half], q[..., half:]
    k1, k2 = k[..., :half], k[..., half:]
    T = q.size(-2)
    device = q.device
    dtype = q.dtype

    pos = torch.arange(T, device=device, dtype=dtype)[:, None]                    # [T,1]
    inv_freqs = torch.exp(-math.log(base) * (torch.arange(0, half, device=device, dtype=dtype) / half))[None, :]   # [1, half]
    ang = pos * inv_freqs                                                    # [T,half]
    sin, cos = torch.sin(ang), torch.cos(ang)                                              # [T,half]
    # broadcast to [B,H,T,half]
    sin = sin[None, None, :, :]; cos = cos[None, None, :, :]

    # rotate (x1, x2) -> (x1*cos - x2*sin, x2*cos + x1*sin)
    def rot(x1, x2):
        return x1 * cos - x2 * sin, x2 * cos + x1 * sin

    q1r, q2r = rot(q1, q2); k1r, k2r = rot(k1, k2) # [B, H, T, half]
    q = torch.cat([q1r, q2r], dim=-1)
    k = torch.cat([k1r, k2r], dim=-1)
    return q, k

In [4]:
# ----- (B) clipped T5-style per-head relative bias -----
class ClippedRelPosBias(nn.Module):
    def __init__(self, num_heads, max_rel=128):
        super().__init__()
        self.max_rel = max_rel
        self.table = nn.Parameter(torch.zeros(num_heads, 2*max_rel - 1))  # [H, 2R-1]
    def forward(self, T, device=None):
        device = device or self.table.device
        q = torch.arange(T, device=device)[:, None] # [T,1]  (query indices i)
        k = torch.arange(T, device=device)[None, :] # [1, T] (key indices j)
        rel = (k - q).clamp(-self.max_rel+1, self.max_rel-1) # [T, T]
        idx = rel + (self.max_rel - 1)              # map to [0..2R-2]
        return self.table[:, idx]                   # [H, T, T]


In [58]:
def large_neg(dtype):
    # Safe additive mask sentinels
    return -1e4 if dtype in (torch.float16, torch.bfloat16) else -1e9

With nn.MultiheadAttention you get two knobs:

- key_padding_mask (bool, True = PAD/disallow)

- attn_mask (bool or float, broadcastable to [T,T], True = disallow for bool, added to logits for float) <br>

…but F.scaled_dot_product_attention (SDPA) has one attn_mask and a different convention:

- attn_mask: bool (True = ALLOW) or float (additive to logits) <br>

No separate key_padding_mask

So when you switch to SDPA you must fold everything (padding, causal, relative bias) into that single attn_mask. Also, on MPS/low-precision, feeding -inf can produce NaNs — prefer a large finite negative.

In [78]:
def build_sdpa_mask(
    pad_mask: torch.Tensor | None,  # [B,T] bool, True = real token
    rel_bias: torch.Tensor | None,  # [H,T,T] float or None
    *, B: int, T: int, H: int, dtype: torch.dtype, device: torch.device,
) -> torch.Tensor | None:
    """
    Returns a FLOAT additive mask for SDPA, shape [B,H,T,T], or None.
    Combines key padding + optional relative bias. (Causal can be handled by SDPA's is_causal=True.)
    """
    attn_mask = None

    if pad_mask is not None and pad_mask.dtype == torch.bool:
        # key-only mask → 0 for allowed keys, large negative for PAD keys
        allow_keys = pad_mask[:, None, None, :]  # [B,1,1,T], True = allowed
        neg = large_neg(dtype)
        # convert boolean allow-mask to float additive: disallowed → a large finite negative
        # Can't use -inf, the MPS backend turn it into nan.
        float_mask = (~allow_keys).to(dtype=dtype) * neg

    if rel_bias is not None:
        # rel_bias should be [H,T,T]; broadcast to batch and match dtype
        bias = rel_bias.to(dtype=dtype, device=device).unsqueeze(0).expand(B, -1, -1, -1)  # [B,H,T,T]
        attn_mask = bias if attn_mask is None else (float_mask + bias)

    return attn_mask  # float additive mask

In [77]:
# ----- (C) SDPA-based MHA (pre-proj qkv, optional RoPE + RelBias) -----
class SDPAMHA(nn.Module):
    def __init__(self, d_model, num_heads, p_drop=0.0, use_rope=False, rel_bias=None):
        super().__init__()
        assert d_model % num_heads == 0
        self.h = num_heads
        self.dh = d_model // num_heads
        # Use chunk later to separate: fewer kernel launches, better cache locality, and lower Python/autograd overhead.
        self.qkv = nn.Linear(d_model, 3*d_model, bias=True)
        self.o   = nn.Linear(d_model, d_model, bias=True)
        self.drop_p = p_drop
        self.use_rope = use_rope
        self.rel_bias = rel_bias  # nn.Module or None

    def split_heads(self, x):  # [B,T,D] -> [B,H,T,Dh]
        B,T,D = x.shape
        x = x.view(B,T,self.h,self.dh).permute(0,2,1,3) # [B,T,D] -> [B, T, H, Dh] -> [B, H, T, Dh]
        return x
    def merge_heads(self, x):  # [B,H,T,Dh] -> [B,T,D]
        B,H,T,Dh = x.shape
        return x.permute(0,2,1,3).reshape(B,T,H*Dh)
        # permute usually makes the tensor non-contiguous, .reshape() is equivalent to .contiguous().view()

    def forward(self, x, pad_mask=None, causal=False):
        # pad_mask: [B,T] bool, True=real token
        B,T,D = x.shape
        q,k,v = self.qkv(x).chunk(3, dim=-1)               # [B,T,D] each
        q,k,v = self.split_heads(q), self.split_heads(k), self.split_heads(v)  # [B,H,T,Dh]

        if self.use_rope:
            q, k = apply_rope(q, k)                        # rotate Q/K in-place

        # ---- Build SDPA attn_mask ----
        # SDPA accepts:
        #  * boolean mask: True = ALLOW attention (opposite of nn.MultiheadAttention)
        #  * float mask: added to logits
        
        # Right before calling the encoder, once per batch:
        assert pad_mask is None or (pad_mask.sum(dim=1) > 0).all(), "Empty sequence in batch; add CLS/UNK or drop it."

        if self.rel_bias is not None:
            bias = self.rel_bias(T, device=x.device)       # [H,T,T]
        attn_mask = build_sdpa_mask(pad_mask, bias, B=B, T=T, H=self.h, dtype=q.dtype, device=q.device)

        for name, t in {"q": q, "k": k, "v": v}.items():
            if not torch.isfinite(t).all():
                raise RuntimeError(f"{name} has non-finite values")

        if attn_mask is not None and attn_mask.dtype.is_floating_point:
            if not torch.isfinite(attn_mask[attn_mask > -1e30]).all():  # ignore our -inf sentinels
                raise RuntimeError("attn_mask has non-finite (non -inf) values")
            
        # ---- SDPA (handles scale, softmax, dropout, matmul) ----
        # SDPA expects [B,H,T,Dh]; attn_mask broadcastable to [B,H,T,T]
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=(self.drop_p if self.training else 0.0),
            is_causal=bool(causal)
        )                                                  # [B,H,T,Dh]
        x = self.merge_heads(out)                          # [B,T,D]
        x = self.o(x)                     # [B,T,D]
        if pad_mask is not None:
            x = x * pad_mask.unsqueeze(-1).to(x.dtype)  # zero padded query rows
        return x

# Plug SDPA MHA into Pre-LN encoder block

In [7]:
class PreLNEncoderBlockSDPA(nn.Module):
    """
    Pre-LN Encoder block that *composes* your SDPAMHA attention module.
    Pre-LN + residual wiring + FFN
    x -> x + Drop( Attn( LN(x) ) )
       -> x + Drop( FFN( LN(x) ) )

    Expects:
      - attn: a module like SDPAMHA with signature:
              attn(x: [B,T,D], pad_mask: Optional[Bool[B,T]], causal: bool) -> [B,T,D]
      - pad_mask: Bool[B,T], True = real token (not PAD)
      - causal: usually False for encoders
    """
    def __init__(
        self,
        d_model: int,
        *,
        attn: nn.Module,          # <-- pass your SDPAMHA instance here
        ff_mult: int = 4,
        p_drop: float = 0.1,
        norm: str = "ln",
        resid_mode: str = "plain",  # {"plain","scaled","rezero"}
    ):
        super().__init__()
        self.attn = attn                       # <-- your SDPAMHA
        self.ln1  = make_norm(norm, d_model)
        self.drop1 = nn.Dropout(p_drop)

        self.ln2  = make_norm(norm, d_model)
        self.ff   = nn.Sequential(
            nn.Linear(d_model, ff_mult * d_model),
            nn.GELU(),
            nn.Linear(ff_mult * d_model, d_model),
        )
        self.drop2 = nn.Dropout(p_drop)

        self.mode = resid_mode
        if resid_mode == "rezero":
            self.g = nn.Parameter(torch.zeros(1))  # learnable gate
        elif resid_mode == "scaled":
            self.alpha = 0.5                       # constant residual scale

    def _resid(self, x, h):
        # residual add with optional scaling/gating
        if self.mode == "plain":
            return x + h
        elif self.mode == "scaled":
            return x + self.alpha * h
        elif self.mode == "rezero":
            return x + self.g * h
        else:
            raise ValueError(f"unknown resid_mode={self.mode}")

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor | None = None, causal: bool = False):
        """
        x:        [B,T,D]
        pad_mask: Bool[B,T], True = real token (not PAD). Will be passed through to SDPAMHA.
        causal:   usually False for encoders
        """
        B, T, D = x.shape
        if pad_mask is not None:
            assert pad_mask.dtype == torch.bool and pad_mask.shape == (B, T), "pad_mask must be Bool[B,T]"

        # --- Attention branch (Pre-LN) ---
        a_in = self.ln1(x)
        a_out = self.attn(a_in, pad_mask=pad_mask, causal=causal)  # expects [B,T,D] from your SDPAMHA
        x = self._resid(x, self.drop1(a_out))

        # --- FFN branch (Pre-LN) ---
        f_in = self.ln2(x)
        f_out = self.ff(f_in)
        x = self._resid(x, self.drop2(f_out))

        return x  # [B,T,D]

In [8]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from src.text_helpers import make_tensors
from src.encoder_classifier_wrapper import EncoderClassifier
from src.train_utils import TrainConfig, kfold_train

In [39]:
# 0）Prepare X, y, M (ids [B,T], labels [B], mask [B,T] bool)
greeting_hard = [
    "good morning everyone",
    "hello there my friend",
    "hey buddy how are you",
    "good evening folks",
    "salutations from the sushi bar",          # greeting + food word
    "pizza party greetings to all",            # greeting + food word
    "hi from the ramen shop",                  # greeting + food word
    "hello and welcome to brunch",             # greeting + food word
]
food_hard = [
    "i love pizza",
    "pasta is tasty tonight",
    "fresh salad with apple",
    "i like sushi a lot",
    "good sandwich this morning",              # food + greeting words
    "ramen is great hello world",              # food + greeting word
    "eating an apple for breakfast",
    "not a fan of pizza anymore",              # negation
]
HARD_SUP = greeting_hard + food_hard
HARD_LABELS = [0]*len(greeting_hard) + [1]*len(food_hard)
X, M, y_float, y_long, stoi, itos, pad_id, cls_id = make_tensors(HARD_SUP, HARD_LABELS, min_freq=1, add_cls=False)

In [41]:
M.shape

torch.Size([16, 6])

In [22]:
pad_id, cls_id

(0, None)

In [61]:
# 1) Build your attention block
def make_block(d_model):
    attn = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=ClippedRelPosBias(4))
    # ClippedRelPosBias(4)
    return PreLNEncoderBlockSDPA(d_model, attn=attn, ff_mult=4, p_drop=0.1, norm="ln", resid_mode="plain")

In [11]:
# 2) Build the model ctor
def make_model(vocab_size, d_model, pad_id, cls_id, num_layers=2, pool="mean"):
    return EncoderClassifier(
        vocab_size=vocab_size, d_model=d_model, pad_id=pad_id,
        num_layers=num_layers, block_ctor=make_block, pool=pool,
        cls_id=cls_id, posenc=None, final_norm="ln"
    )

In [12]:
len(stoi.keys())

57

In [23]:
vocab_size = len(itos)
d_model = 64
make_model(vocab_size, d_model, pad_id, cls_id)

EncoderClassifier(
  (embed): Embedding(57, 64, padding_idx=0)
  (layers): ModuleList(
    (0-1): 2 x PreLNEncoderBlockSDPA(
      (attn): SDPAMHA(
        (qkv): Linear(in_features=64, out_features=192, bias=True)
        (o): Linear(in_features=64, out_features=64, bias=True)
        (rel_bias): ClippedRelPosBias()
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (drop1): Dropout(p=0.1, inplace=False)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
      (drop2): Dropout(p=0.1, inplace=False)
    )
  )
  (final_ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=64, out_features=1, bias=True)
)

In [36]:
X.unsqueeze(0).shape, M.unsqueeze(0).shape

(torch.Size([1, 16, 6]), torch.Size([1, 16, 6]))

In [42]:
embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
X_ids = embed(X)
X_ids.shape

torch.Size([16, 6, 64])

In [74]:
test_attn_base = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=None)
test_attn_base(X_ids, M)

tensor([[[-1.4653e-01, -4.6931e-03,  1.6437e-01,  ..., -1.9432e-01,
           1.5920e-01, -2.3063e-01],
         [-1.6405e-01,  7.5195e-02,  1.6117e-01,  ..., -1.0296e-01,
           1.9788e-01, -1.0972e-01],
         [-1.0968e-01,  5.3158e-02,  1.1645e-01,  ..., -1.6102e-01,
           9.6796e-02, -1.5366e-01],
         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00, -0.0000e+00]],

        [[-5.7306e-02, -1.1439e-01,  1.5806e-01,  ..., -2.3764e-01,
           2.2892e-01, -2.1861e-01],
         [-1.2536e-01, -8.5938e-02,  1.9407e-01,  ..., -7.1285e-02,
           1.7287e-01, -1.4739e-01],
         [-7.8699e-02, -1.6545e-01,  1.4682e-01,  ..., -1.4262e-01,
           1.6964e-01, -5.7005e-02],
         [-5.1923e-02, -1.2090e-01,  3.1135e-01,  ...

In [79]:
test_attn = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=ClippedRelPosBias(4))
test_attn(X_ids, M)

tensor([[[ 0.1767,  0.0518, -0.0660,  ..., -0.0311, -0.0624,  0.1031],
         [ 0.1430,  0.0246, -0.0819,  ...,  0.0052, -0.0574,  0.1335],
         [ 0.1132,  0.0096, -0.1232,  ..., -0.0095, -0.0666,  0.0756],
         [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000]],

        [[ 0.0106, -0.1400, -0.2422,  ..., -0.1927, -0.1355, -0.0282],
         [ 0.0740, -0.1267, -0.1998,  ..., -0.1578, -0.0612, -0.0183],
         [-0.0796, -0.1274, -0.1989,  ..., -0.1984, -0.1063, -0.0164],
         [ 0.0234, -0.0495, -0.1561,  ..., -0.1623, -0.0674,  0.0167],
         [ 0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
         [ 0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000]],

        [[ 0.2408,  0.1793,  0.2721,  ..., -0.3403,  0.1242, -0.0357],
         [ 0.1675,  0.2320,  0.2757,  ..., -0.5905,  0.0434,  0.0443],
  

In [80]:
dummy_block = make_block(d_model)
dummy_block(X_ids, M)

tensor([[[ 1.5011, -1.9598,  0.6901,  ...,  1.0972, -0.1493,  0.5156],
         [ 0.0743,  0.6415, -0.0294,  ..., -0.0042, -0.0296, -0.4958],
         [-1.5815, -0.0377, -0.7621,  ...,  1.4959, -1.3784,  1.8664],
         [-0.0355,  0.1103,  0.0000,  ...,  0.0940, -0.0535,  0.0348],
         [-0.0355,  0.1103, -0.0527,  ...,  0.0940, -0.0535,  0.0348],
         [ 0.0000,  0.1103, -0.0527,  ...,  0.0940, -0.0535,  0.0348]],

        [[ 0.0631,  0.2130, -0.2740,  ...,  0.4496,  0.7326,  0.6123],
         [-0.5080,  0.6649,  1.0902,  ...,  0.0089,  3.0859, -1.7359],
         [-0.4412,  1.6686, -0.5772,  ...,  1.1268, -1.9022, -1.3289],
         [ 1.6179,  0.8584, -2.0297,  ...,  1.4784, -0.6282,  0.7068],
         [-0.0355,  0.1103, -0.0527,  ...,  0.0940, -0.0535,  0.0348],
         [-0.0355,  0.1103, -0.0527,  ...,  0.0940,  0.0000,  0.0348]],

        [[ 1.6170,  0.5406,  0.9320,  ...,  0.6168, -1.4465, -1.2652],
         [ 1.0966, -2.3336, -0.4699,  ...,  0.7605, -1.7515,  0.2239],
  

In [81]:
# 3) Train with k-fold
cfg = TrainConfig(epochs=40, batch_size=None, lr_enc=3e-3, warmup_steps=0)
result = kfold_train(X, y_float, y_long, M, lambda: make_model(vocab_size, d_model, pad_id, cls_id), cfg)

Fold 1, epoch 5, ['loss: 0.9639574289321899', 'acc: 0.75', 'brier: 0.2648617923259735', 'margin: 0.3444175720214844']
Fold 2, epoch 5, ['loss: 1.3912625312805176', 'acc: 0.3333333432674408', 'brier: 0.5033624172210693', 'margin: 0.33863815665245056']
Fold 3, epoch 5, ['loss: 0.9259632229804993', 'acc: 0.6666666865348816', 'brier: 0.29952916502952576', 'margin: 0.32900500297546387']
Fold 3, epoch 10, ['loss: 0.48906663060188293', 'acc: 0.6666666865348816', 'brier: 0.18294931948184967', 'margin: 0.35060420632362366']
Fold 3, epoch 15, ['loss: 0.1283070147037506', 'acc: 1.0', 'brier: 0.021939540281891823', 'margin: 0.3844712972640991']
Fold 3, epoch 20, ['loss: 0.040471915155649185', 'acc: 1.0', 'brier: 0.0029690610244870186', 'margin: 0.4611048698425293']
Fold 3, epoch 25, ['loss: 0.033970560878515244', 'acc: 1.0', 'brier: 0.002316561061888933', 'margin: 0.46725329756736755']
Fold 4, epoch 5, ['loss: 1.4336987733840942', 'acc: 0.6666666865348816', 'brier: 0.3618767261505127', 'margin: 0.

In [82]:
result

{'folds': [{'loss': 0.6579810380935669,
   'acc': 0.5,
   'brier': 0.23283793032169342,
   'margin': 0.09363312274217606},
  {'loss': 0.819510281085968,
   'acc': 0.0,
   'brier': 0.31238171458244324,
   'margin': 0.057350825518369675},
  {'loss': 0.030911490321159363,
   'acc': 1.0,
   'brier': 0.0017949133180081844,
   'margin': 0.47002896666526794},
  {'loss': 0.6589000821113586,
   'acc': 0.6666666865348816,
   'brier': 0.2330356389284134,
   'margin': 0.04276571795344353},
  {'loss': 0.34120258688926697,
   'acc': 0.6666666865348816,
   'brier': 0.10634505748748779,
   'margin': 0.26382431387901306}],
 'mean': {'loss': 0.501701095700264,
  'acc': 0.5666666746139526,
  'brier': 0.1772790509276092,
  'margin': 0.18552058935165405},
 'std': {'loss': 0.28197171627743844,
  'acc': 0.32659863480444,
  'brier': 0.10979492692251096,
  'margin': 0.16268142993309942}}