# Module 2 - Exercise 3: Gradient Flow

## Learning Objectives
- Understand and observe vanishing gradient problems with sigmoid and tanh activations
- Experience exploding gradient problems with poor weight initialization
- Implement gradient clipping to handle exploding gradients
- Learn how ReLU activations help maintain gradient flow
- Understand how batch normalization stabilizes training
- Analyze gradient statistics to diagnose training issues

## Setup and Imports

In [None]:
# Clone the test repository
!git clone https://github.com/racousin/data_science_practice.git /tmp/tests 2>/dev/null || true

# Import required modules
import sys
sys.path.append('/tmp/tests/tests/python_deep_learning')

# Import the improved test utilities
from test_utils import NotebookTestRunner, create_inline_test
from module2.test_exercise3 import Exercise3Validator, EXERCISE3_SECTIONS

# Create test runner and validator
test_runner = NotebookTestRunner("module2", 3)
validator = Exercise3Validator()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Section 1: Vanishing Gradients with Sigmoid

Let's first observe the vanishing gradient problem using sigmoid activations. The sigmoid function squashes inputs to (0, 1) and its derivative has a maximum value of 0.25, causing gradients to shrink as they propagate backward through deep networks.

In [None]:
# TODO: Create a deep network with sigmoid activations
# The network should have at least 10 layers (5 Linear + 5 Sigmoid)
# Input: 10 features, Hidden layers: 20 units each, Output: 1
deep_sigmoid_network = None

# Display network architecture
if deep_sigmoid_network:
    print("Deep Sigmoid Network:")
    print(deep_sigmoid_network)

In [None]:
# TODO: Compute gradients through the sigmoid network
# 1. Create a random input tensor (batch_size=32, features=10)
# 2. Forward pass through the network
# 3. Compute loss (use mean of output)
# 4. Backward pass
# 5. Collect gradients from each Linear layer's weight.grad
sigmoid_gradients = None

# Visualize gradient magnitudes
if sigmoid_gradients:
    grad_norms = [torch.norm(g).item() for g in sigmoid_gradients]
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(grad_norms)), grad_norms)
    plt.xlabel('Layer Index')
    plt.ylabel('Gradient Norm')
    plt.title('Gradient Norms in Deep Sigmoid Network')
    plt.yscale('log')
    plt.show()
    print(f"First layer gradient norm: {grad_norms[0]:.6f}")
    print(f"Last layer gradient norm: {grad_norms[-1]:.6f}")

In [None]:
# TODO: Calculate the vanishing gradient ratio
# Ratio = last_layer_gradient_norm / first_layer_gradient_norm
vanishing_ratio = None

if vanishing_ratio is not None:
    print(f"Vanishing gradient ratio: {vanishing_ratio:.8f}")
    print(f"This means the gradient shrinks by a factor of {1/vanishing_ratio:.2f}!")

In [None]:
# Test Section 1: Vanishing Gradients with Sigmoid
section_tests = [(getattr(validator, name), desc) for name, desc in EXERCISE3_SECTIONS["Section 1: Vanishing Gradients with Sigmoid"]]
test_runner.test_section("Section 1: Vanishing Gradients with Sigmoid", validator, section_tests, locals())

## Section 2: Vanishing Gradients with Tanh

The hyperbolic tangent (tanh) function also suffers from vanishing gradients, though typically less severe than sigmoid. Let's compare.

In [None]:
# TODO: Create a deep network with tanh activations
# Similar structure to sigmoid network but with Tanh activations
deep_tanh_network = None

if deep_tanh_network:
    print("Deep Tanh Network:")
    print(deep_tanh_network)

In [None]:
# TODO: Compute gradients through the tanh network
# Follow the same process as with sigmoid network
tanh_gradients = None

