# RNN and LSTM Comprehensive Guide: Sequential Learning Mastery

This notebook provides a complete understanding of Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTM) networks.

## Complete Learning Objectives:
1. **Sequential Data**: Understanding time series and sequence problems
2. **Vanilla RNN**: Architecture, forward pass, backpropagation through time
3. **Vanishing Gradients**: Why vanilla RNNs fail on long sequences
4. **LSTM Architecture**: Gates, cell state, hidden state mechanics
5. **GRU**: Simplified alternative to LSTM
6. **Applications**: Text, time series, sequence-to-sequence tasks
7. **Implementation**: From scratch understanding with real projects

**Prerequisites**: Complete foundational notebooks (01, 02, 03)

**Why RNNs Matter for GNNs**: Graph Neural Networks use similar message-passing concepts!

In [None]:
# Cell 1: Comprehensive RNN/LSTM Environment Setup
"""
SEQUENTIAL LEARNING LIBRARY ECOSYSTEM:

Core Deep Learning:
- tensorflow: Excellent RNN/LSTM support with tf.keras.layers.LSTM, GRU
- tensorflow.keras.preprocessing: Text and sequence preprocessing
- tensorflow.keras.utils: Sequence utilities and data generators

Text Processing:
- nltk: Natural Language Toolkit for text preprocessing
- re: Regular expressions for text cleaning
- string: String manipulation utilities

Time Series:
- pandas: Time series data manipulation and analysis
- datetime: Date and time handling
- numpy: Numerical operations for sequence data

Visualization:
- matplotlib: Time series plots, sequence visualizations
- seaborn: Statistical plots for sequence analysis
- plotly: Interactive time series plots
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, classification_report
from sklearn.model_selection import train_test_split

import re
import string
import datetime
import warnings
warnings.filterwarnings('ignore')

# For text data
try:
    import nltk
    # Download required NLTK data
    nltk.download('punkt', quiet=True)
    nltk.download('stopwords', quiet=True)
    from nltk.corpus import stopwords
    from nltk.tokenize import word_tokenize
    NLTK_AVAILABLE = True
except ImportError:
    NLTK_AVAILABLE = False
    print("NLTK not available - will use basic text processing")

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

print(f"🔧 RNN/LSTM ENVIRONMENT SETUP")
print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")
print(f"NLTK available: {NLTK_AVAILABLE}")
print(f"Random seed: {RANDOM_SEED}")

# Configure plotting for sequence visualization
plt.style.use('seaborn-v0_8')
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

print("\n✅ All RNN/LSTM libraries imported and configured successfully!")
print("🎯 Ready for sequential learning exploration!")

## 1. Understanding Sequential Data: The Foundation

**What Makes Data Sequential?**
- **Order matters**: Past influences present and future
- **Temporal dependencies**: Information flows through time
- **Variable length**: Sequences can have different lengths
- **Context dependency**: Meaning depends on surrounding elements

**Types of Sequential Data:**

1. **Time Series**: Stock prices, weather, sensor readings
   - Fixed time intervals
   - Continuous values
   - Trend and seasonality patterns

2. **Natural Language**: Text, speech, conversations
   - Discrete tokens (words, characters)
   - Grammar and syntax dependencies
   - Long-range dependencies

3. **Biological Sequences**: DNA, proteins, gene expressions
   - Discrete alphabet (A,C,G,T)
   - Functional dependencies
   - Pattern recognition

**Key Challenges:**
- **Variable lengths**: How to batch different sequence lengths?
- **Long dependencies**: Information from far past affecting present
- **Vanishing gradients**: Difficulty learning long-term patterns

In [None]:
# Cell 2: Sequential Data Analysis and Preparation

print("=== SEQUENTIAL DATA UNDERSTANDING ===")

# 1. TIME SERIES DATA EXAMPLE
print("\n📈 TIME SERIES DATA EXAMPLE:")

def create_synthetic_time_series(n_points=1000, trend=0.001, seasonality=True, noise_level=0.1):
    """
    Create synthetic time series with trend, seasonality, and noise
    """
    t = np.arange(n_points)
    
    # Trend component
    trend_component = trend * t
    
    # Seasonal component
    if seasonality:
        seasonal_component = (
            0.5 * np.sin(2 * np.pi * t / 50) +  # Short-term cycle
            0.3 * np.sin(2 * np.pi * t / 200)   # Long-term cycle
        )
    else:
        seasonal_component = 0
    
    # Noise component
    noise_component = np.random.normal(0, noise_level, n_points)
    
    # Combine components
    time_series = trend_component + seasonal_component + noise_component
    
    return time_series, {'trend': trend_component, 'seasonal': seasonal_component, 'noise': noise_component}

# Generate example time series
ts_data, ts_components = create_synthetic_time_series(n_points=500)

print(f"Time series characteristics:")
print(f"  Length: {len(ts_data)} time points")
print(f"  Value range: [{ts_data.min():.3f}, {ts_data.max():.3f}]")
print(f"  Mean: {ts_data.mean():.3f}, Std: {ts_data.std():.3f}")
print(f"  Components: Trend + Seasonality + Noise")

# 2. TEXT SEQUENCE DATA EXAMPLE
print("\n📚 TEXT SEQUENCE DATA EXAMPLE:")

# Sample text data for sequence learning
sample_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is a subset of artificial intelligence.",
    "Neural networks learn patterns from data through training.",
    "Deep learning uses multiple layers to extract features.",
    "Recurrent networks process sequential information effectively.",
    "Long short-term memory networks solve vanishing gradients.",
    "Natural language processing enables computers to understand text.",
    "Time series forecasting predicts future values from past data."
]

def analyze_text_sequences(texts):
    """
    Analyze text sequence characteristics
    """
    # Basic statistics
    word_counts = [len(text.split()) for text in texts]
    char_counts = [len(text) for text in texts]
    
    # Vocabulary analysis
    all_words = ' '.join(texts).lower().split()
    unique_words = set(all_words)
    word_freq = {}
    for word in all_words:
        word_freq[word] = word_freq.get(word, 0) + 1
    
    return {
        'num_sequences': len(texts),
        'avg_words': np.mean(word_counts),
        'avg_chars': np.mean(char_counts),
        'vocabulary_size': len(unique_words),
        'total_words': len(all_words),
        'most_common': sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:5]
    }

text_stats = analyze_text_sequences(sample_texts)

print(f"Text sequence characteristics:")
print(f"  Number of sequences: {text_stats['num_sequences']}")
print(f"  Average words per sequence: {text_stats['avg_words']:.1f}")
print(f"  Average characters per sequence: {text_stats['avg_chars']:.1f}")
print(f"  Vocabulary size: {text_stats['vocabulary_size']}")
print(f"  Total words: {text_stats['total_words']}")
print(f"  Most common words: {text_stats['most_common']}")

# 3. SEQUENCE PREPROCESSING CHALLENGES
print("\n🔧 SEQUENCE PREPROCESSING CHALLENGES:")

def demonstrate_sequence_challenges():
    """
    Demonstrate common sequence preprocessing challenges
    """
    
    # Variable length sequences
    sequences = [
        [1, 2, 3],
        [4, 5, 6, 7, 8],
        [9, 10],
        [11, 12, 13, 14, 15, 16, 17]
    ]
    
    lengths = [len(seq) for seq in sequences]
    
    print(f"Challenge 1: Variable Length Sequences")
    print(f"  Sequence lengths: {lengths}")
    print(f"  Min length: {min(lengths)}, Max length: {max(lengths)}")
    print(f"  Problem: Can't batch variable lengths directly")
    
    # Padding solution
    max_len = max(lengths)
    padded_sequences = []
    for seq in sequences:
        padded = seq + [0] * (max_len - len(seq))  # Pad with zeros
        padded_sequences.append(padded)
    
    print(f"  Solution: Padding to max length ({max_len})")
    print(f"  Padded sequences: {padded_sequences}")
    
    # Long sequences memory challenge
    long_sequence = list(range(10000))
    memory_estimate = len(long_sequence) * 4 / 1024  # Float32 in KB
    
    print(f"\nChallenge 2: Long Sequences Memory")
    print(f"  Sequence length: {len(long_sequence):,}")
    print(f"  Memory per sequence: {memory_estimate:.1f} KB")
    print(f"  Batch of 32: {memory_estimate * 32:.1f} KB")
    print(f"  Problem: Memory grows quadratically with sequence length")
    
    # Truncation solution
    max_sequence_length = 100
    truncated = long_sequence[:max_sequence_length]
    print(f"  Solution: Truncate to max length ({max_sequence_length})")
    print(f"  Information loss: {len(long_sequence) - len(truncated):,} elements")

demonstrate_sequence_challenges()

# 4. VISUALIZE SEQUENTIAL DATA PATTERNS
print("\n📊 SEQUENTIAL DATA VISUALIZATION:")

fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Plot 1: Time series components
t = np.arange(len(ts_data))
axes[0, 0].plot(t, ts_data, 'b-', linewidth=1, label='Combined Signal')
axes[0, 0].plot(t, ts_components['trend'], 'r--', linewidth=2, label='Trend')
axes[0, 0].set_title('Time Series: Trend Component')
axes[0, 0].set_xlabel('Time')
axes[0, 0].set_ylabel('Value')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Seasonal component
axes[0, 1].plot(t, ts_components['seasonal'], 'g-', linewidth=2, label='Seasonal')
axes[0, 1].set_title('Time Series: Seasonal Component')
axes[0, 1].set_xlabel('Time')
axes[0, 1].set_ylabel('Value')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Noise component
axes[0, 2].plot(t, ts_components['noise'], 'orange', linewidth=1, alpha=0.7, label='Noise')
axes[0, 2].set_title('Time Series: Noise Component')
axes[0, 2].set_xlabel('Time')
axes[0, 2].set_ylabel('Value')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Sequence length distribution
text_lengths = [len(text.split()) for text in sample_texts]
axes[1, 0].hist(text_lengths, bins=5, color='skyblue', alpha=0.7, edgecolor='black')
axes[1, 0].set_title('Text Sequence Length Distribution')
axes[1, 0].set_xlabel('Number of Words')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].grid(True, alpha=0.3)

# Plot 5: Word frequency distribution
word_freqs = list(text_stats['most_common'])
words, freqs = zip(*word_freqs)
axes[1, 1].bar(words, freqs, color='lightgreen', alpha=0.7)
axes[1, 1].set_title('Most Common Words')
axes[1, 1].set_xlabel('Words')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].tick_params(axis='x', rotation=45)

# Plot 6: Sequence preprocessing visualization
original_lengths = [3, 5, 2, 7]
padded_lengths = [7] * 4  # All padded to max length

x = np.arange(len(original_lengths))
width = 0.35

axes[1, 2].bar(x - width/2, original_lengths, width, label='Original', color='lightcoral', alpha=0.7)
axes[1, 2].bar(x + width/2, padded_lengths, width, label='Padded', color='lightblue', alpha=0.7)
axes[1, 2].set_title('Sequence Padding Effect')
axes[1, 2].set_xlabel('Sequence Index')
axes[1, 2].set_ylabel('Length')
axes[1, 2].set_xticks(x)
axes[1, 2].legend()

plt.tight_layout()
plt.show()

print(f"\n💡 SEQUENTIAL DATA KEY INSIGHTS:")
print(f"\n1. TEMPORAL DEPENDENCIES:")
print(f"   • Past information influences future predictions")
print(f"   • Order matters: [A,B,C] ≠ [C,B,A]")
print(f"   • Context provides meaning: 'bank' in finance vs river context")

print(f"\n2. PREPROCESSING CHALLENGES:")
print(f"   • Variable lengths → Padding or truncation needed")
print(f"   • Long sequences → Memory and computational constraints")
print(f"   • Vocabulary size → Embedding dimensionality trade-offs")

print(f"\n3. MODELING IMPLICATIONS:")
print(f"   • Need memory to store past information")
print(f"   • Gradients must flow through time")
print(f"   • Different sequence lengths need special handling")

print(f"\n4. APPLICATIONS:")
print(f"   • Time series: Stock prediction, weather forecasting")
print(f"   • NLP: Language modeling, machine translation")
print(f"   • Speech: Recognition, synthesis")
print(f"   • Biology: Gene sequence analysis, protein folding")

## 2. Vanilla RNN: The Foundation of Sequential Learning

**RNN Core Concept:**
- **Hidden state**: Memory that carries information through time
- **Recurrent connection**: Hidden state fed back to next time step
- **Parameter sharing**: Same weights used at each time step

**RNN Mathematics:**
```
h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)
y_t = W_hy * h_t + b_y
```

**Where:**
- `h_t`: Hidden state at time t
- `x_t`: Input at time t
- `y_t`: Output at time t
- `W_hh`: Hidden-to-hidden weights
- `W_xh`: Input-to-hidden weights
- `W_hy`: Hidden-to-output weights

**Key Properties:**
1. **Memory**: Hidden state acts as memory
2. **Parameter sharing**: Efficiency across time steps
3. **Variable length**: Can handle sequences of any length
4. **Sequential processing**: Information flows step by step

In [None]:
# Cell 3: Vanilla RNN Implementation and Analysis

print("=== VANILLA RNN DEEP DIVE ===")

class SimpleRNN:
    """
    Educational implementation of a simple RNN cell
    This helps understand the mathematics behind RNNs
    """
    
    def __init__(self, input_size, hidden_size, output_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        # Initialize weights with small random values
        # Xavier initialization for better gradient flow
        self.W_xh = np.random.randn(input_size, hidden_size) * np.sqrt(2.0 / input_size)
        self.W_hh = np.random.randn(hidden_size, hidden_size) * np.sqrt(2.0 / hidden_size) 
        self.W_hy = np.random.randn(hidden_size, output_size) * np.sqrt(2.0 / hidden_size)
        
        # Initialize biases to zero
        self.b_h = np.zeros((hidden_size,))
        self.b_y = np.zeros((output_size,))
        
        # Store activations for analysis
        self.hidden_states = []
        self.outputs = []
    
    def forward_step(self, x_t, h_prev):
        """
        Single forward step of RNN
        
        Args:
            x_t: Input at time t, shape (input_size,)
            h_prev: Previous hidden state, shape (hidden_size,)
            
        Returns:
            h_t: New hidden state, shape (hidden_size,)
            y_t: Output at time t, shape (output_size,)
        """
        
        # Compute new hidden state
        # h_t = tanh(x_t @ W_xh + h_prev @ W_hh + b_h)
        linear_combination = (
            np.dot(x_t, self.W_xh) +     # Input contribution
            np.dot(h_prev, self.W_hh) +  # Previous hidden state contribution
            self.b_h                     # Bias
        )
        h_t = np.tanh(linear_combination)
        
        # Compute output
        # y_t = h_t @ W_hy + b_y
        y_t = np.dot(h_t, self.W_hy) + self.b_y
        
        return h_t, y_t
    
    def forward_sequence(self, X, initial_hidden=None):
        """
        Forward pass through entire sequence
        
        Args:
            X: Input sequence, shape (sequence_length, input_size)
            initial_hidden: Initial hidden state, shape (hidden_size,)
            
        Returns:
            hidden_states: List of hidden states
            outputs: List of outputs
        """
        
        sequence_length = X.shape[0]
        
        # Initialize hidden state if not provided
        if initial_hidden is None:
            h_t = np.zeros((self.hidden_size,))
        else:
            h_t = initial_hidden.copy()
        
        # Store results
        hidden_states = []
        outputs = []
        
        # Process each time step
        for t in range(sequence_length):
            x_t = X[t]
            h_t, y_t = self.forward_step(x_t, h_t)
            
            hidden_states.append(h_t.copy())
            outputs.append(y_t.copy())
        
        # Store for analysis
        self.hidden_states = hidden_states
        self.outputs = outputs
        
        return hidden_states, outputs
    
    def get_parameter_count(self):
        """Calculate total number of parameters"""
        params = (
            self.W_xh.size +
            self.W_hh.size +
            self.W_hy.size +
            self.b_h.size +
            self.b_y.size
        )
        return params

# Demonstrate RNN with simple sequence
print(f"\n🧠 SIMPLE RNN DEMONSTRATION:")

# Create a simple RNN
input_size = 3
hidden_size = 4
output_size = 2

rnn = SimpleRNN(input_size, hidden_size, output_size)

print(f"RNN Architecture:")
print(f"  Input size: {input_size}")
print(f"  Hidden size: {hidden_size}")
print(f"  Output size: {output_size}")
print(f"  Total parameters: {rnn.get_parameter_count()}")

# Parameter breakdown
print(f"\nParameter breakdown:")
print(f"  W_xh (input→hidden): {rnn.W_xh.shape} = {rnn.W_xh.size} params")
print(f"  W_hh (hidden→hidden): {rnn.W_hh.shape} = {rnn.W_hh.size} params")
print(f"  W_hy (hidden→output): {rnn.W_hy.shape} = {rnn.W_hy.size} params")
print(f"  b_h (hidden bias): {rnn.b_h.shape} = {rnn.b_h.size} params")
print(f"  b_y (output bias): {rnn.b_y.shape} = {rnn.b_y.size} params")

# Create example input sequence
sequence_length = 5
X_example = np.random.randn(sequence_length, input_size)

print(f"\nExample input sequence:")
print(f"  Shape: {X_example.shape} (time_steps, features)")
print(f"  Values:\n{X_example}")

# Forward pass through sequence
hidden_states, outputs = rnn.forward_sequence(X_example)

print(f"\nRNN Processing Results:")
for t in range(sequence_length):
    print(f"  Time step {t+1}:")
    print(f"    Input: {X_example[t]}")
    print(f"    Hidden state: {hidden_states[t]}")
    print(f"    Output: {outputs[t]}")
    print()

# Analyze hidden state evolution
print(f"\n📊 HIDDEN STATE ANALYSIS:")

hidden_states_array = np.array(hidden_states)  # Shape: (time_steps, hidden_size)
outputs_array = np.array(outputs)              # Shape: (time_steps, output_size)

print(f"Hidden state statistics:")
for t in range(sequence_length):
    h_t = hidden_states_array[t]
    print(f"  Step {t+1}: Mean={h_t.mean():.3f}, Std={h_t.std():.3f}, Range=[{h_t.min():.3f}, {h_t.max():.3f}]")

# Demonstrate RNN memory
print(f"\n🧠 RNN MEMORY DEMONSTRATION:")

def test_rnn_memory():
    """Test how RNN remembers information from early time steps"""
    
    # Create two similar sequences with different first elements
    base_sequence = np.ones((5, 2))  # All ones
    
    sequence_A = base_sequence.copy()
    sequence_A[0] = [5.0, 0.0]  # Strong signal at beginning
    
    sequence_B = base_sequence.copy() 
    sequence_B[0] = [-5.0, 0.0]  # Opposite signal at beginning
    
    # Create RNN for this test
    test_rnn = SimpleRNN(2, 8, 1)
    
    # Process both sequences
    _, outputs_A = test_rnn.forward_sequence(sequence_A)
    _, outputs_B = test_rnn.forward_sequence(sequence_B)
    
    print(f"Memory test results:")
    print(f"  Sequence A (starts with [5,0]):")
    for i, out in enumerate(outputs_A):
        print(f"    Step {i+1}: {out[0]:.3f}")
    
    print(f"  Sequence B (starts with [-5,0]):")
    for i, out in enumerate(outputs_B):
        print(f"    Step {i+1}: {out[0]:.3f}")
    
    # Check if difference persists
    final_diff = abs(outputs_A[-1][0] - outputs_B[-1][0])
    print(f"  Final output difference: {final_diff:.3f}")
    
    if final_diff > 0.1:
        print(f"  ✅ RNN remembers early information!")
    else:
        print(f"  ❌ RNN forgot early information")

test_rnn_memory()

# Visualize RNN behavior
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Plot 1: Input sequence
for i in range(input_size):
    axes[0, 0].plot(range(sequence_length), X_example[:, i], 
                   marker='o', label=f'Feature {i+1}', linewidth=2)
axes[0, 0].set_title('Input Sequence')
axes[0, 0].set_xlabel('Time Step')
axes[0, 0].set_ylabel('Input Value')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Hidden state evolution
for i in range(hidden_size):
    axes[0, 1].plot(range(sequence_length), hidden_states_array[:, i], 
                   marker='s', alpha=0.7, label=f'Hidden {i+1}' if i < 3 else "")
axes[0, 1].set_title('Hidden State Evolution')
axes[0, 1].set_xlabel('Time Step')
axes[0, 1].set_ylabel('Hidden Value')
if hidden_size <= 3:
    axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Output sequence
for i in range(output_size):
    axes[0, 2].plot(range(sequence_length), outputs_array[:, i], 
                   marker='^', label=f'Output {i+1}', linewidth=2)
axes[0, 2].set_title('Output Sequence')
axes[0, 2].set_xlabel('Time Step')
axes[0, 2].set_ylabel('Output Value')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Weight matrices visualization
im1 = axes[1, 0].imshow(rnn.W_xh, cmap='RdBu', aspect='auto')
axes[1, 0].set_title('Input-to-Hidden Weights (W_xh)')
axes[1, 0].set_xlabel('Hidden Units')
axes[1, 0].set_ylabel('Input Features')
plt.colorbar(im1, ax=axes[1, 0])

# Plot 5: Recurrent weights
im2 = axes[1, 1].imshow(rnn.W_hh, cmap='RdBu', aspect='auto')
axes[1, 1].set_title('Hidden-to-Hidden Weights (W_hh)')
axes[1, 1].set_xlabel('Hidden Units (t)')
axes[1, 1].set_ylabel('Hidden Units (t-1)')
plt.colorbar(im2, ax=axes[1, 1])

# Plot 6: Information flow diagram
axes[1, 2].axis('off')
info_text = f"""
RNN INFORMATION FLOW:

At each time step t:
1. Receive input x_t
2. Combine with previous hidden h_{t-1}
3. Apply transformation: 
   h_t = tanh(x_t·W_xh + h_{t-1}·W_hh + b_h)
4. Produce output:
   y_t = h_t·W_hy + b_y
5. Pass h_t to next time step

KEY PROPERTIES:
• Parameter sharing across time
• Sequential processing
• Memory through hidden state
• Variable sequence lengths

LIMITATIONS:
• Vanishing gradients
• Short-term memory
• Sequential bottleneck
"""

axes[1, 2].text(0.05, 0.95, info_text, transform=axes[1, 2].transAxes,
               verticalalignment='top', fontsize=10, fontfamily='monospace')

plt.tight_layout()
plt.show()

print(f"\n💡 VANILLA RNN KEY INSIGHTS:")
print(f"\n1. ARCHITECTURE:")
print(f"   • Hidden state h_t carries information through time")
print(f"   • Same parameters used at each time step (parameter sharing)")
print(f"   • Recurrent connection: h_t depends on h_{t-1}")

print(f"\n2. MATHEMATICAL FORMULATION:")
print(f"   • h_t = tanh(x_t @ W_xh + h_{t-1} @ W_hh + b_h)")
print(f"   • y_t = h_t @ W_hy + b_y")
print(f"   • Tanh activation keeps hidden states bounded [-1, 1]")

print(f"\n3. MEMORY MECHANISM:")
print(f"   • Information from early time steps can influence later outputs")
print(f"   • Memory capacity limited by hidden state size")
print(f"   • Gradual information decay through time steps")

print(f"\n4. LIMITATIONS:")
print(f"   • Vanishing gradients: Difficulty learning long-term dependencies")
print(f"   • Sequential processing: Cannot parallelize across time")
print(f"   • Information bottleneck: All info must pass through hidden state")

## 3. The Vanishing Gradient Problem: Why Vanilla RNNs Struggle

**The Problem:**
- **Backpropagation Through Time (BPTT)**: Gradients flow backward through each time step
- **Gradient multiplication**: Gradients multiply by weights at each step
- **Vanishing**: If weights < 1, gradients → 0 exponentially
- **Exploding**: If weights > 1, gradients → ∞ exponentially

**Mathematical Analysis:**
```
∂L/∂h_1 = ∂L/∂h_T * ∏(t=2 to T) ∂h_t/∂h_{t-1}
```

**Where each term:**
```
∂h_t/∂h_{t-1} = diag(tanh'(·)) * W_hh
```

**Impact:**
- **Short memory**: Can't learn dependencies > 5-10 time steps
- **Training difficulties**: Slow convergence, unstable gradients
- **Limited applications**: Poor for long sequences

**Solutions:**
1. **LSTM/GRU**: Gating mechanisms
2. **Gradient clipping**: Prevent exploding gradients
3. **Better initialization**: Careful weight initialization
4. **Residual connections**: Skip connections through time

In [None]:
# Cell 4: Vanishing Gradient Problem Demonstration

print("=== VANISHING GRADIENT PROBLEM ANALYSIS ===")

def analyze_gradient_flow():
    """
    Demonstrate vanishing gradient problem through gradient magnitude analysis
    """
    
    print(f"\n🔍 GRADIENT FLOW ANALYSIS:")
    
    # Create RNNs with different weight scales
    weight_scales = [0.5, 0.9, 1.0, 1.1, 1.5]
    sequence_length = 20
    
    gradient_magnitudes = {}
    
    for scale in weight_scales:
        print(f"\nAnalyzing weight scale: {scale}")
        
        # Simulate gradient flow (simplified analysis)
        # Gradient at time t flows back through (W_hh)^(T-t) path
        
        # Assume tanh derivative ≈ 0.5 on average
        tanh_derivative = 0.5
        
        gradients = []
        for t in range(sequence_length):
            # Gradient magnitude after flowing back t steps
            steps_back = sequence_length - 1 - t
            gradient_magnitude = (scale * tanh_derivative) ** steps_back
            gradients.append(gradient_magnitude)
        
        gradient_magnitudes[scale] = gradients
        
        final_gradient = gradients[0]  # Gradient at first time step
        print(f"  Final gradient magnitude: {final_gradient:.6f}")
        
        if final_gradient < 1e-5:
            print(f"  ❌ Vanishing gradients detected!")
        elif final_gradient > 1e5:
            print(f"  💥 Exploding gradients detected!")
        else:
            print(f"  ✅ Stable gradients")
    
    return gradient_magnitudes

def demonstrate_memory_limitations():
    """
    Demonstrate how vanilla RNNs struggle with long-term dependencies
    """
    
    print(f"\n🧠 MEMORY LIMITATION DEMONSTRATION:")
    
    # Create sequences with dependencies at different distances
    def create_memory_task(sequence_length, dependency_distance):
        """
        Create a sequence where early information must be remembered
        """
        # Sequence of mostly zeros with signal at beginning and end
        sequence = np.zeros((sequence_length, 2))
        
        # Important signal at the beginning
        sequence[0, 0] = 1.0  # Signal to remember
        
        # Query at the end
        sequence[-1, 1] = 1.0  # Query: "what was the signal?"
        
        # Target: 1 if signal was present, 0 otherwise
        target = 1.0
        
        return sequence, target
    
    # Test different dependency distances
    distances = [5, 10, 20, 50]
    
    print(f"Memory task results (simplified analysis):")
    for distance in distances:
        # Create task
        sequence, target = create_memory_task(distance, distance-1)
        
        # Estimate gradient strength for this distance
        # This is a simplified calculation
        gradient_strength = (0.5 * 0.9) ** (distance - 1)  # Rough estimate
        
        print(f"  Distance {distance:2d}: Gradient strength ≈ {gradient_strength:.6f}")
        
        if gradient_strength < 1e-3:
            difficulty = "Very Hard"
        elif gradient_strength < 1e-2:
            difficulty = "Hard"
        elif gradient_strength < 1e-1:
            difficulty = "Moderate"
        else:
            difficulty = "Easy"
        
        print(f"              Learning difficulty: {difficulty}")

# Analyze gradient flow
gradient_data = analyze_gradient_flow()

# Demonstrate memory limitations
demonstrate_memory_limitations()

# Create comprehensive visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Plot 1: Gradient flow for different weight scales
for scale, gradients in gradient_data.items():
    time_steps = range(len(gradients))
    axes[0, 0].semilogy(time_steps, gradients, marker='o', 
                       label=f'Scale {scale}', linewidth=2)

axes[0, 0].set_title('Gradient Magnitude vs Time Step')
axes[0, 0].set_xlabel('Time Step')
axes[0, 0].set_ylabel('Gradient Magnitude (log scale)')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].axhline(y=1e-5, color='red', linestyle='--', alpha=0.7, label='Vanishing threshold')

# Plot 2: Effect of sequence length on gradient
seq_lengths = range(1, 31)
stable_gradients = [(0.9 * 0.5) ** (l-1) for l in seq_lengths]
vanishing_gradients = [(0.5 * 0.5) ** (l-1) for l in seq_lengths]
exploding_gradients = [(1.5 * 0.5) ** (l-1) for l in seq_lengths]

axes[0, 1].semilogy(seq_lengths, stable_gradients, 'g-', label='Stable (0.9)', linewidth=2)
axes[0, 1].semilogy(seq_lengths, vanishing_gradients, 'r-', label='Vanishing (0.5)', linewidth=2)
axes[0, 1].semilogy(seq_lengths[:15], exploding_gradients[:15], 'b-', label='Exploding (1.5)', linewidth=2)

axes[0, 1].set_title('Gradient vs Sequence Length')
axes[0, 1].set_xlabel('Sequence Length')
axes[0, 1].set_ylabel('Final Gradient Magnitude')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Memory task difficulty
distances = [5, 10, 15, 20, 25, 30]
difficulties = [(0.5 * 0.9) ** (d-1) for d in distances]

axes[0, 2].semilogy(distances, difficulties, 'ro-', linewidth=2, markersize=8)
axes[0, 2].set_title('Memory Task Difficulty')
axes[0, 2].set_xlabel('Dependency Distance')
axes[0, 2].set_ylabel('Learning Signal Strength')
axes[0, 2].grid(True, alpha=0.3)
axes[0, 2].axhline(y=1e-3, color='red', linestyle='--', alpha=0.7, label='Difficulty threshold')
axes[0, 2].legend()

# Plot 4: Activation function derivatives
x = np.linspace(-3, 3, 100)
tanh_vals = np.tanh(x)
tanh_derivs = 1 - tanh_vals**2

axes[1, 0].plot(x, tanh_vals, 'b-', linewidth=2, label='tanh(x)')
axes[1, 0].plot(x, tanh_derivs, 'r--', linewidth=2, label="tanh'(x)")
axes[1, 0].set_title('Tanh and its Derivative')
axes[1, 0].set_xlabel('x')
axes[1, 0].set_ylabel('Value')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].axhline(y=0, color='k', linestyle='-', alpha=0.3)

# Plot 5: Weight initialization impact
init_methods = ['Small\n(0.1)', 'Xavier\n(√2/n)', 'Large\n(2.0)']
gradient_survival = [0.01, 0.1, 10.0]  # Rough estimates

bars = axes[1, 1].bar(init_methods, gradient_survival, 
                     color=['lightcoral', 'lightgreen', 'gold'], alpha=0.7)
axes[1, 1].set_title('Weight Initialization Impact')
axes[1, 1].set_ylabel('Gradient Survival Rate')
axes[1, 1].set_yscale('log')

# Add value labels
for bar, val in zip(bars, gradient_survival):
    axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height(),
                   f'{val}', ha='center', va='bottom')

# Plot 6: Solutions overview
axes[1, 2].axis('off')
solutions_text = """
VANISHING GRADIENT SOLUTIONS:

1. LSTM/GRU Networks:
   • Gating mechanisms
   • Selective information flow
   • Constant error carousel

2. Gradient Clipping:
   • Prevent exploding gradients
   • Clip norm to threshold
   • Stable training

3. Better Initialization:
   • Xavier/He initialization
   • Identity matrix for W_hh
   • Orthogonal initialization

4. Residual Connections:
   • Skip connections through time
   • Highway networks
   • Direct gradient paths

5. Alternative Architectures:
   • Transformers (attention)
   • Temporal convolutional networks
   • State space models
"""

axes[1, 2].text(0.05, 0.95, solutions_text, transform=axes[1, 2].transAxes,
               verticalalignment='top', fontsize=10, fontfamily='monospace')

plt.tight_layout()
plt.show()

# Quantitative analysis
print(f"\n📊 QUANTITATIVE ANALYSIS:")

print(f"\nGradient decay rates:")
for scale in [0.5, 0.9, 1.0, 1.1]:
    # Effective learning rate after T steps
    effective_rate_5 = (scale * 0.5) ** 4   # After 5 steps
    effective_rate_20 = (scale * 0.5) ** 19 # After 20 steps
    
    print(f"  Weight scale {scale}:")
    print(f"    After 5 steps:  {effective_rate_5:.6f} (decay factor)")
    print(f"    After 20 steps: {effective_rate_20:.6f} (decay factor)")

print(f"\nMemory capacity estimates:")
threshold = 1e-3  # Minimum gradient for effective learning
for scale in [0.5, 0.9, 1.0]:
    # Calculate maximum effective sequence length
    if scale * 0.5 < 1.0:
        max_length = int(np.log(threshold) / np.log(scale * 0.5))
        print(f"  Weight scale {scale}: ~{max_length} time steps")
    else:
        print(f"  Weight scale {scale}: Unstable (exploding gradients)")

print(f"\n💡 VANISHING GRADIENT KEY INSIGHTS:")
print(f"\n1. MATHEMATICAL ROOT CAUSE:")
print(f"   • Gradient = product of derivatives through time")
print(f"   • Each derivative ≤ 1 (tanh) → product → 0")
print(f"   • Exponential decay: (factor)^T")

print(f"\n2. PRACTICAL IMPLICATIONS:")
print(f"   • Vanilla RNNs: ~5-10 time step memory")
print(f"   • Long sequences: Early information lost")
print(f"   • Training: Slow convergence on long dependencies")

print(f"\n3. WHY IT MATTERS:")
print(f"   • Language: Long-range grammatical dependencies")
print(f"   • Time series: Seasonal patterns (yearly cycles)")
print(f"   • Speech: Sentence-level context")

print(f"\n4. SOLUTION PREVIEW:")
print(f"   • LSTM: Gated memory cells")
print(f"   • GRU: Simplified gating")
print(f"   • Attention: Direct connections to all time steps")