In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [8]:
# ============================================================================
# ORIGINAL TRANSFORMER (Attention is All You Need, 2017)
# ============================================================================

class OriginalMultiHeadAttention(nn.Module):
    """Multi-Head Attention from original Transformer paper"""
    def __init__(self, d_model, num_heads, dropout=0.1):
        # super(OriginalMultiHeadAttention, self).__init__()
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections and split into heads
        # (batch_size, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)

        # Concatenate heads and apply output projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)

        return output

In [9]:
class OriginalTransformerBlock(nn.Module):
    """
    Original Transformer Block with POST-NORM architecture
    Structure: X -> MHA -> Add -> Norm -> FFN -> Add -> Norm
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        # super(OriginalTransformerBlock, self).__init__()
        super().__init__()

        # Multi-head attention
        self.attention = OriginalMultiHeadAttention(d_model, num_heads, dropout)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        # Layer normalization (applied AFTER residual)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with POST-NORM
        attn_output = self.attention(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)  # Norm AFTER residual

        # Feed-forward with POST-NORM
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)  # Norm AFTER residual

        return x

In [10]:
# ============================================================================
# MODERN TRANSFORMER (Pre-Norm + Improvements)
# ============================================================================

class ModernMultiHeadAttention(nn.Module):
    """
    Modern Multi-Head Attention with optimizations:
    - Fused QKV projection for efficiency
    - Optional Flash Attention pattern
    """
    def __init__(self, d_model, num_heads, dropout=0.1, bias=True):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Fused QKV projection (more efficient)
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=bias)

        # Output projection
        self.proj = nn.Linear(d_model, d_model, bias=bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape

        # Fused QKV projection and split
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq_len, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(
            batch_size, seq_len, d_model
        )

        # Output projeciton
        output = self.proj(attn_output)
        output = self.proj_dropout(output)

        return output

In [11]:
class ModernTransformerBlock(nn.Module):
    """
    Modern Transformer Block with PRE-NORM architecture
    Structure: X -> Norm -> MHA -> Add -> Norm -> FFN -> Add

    Key improvements:
    - Pre-LayerNorm (btter gradient flow, easier training)
    - GELU activation (smoother gradients than ReLU)
    - Optional bias=False for better performance
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, bias=True):
        super().__init__()

        # Multi-head attention
        self.attention = ModernMultiHeadAttention(d_model, num_heads, dropout, bias=bias)

        # Feed-forward network with GELU
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=bias),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model, bias=bias),
            nn.Dropout(dropout)
        )

        # Layer normalization (applied BEFORE sublayer)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # Self-attention with PRE-NORM
        x = x + self.attention(self.norm1(x), mask)  # Norm BEFORE attention

        # Feed-forward with PRE-NORM
        x = x + self.ffn(self.norm2(x))  # Norm BEFORE FFN

        return x

In [12]:
# ============================================================================
# DEMONSTRATION AND COMPARISON
# ============================================================================

if __name__ == "__main__":
    # Hyperparameters
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_heads = 8
    d_ff = 2048
    dropout = 0.1

    # Create sample input
    x = torch.randn(batch_size, seq_len, d_model)

    print("=" * 70)
    print("TRANSFORMER BLOCK COMPARISON")
    print("=" * 70)

    # Original Transformer
    print("\n1. ORIGINAL TRANSFORMER (Post-Norm, ReLU)")
    print("-" * 70)
    original_block = OriginalTransformerBlock(d_model, num_heads, d_ff, dropout)
    original_output = original_block(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {original_output.shape}")
    print(f"Parameters:   {sum(p.numel() for p in original_block.parameters()):,}")
    print(f"Architecture: X -> MHA -> Add -> Norm -> FFN -> Add -> Norm")

    # Modern Transformer
    print("\n2. MODERN TRANSFORMER (Pre-Norm, GELU)")
    print("-" * 70)
    modern_block = ModernTransformerBlock(d_model, num_heads, d_ff, dropout)
    modern_output = modern_block(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {modern_output.shape}")
    print(f"Parameters:   {sum(p.numel() for p in modern_block.parameters()):,}")
    print(f"Architecture: X -> Norm -> MHA -> Add -> Norm -> FFN -> Add")

    print("\n" + "=" * 70)
    print("KEY DIFFERENCES")
    print("=" * 70)
    print("""
Original (Post-Norm):
  ✓ Residual connection first, then normalization
  ✓ ReLU activation in FFN
  ✓ Separate Q, K, V projections
  ✗ Can have training instability with deep networks
  ✗ Requires careful learning rate tuning

Modern (Pre-Norm):
  ✓ Normalization first, then residual connection
  ✓ GELU activation (smoother gradients)
  ✓ Fused QKV projection (more efficient)
  ✓ Better gradient flow, easier to train deep networks
  ✓ More stable training, works well with larger learning rates
  ✓ Used in GPT-2, GPT-3, and most modern LLMs
    """)

TRANSFORMER BLOCK COMPARISON

1. ORIGINAL TRANSFORMER (Post-Norm, ReLU)
----------------------------------------------------------------------
Input shape:  torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Parameters:   3,152,384
Architecture: X -> MHA -> Add -> Norm -> FFN -> Add -> Norm

2. MODERN TRANSFORMER (Pre-Norm, GELU)
----------------------------------------------------------------------
Input shape:  torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Parameters:   3,152,384
Architecture: X -> Norm -> MHA -> Add -> Norm -> FFN -> Add

KEY DIFFERENCES

Original (Post-Norm):
  ✓ Residual connection first, then normalization
  ✓ ReLU activation in FFN
  ✓ Separate Q, K, V projections
  ✗ Can have training instability with deep networks
  ✗ Requires careful learning rate tuning

Modern (Pre-Norm):
  ✓ Normalization first, then residual connection
  ✓ GELU activation (smoother gradients)
  ✓ Fused QKV projection (more efficient)
  ✓ Better gradient flow