# RoPE + Relative Bias in MHA

A manual MHA that supports: <br>

- Sinusoidal add (absolute)

- RoPE (rotary) on Q/K only

- T5-style relative position bias added to scores

In [1]:
import math, torch, torch.nn as nn, torch.nn.functional as F

DEVICE = torch.device("mps") if torch.backends.mps.is_available() else (
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)

In [2]:
# -------------------- Masks --------------------
def make_padding_mask(lengths, T):
    rng = torch.arange(T, device=lengths.device).unsqueeze(0)       # [1,T]
    return (rng < lengths.unsqueeze(1))                             # [B,T] True=real

def make_causal_mask(T, device):
    i = torch.arange(T, device=device).unsqueeze(1)  # [T,1]
    j = torch.arange(T, device=device).unsqueeze(0)  # [1,T]
    return (j <= i)                                  # [T,T] True=keep (no future)

In [15]:
# -------------------- Sinusoidal positions (absolute) --------------------
def sinusoidal_positions(T, d_model, device, base=10000.0):
    pos = torch.arange(T, device=device).float()             # [T]
    i = torch.arange(0, d_model, 2, device=device).float()   # [D/2]
    inv_freq = torch.exp(-math.log(base)*(i / d_model))      # [D/2]
    emb = torch.zeros(T, d_model, device=device)
    angles = pos.unsqueeze(1) * inv_freq # [T, 1] * [1, D/2] = [T, D/2] 
    emb[:, 0::2] = torch.sin(angles)
    emb[:, 1::2] = torch.cos(angles)
    return emb                                              # [T, D]

In [4]:
# -------------------- RoPE cache & apply (for Q/K) --------------------
def rope_cache(T, dim, device, base=10000.0):
    assert dim % 2 == 0, "RoPE dimension must be even"
    half = dim // 2
    idx = torch.arange(0, half, device=device).float()
    inv_freq = torch.exp(-math.log(base) * (idx / half))        # [half]
    pos = torch.arange(T, device=device).float()   # [T]
    # Einstein summation notation: torch.einsum("ij,jk->ik", A, B) performs matrix multiplication of tensors A and B
    freqs = torch.einsum("t,f->tf", pos, inv_freq) # [T, half] outer product
    return torch.sin(freqs), torch.cos(freqs)      # both [T, half]

def apply_rope(x, sin, cos):
    """
    x: [B,H,T,Dh]; sin,cos: [T, Dh/2]
    returns x_rot: [B,H,T,Dh]
    """
    B,H,T,Dh = x.shape
    half = Dh // 2
    x1, x2 = x[..., :half], x[..., half:] # [B, H, T, Dh/2]
    # broadcast sin/cos: [1,1,T,half]
    sin_ = sin.view(1,1,T,half)
    cos_ = cos.view(1,1,T,half)
    xr1 = x1 * cos_ - x2 * sin_ # [B, H, T, Dh/2]
    xr2 = x1 * sin_ + x2 * cos_ # [B, H, T, Dh/2]
    return torch.cat([xr1, xr2], dim=-1) # [B,H,T,half] + [B,H,T,half] -> [B,H,T,Dh]

In [5]:
# -------------------- T5-style Relative Position Bias --------------------
class RelativePositionBias(nn.Module):
    """
    T5-style: learn per-head scalar biases for relative offsets, clipped to [-max_rel+1, max_rel-1].
    Added *to attention scores* before softmax.
    """
    def __init__(self, num_heads: int, max_rel: int = 128):
        super().__init__()
        self.num_heads = num_heads
        self.max_rel = max_rel
        # We learn a scalar bias per head, per relative offset Δ in the range Δ∈{−R+1,…,0,…,+R−1}, where R=max_rel.
        self.table = nn.Parameter(torch.zeros(num_heads, 2*max_rel - 1))  # [H, 2R-1]
        nn.init.zeros_(self.table)  # small init is fine

    def forward(self, T: int, device):
        # rel[i,j] = (j - i) clipped into [-R+1, R-1], then shifted to [0, 2R-2]
        q_pos = torch.arange(T, device=device)[:, None] # [T,1]  (query indices i)
        k_pos = torch.arange(T, device=device)[None, :] # [1,T]  (key indices   j)
        # Clipping: distances beyond ±(R−1) get folded into the border bucket.
        # Shift: map offsets from [−(R−1),…,+(R−1)] to indices [0,…,2R−2]
        rel = (k_pos - q_pos).clamp(-self.max_rel+1, self.max_rel-1) + (self.max_rel - 1)  # [T,T]
        # Advanced indexing: for each head h, and each pair (i,j), we look up the scalar table[h, bucket(j-i)]
        bias = self.table[:, rel]  # [H, T, T]
        return bias.unsqueeze(0)   # [1, H, T, T]

