# Drop-in SDPA MHA with optional RoPE + RelBias

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math, torch, torch.nn as nn, torch.nn.functional as F

In [3]:
# ----- (A) tiny RoPE helper -----
def apply_rope(q, k, base=10000.0):
    # q,k: [B,H,T,Dh]; split half-dims
    Dh = q.size(-1); assert Dh % 2 == 0, "RoPE needs even head dim"
    half = Dh // 2
    q1, q2 = q[..., :half], q[..., half:]
    k1, k2 = k[..., :half], k[..., half:]
    T = q.size(-2)
    device = q.device
    dtype = q.dtype

    pos = torch.arange(T, device=device, dtype=dtype)[:, None]                    # [T,1]
    inv_freqs = torch.exp(-math.log(base) * (torch.arange(0, half, device=device, dtype=dtype) / half))[None, :]   # [1, half]
    ang = pos * inv_freqs                                                    # [T,half]
    sin, cos = torch.sin(ang), torch.cos(ang)                                              # [T,half]
    # broadcast to [B,H,T,half]
    sin = sin[None, None, :, :]; cos = cos[None, None, :, :]

    # rotate (x1, x2) -> (x1*cos - x2*sin, x2*cos + x1*sin)
    def rot(x1, x2):
        return x1 * cos - x2 * sin, x2 * cos + x1 * sin

    q1r, q2r = rot(q1, q2); k1r, k2r = rot(k1, k2) # [B, H, T, half]
    q = torch.cat([q1r, q2r], dim=-1)
    k = torch.cat([k1r, k2r], dim=-1)
    return q, k

In [4]:
# ----- (B) clipped T5-style per-head relative bias -----
class ClippedRelPosBias(nn.Module):
    def __init__(self, num_heads, max_rel=128):
        super().__init__()
        self.max_rel = max_rel
        self.table = nn.Parameter(torch.zeros(num_heads, 2*max_rel - 1))  # [H, 2R-1]
    def forward(self, T, device=None):
        device = device or self.table.device
        q = torch.arange(T, device=device)[:, None] # [T,1]  (query indices i)
        k = torch.arange(T, device=device)[None, :] # [1, T] (key indices j)
        rel = (k - q).clamp(-self.max_rel+1, self.max_rel-1) # [T, T]
        idx = rel + (self.max_rel - 1)              # map to [0..2R-2]
        return self.table[:, idx]                   # [H, T, T]


In [58]:
def large_neg(dtype):
    # Safe additive mask sentinels
    return -1e4 if dtype in (torch.float16, torch.bfloat16) else -1e9

In [59]:
# ----- (C) SDPA-based MHA (pre-proj qkv, optional RoPE + RelBias) -----
class SDPAMHA(nn.Module):
    def __init__(self, d_model, num_heads, p_drop=0.0, use_rope=False, rel_bias=None):
        super().__init__()
        assert d_model % num_heads == 0
        self.h = num_heads
        self.dh = d_model // num_heads
        # Use chunk later to separate: fewer kernel launches, better cache locality, and lower Python/autograd overhead.
        self.qkv = nn.Linear(d_model, 3*d_model, bias=True)
        self.o   = nn.Linear(d_model, d_model, bias=True)
        self.drop_p = p_drop
        self.use_rope = use_rope
        self.rel_bias = rel_bias  # nn.Module or None

    def split_heads(self, x):  # [B,T,D] -> [B,H,T,Dh]
        B,T,D = x.shape
        x = x.view(B,T,self.h,self.dh).permute(0,2,1,3) # [B,T,D] -> [B, T, H, Dh] -> [B, H, T, Dh]
        return x
    def merge_heads(self, x):  # [B,H,T,Dh] -> [B,T,D]
        B,H,T,Dh = x.shape
        return x.permute(0,2,1,3).reshape(B,T,H*Dh)
        # permute usually makes the tensor non-contiguous, .reshape() is equivalent to .contiguous().view()

    def forward(self, x, pad_mask=None, causal=False):
        # pad_mask: [B,T] bool, True=real token
        B,T,D = x.shape
        q,k,v = self.qkv(x).chunk(3, dim=-1)               # [B,T,D] each
        q,k,v = self.split_heads(q), self.split_heads(k), self.split_heads(v)  # [B,H,T,Dh]

        if self.use_rope:
            q, k = apply_rope(q, k)                        # rotate Q/K in-place

        # ---- Build SDPA attn_mask ----
        # SDPA accepts:
        #  * boolean mask: True = ALLOW attention (opposite of nn.MultiheadAttention)
        #  * float mask: added to logits
        attn_mask = None
        if pad_mask is not None: # pad original dim: [B,T]
            # Allow attending to real keys only; do NOT AND with query mask
            m = pad_mask[:, None, None, :]   # [B,1,1,T], True = allowed keys
            attn_mask = m  # boolean is fine when no extra bias
        
        # Right before calling the encoder, once per batch:
        assert pad_mask is None or (pad_mask.sum(dim=1) > 0).all(), "Empty sequence in batch; add CLS/UNK or drop it."

        if self.rel_bias is not None:
            bias = self.rel_bias(T, device=x.device)       # [H,T,T]
            # turn everything into a FLOAT additive mask broadcastable to [B,H,T,T]
            bias = bias.unsqueeze(0).expand(B, -1, -1, -1) # [B,H,T,T]
            if attn_mask is not None and attn_mask.dtype == torch.bool:
                neg = large_neg(x.dtype)
                # convert boolean allow-mask to float additive: disallowed → a large finite negative
                # Can't use -inf, the MPS backend turn it into nan.
                float_mask = (~attn_mask).to(bias.dtype) * neg
                attn_mask = float_mask + bias               # [B,H,T,T] float
            else:
                attn_mask = bias                            # [B,H,T,T] float

        for name, t in {"q": q, "k": k, "v": v}.items():
            if not torch.isfinite(t).all():
                raise RuntimeError(f"{name} has non-finite values")

        if attn_mask is not None and attn_mask.dtype.is_floating_point:
            if not torch.isfinite(attn_mask[attn_mask > -1e30]).all():  # ignore our -inf sentinels
                raise RuntimeError("attn_mask has non-finite (non -inf) values")
            
        # ---- SDPA (handles scale, softmax, dropout, matmul) ----
        # SDPA expects [B,H,T,Dh]; attn_mask broadcastable to [B,H,T,T]
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=(self.drop_p if self.training else 0.0),
            is_causal=bool(causal)
        )                                                  # [B,H,T,Dh]
        x = self.merge_heads(out)                          # [B,T,D]
        x = self.o(x)                     # [B,T,D]
        if pad_mask is not None:
            x = x * pad_mask.unsqueeze(-1).to(x.dtype)  # zero padded query rows
        return x

