# Transformer Block

Now we'll combine attention with the other key components to build a complete transformer block.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# Import our attention implementation
import sys
sys.path.append('../src')

# For now, let's copy our attention code here
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1))
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        attention_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, self.d_model)
        
        output = self.W_o(attention_output)
        
        return output

## Feed-Forward Network

The second key component of a transformer block is the position-wise feed-forward network.

In [2]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout=0.1):
        super().__init__()
        # Typically d_ff = 4 * d_model in transformers
        if d_ff is None:
            d_ff = 4 * d_model
            
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # GPT uses GELU activation instead of ReLU
        x = F.gelu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Test it
ff = FeedForward(d_model=64)
x = torch.randn(2, 10, 64)
output = ff(x)
print(f"FeedForward output shape: {output.shape}")

FeedForward output shape: torch.Size([2, 10, 64])


## Layer Normalization

Layer normalization is crucial for stable training of deep transformers.

In [3]:
# Let's understand what LayerNorm does
def visualize_layer_norm():
    # Create sample data with different scales
    x = torch.randn(1, 4, 8) * torch.tensor([1, 10, 0.1, 5]).view(1, 4, 1)
    
    # Apply layer norm
    ln = nn.LayerNorm(8)
    x_norm = ln(x)
    
    print(f"Before LayerNorm - mean: {x.mean(dim=-1)}, std: {x.std(dim=-1)}")
    print(f"After LayerNorm - mean: {x_norm.mean(dim=-1)}, std: {x_norm.std(dim=-1)}")
    
visualize_layer_norm()

Before LayerNorm - mean: tensor([[ 0.2553,  0.3720,  0.0229, -1.3679]]), std: tensor([[1.0016, 9.0629, 0.0996, 7.0906]])
After LayerNorm - mean: tensor([[ 0.0000e+00, -1.1176e-08,  1.8626e-09,  7.4506e-09]],
       grad_fn=<MeanBackward1>), std: tensor([[1.0690, 1.0690, 1.0684, 1.0690]], grad_fn=<StdBackward0>)


## Complete Transformer Block

Now let's combine everything with residual connections!

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        # Note: GPT uses "Pre-LN" (normalize first)
        attn_output = self.attention(self.ln1(x), mask)
        x = x + self.dropout(attn_output)
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(self.ln2(x))
        x = x + self.dropout(ff_output)
        
        return x

# Test the transformer block
block = TransformerBlock(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64)
output = block(x)
print(f"Transformer block output shape: {output.shape}")
print(f"Output maintained input shape: {output.shape == x.shape}")

Transformer block output shape: torch.Size([2, 10, 64])
Output maintained input shape: True


## Visualizing Residual Connections

Let's see why residual connections are important:

In [5]:
def test_gradient_flow():
    # Without residual connections
    x = torch.randn(1, 10, 64, requires_grad=True)
    
    # Simulate deep network without residuals
    y = x
    for _ in range(10):
        y = F.gelu(nn.Linear(64, 64)(y))
    loss_without = y.sum()
    loss_without.backward()
    grad_without = x.grad.norm().item()
    
    # With residual connections
    x = torch.randn(1, 10, 64, requires_grad=True)
    y = x
    for _ in range(10):
        y = y + F.gelu(nn.Linear(64, 64)(y))  # Residual!
    loss_with = y.sum()
    loss_with.backward()
    grad_with = x.grad.norm().item()
    
    print(f"Gradient norm without residuals: {grad_without:.6f}")
    print(f"Gradient norm with residuals: {grad_with:.6f}")
    print(f"Ratio: {grad_with / (grad_without + 1e-8):.2f}x stronger gradients!")

test_gradient_flow()

Gradient norm without residuals: 0.000086
Gradient norm with residuals: 211.578506
Ratio: 2473709.67x stronger gradients!


## Findings

- **Residual connections are absolutely critical!** Without them, gradients vanish after just a few layers. Our experiment showed 2.5 million times stronger gradients with residuals.
- **Layer normalization** keeps activations stable across different features
- **Pre-LN architecture** (normalize before attention/FF) is more stable than Post-LN
- **GELU activation** is smoother than ReLU and works better for transformers
- The transformer block maintains input shape throughout - elegant design!