In [6]:
# -------------------- Manual MHA with options --------------------
class MHALayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int,
                 p_drop: float = 0.0,
                 posenc: str = "none",          # {"none","sin","rope"}
                 use_relbias: bool = False,
                 relbias_max_rel: int = 128):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.h = num_heads
        self.dh = d_model // num_heads
        self.posenc = posenc
        self.use_relbias = use_relbias

        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
        self.drop = nn.Dropout(p_drop)
        self.scale = math.sqrt(self.dh)
        if use_relbias:
            self.relbias = RelativePositionBias(num_heads, max_rel=relbias_max_rel)
        
    def split_heads(self, x):
        B,T,D = x.shape
        return x.view(B, T, self.h, self.dh).permute(0,2,1,3)  # [B,H,T,Dh]

    def merge_heads(self, x):
        B,H,T,Dh = x.shape
        return x.permute(0,2,1,3).contiguous().view(B, T, H*Dh) # [B,T,D]

    def forward(self, x, pad_mask=None, causal=False):
        """
        x: [B,T,D]; pad_mask: [B,T] True=real, False=PAD
        """
        B,T,D = x.shape
        # (A) absolute sinusoid add BEFORE projections (classic)
        if self.posenc == "sin":
            x = x + sinusoidal_positions(T, D, x.device)

        Q = self.split_heads(self.Wq(x))  # [B,H,T,Dh]
        K = self.split_heads(self.Wk(x))  # [B,H,T,Dh]
        V = self.split_heads(self.Wv(x))  # [B,H,T,Dh]

        # (B) RoPE rotates Q,K (NOT V)
        if self.posenc == "rope":
            sin, cos = rope_cache(T, self.dh, x.device)
            Q = apply_rope(Q, sin, cos)
            K = apply_rope(K, sin, cos)

        scores = torch.matmul(Q, K.transpose(-2, -1))/self.scale # [B, H, T, T]

        # (C) T5-style relative bias
        if self.use_relbias:
            scores = scores + self.relbias(T, x.device)  # broadcast over batch

        # padding mask (mask keys): [B,1,1,T]
        if pad_mask is not None:
            scores = scores.masked_fill(~pad_mask[:, None, None, :], float("-inf"))

        # causal mask: [1,1,T,T]
        if causal:
            cm = make_causal_mask(T, x.device)
            scores = scores.masked_fill(~cm[None, None, :, :], float("-inf"))

        attn = torch.softmax(scores, dim=-1)  # [B,H,T,T]
        attn = self.drop(attn)
        ctx = torch.matmul(attn, V)           # [B,H,T,Dh]
        out = self.Wo(self.merge_heads(ctx)) # [B,T,D]
        return out, attn

In [10]:
# -------------------- Pre-LN Encoder Block --------------------
class PreLNEncoderBlock(nn.Module):
    def __init__(self, d_model:int, num_heads:int, p_drop:float=0.1,
                 posenc:str="none", use_relbias:bool=False, relbias_max_rel:int=128, ff_mult:int=4):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.mha = MHALayer(d_model, num_heads, p_drop, posenc, use_relbias, relbias_max_rel)
        self.drop1 = nn.Dropout(p_drop)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_mult*d_model),
            nn.GELU(),
            nn.Linear(ff_mult*d_model, d_model)
        )
        self.drop2 = nn.Dropout(p_drop)

    def forward(self, x, pad_mask=None, causal=False):
        a, attn = self.mha(self.ln1(x), pad_mask=pad_mask, causal=causal)
        x = x + self.drop1(a)              # residual after MHA
        f = self.ffn(self.ln2(x))
        x = x + self.drop2(f)              # residual after FFN
        return x, attn


