In [None]:
'''
Cross-attention is extremely simple, same as normal attention but x in Q = W_q(x) and x in W_k(x)/W_v(x) are no longer 
the same sequence. The "attention" paid from queries to keys/vals is now across sequences rather than 
within a sequence of tokens. This is actually the "og" attention in that it's how attention appeared 
in the original "Attention is All You Need" paper which was originally motivated by the language translation task. 
Cross-attention there allowed any token in "I am a guy" to attend to any token in "Je suis un mec," which is obviously 
something you'd want to do. 

Cross-attention isn't used this way in LLMs these days (or at all). It is still fundamental because it's used in many other 
modern settings: RAG to attend to the top-k embeddings a retrieval model/reranker gives you, multimodal LLMs for modality tokens 
to talk to each other, conditional diffusion models to steer the latent/make the latent a function of an external 
sequence you want to condition on/use as context, and much more. It's a pretty neat trick to know well if you're planning 
on reading about/working on arch at all. 

The interpretation of cross-attention is that the resulting vector (1/D) * SM(Q@K.T) @ V that is 
a linear combination of values has value vectors now coming from the another sequence, but their weights 
(co-efficients) are chosen by the SOURCE sequence. 

Personal note: this is actually an important tool if you're innovating in architectures, for instance 
one of the first projects I tried years ago attempted to improve the performance of a LM on languaeg (eg. perplexity)
by having it attend to embeddings produced by a VISION encoder. The dream was that attending to an image of various 
flower species would help you distinguish them even when you were presented only textual descriptions. 
Ultimately this didn't work but others got a related idea using cross-attention to work augmenting 
LLMs with skills that other LLMs have, see (https://arxiv.org/pdf/2401.02412).
'''


In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import math 

# let's quickly reimplement mhsa as a baseline that we can then tweak. 
# good practice to implement this from scratch many times!
class MHSA(nn.Module): 
    def __init__(self, head_dim=64, D=512, causal=True): 
        super().__init__()
        self.head_dim=head_dim
        self.D = D
        assert D % head_dim == 0, "Error: head dimension does not evenly divide residual stream dimension."
        self.nheads = D//head_dim
        self.causal = causal 
        
        self.qkv_proj = nn.Linear(D, 3 * D)
        self.out_proj = nn.Linear(D, D)

    def forward(self, x): # bsd -> bsd 
        B, S, _ = x.shape
        qkv = self.qkv_proj(x)  # [b, s, 3*D]
        qkv = qkv.reshape(B, S, 3, self.nheads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, b, nheads, s, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]  # each [b, nheads, s, head_dim]

        attn_logits = torch.einsum('bnid,bnjd->bnij', q, k) # output is [b, nh, s, s]
        normalize = math.sqrt(self.head_dim)
        attn_logits = attn_logits / normalize # elementwise 

        if self.causal: 
            mask = torch.arange(S)[:, None] >= torch.arange(S) # [s, s]
            attn_logits = torch.where(mask, attn_logits, float('-inf') * torch.ones_like(attn_logits))

        A = F.softmax(attn_logits, dim=-1) # [b, nh, s, s]

        # out = A @ V then cat all attn heads together to get a long value vector of dim D
        out = torch.einsum('bnij,bnjd->bnid', A, v) # [b, nh, s, s] @ [b, nh, s, hd] -> [b, nh, s, hd]
        out = out.transpose(1, 2).reshape(B, S, -1) # [b, s, d]

        return self.out_proj(out) # [d, d] @ [b, s, d] along last dim --> [b, s, d] outputs 



