# Pre-LN manual MHA encoder block

Supports (a) a padding mask (to ignore PAD tokens) and (b) an optional causal mask (to block attending to future positions). <br>
 - Encoder self-attention: causal=False, use padding mask. <br>
 - Decoder self-attention: causal=True (+ padding mask if sequences are padded). <br>
 - Cross-attention (decoder→encoder): causal=False, use encoder’s padding mask on K/V. <br>

In [1]:
import math, torch, torch.nn as nn, torch.nn.functional as F

In [2]:
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [3]:
# ---------- Masks ----------
def make_padding_mask(lengths, T):
    """
    lengths: LongTensor [B] of true lengths
    returns mask [B, T] with True for real tokens, False for PAD
    """
    rng = torch.arange(T, device=lengths.device).unsqueeze(0)  # [1,T]
    return (rng < lengths.unsqueeze(1))                        # [B,T] bool

def make_causal_mask(T, device, allow_k_to_q=False):
    """
    returns [T,T] with True where attention is allowed.
    Standard causal: each query t can attend to keys <= t.
    """
    i = torch.arange(T, device=device).unsqueeze(1)
    j = torch.arange(T, device=device).unsqueeze(0)
    if allow_k_to_q:
        # (not needed today) alternative patterns for enc-dec
        pass
    return (j <= i)  # True on/left of diagonal

In [4]:
# ---------- Manual Multi-Head Self-Attention ----------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, p_drop: float = 0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" # num_heads of emebeddings concatenate to be d_model
        self.d_model = d_model
        self.h = num_heads
        self.dh = d_model // num_heads
        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(p_drop)
        self.scale = math.sqrt(self.dh)

    def split_heads(self, x):
        # x: [B,T,D] -> [B,H,T,Dh]
        B, T, D = x.shape
        x = x.view(B, T, self.h, self.dh).permute(0, 2, 1, 3)
        return x

    def merge_heads(self, x):
        # x: [B,H,T,Dh] -> [B,T,D]
        B, H, T, Dh = x.shape
        # permute: [B, H, T, Dh] → [B, T, H, Dh]
        # contiguous: ensure memory is laid out, materializes a copy in the new order (after permute, which is non-contiguous)
        # view: → [B, T, D] where D = H*Dh. view can only reshape contiguous tensors
        return x.permute(0, 2, 1, 3).contiguous().view(B, T, H * Dh)

    def forward(self, x, pad_mask=None, causal=False):
        """
        x: [B,T,D]; pad_mask: [B,T] (True=real, False=PAD)
        returns: out [B,T,D], attn [B,H,T,T]; full attention matrix per head [T, T] (row: query token, col: key token)
        """
        B, T, D = x.shape
        Q = self.split_heads(self.Wq(x))  # [B,H,T,Dh]
        K = self.split_heads(self.Wk(x))  # [B,H,T,Dh]
        V = self.split_heads(self.Wv(x))  # [B,H,T,Dh]

        # scores: [B,H,T,T]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # padding mask (mask keys dimension): broadcast to [B,1,1,T]
        if pad_mask is not None:
            key_mask = pad_mask.unsqueeze(1).unsqueeze(1)  # True=keep
            # mask keys before softmax so every query distributes probability only over valid keys
            scores = scores.masked_fill(~key_mask, float("-inf"))

        # causal mask (prevent looking right): [T,T] -> [1,1,T,T]
        if causal:
            cm = make_causal_mask(T, x.device)  # True=keep
            scores = scores.masked_fill(~cm.unsqueeze(0).unsqueeze(0), float("-inf"))

        attn = torch.softmax(scores, dim=-1)             # [B,H,T,T]
        attn = self.dropout(attn)
        # ctx = context (the attended representation), i.e., attn @ V
        ctx = torch.matmul(attn, V)                      # [B,H,T,Dh]
        out = self.Wo(self.merge_heads(ctx))             # [B,T,D]
        return out, attn

In [None]:
# ---------- Pre-LN Encoder Block with MHA ----------
class PreLNEncoderBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, p_drop: float = 0.1, ff_mult: int = 4):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadSelfAttention(d_model, num_heads, p_drop=p_drop)
        self.drop1 = nn.Dropout(p_drop)

        self.ln2 = nn.LayerNorm(d_model)
        # FFN needs enough channels to re-combine features nonlinearly per token; generally FFN hidden dim ≈ 4 × d_model
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_mult*d_model),
            nn.GELU(), # Gaussian Error Linear Unit
            nn.Linear(ff_mult*d_model, d_model),
        )
        self.drop2 = nn.Dropout(p_drop)

    def forward(self, x, pad_mask=None, causal=False):
        a, attn = self.mha(self.ln1(x), pad_mask=pad_mask, causal=causal)
        x = x + self.drop1(a)              # residual after MHA
        f = self.ffn(self.ln2(x))
        x = x + self.drop2(f)              # residual after FFN
        return x, attn


In [6]:
torch.manual_seed(0)
B, T, D, H = 2, 6, 24, 4
lengths = torch.tensor([6, 4])                # second sequence has 2 PADs
pad_mask = make_padding_mask(lengths, T).to(DEVICE)  # [B,T]
x = torch.randn(B, T, D, device=DEVICE)

block = PreLNEncoderBlock(d_model=D, num_heads=H, p_drop=0.0).to(DEVICE)

In [None]:
# Padding mask only (encoder usage)
y_pad, attn_pad = block(x, pad_mask=pad_mask, causal=False)
print("attn_pad shape:", tuple(attn_pad.shape))  # [B,H,T,T]