In [8]:
torch.manual_seed(0)
B, T, D, H = 2, 8, 32, 4
lengths = torch.tensor([8, 5], device=DEVICE)
pad_mask = make_padding_mask(lengths, T).to(DEVICE)
x = torch.randn(B, T, D, device=DEVICE)

In [11]:
# == No posenc, no relbias, encoder mask ==
blk = PreLNEncoderBlock(D, H, p_drop=0.0, posenc="none", use_relbias=False).to(DEVICE)
y, A = blk(x, pad_mask=pad_mask, causal=False)
print("out shape:", tuple(y.shape), "attn shape:", tuple(A.shape))
print("attn[1,0,-1] (last query, head0) to PAD keys ~0:", A[1,0,-1].tolist())

out shape: (2, 8, 32) attn shape: (2, 4, 8, 8)
attn[1,0,-1] (last query, head0) to PAD keys ~0: [0.18280170857906342, 0.23103535175323486, 0.2453910857439041, 0.19526836276054382, 0.145503431558609, 0.0, 0.0, 0.0]


In [16]:
# == sinusoidal Pos Enc ==
blk = PreLNEncoderBlock(D, H, p_drop=0.0, posenc="sin", use_relbias=False).to(DEVICE)
y, A = blk(x, pad_mask=pad_mask, causal=False)
print("Sinusoid attn[0,0,3] row:", [round(float(v),4) for v in A[0,0,3]])

Sinusoid attn[0,0,3] row: [0.1082, 0.1355, 0.1053, 0.1111, 0.0876, 0.1482, 0.1909, 0.1131]


Consider using tensor.detach() first. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  print("Sinusoid attn[0,0,3] row:", [round(float(v),4) for v in A[0,0,3]])


In [18]:
# == Rope Pos Enc ==
blk = PreLNEncoderBlock(D, H, p_drop=0.0, posenc="rope", use_relbias=False).to(DEVICE)
y, A = blk(x, pad_mask=pad_mask, causal=False)
print("Rope attn[0,0,3] row:", [round(float(v),4) for v in A[0,0,3]])

Rope attn[0,0,3] row: [0.1218, 0.1905, 0.2053, 0.075, 0.0678, 0.1006, 0.1499, 0.0891]


In [19]:
# == RoPE + T5 relative bias ==
blk = PreLNEncoderBlock(D, H, p_drop=0.0, posenc="rope", use_relbias=True).to(DEVICE)
y, A = blk(x, pad_mask=pad_mask, causal=False)
print("Rope attn[0,0,3] row:", [round(float(v),4) for v in A[0,0,3]])

Rope attn[0,0,3] row: [0.1141, 0.1353, 0.0942, 0.0912, 0.1125, 0.1933, 0.1593, 0.1001]


In [20]:
# == Decoder-style: causal + padding together ==
blk = PreLNEncoderBlock(D, H, p_drop=0.0, posenc="rope", use_relbias=True, relbias_max_rel=4).to(DEVICE)
y, A = blk(x, pad_mask=pad_mask, causal=True)
print("causal row t=4 (head 0), weights beyond j>4 ~0:", [round(float(v),4) for v in A[0,0,4]])

causal row t=4 (head 0), weights beyond j>4 ~0: [0.1817, 0.1905, 0.3169, 0.1942, 0.1167, 0.0, 0.0, 0.0]


Sinusoidal (absolute add): add a fixed position vector to inputs; attention still uses content+absolute cues. Ties representations to absolute positions. <br>

Relative position (bias on scores): add a learned bias per offset directly to logits (T5); encourages local attention or other patterns. Imposes a global tendency (e.g., strong local attention), resolving ties and stabilizing training. <br>

RoPE (rotary): rotate Q/K by angle ∝ position so that their dot-product inherently depends on the relative displacement. It lets the model condition matching on distance/direction in a smooth, content-aware way. But it's only applied to Q,K, meaning that it matches by relative position, but the information retrieved (V) is the same, once matched. 