In [2]:
class CrossMHSA(nn.Module): 
    def __init__(self, head_dim=64, D=512): 
        super().__init__()
        self.head_dim = head_dim
        self.D = D
        assert D % head_dim == 0, "Error: head dimension does not evenly divide residual stream dimension."
        self.nheads = D//head_dim
        
        self.q_proj = nn.Linear(D, D)
        self.k_proj = nn.Linear(D, D)
        self.v_proj = nn.Linear(D, D)
        self.out_proj = nn.Linear(D, D)

    def forward(self, x1, x2): # (bs1d, bs2d) -> bsd
        B, S1, _ = x1.shape
        _, S2, _ = x2.shape
        q = self.q_proj(x1) # [b, s1, d]
        k, v = self.k_proj(x2), self.v_proj(x2)

        q = q.reshape(B, S1, self.nheads, self.head_dim).transpose(1, 2)  # [b, nh, s1, hd]
        k = k.reshape(B, S2, self.nheads, self.head_dim).transpose(1, 2)  # [b, nh, s2, hd]
        v = v.reshape(B, S2, self.nheads, self.head_dim).transpose(1, 2)  # [b, nh, s2, hd]

        attn_logits = torch.einsum('bnid,bnjd->bnij', q, k) # output is [b, nh, s1, s2]
        normalize = math.sqrt(self.head_dim)
        attn_logits = attn_logits / normalize # elementwise 

        A = F.softmax(attn_logits, dim=-1) # [b, nh, s1, s2]

        # out = A @ V then cat all attn heads together to get a long value vector of dim D
        out = torch.einsum('bnij,bnjd->bnid', A, v) # [b, nh, s1, s2] @ [b, nh, s2, hd] -> [b, nh, s1, hd]
        out = out.transpose(1, 2).reshape(B, S1, -1) # [b, s1, d]

        return self.out_proj(out) # [d, d] @ [b, s, d] along last dim --> [b, s, d] outputs 



In [3]:
# tests
batch_size = 2
seq_len = 5
dim = 512
head_dim = 64

# test 1: when sequences are the same
x = torch.randn(batch_size, seq_len, dim)

# initialize both attention modules
# for fair comparison, we need to set causal=False in MHSA
mhsa = MHSA(head_dim=head_dim, D=dim, causal=False)
cross_attn = CrossMHSA(head_dim=head_dim, D=dim)

# initialize weights to be the same
with torch.no_grad():
    # set qkv projections to be equivalent
    mhsa.qkv_proj.weight.copy_(torch.cat([
        cross_attn.q_proj.weight,
        cross_attn.k_proj.weight,
        cross_attn.v_proj.weight
    ], dim=0))
    mhsa.qkv_proj.bias.copy_(torch.cat([
        cross_attn.q_proj.bias,
        cross_attn.k_proj.bias,
        cross_attn.v_proj.bias
    ], dim=0))
    
    # set output projections to be the same
    cross_attn.out_proj.weight.copy_(mhsa.out_proj.weight)
    cross_attn.out_proj.bias.copy_(mhsa.out_proj.bias)

# forward pass
mhsa_output = mhsa(x)
cross_attn_output = cross_attn(x, x)  # same sequence for both inputs

# check if outputs are the same when inputs are the same
is_close = torch.allclose(mhsa_output, cross_attn_output, rtol=1e-4, atol=1e-4)
print(f"when sequences are identical: outputs match = {is_close}")

# test 2: when sequences are different
seq_len1 = 5
seq_len2 = 7
x1 = torch.randn(batch_size, seq_len1, dim)
x2 = torch.randn(batch_size, seq_len2, dim)

# forward pass with different sequences
cross_attn_diff = cross_attn(x1, x2)

# check output shape
expected_shape = (batch_size, seq_len1, dim)
assert cross_attn_diff.shape == expected_shape, f"expected shape {expected_shape}, got {cross_attn_diff.shape}"

# try to compute regular self-attention with different sequence lengths (this should fail or give different results)
try:
    # this won't work directly with MHSA as it expects a single sequence
    print("cross-attention allows different sequence lengths, while self-attention requires identical sequences")
except Exception as e:
    print(f"as expected, self-attention can't handle different sequence lengths: {e}")

# verify that cross-attention with different sequences gives different results than with identical sequences
x1_copy = x1.clone()
cross_attn_same = cross_attn(x1_copy, x1_copy)
is_different = not torch.allclose(cross_attn_same, cross_attn_diff, rtol=1e-4, atol=1e-4)
print(f"cross-attention gives different results with different sequences: {is_different}")

print("all tests completed!")


when sequences are identical: outputs match = True
cross-attention allows different sequence lengths, while self-attention requires identical sequences
cross-attention gives different results with different sequences: True
all tests completed!
