In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Multi-Head Attention implementation

        Args:
            d_model: Model dimension (embedding size)
            num_heads: Number of attention heads
            dropout: Dropout rate
        """
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        # Dimension of each head
        # Splits embedding dimension into multiple smaller chunks
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model, bias=True)
        self.w_k = nn.Linear(d_model, d_model, bias=True)
        self.w_v = nn.Linear(d_model, d_model, bias=True)

        # Output projection
        self.w_o = nn.Linear(d_model, d_model, bias=True)

        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Compute scaled dot product attention

        Args:
            Q, K, V: Query, Key, Value matrices
            mask: Optional attention mask

        Returns:
            attention_output: Weighted values
            attention_weights: Attention weights
        """
        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (set masked positions to large negative value.
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention weights to values
        attention_output = torch.matmul(attention_weights, V)

        return attention_output, attention_weights

    def forward(self, query, key, value, mask=None):
        """
        Forward pass of multi-head attention

        Args:
            query: Query tensor [batch_size, seq_len, d_model]
            key: Key tensor [batch_size, seq_len, d_model]
            value: Value tensor [batch_size, seq_len, d_model]
            mask: Optional attention mask [batch_size, seq_len, seq_len]

        Returns:
            output: Attention output [batch_size, seq_len, d_model]
            attention_weights: Attention weights for visualization
        """
        batch_size, seq_len = query.size(0), query.size(1)

        # 1. Linear projections to get Q, K, V
        Q = self.w_q(query)  # [batch_size, seq_len, d_model]
        K = self.w_k(key)    # [batch_size, seq_len, d_model]
        V = self.w_v(value)  # [batch_size, seq_len, d_model]

        # 2. Reshape and transpose for multi-head attention
        # Split into multiple heads: [batch_size, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 3. Apply scaled dot-product attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads
        # Transpose back: [batch_size, seq_len, num_heads, d_k]
        attention_output = attention_output.transpose(1, 2).contiguous()
        # Reshape to concatenate heads: [batch_size, seq_len, d_model]
        attention_output = attention_output.view(batch_size, seq_len, self.d_model)

        # 5. Final linear projection
        output = self.w_o(attention_output)

        return output, attention_weights

In [24]:
# Example usage and testing
def create_padding_mask(seq, pad_token=0):
    """Create padding mask for sequences with padding tokens"""
    return (seq != pad_token).unsqueeze(1).unsqueeze(2)

def create_causal_mask(seq_len):
    """Create causal (lower triangular) mask for decoder self-attention"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(0)

# Example usage
if __name__ == "__main__":
    # Parameters
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_heads = 8

    # Create sample input
    x = torch.randn(batch_size, seq_len, d_model)

    # Initialize multi-head attention
    mha = MultiHeadAttention(d_model, num_heads)

    # Self-attention (query, key, value are the same)
    output, attention_weights = mha(x, x, x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attention_weights.shape}")

    # Example with causal mask (for decoder)
    causal_mask = create_causal_mask(seq_len)
    output_masked, attention_weights_masked = mha(x, x, x, mask=causal_mask)

    print(f"\nWith causal mask:")
    print(f"Output shape: {output_masked.shape}")
    print(f"Attention weights shape: {attention_weights_masked.shape}")

    # Verify that future positions are masked (should be close to 0)
    print(f"Attention to future positions (should be ~0): {attention_weights_masked[0, 0, 0, -1].item():.6f}")
    print(f"Attention to current/past positions: {attention_weights_masked[0, 0, 0, 0].item():.6f}")

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Attention weights shape: torch.Size([2, 8, 10, 10])

With causal mask:
Output shape: torch.Size([2, 10, 512])
Attention weights shape: torch.Size([2, 8, 10, 10])
Attention to future positions (should be ~0): 0.000000
Attention to current/past positions: 1.111111


In [25]:
# Complete Transformer Block with Multi-Head Attention
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Complete Transformer block with multi-head attention and feed-forward network

        Args:
            d_model: Model dimension (embedding size)
            num_heads: Number of attention heads
            d_ff: Feed-forward network dimension
            dropout: Dropout rate
        """
        super(TransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Feed-forward network
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, attention_weights = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed-forward with residual connection
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x, attention_weights

In [26]:
# Example with complete transformer block
print("\n" + "="*50)
print("COMPLETE TRANSFORMER BLOCK EXAMPLE")
print("="*50)

transformer_block = TransformerBlock(d_model=512, num_heads=8, d_ff=2048)
output, attention_weights = transformer_block(x)

print(f"Transformer block output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")


COMPLETE TRANSFORMER BLOCK EXAMPLE
Transformer block output shape: torch.Size([2, 10, 512])
Attention weights shape: torch.Size([2, 8, 10, 10])


In [27]:
# ============================================================================
# POST-NORM TRANSFORMER BLOCK (Original Transformer - "Attention Is All You Need")
# ============================================================================
class PostNormTransformerBlock(nn.Module):
    """
    POST-NORM: LayerNorm is applied AFTER the residual connection
    Pattern: x = LayerNorm(x + Sublayer(x))

    Pros:
    - Original Transformer architecture
    - Stronger gradient signal to early layers

    Cons:
    - Can be unstable during training
    - Often requires learning rate warmup
    - Harder to train very deep models
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(PostNormTransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # POST-NORM: Norm AFTER residual

        # Attention block
        attention_output, attention_weights = self.attention(x, x, x, mask)
        # Norm after addition
        x = self.norm1(x + self.dropout(attention_output))

        # Feed-forward block
        ff_output = self.ff(x)
        # Norm after addition
        x = self.norm2(x + self.dropout(ff_output))

        return x, attention_weights

In [28]:
# ============================================================================
# PRE-NORM TRANSFORMER BLOCK (Modern variant - GPT, BERT variants)
# ============================================================================
class PreNormTransformerBlock(nn.Module):
    """
    PRE-NORM: LayerNorm is applied BEFORE the sub-layer (attention or FFN)
    Pattern: x = x + Sublayer(LayerNorm(x))

    Pros:
    - More stable training
    - Easier to train deep models (100+ layers)
    - No need for learning rate warmup
    - Better gradient flow

    Cons:
    - Slightly different from original paper
    - May need additional final LayerNorm at the end of the model
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(PreNormTransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # PRE-NORM: Norm BEFORE sub-layer

        # Attention block
        normed_x = self.norm1(x)  # Norm first
        attention_output, attention_weights = self.attention(normed_x, normed_x, normed_x, mask)
        x = x + self.dropout(attention_output)  # Residual without norm

        # Feed-forward block
        normed_x = self.norm2(x)  # Norm first
        ff_output = self.ff(normed_x)
        x = x + self.dropout(ff_output)  # Residual without norm

        return x, attention_weights

In [29]:
# ============================================================================
# DEMONSTRATION AND COMPARISON
# ============================================================================
if __name__ == "__main__":
    print("="*70)
    print("POST-NORM vs PRE-NORM TRANSFORMER BLOCKS")
    print("="*70)

    # Parameters
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_heads = 8
    d_ff = 2048

    # Create sample input
    x = torch.randn(batch_size, seq_len, d_model)

    # Initialize both versions
    post_norm_block = PostNormTransformerBlock(d_model, num_heads, d_ff)
    pre_norm_block = PreNormTransformerBlock(d_model, num_heads, d_ff)

    print("\n--- POST-NORM (Original Transformer) ---")
    print("Order: Sublayer → Add Residual → LayerNorm")
    output_post, _ = post_norm_block(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output_post.shape}")
    print(f"Output mean:  {output_post.mean().item():.6f}")
    print(f"Output std:   {output_post.std().item():.6f}")

    print("\n--- PRE-NORM (Modern variant) ---")
    print("Order: LayerNorm → Sublayer → Add Residual")
    output_pre, _ = pre_norm_block(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output_pre.shape}")
    print(f"Output mean:  {output_pre.mean().item():.6f}")
    print(f"Output std:   {output_pre.std().item():.6f}")

    print("\n" + "="*70)
    print("KEY DIFFERENCES")
    print("="*70)
    print("""
POST-NORM (x = LayerNorm(x + Sublayer(x))):
  ✓ Original Transformer architecture
  ✓ Stronger gradient signal
  ✗ Can be unstable with deep models
  ✗ Requires learning rate warmup

PRE-NORM (x = x + Sublayer(LayerNorm(x))):
  ✓ More stable training
  ✓ Better for very deep models (100+ layers)
  ✓ No warmup needed
  ✓ Used in GPT-2, GPT-3, many modern models
  ✗ May need final LayerNorm at end of model
    """)

    print("\n" + "="*70)
    print("WHEN TO USE WHICH?")
    print("="*70)
    print("""
Use POST-NORM when:
  - Following the original Transformer paper exactly
  - Building shallow models (6-12 layers)
  - You have good hyperparameter tuning infrastructure

Use PRE-NORM when:
  - Building deep models (20+ layers)
  - You want stable training out-of-the-box
  - You want faster convergence
  - Building large language models (like GPT)
    Sequential""")

POST-NORM vs PRE-NORM TRANSFORMER BLOCKS

--- POST-NORM (Original Transformer) ---
Order: Sublayer → Add Residual → LayerNorm
Input shape:  torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Output mean:  -0.000000
Output std:   1.000044

--- PRE-NORM (Modern variant) ---
Order: LayerNorm → Sublayer → Add Residual
Input shape:  torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Output mean:  -0.000116
Output std:   1.060625

KEY DIFFERENCES

POST-NORM (x = LayerNorm(x + Sublayer(x))):
  ✓ Original Transformer architecture
  ✓ Stronger gradient signal
  ✗ Can be unstable with deep models
  ✗ Requires learning rate warmup

PRE-NORM (x = x + Sublayer(LayerNorm(x))):
  ✓ More stable training
  ✓ Better for very deep models (100+ layers)
  ✓ No warmup needed
  ✓ Used in GPT-2, GPT-3, many modern models
  ✗ May need final LayerNorm at end of model
    

WHEN TO USE WHICH?

Use POST-NORM when:
  - Following the original Transformer paper exactly
  - Building shallow 