# Plug SDPA MHA into Pre-LN encoder block

In [7]:
class PreLNEncoderBlockSDPA(nn.Module):
    """
    Pre-LN Encoder block that *composes* your SDPAMHA attention module.
    Pre-LN + residual wiring + FFN
    x -> x + Drop( Attn( LN(x) ) )
       -> x + Drop( FFN( LN(x) ) )

    Expects:
      - attn: a module like SDPAMHA with signature:
              attn(x: [B,T,D], pad_mask: Optional[Bool[B,T]], causal: bool) -> [B,T,D]
      - pad_mask: Bool[B,T], True = real token (not PAD)
      - causal: usually False for encoders
    """
    def __init__(
        self,
        d_model: int,
        *,
        attn: nn.Module,          # <-- pass your SDPAMHA instance here
        ff_mult: int = 4,
        p_drop: float = 0.1,
        norm: str = "ln",
        resid_mode: str = "plain",  # {"plain","scaled","rezero"}
    ):
        super().__init__()
        self.attn = attn                       # <-- your SDPAMHA
        self.ln1  = make_norm(norm, d_model)
        self.drop1 = nn.Dropout(p_drop)

        self.ln2  = make_norm(norm, d_model)
        self.ff   = nn.Sequential(
            nn.Linear(d_model, ff_mult * d_model),
            nn.GELU(),
            nn.Linear(ff_mult * d_model, d_model),
        )
        self.drop2 = nn.Dropout(p_drop)

        self.mode = resid_mode
        if resid_mode == "rezero":
            self.g = nn.Parameter(torch.zeros(1))  # learnable gate
        elif resid_mode == "scaled":
            self.alpha = 0.5                       # constant residual scale

    def _resid(self, x, h):
        # residual add with optional scaling/gating
        if self.mode == "plain":
            return x + h
        elif self.mode == "scaled":
            return x + self.alpha * h
        elif self.mode == "rezero":
            return x + self.g * h
        else:
            raise ValueError(f"unknown resid_mode={self.mode}")

    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor | None = None, causal: bool = False):
        """
        x:        [B,T,D]
        pad_mask: Bool[B,T], True = real token (not PAD). Will be passed through to SDPAMHA.
        causal:   usually False for encoders
        """
        B, T, D = x.shape
        if pad_mask is not None:
            assert pad_mask.dtype == torch.bool and pad_mask.shape == (B, T), "pad_mask must be Bool[B,T]"

        # --- Attention branch (Pre-LN) ---
        a_in = self.ln1(x)
        a_out = self.attn(a_in, pad_mask=pad_mask, causal=causal)  # expects [B,T,D] from your SDPAMHA
        x = self._resid(x, self.drop1(a_out))

        # --- FFN branch (Pre-LN) ---
        f_in = self.ln2(x)
        f_out = self.ff(f_in)
        x = self._resid(x, self.drop2(f_out))

        return x  # [B,T,D]

