In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention with optional masking."""
    
    def __init__(self, embed_dim, num_heads=8, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim  # d_model
        self.num_heads = num_heads  # h
        self.head_dim = embed_dim // num_heads  # d_k = d_model / h

        # The scaling factor: 1/√d_k
        self.scale = math.sqrt(self.head_dim)

        # Efficient projection: one large matrix per Q/K/V instead of h small ones
        # Shape: (embed_dim, embed_dim) which we’ll reshape into h heads
        self.query_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.key_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        self.dropout = nn.Dropout(dropout)
        
        # Output projection: concatenated heads → original dimension
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, embed_dim)
            mask: (batch_size, 1, 1, seq_len) or (batch_size, 1, seq_len, seq_len)
                  1 = attend, 0 = mask out
        Returns:
            output: (batch_size, seq_len, embed_dim)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        batch_size, seq_len, embed_dim = x.shape

        # Step 1: Project to Q, K, V
        # Shape: (batch_size, seq_len, embed_dim)
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)

        # Step 2: Reshape for multi-head attention
        # (batch, seq, embed_dim) → (batch, seq, num_heads, head_dim)
        # Then transpose to (batch, num_heads, seq, head_dim)
        # This groups all heads together for parallel processing
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Step 3: Compute scaled dot-product attention scores
        # Q: (batch, num_heads, seq_len, head_dim)
        # K^T: (batch, num_heads, head_dim, seq_len)
        # scores: (batch, num_heads, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # Step 4: Apply mask if provided
        # Padding mask or causal mask
        if mask is not None:
            # Set masked positions to -inf so softmax makes them ~0
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Step 5: Softmax to get attention probabilities
        # Shape: (batch, num_heads, seq_len, seq_len)
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Step 6: Weighted sum of values
        # attention_weights: (batch, num_heads, seq_len, seq_len)
        # V: (batch, num_heads, seq_len, head_dim)
        # Output: (batch, num_heads, seq_len, head_dim)
        attended_values = torch.matmul(attention_weights, V)

        # Step 7: Concatenate heads and project
        # (batch, num_heads, seq_len, head_dim) → (batch, seq_len, num_heads, head_dim)
        # → (batch, seq_len, embed_dim)
        attended_values = attended_values.transpose(1, 2).contiguous().view(
            batch_size, seq_len, embed_dim
        )

        # Final output projection
        output = self.out_proj(attended_values)

        return output, attention_weights


class TransformerBlock(nn.Module):
    """Single transformer block: Self-Attention + Feed-Forward."""
    
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        
        # Multi-head self-attention
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        
        # Layer normalization (Pre-LN architecture - see note below)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Feed-forward network: two linear layers with ReLU
        # This is the “thinking” step after attention gathers information
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, embed_dim)
            mask: Optional attention mask
        Returns:
            x: (batch_size, seq_len, embed_dim)
        """
        # Architecture Note: We use Pre-LayerNorm (Pre-LN)
        # Original “Attention is All You Need” used Post-LN: Sublayer(LayerNorm(x + Sublayer(x)))
        # Modern transformers prefer Pre-LN: x + Sublayer(LayerNorm(x))
        # Pre-LN is more stable and trains better, especially for deep networks
        
        # 1. Self-attention sublayer with residual connection
        attended_x, _ = self.attention(self.norm1(x), mask)
        x = x + self.dropout(attended_x)  # Residual: original + attended

        # 2. Feed-forward sublayer with residual connection
        ffn_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ffn_out)  # Residual: original + transformed

        return x


# Example usage
if __name__ == "__main__":
    # Create a transformer block
    transformer_block = TransformerBlock(
        embed_dim=512,   # d_model
        num_heads=8,     # h
        ff_dim=2048,     # Typically 4 × embed_dim
        dropout=0.1
    )
    
    # Sample input: batch of 2 sequences, each length 10, embedding dim 512
    x = torch.randn(2, 10, 512)
    
    # Optional: Create a causal mask for autoregressive modeling
    seq_len = 10
    causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
    # Shape: (1, 1, seq_len, seq_len) - broadcasts to batch
    
    # Forward pass
    output = transformer_block(x, mask=causal_mask)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")  # Same as input!
    print(f"Parameters: {sum(p.numel() for p in transformer_block.parameters()):,}")

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Parameters: 3,150,848