# Compare sigmoid vs tanh gradient flow
if sigmoid_gradients and tanh_gradients:
    sigmoid_norms = [torch.norm(g).item() for g in sigmoid_gradients]
    tanh_norms = [torch.norm(g).item() for g in tanh_gradients]
    
    plt.figure(figsize=(10, 4))
    x = range(len(sigmoid_norms))
    plt.plot(x, sigmoid_norms, 'r-', label='Sigmoid', marker='o')
    plt.plot(x, tanh_norms, 'b-', label='Tanh', marker='s')
    plt.xlabel('Layer Index')
    plt.ylabel('Gradient Norm (log scale)')
    plt.title('Gradient Flow: Sigmoid vs Tanh')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
# Test Section 2: Vanishing Gradients with Tanh
section_tests = [(getattr(validator, name), desc) for name, desc in EXERCISE3_SECTIONS["Section 2: Vanishing Gradients with Tanh"]]
test_runner.test_section("Section 2: Vanishing Gradients with Tanh", validator, section_tests, locals())

## Section 3: Exploding Gradients

Now let's create the opposite problem: exploding gradients. This typically happens with poor weight initialization or unstable network architectures.

In [None]:
# TODO: Create a network prone to exploding gradients
# Use large weight initialization (std=5.0 or multiply weights by 5)
# Use Linear layers without activation functions between them
unstable_network = None

if unstable_network:
    print("Unstable Network (prone to exploding gradients):")
    for i, module in enumerate(unstable_network.modules()):
        if isinstance(module, nn.Linear):
            print(f"Layer {i}: Weight std = {module.weight.data.std():.3f}")

In [None]:
# TODO: Observe exploding gradients
# Compute gradients and watch them explode!
# You might see NaN or very large values
exploding_gradients = None

if exploding_gradients:
    grad_norms = [torch.norm(g).item() if not torch.isnan(g).any() else float('nan') 
                  for g in exploding_gradients]
    
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(grad_norms)), grad_norms)
    plt.xlabel('Layer Index')
    plt.ylabel('Gradient Norm')
    plt.title('Exploding Gradients!')
    plt.show()
    
    print("Gradient norms:")
    for i, norm in enumerate(grad_norms):
        print(f"Layer {i}: {norm:.2e}")

In [None]:
# TODO: Implement gradient clipping
# Clip gradients to a maximum norm of 1.0
# Use torch.nn.utils.clip_grad_norm_
clipped_gradients = None

if clipped_gradients:
    clipped_norms = [torch.norm(g).item() for g in clipped_gradients]
    print("Clipped gradient norms:")
    for i, norm in enumerate(clipped_norms):
        print(f"Layer {i}: {norm:.4f}")
    print(f"\nMax gradient norm after clipping: {max(clipped_norms):.4f}")

In [None]:
# Test Section 3: Exploding Gradients
section_tests = [(getattr(validator, name), desc) for name, desc in EXERCISE3_SECTIONS["Section 3: Exploding Gradients"]]
test_runner.test_section("Section 3: Exploding Gradients", validator, section_tests, locals())

## Section 4: Solutions - ReLU and Batch Normalization

Let's explore two popular solutions to gradient flow problems:
1. **ReLU activations**: Don't saturate for positive inputs
2. **Batch Normalization**: Stabilizes gradients by normalizing inputs

In [None]:
# TODO: Create a deep network with ReLU activations
# Similar structure but use ReLU instead of sigmoid/tanh
relu_network = None

if relu_network:
    print("ReLU Network:")
    print(relu_network)

In [None]:
# TODO: Compute gradients through ReLU network
relu_gradients = None

# Compare all activation functions
if sigmoid_gradients and tanh_gradients and relu_gradients:
    sigmoid_norms = [torch.norm(g).item() for g in sigmoid_gradients]
    tanh_norms = [torch.norm(g).item() for g in tanh_gradients]
    relu_norms = [torch.norm(g).item() for g in relu_gradients]
    
    plt.figure(figsize=(12, 5))
    x = range(len(sigmoid_norms))
    plt.plot(x, sigmoid_norms, 'r-', label='Sigmoid', marker='o')
    plt.plot(x, tanh_norms, 'b-', label='Tanh', marker='s')
    plt.plot(x, relu_norms, 'g-', label='ReLU', marker='^')
    plt.xlabel('Layer Index (deeper →)')
    plt.ylabel('Gradient Norm (log scale)')
    plt.title('Gradient Flow Comparison: Different Activation Functions')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Calculate gradient preservation ratios
    for name, grads in [("Sigmoid", sigmoid_norms), ("Tanh", tanh_norms), ("ReLU", relu_norms)]:
        ratio = grads[-1] / grads[0] if grads[0] > 0 else 0
        print(f"{name} gradient preservation ratio: {ratio:.6f}")