In [8]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from src.text_helpers import make_tensors
from src.encoder_classifier_wrapper import EncoderClassifier
from src.train_utils import TrainConfig, kfold_train

In [39]:
# 0）Prepare X, y, M (ids [B,T], labels [B], mask [B,T] bool)
greeting_hard = [
    "good morning everyone",
    "hello there my friend",
    "hey buddy how are you",
    "good evening folks",
    "salutations from the sushi bar",          # greeting + food word
    "pizza party greetings to all",            # greeting + food word
    "hi from the ramen shop",                  # greeting + food word
    "hello and welcome to brunch",             # greeting + food word
]
food_hard = [
    "i love pizza",
    "pasta is tasty tonight",
    "fresh salad with apple",
    "i like sushi a lot",
    "good sandwich this morning",              # food + greeting words
    "ramen is great hello world",              # food + greeting word
    "eating an apple for breakfast",
    "not a fan of pizza anymore",              # negation
]
HARD_SUP = greeting_hard + food_hard
HARD_LABELS = [0]*len(greeting_hard) + [1]*len(food_hard)
X, M, y_float, y_long, stoi, itos, pad_id, cls_id = make_tensors(HARD_SUP, HARD_LABELS, min_freq=1, add_cls=False)

In [41]:
M.shape

torch.Size([16, 6])

In [22]:
pad_id, cls_id

(0, None)

In [61]:
# 1) Build your attention block
def make_block(d_model):
    attn = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=ClippedRelPosBias(4))
    # ClippedRelPosBias(4)
    return PreLNEncoderBlockSDPA(d_model, attn=attn, ff_mult=4, p_drop=0.1, norm="ln", resid_mode="plain")

In [11]:
# 2) Build the model ctor
def make_model(vocab_size, d_model, pad_id, cls_id, num_layers=2, pool="mean"):
    return EncoderClassifier(
        vocab_size=vocab_size, d_model=d_model, pad_id=pad_id,
        num_layers=num_layers, block_ctor=make_block, pool=pool,
        cls_id=cls_id, posenc=None, final_norm="ln"
    )

In [12]:
len(stoi.keys())

57

In [23]:
vocab_size = len(itos)
d_model = 64
make_model(vocab_size, d_model, pad_id, cls_id)

EncoderClassifier(
  (embed): Embedding(57, 64, padding_idx=0)
  (layers): ModuleList(
    (0-1): 2 x PreLNEncoderBlockSDPA(
      (attn): SDPAMHA(
        (qkv): Linear(in_features=64, out_features=192, bias=True)
        (o): Linear(in_features=64, out_features=64, bias=True)
        (rel_bias): ClippedRelPosBias()
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (drop1): Dropout(p=0.1, inplace=False)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
      (drop2): Dropout(p=0.1, inplace=False)
    )
  )
  (final_ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=64, out_features=1, bias=True)
)

In [36]:
X.unsqueeze(0).shape, M.unsqueeze(0).shape

(torch.Size([1, 16, 6]), torch.Size([1, 16, 6]))

In [42]:
embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
X_ids = embed(X)
X_ids.shape

torch.Size([16, 6, 64])

In [62]:
test_attn_base = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=None)
test_attn_base(X_ids, M)

