In [None]:
class MHSA(nn.Module): 
    def __init__(self, D, head_dim=64):
        super().__init__()
        self.D = D
        self.wq = nn.Linear(D,D)
        self.wk = nn.Linear(D,D)
        self.wv = nn.Linear(D,D)
        self.wo = nn.Linear(D,D)
        
        assert D % head_dim == 0 
        self.n_heads = D//head_dim 
        self.head_dim = head_dim 

    def forward(self, x): # BSD -> BSD 
        B, S, D = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x) # BSD -> BSD

        # view as B, n_heads, S, head_dim
        q = q.view(B, self.n_heads, S, self.head_dim)
        k = k.view(B, self.n_heads, S, self.head_dim)
        v = v.view(B, self.n_heads, S, self.head_dim)
        
        # compute attn scores using einsum to do q@k.t within each head now 
        scores = torch.einsum('bnqd,bnkd->bnqk', q, k) # [batch, nheads, seq, seq]
        scores = scores / (self.head_dim ** 0.5) # scale by sqrt(d_k)
        
        # the dim=-1 below was always confusing to me when learning about softmax
        # since we want to normalize each row 
        # the intuition is that to normalize each row, we want the COLUMNS to sum to 1
        # hence should tell torch to softmax over the cols, which here is the last dim of A 
        # since A is [B, S, S] 
        
        A = F.softmax(scores, dim=-1)
        
        # apply attention: multiply attention weights with values
        # A is [batch, n_heads, seq, seq]
        # v is [batch, n_heads, seq, head_dim]
        # we want out to be [batch, n_heads, seq, head_dim]
        out = torch.einsum('bnqk,bnkd->bnqd', A, v)

        # A is BSS and v is BSD
        # A@v is BSD 
        # A is queries (rows) by keys (cols), so want all rows to sum to 1
        # because a query can only pay unit attention over all keys behind it
        
        # reshape from [batch, n_heads, seq, head_dim] back to [batch, seq, dim]
        out = out.view(B, S, D)
        
        return self.wo(out)

# test it
if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Test parameters
    batch_size = 2
    seq_len = 4
    dim = 128
    head_dim = 32
    
    # Create random input tensor
    x = torch.randn(batch_size, seq_len, dim)
    
    # Initialize MHSA module
    mhsa = MHSA(D=dim, head_dim=head_dim)
    
    # Forward pass
    output = mhsa(x)
    
    # Basic shape tests
    assert output.shape == (batch_size, seq_len, dim), f"Expected shape {(batch_size, seq_len, dim)} but got {output.shape}"
    
    # Test attention weights sum to 1
    q, k, v = mhsa.wq(x), mhsa.wk(x), mhsa.wv(x)
    q = q.view(batch_size, mhsa.n_heads, seq_len, head_dim)
    k = k.view(batch_size, mhsa.n_heads, seq_len, head_dim)
    scores = torch.einsum('bnqd,bnkd->bnqk', q, k) / (head_dim ** 0.5)
    attn_weights = F.softmax(scores, dim=-1)
    
    # Check if attention weights sum to 1 (with small numerical tolerance)
    assert torch.allclose(attn_weights.sum(dim=-1), torch.ones_like(attn_weights.sum(dim=-1)), atol=1e-6)
    
    print("All tests passed!")