attn_pad shape: (2, 4, 6, 6)


In [8]:
# Check PAD keys (last 2 positions in batch 2) get ~0 prob:
print("attn to PAD keys (batch 1, head 0):", attn_pad[1,0,-1].tolist())  # row: last query

attn to PAD keys (batch 1, head 0): [0.3118632733821869, 0.24609842896461487, 0.20189827680587769, 0.24014002084732056, 0.0, 0.0]


In [10]:
# Try H = 8
H = 8
block = PreLNEncoderBlock(d_model=D, num_heads=H, p_drop=0.0).to(DEVICE)
y_pad8, attn_pad8 = block(x, pad_mask=pad_mask, causal=False)
print("attn_pad shape:", tuple(attn_pad8.shape))  # [B,H,T,T]

attn_pad shape: (2, 8, 6, 6)


In [18]:
# Causal mask only (decoder usage)
y_pad8_d, attn_pad8_d = block(x, pad_mask=None, causal=True)
print("causal row (no pad mask) (batch 1, head 0):", attn_pad8_d[1,0,-1].tolist())

causal row (no pad mask) (batch 1, head 0): [0.1813250333070755, 0.1686900109052658, 0.16043618321418762, 0.076955147087574, 0.15375010669231415, 0.2588435411453247]


In [None]:
# Ensure future is masked: weights beyond diagonal ~0
print("causal row t=1 (head 0):", attn_pad8_d[0,0,1].tolist())

causal row t=2 (head 0): [0.31768906116485596, 0.6823108792304993, 0.0, 0.0, 0.0, 0.0]


In [21]:
# Both masks together (causal LM on padded batch)
y_pad8_b, attn_pad8_b = block(x, pad_mask=pad_mask, causal=True)
print("Ran with both masks (batch 1, head 0):", attn_pad8_b[1,0,-1].tolist())
print("shapes:", tuple(y_pad8_b.shape))

Ran with both masks (batch 1, head 0): [0.30868756771087646, 0.2871776819229126, 0.27312639355659485, 0.1310083568096161, 0.0, 0.0]
shapes: (2, 6, 24)


In [None]:

# Compare with PyTorch nn.MultiheadAttention shapes
mha = nn.MultiheadAttention(embed_dim=D, num_heads=H, batch_first=True).to(DEVICE)
attn_out, attn_w = mha(query=x, key=x, value=x, 
                       key_padding_mask=~pad_mask, need_weights=True, attn_mask=None, 
                       average_attn_weights=True, is_causal=False) 
print("torch.nn.MultiheadAttention ok:", tuple(attn_out.shape), tuple(attn_w.shape))

torch.nn.MultiheadAttention ok: (2, 6, 24) (2, 6, 6)


In [25]:
attn_w[1,0,:].tolist()

[0.23362691700458527,
 0.23248961567878723,
 0.2443876713514328,
 0.2894957959651947,
 0.0,
 0.0]

In [27]:
causal_mask = make_causal_mask(T, x.device)
causal_mask

tensor([[ True, False, False, False, False, False],
        [ True,  True, False, False, False, False],
        [ True,  True,  True, False, False, False],
        [ True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True]], device='mps:0')

In [30]:
~pad_mask

tensor([[False, False, False, False, False, False],
        [False, False, False, False,  True,  True]], device='mps:0')

In [31]:
attn_out_d, attn_w_d = mha(query=x, key=x, value=x, 
                       key_padding_mask=~pad_mask, need_weights=True, attn_mask=~causal_mask, 
                       average_attn_weights=True, is_causal=False) 
attn_w_d[1,:,:].tolist()

[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
 [0.4493405818939209, 0.5506594181060791, 0.0, 0.0, 0.0, 0.0],
 [0.35129815340042114, 0.3132854104042053, 0.33541643619537354, 0.0, 0.0, 0.0],
 [0.3033524453639984,
  0.2528725266456604,
  0.2039852887392044,
  0.23978973925113678,
  0.0,
  0.0],
 [0.27934426069259644,
  0.2716120481491089,
  0.21209318935871124,
  0.23695051670074463,
  0.0,
  0.0],
 [0.21622858941555023,
  0.26489898562431335,
  0.2565867304801941,
  0.26228567957878113,
  0.0,
  0.0]]

In [35]:
# Add p_drop=0.1 and re-run to see effect.
H = 4
block_do = PreLNEncoderBlock(d_model=D, num_heads=H, p_drop=0.1).to(DEVICE)
y_pad_do, attn_pad_do = block_do(x, pad_mask=pad_mask, causal=False)
attn_pad_do.shape

torch.Size([2, 4, 6, 6])

In [38]:
attn_pad_do[1,0].tolist()

[[0.26510390639305115,
  0.25778496265411377,
  0.285604864358902,
  0.3026174008846283,
  0.0,
  0.0],
 [0.30113857984542847, 0.0, 0.2674764096736908, 0.0, 0.0, 0.0],
 [0.3511961102485657, 0.0, 0.2110809087753296, 0.2361455112695694, 0.0, 0.0],
 [0.31595924496650696,
  0.27965593338012695,
  0.24938443303108215,
  0.2661115527153015,
  0.0,
  0.0],
 [0.27424779534339905,
  0.26770955324172974,
  0.2704418897628784,
  0.29871201515197754,
  0.0,
  0.0],
 [0.27505114674568176, 0.32703322172164917, 0.2993795871734619, 0.0, 0.0, 0.0]]