# **Day 8 (Part 2 - Step 2): Position-wise Feed-Forward Network**

## The Non-Linear Powerhouse of the Transformer

### **1. Recap: What We've Built So Far**

**Our Encoder Layer Building Progress:**

| Component | Status | Purpose |
|-----------|--------|--------|
| Multi-Head Attention | ‚úÖ Day 7 | Relate words to each other |
| Positional Encoding | ‚úÖ Step 1 | Add position information |
| **Feed-Forward Network** | üîß **Today** | **Add non-linear transformations** |
| Layer Normalization | ‚è≥ Step 3 | Stabilize training |
| Residual Connections | ‚è≥ Step 3 | Enable gradient flow |
| Complete Encoder Layer | ‚è≥ Step 3 | Assemble everything |

**Today's Goal:** Build the Position-wise Feed-Forward Network (FFN) ‚Äì a simple but crucial component that adds **non-linear transformation power** to the Transformer!

Let's dive in! üöÄ

### **2. What is the Position-wise Feed-Forward Network?**

#### **The Problem: Attention is Linear!**

Remember the attention formula?

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

While softmax adds some non-linearity to the attention weights, the overall operation is mostly **linear**:
- Matrix multiplications (Q, K, V projections)
- Weighted sum of values

**Why is this a problem?**

Linear transformations have limited expressiveness:
- Multiple linear layers = equivalent to a single linear layer
- Can't learn complex, non-linear patterns
- Limited representational power

$$\text{Linear}_2(\text{Linear}_1(x)) = \text{Linear}_3(x)$$

We need **non-linearity** to learn complex functions!

#### **The Solution: Feed-Forward Network**

The Feed-Forward Network (FFN) is a simple **two-layer MLP** with a non-linear activation:

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

Or using ReLU notation:

$$\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$$

**In Modern Transformers (GELU):**

$$\text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2$$

**Architecture Diagram:**

```
Input (d_model=512)
        ‚Üì
  [Linear Layer 1]
  (512 ‚Üí 2048)
        ‚Üì
  [ReLU / GELU]
        ‚Üì
  [Dropout]
        ‚Üì
  [Linear Layer 2]
  (2048 ‚Üí 512)
        ‚Üì
Output (d_model=512)
```

<div align="center">
  <img src="https://miro.medium.com/v2/resize:fit:640/format:webp/0*l-kF0t8dFKSUFrXx.png" width="400"/>
  <p><i>Position-wise Feed-Forward Network</i></p>
</div>

### **3. Why "Position-wise"?**

The key insight is that the FFN is applied **independently to each position**:

```
Sequence: ["The", "cat", "sat", "on", "the", "mat"]
              ‚Üì       ‚Üì      ‚Üì     ‚Üì      ‚Üì      ‚Üì
            FFN     FFN    FFN   FFN    FFN    FFN
              ‚Üì       ‚Üì      ‚Üì     ‚Üì      ‚Üì      ‚Üì
          [out_1] [out_2] [out_3] [out_4] [out_5] [out_6]
```

**Important:** 
- Same weights for all positions (parameter sharing)
- No interaction between positions (unlike attention)
- Each position transformed independently

**Analogy:**

Think of attention as a **group discussion** where everyone talks to each other, and FFN as **individual thinking** where each person processes information on their own.

| Component | What It Does | Analogy |
|-----------|--------------|--------|
| Attention | Positions interact with each other | Group discussion üë• |
| FFN | Each position transformed independently | Individual thinking üß† |

### **4. The Expand-Contract Pattern**

A key design choice in the FFN is the **expand-then-contract** pattern:

```
d_model (512) ‚Üí d_ff (2048) ‚Üí d_model (512)
     ‚Üì              ‚Üì              ‚Üì
  Narrow        Expanded        Narrow
```

**Why Expand to 4√ó the Size?**

1. **More Expressiveness**: The larger hidden dimension allows for more complex computations
2. **Bottleneck Architecture**: Forces the model to learn compressed representations
3. **Feature Detection**: Each hidden unit can detect different patterns

**Analogy: The Thinking Process**

Imagine solving a complex problem:

1. **Input (512 dims)**: You receive information
2. **Expand (2048 dims)**: You think about many aspects, consider various possibilities
3. **Contract (512 dims)**: You summarize your thoughts into a conclusion

The expansion allows for **richer intermediate representations**!

<div align="center">
  <img src="https://theaisummer.com/static/3e9d1a5498e65f15e019bb48e50f529c/ee604/feed-forward-layer.png" width="500"/>