tensor([[[-5.9122e-02,  3.1716e-01, -4.0687e-01,  ..., -3.4909e-01,
           4.4481e-02,  2.4969e-01],
         [-6.5194e-02,  2.5074e-01, -4.3719e-01,  ..., -2.8228e-01,
          -1.4467e-02,  1.3763e-01],
         [-3.0605e-02,  1.2488e-01, -3.9658e-01,  ..., -2.6313e-01,
          -4.9807e-02,  1.5437e-01],
         [-0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
          -0.0000e+00,  0.0000e+00]],

        [[ 1.0762e-01, -1.1731e-01, -2.3151e-01,  ...,  1.5283e-01,
           4.9701e-01,  1.3370e-01],
         [-3.9643e-02, -1.9676e-01, -2.4049e-01,  ...,  9.1659e-02,
           5.0472e-01,  5.4839e-02],
         [ 8.1359e-02, -1.7565e-01, -2.1104e-01,  ...,  1.3653e-01,
           4.8680e-01,  1.5607e-01],
         [ 1.1275e-01, -9.0813e-02, -2.0965e-01,  ...

In [63]:
test_attn = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=ClippedRelPosBias(4))
test_attn(X_ids, M)

tensor([[[ 0.1203,  0.0819,  1.0217,  ...,  0.1773, -0.1169, -0.2097],
         [-0.0564,  0.1021,  0.5484,  ..., -0.0596, -0.3017, -0.0144],
         [ 0.0375,  0.2190,  0.9778,  ...,  0.2113, -0.3425, -0.1834],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000, -0.0000]],

        [[ 0.5148,  0.3040,  0.2782,  ..., -0.2069, -0.0372, -0.0615],
         [ 0.3745,  0.2437,  0.2511,  ..., -0.0970, -0.0199, -0.0289],
         [ 0.2498,  0.2925,  0.1969,  ..., -0.1357,  0.0079, -0.0318],
         [ 0.2801,  0.1823,  0.1290,  ..., -0.1760, -0.0445, -0.0606],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000]],

        [[-0.1756,  0.0578, -0.0193,  ..., -0.2550, -0.0818, -0.1166],
         [-0.3312,  0.1054, -0.0264,  ..., -0.0339, -0.2005, -0.0836],
  

In [64]:
dummy_block = make_block(d_model)
dummy_block(X_ids, M)

tensor([[[ 1.2913, -1.9837,  1.1537,  ...,  1.2065, -0.3974,  0.8566],
         [ 0.3824,  0.6312, -0.0348,  ..., -0.1078, -0.3644,  0.2711],
         [-1.5770, -1.0095, -0.5870,  ...,  1.1273, -1.5503,  2.0250],
         [ 0.0544, -0.0563,  0.0144,  ..., -0.0632, -0.0425, -0.0380],
         [ 0.0544, -0.0563,  0.0144,  ..., -0.0632, -0.0425, -0.0380],
         [ 0.0544, -0.0563,  0.0144,  ..., -0.0632, -0.0425, -0.0380]],

        [[-0.8735,  0.3605,  0.3083,  ...,  0.2514,  0.2068,  0.6130],
         [-1.1872,  1.3559,  1.2867,  ..., -0.3615,  2.5537, -1.8473],
         [-0.7304,  1.0131, -0.1052,  ...,  0.3780, -1.8876, -1.8145],
         [ 1.6754,  1.1961, -1.1641,  ...,  1.8507, -0.8605,  0.8219],
         [ 0.0544, -0.0563,  0.0144,  ..., -0.0632, -0.0425,  0.0000],
         [ 0.0000, -0.0563,  0.0144,  ..., -0.0632,  0.0000, -0.0380]],

        [[ 1.5728,  0.4876,  1.1922,  ...,  0.7780, -2.5383, -1.3821],
         [ 0.7214, -2.4537, -0.3922,  ...,  0.1538, -1.9146, -0.6202],
  

In [27]:
y_float

tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.])

In [65]:
# 3) Train with k-fold
cfg = TrainConfig(epochs=40, batch_size=None, lr_enc=3e-3, warmup_steps=0)
result = kfold_train(X, y_float, y_long, M, lambda: make_model(vocab_size, d_model, pad_id, cls_id), cfg)

Fold 1, epoch 5, ['loss: 1.072036623954773', 'acc: 0.75', 'brier: 0.2785666584968567', 'margin: 0.3352586030960083']
Fold 2, epoch 5, ['loss: 1.3459709882736206', 'acc: 0.3333333432674408', 'brier: 0.49663570523262024', 'margin: 0.31585898995399475']
Fold 3, epoch 5, ['loss: 0.95474773645401', 'acc: 0.6666666865348816', 'brier: 0.2979356050491333', 'margin: 0.3949678838253021']
Fold 4, epoch 5, ['loss: 1.465563178062439', 'acc: 0.6666666865348816', 'brier: 0.36048707365989685', 'margin: 0.3280555307865143']
Fold 5, epoch 5, ['loss: 0.576449453830719', 'acc: 0.6666666865348816', 'brier: 0.20134033262729645', 'margin: 0.23472516238689423']


In [66]:
result

{'folds': [{'loss': 0.7314648628234863,
   'acc': 0.5,
   'brier': 0.26776188611984253,
   'margin': 0.08592170476913452},
  {'loss': 0.8871936798095703,
   'acc': 0.3333333432674408,
   'brier': 0.3426041901111603,
   'margin': 0.12830226123332977},
  {'loss': 0.8276603817939758,
   'acc': 0.6666666865348816,
   'brier': 0.2919054329395294,
   'margin': 0.24631471931934357},
  {'loss': 0.6768187880516052,
   'acc': 0.6666666865348816,
   'brier': 0.2418615072965622,
   'margin': 0.04568130895495415},
  {'loss': 0.550352156162262,
   'acc': 0.6666666865348816,
   'brier': 0.18369753658771515,
   'margin': 0.15188170969486237}],
 'mean': {'loss': 0.7346979737281799,
  'acc': 0.5666666805744172,
  'brier': 0.2655661106109619,
  'margin': 0.13162034079432489},
 'std': {'loss': 0.11767819108178501,
  'acc': 0.13333333879709255,
  'brier': 0.052721567986518274,
  'margin': 0.0678972984955782}}