In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

class Attention(nn.Module): # single head for now, we'll add multi-head later!
    def __init__(self, D=256): 
        super().__init__()
        self.D = D
        # divide by sqrt(D) to keep variance roughly constant, otherwise logits get too big
        self.scale = torch.sqrt(torch.tensor(D, dtype=torch.float32))
        # these are just linear projections that map input -> query/key/value vectors
        self.wq = nn.Linear(D, D) # query projection
        self.wk = nn.Linear(D, D) # key projection  
        self.wv = nn.Linear(D, D) # value projection
        self.wo = nn.Linear(D, D) # final output projection

        # critical to understand that weights themselves are independent of seqlen, S
        # so we can adaptively handle varying length sequences at inference time

    def forward(self, x): # x is [B, S, D] (batch, sequence length, hidden dim)
        # project input into Q,K,V vectors - each is [B, S, D]
        Q, K, V = self.wq(x), self.wk(x), self.wv(x)
        
        # compute attention scores between each position - [B, S, D] @ [B, D, S] -> [B, S, S]
        # scale to prevent softmax saturation which would kill gradients
        A_logits = (Q @ K.transpose(1, 2))/self.scale
        # this is the heart of attention, keys and values "talking to each other"
        # and why people call attention a "soft lookup" since it's like matching a new key 
        # to all existing keys based on how similar that key is, rather than with eg. relational DBs
        # where you need an exact match to return the value 
        
        # convert scores to probabilities with softmax - each query attends to all keys
        A = F.softmax(A_logits, dim=-1) # [B, S, S]
        
        # weighted sum of values based on attention probs
        # [B, S, S] @ [B, S, D] -> [B, S, D], then project back to output space
        return self.wo(A@V)


In [2]:
# Test the attention module
if __name__ == "__main__":
    import torch.nn as nn
    import torch.nn.functional as F
    
    # dummy inputs
    batch_size, seq_len, hidden_dim = 2, 4, 256
    x = torch.randn(batch_size, seq_len, hidden_dim)

    attn = Attention(D=hidden_dim)
    
    # run the forward pass
    out = attn(x)
    
    # check shapes
    assert out.shape == (batch_size, seq_len, hidden_dim), f"Expected shape {(batch_size, seq_len, hidden_dim)} but got {out.shape}"
    print("Test passed! Output shape is correct.")

Test passed! Output shape is correct.
