In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# Let's start by understanding each component step by step

print("=" * 60)
print("TRANSFORMER COMPONENTS TUTORIAL")
print("=" * 60)

TRANSFORMER COMPONENTS TUTORIAL


In [2]:
# 1. POSITIONAL ENCODING
# ======================
# Problem: Unlike RNNs, Transformers process all positions simultaneously
# Solution: Add positional information to input embeddings

class PositionalEncoding(nn.Module):
    """
    Adds positional information to word embeddings.
    Uses sine and cosine functions of different frequencies.
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # Create a matrix to hold positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # Create division term for the encoding formula
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices  
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer (not a parameter)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # Add positional encoding to input embeddings
        return x + self.pe[:, :x.size(1)]

In [None]:
# 2. SELF-ATTENTION MECHANISM
# ===========================
# Key insight: Instead of processing sequentially like RNNs,
# allow each position to attend to ALL positions simultaneously

class SelfAttention(nn.Module):
    """
    Self-attention allows each position to attend to all positions.
    Think of it as: "For each word, how much should I focus on every other word?"
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        
        # Linear projections for Query, Key, Value
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False) 
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        # Step 1: Create Query, Key, Value matrices
        Q = self.W_q(x)  # What am I looking for?
        K = self.W_k(x)  # What do I contain?
        V = self.W_v(x)  # What is my actual content?
        
        # Step 2: Compute attention scores
        # scores[i,j] = how much position i should attend to position j
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_model)
        
        # Step 3: Apply mask if provided (for decoder)
        if mask is not None:
            scores.masked_fill_(mask == 0, -1e9)
        
        # Step 4: Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        
        # Step 5: Apply attention weights to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

In [None]:
# 3. MULTI-HEAD ATTENTION
# =======================
# Instead of one attention, use multiple "heads" to capture different relationships

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention runs several attention mechanisms in parallel.
    Each head can focus on different types of relationships.
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # dimension of each head
        
        # Linear layers for all heads combined
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)  # Output projection
        
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        # Step 1: Linear projections in batch from d_model => h x d_k
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Step 2: Apply attention on all the projected vectors in batch
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Step 3: Concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model)
        
        return self.W_o(attention_output)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores.masked_fill_(mask == 0, -1e9)
            
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

In [None]:
# 4. LAYER NORMALIZATION
# ======================
# Normalizes inputs to each layer for stable training

class LayerNorm(nn.Module):
    """
    Layer normalization normalizes across the feature dimension.
    Helps with training stability and convergence.
    """
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))  # Scale parameter
        self.beta = nn.Parameter(torch.zeros(d_model))  # Shift parameter
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [None]:
# 5. FEED FORWARD NETWORK
# =======================
# Position-wise fully connected network applied to each position

In [None]:
class FeedForward(nn.Module):
    """
    Position-wise feed-forward network.
    Applied to each position separately and identically.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# 6. ENCODER BLOCK
# ================
# Combines multi-head attention + feed forward with residual connections

class EncoderBlock(nn.Module):
    """
    Encoder block: Multi-Head Attention + Feed Forward
    Each sub-layer has residual connection + layer norm
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Multi-head attention with residual connection
        attn_output = self.multi_head_attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))  # Residual connection
        
        # Feed forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))  # Residual connection
        
        return x

In [None]:
# 7. DECODER BLOCK
# ================
# Like encoder but with masked self-attention + encoder-decoder attention