In [None]:
# TODO: Create a network with Batch Normalization
# Add BatchNorm1d layers after Linear layers (before activation)
batchnorm_network = None

if batchnorm_network:
    print("BatchNorm Network:")
    print(batchnorm_network)

In [None]:
# TODO: Compute gradients through batch normalized network
# Note: BatchNorm requires batch_size > 1 and training mode
batchnorm_gradients = None

if batchnorm_gradients:
    bn_norms = [torch.norm(g).item() for g in batchnorm_gradients]
    
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(bn_norms)), bn_norms)
    plt.xlabel('Layer Index')
    plt.ylabel('Gradient Norm')
    plt.title('Gradient Flow with Batch Normalization')
    plt.show()
    
    print(f"Gradient norm std without BN: {np.std([torch.norm(g).item() for g in sigmoid_gradients]):.6f}")
    print(f"Gradient norm std with BN: {np.std(bn_norms):.6f}")

In [None]:
# Test Section 4: Solutions - ReLU and Batch Normalization
section_tests = [(getattr(validator, name), desc) for name, desc in EXERCISE3_SECTIONS["Section 4: Solutions - ReLU and Batch Normalization"]]
test_runner.test_section("Section 4: Solutions - ReLU and Batch Normalization", validator, section_tests, locals())

## Section 5: Gradient Analysis

Let's implement tools to analyze gradient statistics, which is crucial for debugging training issues.

In [None]:
# TODO: Compute gradient statistics for any network
# Calculate mean, std, min, max of all gradients
def compute_gradient_stats(gradients_list):
    """Compute statistics across all gradients"""
    # Flatten all gradients into a single tensor
    # Return dict with 'mean', 'std', 'min', 'max'
    pass

# Apply to ReLU network gradients
gradient_stats = None

if gradient_stats:
    print("Gradient Statistics:")
    for key, value in gradient_stats.items():
        if isinstance(value, torch.Tensor):
            value = value.item()
        print(f"{key}: {value:.6f}")

In [None]:
# TODO: Create gradient histogram data
# Flatten all gradients and prepare for histogram plotting
gradient_histogram_data = None

if gradient_histogram_data is not None:
    plt.figure(figsize=(12, 4))
    
    # Convert to numpy if tensor
    if torch.is_tensor(gradient_histogram_data):
        hist_data = gradient_histogram_data.cpu().numpy()
    else:
        hist_data = np.array(gradient_histogram_data)
    
    plt.subplot(1, 2, 1)
    plt.hist(hist_data, bins=50, alpha=0.7, color='blue', edgecolor='black')
    plt.xlabel('Gradient Value')
    plt.ylabel('Frequency')
    plt.title('Gradient Distribution')
    
    plt.subplot(1, 2, 2)
    plt.hist(np.log10(np.abs(hist_data) + 1e-10), bins=50, alpha=0.7, color='green', edgecolor='black')
    plt.xlabel('Log10(|Gradient|)')
    plt.ylabel('Frequency')
    plt.title('Log-Scale Gradient Distribution')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Total gradient elements: {len(hist_data)}")
    print(f"Zero gradients: {np.sum(np.abs(hist_data) < 1e-10)}")
    print(f"Large gradients (>1.0): {np.sum(np.abs(hist_data) > 1.0)}")

In [None]:
# Test Section 5: Gradient Analysis
section_tests = [(getattr(validator, name), desc) for name, desc in EXERCISE3_SECTIONS["Section 5: Gradient Analysis"]]
test_runner.test_section("Section 5: Gradient Analysis", validator, section_tests, locals())

In [None]:
# Display final summary
test_runner.final_summary()