</div>

### **5. Activation Functions: ReLU vs GELU**

The original Transformer used **ReLU**, but modern transformers often use **GELU**.

#### **ReLU (Rectified Linear Unit)**

$$\text{ReLU}(x) = \max(0, x)$$

- Simple and fast
- "Hard" threshold at 0
- Can cause "dead neurons" (always output 0)

#### **GELU (Gaussian Error Linear Unit)**

$$\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]$$

Approximation:
$$\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)$$

- Smooth activation
- "Soft" threshold (probabilistic)
- Used in GPT, BERT, and most modern transformers

Let's visualize both!

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Create input range
x = torch.linspace(-4, 4, 200)

# Compute activations
relu = F.relu(x)
gelu = F.gelu(x)

# Plot
plt.figure(figsize=(12, 5))

# ReLU
plt.subplot(1, 2, 1)
plt.plot(x.numpy(), relu.numpy(), 'b-', linewidth=2, label='ReLU')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
plt.fill_between(x.numpy(), relu.numpy(), alpha=0.3)
plt.xlabel('Input (x)')
plt.ylabel('Output')
plt.title('ReLU: max(0, x)\n"Hard" threshold at 0', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# GELU
plt.subplot(1, 2, 2)
plt.plot(x.numpy(), gelu.numpy(), 'r-', linewidth=2, label='GELU')
plt.plot(x.numpy(), relu.numpy(), 'b--', linewidth=1, alpha=0.5, label='ReLU (reference)')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
plt.fill_between(x.numpy(), gelu.numpy(), alpha=0.3, color='red')
plt.xlabel('Input (x)')
plt.ylabel('Output')
plt.title('GELU: x ¬∑ Œ¶(x)\n"Soft" probabilistic threshold', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key Differences:")
print("‚Ä¢ ReLU: Hard cutoff at 0 - either fully on or fully off")
print("‚Ä¢ GELU: Smooth transition - allows small negative values through")
print("‚Ä¢ GELU is used in GPT, BERT, and most modern transformers!")

### **6. Implementing the Position-wise Feed-Forward Network**

Now let's implement the FFN as a PyTorch module!

In [None]:
class PositionWiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network
    
    This is a simple two-layer MLP applied independently to each position.
    
    Architecture:
        Input (d_model) ‚Üí Linear ‚Üí Activation ‚Üí Dropout ‚Üí Linear ‚Üí Output (d_model)
    
    Args:
        d_model: Model dimension (e.g., 512)
        d_ff: Feed-forward hidden dimension (typically 4 * d_model = 2048)
        dropout: Dropout probability (default: 0.1)
        activation: Activation function ('relu' or 'gelu')
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1, activation='relu'):
        super(PositionWiseFeedForward, self).__init__()
        
        # Store dimensions for inspection
        self.d_model = d_model
        self.d_ff = d_ff
        
        # First linear layer: expand from d_model to d_ff
        self.linear1 = nn.Linear(d_model, d_ff)
        
        # Second linear layer: contract from d_ff back to d_model
        self.linear2 = nn.Linear(d_ff, d_model)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Activation function
        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'gelu':
            self.activation = F.gelu
        else:
            raise ValueError(f"Unknown activation: {activation}. Use 'relu' or 'gelu'.")
        
        self.activation_name = activation
    
    def forward(self, x):
        """
        Forward pass of the FFN.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
        
        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """
        # Step 1: Expand dimensions (d_model ‚Üí d_ff)
        x = self.linear1(x)  # (batch, seq_len, d_ff)
        
        # Step 2: Apply activation function
        x = self.activation(x)
        
        # Step 3: Apply dropout
        x = self.dropout(x)
        
        # Step 4: Contract dimensions (d_ff ‚Üí d_model)
        x = self.linear2(x)  # (batch, seq_len, d_model)
        
        return x
    
    def __repr__(self):
        return (f"PositionWiseFeedForward(\n"
                f"  d_model={self.d_model},\n"
                f"  d_ff={self.d_ff},\n"
                f"  activation={self.activation_name}\n"
                f")")

In [None]:
# Let's test our implementation!

# Configuration (matching the original Transformer paper)
d_model = 512
d_ff = 2048  # 4 * d_model
dropout = 0.1

# Create FFN with ReLU
ffn_relu = PositionWiseFeedForward(d_model, d_ff, dropout, activation='relu')
print("FFN with ReLU:")
print(ffn_relu)
print()

# Create FFN with GELU (modern transformers)
ffn_gelu = PositionWiseFeedForward(d_model, d_ff, dropout, activation='gelu')
print("FFN with GELU:")
print(ffn_gelu)

In [None]:
# Test with sample input
batch_size = 2
seq_len = 10

# Create random input
x = torch.randn(batch_size, seq_len, d_model)
print(f"Input shape: {x.shape}")
print(f"  ‚Üí (batch_size={batch_size}, seq_len={seq_len}, d_model={d_model})")
print()

# Forward pass
ffn_relu.eval()  # Set to eval mode to disable dropout for testing
with torch.no_grad():
    output = ffn_relu(x)

print(f"Output shape: {output.shape}")
print(f"  ‚Üí (batch_size={batch_size}, seq_len={seq_len}, d_model={d_model})")
print()
print("‚úÖ Input and output shapes match! (as expected)")

### **7. Visualizing the Transformation**

Let's visualize what happens inside the FFN!

In [None]:
# Create a small FFN for visualization
d_model_small = 8
d_ff_small = 32

ffn_viz = PositionWiseFeedForward(d_model_small, d_ff_small, dropout=0.0, activation='relu')
ffn_viz.eval()

# Create a simple input (1 batch, 4 positions, 8 dims)
x_viz = torch.randn(1, 4, d_model_small)

# Get intermediate activations
with torch.no_grad():
    # Step 1: After first linear
    after_linear1 = ffn_viz.linear1(x_viz)
    
    # Step 2: After activation
    after_activation = F.relu(after_linear1)
    
    # Step 3: After second linear (final output)
    output_viz = ffn_viz.linear2(after_activation)

# Visualize the transformations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Input
im1 = axes[0, 0].imshow(x_viz[0].numpy(), cmap='RdBu', aspect='auto', vmin=-2, vmax=2)
axes[0, 0].set_title(f'1. Input\nShape: (4 positions, {d_model_small} dims)', fontsize=11)
axes[0, 0].set_xlabel('Dimension')
axes[0, 0].set_ylabel('Position')
plt.colorbar(im1, ax=axes[0, 0])

# After first linear (expanded)
im2 = axes[0, 1].imshow(after_linear1[0].numpy(), cmap='RdBu', aspect='auto', vmin=-2, vmax=2)
axes[0, 1].set_title(f'2. After Linear1 (Expanded!)\nShape: (4 positions, {d_ff_small} dims)', fontsize=11)
axes[0, 1].set_xlabel('Dimension (expanded to 32)')
axes[0, 1].set_ylabel('Position')
plt.colorbar(im2, ax=axes[0, 1])

# After ReLU
im3 = axes[1, 0].imshow(after_activation[0].numpy(), cmap='RdBu', aspect='auto', vmin=-2, vmax=2)
axes[1, 0].set_title(f'3. After ReLU\nNegatives ‚Üí 0 (see the zeros!)', fontsize=11)
axes[1, 0].set_xlabel('Dimension')
axes[1, 0].set_ylabel('Position')
plt.colorbar(im3, ax=axes[1, 0])

# Output (contracted back)
im4 = axes[1, 1].imshow(output_viz[0].numpy(), cmap='RdBu', aspect='auto', vmin=-2, vmax=2)
axes[1, 1].set_title(f'4. After Linear2 (Contracted!)\nShape: (4 positions, {d_model_small} dims)', fontsize=11)
axes[1, 1].set_xlabel('Dimension (back to 8)')
axes[1, 1].set_ylabel('Position')
plt.colorbar(im4, ax=axes[1, 1])

plt.suptitle('FFN Transformation: Expand ‚Üí Activate ‚Üí Contract', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Observations:")
print(f"‚Ä¢ Input:        {d_model_small} dimensions")
print(f"‚Ä¢ Expanded to:  {d_ff_small} dimensions (4√ó larger!)")
print(f"‚Ä¢ After ReLU:   Negative values become 0 (white areas)")
print(f"‚Ä¢ Output:       Back to {d_model_small} dimensions")

### **8. Parameter Count Analysis**

Let's understand how many parameters are in the FFN!

In [None]:
def count_parameters(model):
    """Count the number of trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def analyze_ffn_parameters(d_model, d_ff):
    """Analyze FFN parameter breakdown."""
    ffn = PositionWiseFeedForward(d_model, d_ff, dropout=0.1)
    
    # Linear 1: d_model ‚Üí d_ff
    linear1_weights = d_model * d_ff
    linear1_bias = d_ff
    
    # Linear 2: d_ff ‚Üí d_model
    linear2_weights = d_ff * d_model
    linear2_bias = d_model
    
    total = linear1_weights + linear1_bias + linear2_weights + linear2_bias
    
    print(f"FFN Parameter Analysis (d_model={d_model}, d_ff={d_ff})")
    print("=" * 60)
    print(f"Linear 1 (d_model ‚Üí d_ff):")
    print(f"  ‚Ä¢ Weights: {d_model} √ó {d_ff} = {linear1_weights:,}")
    print(f"  ‚Ä¢ Bias:    {d_ff:,}")
    print(f"  ‚Ä¢ Subtotal: {linear1_weights + linear1_bias:,}")
    print()
    print(f"Linear 2 (d_ff ‚Üí d_model):")
    print(f"  ‚Ä¢ Weights: {d_ff} √ó {d_model} = {linear2_weights:,}")
    print(f"  ‚Ä¢ Bias:    {d_model:,}")
    print(f"  ‚Ä¢ Subtotal: {linear2_weights + linear2_bias:,}")
    print()
    print(f"Total FFN Parameters: {total:,}")
    print()
    
    # Verify with actual model
    actual_count = count_parameters(ffn)
    print(f"Verification (actual count): {actual_count:,}")
    print(f"Match: {'‚úÖ Yes!' if actual_count == total else '‚ùå No'}")
    
    return total

# Analyze with original Transformer dimensions
params = analyze_ffn_parameters(512, 2048)

In [None]:
# Compare FFN params to total Encoder Layer params
print("\n" + "=" * 60)
print("FFN's Share of Encoder Layer Parameters")
print("=" * 60)

d_model = 512
d_ff = 2048
num_heads = 8

# FFN parameters
ffn_params = 2 * d_model * d_ff + d_ff + d_model

# Multi-Head Attention parameters (W_q, W_k, W_v, W_o)
mha_params = 4 * (d_model * d_model + d_model)  # 4 linear layers with biases

# Layer Norm parameters (2 layer norms, each has gamma and beta)
ln_params = 2 * 2 * d_model

total_encoder_layer = ffn_params + mha_params + ln_params

print(f"Multi-Head Attention: {mha_params:,} params ({100*mha_params/total_encoder_layer:.1f}%)")
print(f"Feed-Forward Network: {ffn_params:,} params ({100*ffn_params/total_encoder_layer:.1f}%)")
print(f"Layer Normalization:  {ln_params:,} params ({100*ln_params/total_encoder_layer:.1f}%)")
print(f"-" * 40)
print(f"Total Encoder Layer:  {total_encoder_layer:,} params")
print()
print("üí° Insight: FFN contains about 2/3 of the encoder layer's parameters!")
print("   This is because d_ff = 4 √ó d_model makes the FFN very large.")

### **9. Position-wise Independence: Demonstration**

Let's verify that the FFN processes each position **independently**!

In [None]:
# Create an FFN
ffn = PositionWiseFeedForward(d_model=64, d_ff=256, dropout=0.0)
ffn.eval()

# Create two inputs that differ only at position 0
x1 = torch.randn(1, 5, 64)  # 5 positions
x2 = x1.clone()
x2[0, 0, :] = torch.randn(64)  # Change only position 0

print("Testing Position-wise Independence")
print("=" * 50)
print("x1 and x2 differ ONLY at position 0")
print()

with torch.no_grad():
    out1 = ffn(x1)
    out2 = ffn(x2)

# Check each position
for pos in range(5):
    diff = torch.abs(out1[0, pos] - out2[0, pos]).max().item()
    is_same = diff < 1e-6
    status = "‚úÖ Same" if is_same else "‚ùå Different"
    print(f"Position {pos}: {status} (max diff: {diff:.2e})")

print()
print("Conclusion:")
print("‚Ä¢ Positions 1-4 are identical (not affected by change at position 0)")
print("‚Ä¢ Position 0 is different (as expected)")
print("‚Ä¢ This proves FFN is POSITION-WISE (no cross-position interaction)!")

### **10. Comparing with Attention: Key Differences**

In [None]:
# Let's contrast FFN with attention!

print("FFN vs Attention: Key Differences")
print("=" * 60)
print()

comparison = [
    ["Position Interaction", "No (independent)", "Yes (all-to-all)"],
    ["Purpose", "Non-linear transformation", "Aggregate information"],
    ["Computation Type", "Same weights, different inputs", "Content-dependent weights"],
    ["Parallelism", "Fully parallel", "Fully parallel"],
    ["Parameters", "~2/3 of layer", "~1/3 of layer"],
    ["Complexity", "O(n √ó d¬≤)", "O(n¬≤ √ó d)"],
    ["Role in Layer", "Feature transformation", "Context aggregation"],
]

# Print as table
print(f"{'Aspect':<25} {'FFN':<25} {'Attention':<25}")
print("-" * 75)
for row in comparison:
    print(f"{row[0]:<25} {row[1]:<25} {row[2]:<25}")

print()
print("üîë Key Insight:")
print("   Attention = 'What should I pay attention to?'")
print("   FFN = 'How should I transform what I've gathered?'")
print("   Together, they form a powerful processing unit!")

### **11. Modern Variations: SwiGLU and Other Alternatives**

Modern transformers (like LLaMA, PaLM) often use improved FFN variants:

#### **SwiGLU (Used in LLaMA, PaLM)**

$$\text{SwiGLU}(x) = (xW_1) \odot \text{Swish}(xW_{gate})$$

Where:
- $\text{Swish}(x) = x \cdot \sigma(x)$ (sigmoid-weighted)
- $\odot$ is element-wise multiplication
- $W_1$ and $W_{gate}$ are separate linear projections

**Benefits:**
- Better gradient flow
- Improved training dynamics
- Slightly better performance

For now, we'll stick with the **original ReLU/GELU FFN** ‚Äì it's simpler and works great for learning!

In [None]:
# Bonus: Let's implement a simple SwiGLU for the curious!

class SwiGLUFeedForward(nn.Module):
    """
    SwiGLU Feed-Forward Network (used in LLaMA, PaLM)
    
    SwiGLU(x) = (x @ W1) * Swish(x @ W_gate) @ W2
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        
        # SwiGLU uses 3 linear layers instead of 2
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Swish activation: x * sigmoid(x)
        swish = lambda x: x * torch.sigmoid(x)
        
        # SwiGLU: (x @ W1) * Swish(x @ W_gate)
        gate = swish(self.w_gate(x))
        x = self.w1(x) * gate
        x = self.dropout(x)
        x = self.w2(x)
        
        return x

# Quick test
swiglu = SwiGLUFeedForward(512, 2048)
x_test = torch.randn(2, 10, 512)
out_swiglu = swiglu(x_test)
print(f"SwiGLU input shape:  {x_test.shape}")
print(f"SwiGLU output shape: {out_swiglu.shape}")
print("\nüí° SwiGLU is what LLaMA and PaLM use instead of ReLU/GELU FFN!")

### **12. Summary: What We Built**

Excellent work! You've built the Position-wise Feed-Forward Network! üéâ

**Key Takeaways:**

‚úÖ **Purpose**: Add non-linear transformation power to the Transformer

‚úÖ **Architecture**: Two-layer MLP with expand-contract pattern
```
d_model (512) ‚Üí d_ff (2048) ‚Üí d_model (512)
```

‚úÖ **Position-wise**: Applied independently to each position (no cross-position interaction)

‚úÖ **Activation**: ReLU (original) or GELU (modern transformers)

‚úÖ **Parameters**: Contains about 2/3 of the Encoder layer's parameters!

‚úÖ **Complement to Attention**: 
- Attention = Context aggregation (gather information)
- FFN = Feature transformation (process information)

---

**Our Progress:**

| Component | Status |
|-----------|--------|
| Multi-Head Attention | ‚úÖ Day 7 |
| Positional Encoding | ‚úÖ Step 1 |
| **Feed-Forward Network** | ‚úÖ **Step 2 (Today!)** |
| Layer Normalization | ‚è≥ Step 3 |
| Residual Connections | ‚è≥ Step 3 |
| Complete Encoder Layer | ‚è≥ Step 3 |

**Next Up:** In Step 3, we'll combine everything to build the complete Encoder Layer!

### **13. Exercises**

Try these exercises to deepen your understanding!

**Exercise 1:** What happens if we set `d_ff = d_model` (no expansion)? Try it and compare outputs.

**Exercise 2:** Add a third linear layer to the FFN. Does this improve expressiveness?

**Exercise 3:** Implement the SwiGLU variant and compare parameter counts with the standard FFN.

**Exercise 4:** Visualize what happens with GELU vs ReLU ‚Äì how do the intermediate activations differ?

In [None]:
# Space for your exercises!

# Exercise 1: FFN with no expansion
# ffn_no_expand = PositionWiseFeedForward(d_model=512, d_ff=512)  # d_ff = d_model
# Compare with standard FFN...

# Your code here!