class DecoderBlock(nn.Module):
    """
    Decoder block: Masked Multi-Head Attention + Encoder-Decoder Attention + Feed Forward
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.masked_multi_head_attention = MultiHeadAttention(d_model, n_heads)
        self.encoder_decoder_attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Masked self-attention (can only attend to previous positions)
        attn_output = self.masked_multi_head_attention(x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Encoder-decoder attention (attend to encoder output)
        attn_output = self.encoder_decoder_attention(x, src_mask)  # Simplified for demo
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

In [None]:
# 8. COMPLETE TRANSFORMER MODEL
# =============================

class Transformer(nn.Module):
    """
    Complete Transformer model combining all components
    """
    def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6, d_ff=2048, 
                 max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # Embedding layers
        self.src_embedding = nn.Embedding(vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        
        # Encoder and Decoder stacks
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # Encoder
        src_embedded = self.dropout(self.positional_encoding(
            self.src_embedding(src) * math.sqrt(self.d_model)))
        
        encoder_output = src_embedded
        for encoder_block in self.encoder_blocks:
            encoder_output = encoder_block(encoder_output, src_mask)
        
        # Decoder  
        tgt_embedded = self.dropout(self.positional_encoding(
            self.tgt_embedding(tgt) * math.sqrt(self.d_model)))
        
        decoder_output = tgt_embedded
        for decoder_block in self.decoder_blocks:
            decoder_output = decoder_block(decoder_output, encoder_output, 
                                         src_mask, tgt_mask)
        
        # Output projection
        output = self.output_projection(decoder_output)
        return output

In [None]:
# 9. DEMONSTRATION AND TESTING
# ============================

def create_padding_mask(seq, pad_idx=0):
    """Create mask to ignore padding tokens"""
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

def create_look_ahead_mask(size):
    """Create mask to prevent attending to future positions"""
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask == 0

print("\n1. TESTING POSITIONAL ENCODING")
print("-" * 40)
pos_enc = PositionalEncoding(d_model=512, max_len=100)
# Test with dummy input
dummy_input = torch.randn(1, 10, 512)  # batch_size=1, seq_len=10, d_model=512
pos_encoded = pos_enc(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"After positional encoding: {pos_encoded.shape}")
print("✓ Positional encoding adds location information to embeddings")

print("\n2. TESTING SELF-ATTENTION")
print("-" * 40)
self_attn = SelfAttention(d_model=512)
attn_output, attn_weights = self_attn(dummy_input)
print(f"Attention output shape: {attn_output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print("✓ Self-attention allows each position to attend to all positions")

print("\n3. TESTING MULTI-HEAD ATTENTION")
print("-" * 40)
multi_head_attn = MultiHeadAttention(d_model=512, n_heads=8)
mh_output = multi_head_attn(dummy_input)
print(f"Multi-head attention output shape: {mh_output.shape}")
print("✓ Multi-head attention captures different types of relationships")

print("\n4. TESTING ENCODER BLOCK")
print("-" * 40)
encoder_block = EncoderBlock(d_model=512, n_heads=8, d_ff=2048)
encoder_output = encoder_block(dummy_input)
print(f"Encoder block output shape: {encoder_output.shape}")
print("✓ Encoder block combines attention + feed-forward with residual connections")

print("\n5. TESTING COMPLETE TRANSFORMER")
print("-" * 40)
vocab_size = 10000
transformer = Transformer(vocab_size=vocab_size, d_model=512, n_heads=8, n_layers=6)

# Create dummy sequences
src_seq = torch.randint(1, vocab_size, (2, 10))  # batch_size=2, seq_len=10
tgt_seq = torch.randint(1, vocab_size, (2, 8))   # batch_size=2, seq_len=8

# Create masks
src_mask = create_padding_mask(src_seq)
tgt_mask = create_look_ahead_mask(8).unsqueeze(0).unsqueeze(0)

print(f"Source sequence shape: {src_seq.shape}")
print(f"Target sequence shape: {tgt_seq.shape}")
print(f"Source mask shape: {src_mask.shape}")
print(f"Target mask shape: {tgt_mask.shape}")

# Forward pass
output = transformer(src_seq, tgt_seq, src_mask, tgt_mask)
print(f"Transformer output shape: {output.shape}")
print("✓ Complete transformer processes input successfully!")

print("\n" + "=" * 60)
print("KEY CONCEPTS SUMMARY:")
print("=" * 60)
print("🔸 POSITIONAL ENCODING: Adds position information since Transformers lack inherent sequence order")
print("🔸 SELF-ATTENTION: Each position attends to all positions simultaneously")  
print("🔸 MULTI-HEAD ATTENTION: Multiple attention mechanisms capture different relationships")
print("🔸 RESIDUAL CONNECTIONS: Help gradients flow and enable deeper networks")
print("🔸 LAYER NORMALIZATION: Stabilizes training by normalizing layer inputs")
print("🔸 ENCODER-DECODER: Encoder processes input, decoder generates output autoregressively")
print("🔸 MASKS: Prevent attention to padding tokens and future positions in decoder")

print("\n🚀 You now understand the core components of Transformers!")
print("Each component serves a specific purpose in enabling the model to process")
print("sequences in parallel while capturing complex relationships between positions.")