# Comprehensive Transformer Implementation for Text Generation with KV Cache Optimization

This notebook implements a complete Transformer-based language model from scratch using TensorFlow/Keras.
The implementation focuses on text generation with an advanced Key-Value (KV) Cache optimization system
that significantly speeds up autoregressive text generation.

## What This Implementation Does:
1. Implements a full Transformer architecture with multi-head self-attention
2. Provides KV Cache optimization for faster text generation
3. Trains the model on text data (Shakespeare or custom corpus)
4. Generates text using the trained model with performance comparisons
5. Includes comprehensive logging and visualization tools

## Key Concepts Explained:
- **Transformers**: Neural network architecture that uses attention mechanisms
- **Self-Attention**: Allows the model to focus on different parts of the input sequence
- **KV Cache**: Optimization technique that stores computed attention keys/values to avoid redundant calculations
- **Autoregressive Generation**: Generating text one token at a time, using previous tokens as context
- **Causal Masking**: Prevents the model from looking at future tokens during training and generation

## 📺 Watch the Tutorial

Prefer a video walkthrough? Check out the accompanying tutorial on YouTube:

[Video Tutorial](https://youtu.be/N44mgyzxxKU)

## 🚀 Quick Start Guide

Ready to run the code? Follow these simple steps to set up your environment and execute the Transformer text generation model:

### Step 1: Create a Python Virtual Environment
```bash
# Create a new virtual environment
python -m venv /path/to/your/project/

# Activate the virtual environment
# On macOS/Linux:
source /path/to/your/project/bin/activate
```

### Step 2: Install Required Dependencies
```bash
# Install all required packages
pip install -r requirements.txt
```

The `requirements.txt` file includes:
- `tensorflow>=2.0` - Deep learning framework
- `numpy` - Numerical computing
- `matplotlib` - Plotting and visualization
- `visualkeras` - Model architecture visualization

### Step 3: Prepare Your Text Corpus (Optional)
The script can work with two data sources:
- **Web dataset:** Automatically downloads Shakespeare's complete works
- **Local corpus:** Place your own text file as `corpus.txt` in the project directory

To use your own text data:
```bash
# Place your text file in the project directory
cp /path/to/your/text/data.txt corpus.txt
```

### Step 4: Run the Main Script
```bash
# Execute the Transformer text generation model
python main.py
```

### What Happens When You Run It?
1. **Data Loading:** Downloads Shakespeare dataset or loads local corpus.txt
2. **Text Preprocessing:** Tokenizes and vectorizes text using TextVectorization
3. **Sequence Generation:** Creates input-target pairs for next-token prediction
4. **Model Building:** Constructs the Transformer architecture with positional encoding
5. **Training:** Trains the model with early stopping and TensorBoard logging
6. **Visualization:** Generates model architecture diagrams and training plots
7. **Text Generation:** Creates novel text sequences using the trained model

### Monitoring Training Progress
The script automatically sets up TensorBoard logging. After running, you can monitor training in real-time:
```bash
# Launch TensorBoard (the script will show you the exact command)
tensorboard --logdir logs/[timestamp]
```

Then open your browser to `http://localhost:6006` to view:
- Training loss curves
- Model architecture visualization
- Text generation samples
- Model weights and gradients

**Expected Runtime:** Approximately 10-20 minutes on a modern CPU, faster with GPU acceleration.

## Import Libraries and GPU Configuration

In [None]:
#!/usr/bin/env python3

import os  # Operating system interface for file/directory operations
from datetime import datetime  # For timestamping logs and outputs
import numpy as np  # Numerical computing library for array operations
import tensorflow as tf  # Deep learning framework - our main ML library

# TensorFlow/Keras specific imports for building neural network components
from tensorflow.keras.layers import (
    Layer,              # Base class for all neural network layers
    Dense,              # Fully connected (linear) layer - core building block
    LayerNormalization, # Normalization technique for stable training
    Dropout,            # Regularization technique to prevent overfitting
    Embedding,          # Converts token IDs to dense vector representations
    TextVectorization   # Preprocesses text data into numerical format
)
from tensorflow.keras.models import Model  # Base class for complex models
from tensorflow.keras.utils import get_file  # Utility for downloading datasets
from tensorflow.keras.callbacks import EarlyStopping  # Prevents overfitting during training
import matplotlib.pyplot as plt  # Plotting library for visualizations
import visualkeras  # Library for visualizing neural network architectures

### 🔍 **Detailed Analysis: Import Libraries and Their Critical Roles**

Each import in this transformer implementation serves a specific, crucial purpose in building a state-of-the-art language model. Let's dive deep into understanding why each library is essential:

#### **Core System Libraries**

**`import os`** - Operating System Interface
- **Purpose**: Provides cross-platform file system operations
- **Usage in our implementation**: 
  - Creating directory structures for model checkpoints and logs
  - Managing file paths for saving/loading trained models
  - Handling corpus file detection and validation
- **Why it's critical**: Deep learning projects require extensive file management for datasets, model weights, logs, and outputs

**`from datetime import datetime`** - Timestamp Management
- **Purpose**: Provides date and time functionality for logging and organization
- **Usage in our implementation**:
  - Creating unique timestamps for TensorBoard log directories
  - Organizing training runs by date/time for experiment tracking
  - Generating unique filenames to prevent overwrites
- **Why it's critical**: Proper experiment tracking is essential for reproducible research and model development

#### **Numerical Computing Foundation**

**`import numpy as np`** - Numerical Python
- **Purpose**: Fundamental package for scientific computing with Python
- **Usage in our implementation**:
  - **Positional Encoding**: Creating sinusoidal position embeddings using trigonometric functions
  - **Data Preprocessing**: Converting text sequences to numerical arrays
  - **Mathematical Operations**: Angle calculations, array manipulations, and tensor operations
  - **Initialization**: Setting up mathematical constants and arrays
- **Why it's critical**: NumPy provides the mathematical foundation that TensorFlow builds upon, especially for custom mathematical operations

#### **Deep Learning Framework**

**`import tensorflow as tf`** - TensorFlow Core
- **Purpose**: Google's open-source machine learning framework
- **Usage in our implementation**:
  - **Tensor Operations**: All mathematical computations (matrix multiplication, activation functions)
  - **Automatic Differentiation**: Gradient computation for backpropagation
  - **GPU Acceleration**: Utilizing CUDA cores for parallel processing
  - **Graph Execution**: Optimized computation graphs for efficient training
- **Why it's critical**: TensorFlow provides the computational engine that makes training large transformer models feasible

#### **Neural Network Building Blocks**

**`Layer`** - Base Layer Class
- **Purpose**: Abstract base class for all neural network layers
- **Usage in our implementation**:
  - **Custom Layers**: MultiHeadSelfAttention and TransformerBlock inherit from this
  - **State Management**: Handles layer weights, training states, and forward/backward passes
  - **Composability**: Enables stacking layers to build complex architectures
- **Deep Principle**: Object-oriented design allows modular, reusable components that can be combined to create sophisticated models

**`Dense`** - Fully Connected Layer
- **Purpose**: Implements linear transformation: output = input × weight + bias
- **Usage in our implementation**:
  - **Attention Projections**: Query, Key, Value transformations (Q = XW_q, K = XW_k, V = XW_v)
  - **Feed-Forward Networks**: Two-layer MLPs within transformer blocks
  - **Output Projection**: Final layer mapping hidden states to vocabulary probabilities
- **Mathematical Foundation**: Dense layers perform affine transformations that are fundamental to neural network expressiveness

**`LayerNormalization`** - Normalization Technique
- **Purpose**: Normalizes inputs across features to stabilize training
- **Mathematical Formula**: LN(x) = γ × (x - μ) / σ + β
- **Usage in our implementation**:
  - **Pre-Norm Architecture**: Applied before attention and feed-forward operations
  - **Gradient Stabilization**: Prevents vanishing/exploding gradients in deep networks
  - **Training Acceleration**: Enables higher learning rates and faster convergence
- **Why it's superior to BatchNorm**: Works better with variable-length sequences and doesn't depend on batch statistics

**`Dropout`** - Regularization Technique
- **Purpose**: Randomly sets input units to 0 during training to prevent overfitting
- **Mathematical Principle**: During training, each neuron has probability p of being "dropped out"
- **Usage in our implementation**:
  - **Attention Dropout**: Applied after attention weights computation
  - **Feed-Forward Dropout**: Applied in the MLP layers
  - **Generalization**: Forces the model to not rely on specific neurons
- **Critical Insight**: Dropout acts as an ensemble method, training multiple sub-networks simultaneously

**`Embedding`** - Token-to-Vector Conversion
- **Purpose**: Converts discrete token IDs to dense vector representations
- **Mathematical Operation**: lookup_table[token_id] → dense_vector
- **Usage in our implementation**:
  - **Token Embeddings**: Converting word/subword tokens to learnable vectors
  - **Semantic Representation**: Capturing semantic relationships between tokens
  - **Dimensionality**: Maps vocabulary_size → embedding_dimension
- **Deep Learning Principle**: Embeddings learn distributed representations where similar tokens have similar vectors

**`TextVectorization`** - Text Preprocessing Pipeline
- **Purpose**: Converts raw text to numerical sequences suitable for neural networks
- **Operations Performed**:
  - **Tokenization**: Splitting text into individual tokens
  - **Vocabulary Building**: Creating token-to-ID mappings
  - **Sequence Conversion**: Converting text to integer sequences
  - **Padding/Truncation**: Ensuring uniform sequence lengths
- **Usage in our implementation**: Preprocessing training data and handling text generation inputs

#### **Model Architecture and Training**

**`Model`** - Keras Model Base Class
- **Purpose**: High-level API for building and training complex neural networks
- **Features Provided**:
  - **Training Loop**: Built-in fit() method with optimization
  - **Serialization**: Save/load model weights and architecture
  - **Metrics**: Built-in loss computation and evaluation
  - **Callbacks**: Extensible training process with custom behaviors
- **Usage**: Our TransformerModel inherits from Model to get all these capabilities

**`get_file`** - Dataset Utility
- **Purpose**: Downloads and caches datasets from URLs
- **Usage in our implementation**: Downloading Shakespeare dataset for training
- **Features**: Automatic caching, integrity checking, and progress bars

**`EarlyStopping`** - Training Callback
- **Purpose**: Prevents overfitting by stopping training when validation metrics stop improving
- **Algorithm**: Monitors specified metric and stops training after patience epochs without improvement
- **Usage**: Automatically stops training when loss plateaus, saving computational resources

#### **Visualization and Analysis**

**`matplotlib.pyplot`** - Plotting Library
- **Purpose**: Creating publication-quality plots and visualizations
- **Usage in our implementation**:
  - **Training Curves**: Plotting loss over epochs
  - **Performance Analysis**: Visualizing generation speed comparisons
  - **Model Insights**: Creating attention heatmaps and analysis plots

**`visualkeras`** - Neural Network Visualization
- **Purpose**: Creates visual representations of neural network architectures
- **Usage**: Generating diagrams of our transformer model structure
- **Benefit**: Helps understand model complexity and architecture design

### **🎯 Integration and Synergy**

These libraries work together to create a complete ecosystem:

1. **Data Flow**: `os` and `get_file` handle data acquisition → `TextVectorization` preprocesses → `numpy` provides mathematical operations
2. **Model Building**: `Layer`, `Dense`, `Embedding` create components → `Model` orchestrates training
3. **Training Process**: `tensorflow` provides the engine → `LayerNormalization` and `Dropout` ensure stability → `EarlyStopping` prevents overfitting
4. **Analysis**: `matplotlib` and `visualkeras` provide insights into model behavior and performance

This carefully curated set of imports provides everything needed for a production-ready transformer implementation with proper data handling, model architecture, training procedures, and analysis capabilities.

In [None]:
# =============================================================================
# GPU CONFIGURATION - Optimizing hardware utilization
# =============================================================================

# Detect all available GPU devices on the system
gpus = tf.config.list_physical_devices('GPU')

if gpus:  # If GPUs are detected
    try:
        # Configure each GPU for memory growth to prevent allocation issues
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        
        # Get logical GPU devices
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs configured.")
        
    except RuntimeError as e:
        print(f"Error setting up GPU memory growth: {e}")
else:
    print("No GPU detected, using CPU.")

### 🚀 **Deep Dive: GPU Configuration and Memory Management**

The GPU configuration section is critical for optimal performance when training and running transformer models. Let's explore the sophisticated memory management strategies employed:

#### **🔧 GPU Detection and Enumeration**

**`tf.config.list_physical_devices('GPU')`**
- **Purpose**: Discovers all available GPU hardware on the system
- **Technical Details**: 
  - Queries the CUDA runtime to enumerate physical GPU devices
  - Returns a list of PhysicalDevice objects representing each GPU
  - Works with NVIDIA GPUs (CUDA) and potentially other accelerators
- **Why This Matters**: 
  - Enables dynamic hardware detection across different systems
  - Allows code to adapt to available hardware resources
  - Prevents crashes when GPU drivers or hardware aren't available

#### **🧠 Memory Growth Strategy: The Key to Stable Training**

**`tf.config.experimental.set_memory_growth(gpu, True)`**

This is one of the most critical optimizations for deep learning workflows. Here's why:

**Traditional GPU Memory Allocation (Default Behavior):**
- TensorFlow typically allocates ALL available GPU memory at startup
- Example: If you have 8GB GPU, TensorFlow claims all 8GB immediately
- **Problems with this approach**:
  - Prevents running multiple models simultaneously
  - Causes out-of-memory errors in multi-process environments
  - Wastes memory when model doesn't need full capacity
  - Makes development and experimentation difficult

**Memory Growth Strategy (Our Implementation):**
- **Dynamic Allocation**: Starts with minimal memory usage
- **On-Demand Growth**: Allocates more memory only when needed
- **Gradual Expansion**: Memory usage grows incrementally as model requires it

**Technical Implementation Details:**
```python
# Without memory growth:
# GPU Memory: [████████████████████████████████] 8GB allocated immediately
# Available:  [                                ] 0GB for other processes

# With memory growth:
# Initial:    [██                              ] 512MB allocated
# As needed:  [████████                        ] 2GB allocated
# Maximum:    [████████████████████████████████] 8GB only if required
```

**Benefits for Transformer Training:**
1. **Multi-Model Development**: Can run multiple experiments simultaneously
2. **Resource Sharing**: Other applications can use GPU memory
3. **Debugging Friendly**: Easier to debug without memory conflicts
4. **Production Deployment**: Better resource utilization in production environments

#### **🎯 Logical vs Physical GPU Devices**

**Physical GPUs**: The actual hardware devices (e.g., RTX 3080, A100)
**Logical GPUs**: Software abstractions that can be created from physical GPUs

**Why This Distinction Matters:**
- **Memory Partitioning**: One physical GPU can be split into multiple logical GPUs
- **Resource Isolation**: Different models can use different logical GPUs
- **Multi-Tenancy**: Multiple users can share the same physical hardware

**Example Scenarios:**
```python
# Scenario 1: Single GPU System
# Physical GPUs: 1 (RTX 3080 with 10GB)
# Logical GPUs: 1 (Full access to 10GB)

# Scenario 2: Multi-GPU Workstation
# Physical GPUs: 4 (Each with 24GB)
# Logical GPUs: 4 (Each model can use different GPU)

# Scenario 3: Memory-Limited Partitioning
# Physical GPUs: 1 (A100 with 40GB)
# Logical GPUs: 4 (Each limited to 10GB for different experiments)
```

#### **⚠️ Error Handling and Robustness**

**`try-except` Block Analysis:**
```python
try:
    # GPU configuration code
except RuntimeError as e:
    print(f"Error setting up GPU memory growth: {e}")
```

**Common RuntimeError Scenarios:**
1. **Already Initialized**: TensorFlow context already created with different settings
2. **Driver Issues**: CUDA drivers incompatible or not installed
3. **Hardware Problems**: GPU hardware malfunction or overheating
4. **Permission Issues**: Insufficient permissions to access GPU resources

**Graceful Degradation Strategy:**
- **Primary**: Attempt GPU configuration with memory growth
- **Fallback**: Continue with default GPU settings if memory growth fails
- **Ultimate Fallback**: Use CPU if no GPU available

#### **🔍 Performance Implications for Transformer Models**

**Why GPU Configuration is Critical for Transformers:**

1. **Attention Computation Complexity**: O(n²) memory usage for sequence length n
2. **Large Parameter Count**: Modern transformers have millions to billions of parameters
3. **Batch Processing**: Training requires processing multiple sequences simultaneously
4. **KV Cache Storage**: Our implementation stores attention keys/values for fast generation

**Memory Usage Patterns:**
```python
# Training Phase:
# - Model weights: ~100MB to 10GB depending on size
# - Gradients: Same size as model weights
# - Activations: Varies with batch size and sequence length
# - Optimizer states: 2-3x model weight size (Adam optimizer)

# Generation Phase:
# - Model weights: Same as training
# - KV Cache: Grows with sequence length
# - Attention matrices: O(sequence_length²)
```

#### **🎛️ Advanced Configuration Options**

**Additional GPU optimizations you could implement:**

```python
# Memory limit setting (alternative to memory growth)
tf.config.experimental.set_memory_limit(gpu, 4096)  # Limit to 4GB

# Mixed precision training (faster training, less memory)
tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})

# XLA compilation (faster execution)
tf.config.optimizer.set_jit(True)
```

#### **🏗️ Production Considerations**

**For Production Deployment:**
1. **Resource Monitoring**: Track GPU memory usage and temperature
2. **Batch Size Optimization**: Balance throughput vs memory usage
3. **Model Sharding**: Split large models across multiple GPUs
4. **Inference Optimization**: Use TensorRT or similar optimizations

**Container Deployment:**
```dockerfile
# Docker considerations for GPU access
# Requires nvidia-docker runtime
# Must expose GPU devices to container
# Memory growth becomes even more critical in containerized environments
```

This GPU configuration section ensures our transformer implementation can:
- **Scale efficiently** across different hardware configurations
- **Share resources** in multi-user or multi-model environments
- **Degrade gracefully** when optimal hardware isn't available
- **Optimize memory usage** for both training and inference workloads

The memory growth strategy is particularly crucial for transformer models due to their large memory footprint and the variable memory requirements during different phases of training and generation.

## Multi-Head Self-Attention with KV Cache

The core of the Transformer architecture with advanced KV caching for efficient text generation.

In [None]:
class MultiHeadSelfAttention(Layer):
    """
    MULTI-HEAD SELF-ATTENTION MECHANISM - THE HEART OF TRANSFORMERS
    
    This class implements the core attention mechanism that allows tokens to
    communicate with each other. It includes advanced KV cache optimization
    for efficient autoregressive text generation.
    
    Key Features:
    - Multiple parallel attention heads for different relationship types
    - KV cache support for fast text generation
    - Causal masking for autoregressive modeling
    - Scaled dot-product attention with proper normalization
    """
    
    def __init__(self, embed_dim, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.projection_dim = embed_dim // num_heads
        
        # Linear transformation layers for Q, K, V
        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)
        
        # Final layer to combine outputs from all attention heads
        self.combine_heads = Dense(embed_dim)

    def attention(self, query, key, value, mask=None):
        """
        Core attention computation with causal masking support
        """
        # Compute attention scores
        score = tf.matmul(query, key, transpose_b=True)
        
        # Scale by sqrt(d_k) to prevent vanishing gradients
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        
        # Apply causal mask if provided
        if mask is not None:
            scaled_score += (mask * -1e9)
        
        # Apply softmax to get attention weights
        weights = tf.nn.softmax(scaled_score, axis=-1)
        
        # Apply attention weights to values
        output = tf.matmul(weights, value)
        
        return output, weights

    def split_heads(self, x, batch_size):
        """
        Split embedding dimension into multiple heads for parallel processing
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs, kv_cache=None, use_cache=False, training=False):
        """
        Forward pass with KV cache optimization for efficient text generation
        """
        batch_size = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]
        
        # Compute Q, K, V transformations
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        
        # Split into multiple heads
        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)
        
        # KV Cache optimization
        if use_cache and kv_cache is not None:
            cached_key = kv_cache.get('key')
            cached_value = kv_cache.get('value')
            
            if cached_key is not None and cached_value is not None:
                key = tf.concat([cached_key, key], axis=2)
                value = tf.concat([cached_value, value], axis=2)
        
        # Create causal mask for autoregressive generation
        total_seq_len = tf.shape(key)[2]
        mask = tf.linalg.band_part(tf.ones((seq_len, total_seq_len)), -1, 0)
        mask = tf.where(mask == 0, 1.0, 0.0)
        
        # Apply attention mechanism
        attention_output, attention_weights = self.attention(query, key, value, mask)
        
        # Combine attention heads
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention_output, (batch_size, -1, self.embed_dim))
        output = self.combine_heads(concat_attention)
        
        # Update cache for next iteration
        new_cache = {'key': key, 'value': value} if use_cache else None
        
        return output, new_cache

### 🧠 **Deep Dive: Multi-Head Self-Attention - The Revolutionary Mechanism**

The Multi-Head Self-Attention mechanism is the cornerstone innovation that makes Transformers so powerful. Let's dissect every component and understand the profound mathematical and computational principles at work.

#### **🎯 The Attention Revolution: Why It Changed Everything**

**Before Attention (RNNs/LSTMs):**
- Sequential processing: word₁ → word₂ → word₃ → ... → wordₙ
- Information bottleneck: distant words lose context
- No parallelization: must process sequentially
- Vanishing gradients: long-range dependencies are difficult to learn

**With Self-Attention (Transformers):**
- Parallel processing: all words attend to all other words simultaneously
- Direct connections: word₁ can directly influence wordₙ
- Dynamic relationships: attention weights adapt based on context
- Scalable: O(n²) complexity but highly parallelizable

#### **🔬 Mathematical Foundation: The Attention Equation**

**The Core Attention Formula:**
```
Attention(Q, K, V) = softmax(QK^T / √d_k)V
```

Let's break this down step by step:

**Step 1: Query-Key Similarity**
```python
scores = tf.matmul(query, key, transpose_b=True)  # QK^T
# Shape: [batch_size, num_heads, seq_len_q, seq_len_k]
```
- **Purpose**: Compute similarity between each query and every key
- **Intuition**: "How much should each word pay attention to every other word?"
- **Mathematical Meaning**: Dot product measures vector similarity/alignment

**Step 2: Scaled Attention (Critical for Stability)**
```python
scaled_scores = scores / tf.math.sqrt(dim_key)  # QK^T / √d_k
```
- **Why Scaling is Essential**: 
  - Without scaling: dot products grow with dimension size
  - Large values → softmax saturation → vanishing gradients
  - √d_k scaling keeps variance approximately constant
- **Mathematical Insight**: If Q and K have unit variance, QK^T has variance d_k

**Step 3: Attention Weights via Softmax**
```python
attention_weights = tf.nn.softmax(scaled_scores, axis=-1)
```
- **Purpose**: Convert raw scores to probability distribution
- **Properties**: 
  - All weights sum to 1.0 for each query
  - Differentiable (enables gradient-based learning)
  - Emphasizes highest-scoring keys while maintaining some attention to others

**Step 4: Weighted Value Aggregation**
```python
output = tf.matmul(attention_weights, value)  # softmax(...)V
```
- **Purpose**: Combine value vectors based on attention weights
- **Result**: Each output position is a weighted combination of all input values

#### **🎭 Multi-Head Architecture: Parallel Attention Specialists**

**Why Multiple Heads?**
Single attention head limitations:
- Can only capture one type of relationship at a time
- May focus on dominant patterns, missing subtle interactions
- Limited representational capacity

**Multi-Head Solution:**
```python
# Instead of one 512-dimensional attention:
# Use 8 heads × 64 dimensions each = 512 total
self.num_heads = 8
self.projection_dim = embed_dim // num_heads  # 512 // 8 = 64
```

**Each Head Specializes in Different Relationships:**
- **Head 1**: Syntactic relationships (subject-verb agreement)
- **Head 2**: Semantic similarity (synonyms, related concepts)
- **Head 3**: Positional relationships (adjacent words)
- **Head 4**: Long-range dependencies (pronouns to antecedents)
- **Head 5-8**: Other learned patterns specific to the data

#### **🔄 Head Splitting and Recombination Process**

**`split_heads()` Function Deep Analysis:**
```python
def split_heads(self, x, batch_size):
    # Input shape: [batch_size, seq_len, embed_dim]
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
    # New shape: [batch_size, seq_len, num_heads, projection_dim]
    
    return tf.transpose(x, perm=[0, 2, 1, 3])
    # Final shape: [batch_size, num_heads, seq_len, projection_dim]
```

**Transformation Visualization:**
```
Original: [Batch, Sequence, 512]

Reshape: [Batch, Sequence, 8_heads, 64_per_head]

Transpose: [Batch, 8_heads, Sequence, 64_per_head]
```

**Why This Transformation?**
- Enables parallel processing of all attention heads
- Each head operates on its own 64-dimensional subspace
- Maintains batch and sequence dimensions for efficient computation

#### **⚡ KV Cache: The Performance Game-Changer**

**The Autoregressive Generation Problem:**
```python
# Without KV Cache (Inefficient):
# Step 1: Process "The cat"
# Step 2: Process "The cat sat" (recomputes "The cat" keys/values)
# Step 3: Process "The cat sat on" (recomputes everything again)
# Step N: Exponentially growing computation!
```

**KV Cache Solution:**
```python
# With KV Cache (Efficient):
# Step 1: Process "The cat", cache K₁,V₁
# Step 2: Process "sat", use cached K₁,V₁ + new K₂,V₂
# Step 3: Process "on", use cached K₁,V₁,K₂,V₂ + new K₃,V₃
# Step N: Linear computation growth!
```

**Implementation Details:**
```python
if use_cache and kv_cache is not None:
    cached_key = kv_cache.get('key')      # Previously computed keys
    cached_value = kv_cache.get('value')  # Previously computed values
    
    if cached_key is not None and cached_value is not None:
        # Concatenate old and new keys/values
        key = tf.concat([cached_key, key], axis=2)    # Along sequence dimension
        value = tf.concat([cached_value, value], axis=2)
```

**Performance Impact:**
- **Without Cache**: O(n²) computation for each new token
- **With Cache**: O(n) computation for each new token
- **Speed Improvement**: 10x to 100x faster generation for long sequences

#### **🎭 Causal Masking: Preventing Future Information Leakage**

**The Causality Requirement:**
In autoregressive generation, token at position i should only attend to positions ≤ i

**Mask Creation:**
```python
# Create lower triangular matrix
mask = tf.linalg.band_part(tf.ones((seq_len, total_seq_len)), -1, 0)
# Convert 0s to large negative values, 1s to 0s
mask = tf.where(mask == 0, 1.0, 0.0)
```

**Mask Visualization:**
```
Sequence: ["The", "cat", "sat", "on"]

Attention Matrix (before masking):
     The  cat  sat  on
The  [1.0  0.8  0.2  0.1]
cat  [0.3  1.0  0.9  0.4]
sat  [0.1  0.6  1.0  0.8]
on   [0.2  0.3  0.7  1.0]

Causal Mask:
     The  cat  sat  on
The  [0    -∞   -∞   -∞ ]
cat  [0    0    -∞   -∞ ]
sat  [0    0    0    -∞ ]
on   [0    0    0    0  ]

After Softmax (masked positions become 0):
     The  cat  sat  on
The  [1.0  0.0  0.0  0.0]
cat  [0.4  0.6  0.0  0.0]
sat  [0.1  0.3  0.6  0.0]
on   [0.1  0.2  0.3  0.4]
```

#### **🔧 Layer Architecture and Weight Initialization**

**Query, Key, Value Projections:**
```python
self.query_dense = Dense(embed_dim)  # X → Q transformation
self.key_dense = Dense(embed_dim)    # X → K transformation  
self.value_dense = Dense(embed_dim)  # X → V transformation
```

**Why Separate Projections?**
- **Learned Specialization**: Each projection learns different aspects
- **Query**: "What am I looking for?"
- **Key**: "What do I represent for others to find?"
- **Value**: "What information do I contribute?"

**Output Combination:**
```python
self.combine_heads = Dense(embed_dim)
```
- **Purpose**: Integrate information from all attention heads
- **Learnable Mixing**: Model learns optimal way to combine head outputs
- **Dimensionality**: Maps concatenated heads back to original embedding size

#### **🎪 Attention Patterns in Practice**

**Common Attention Patterns Discovered:**
1. **Local Attention**: Focus on nearby words (n-gram patterns)
2. **Syntactic Attention**: Connect grammatically related words
3. **Semantic Attention**: Link semantically similar concepts
4. **Positional Attention**: Attend to specific relative positions
5. **Global Attention**: Some heads attend broadly across the sequence

**Example Attention Visualization:**
```
Input: "The quick brown fox jumps over the lazy dog"

Head 1 (Syntactic): "fox" → "jumps" (subject-verb)
Head 2 (Semantic): "quick" → "fast", "lazy" → "slow"
Head 3 (Positional): Each word → previous word
Head 4 (Determiners): "the" → following nouns
```

This Multi-Head Self-Attention mechanism is what enables Transformers to:
- **Capture Complex Relationships**: Multiple types of dependencies simultaneously
- **Process in Parallel**: Unlike sequential RNNs
- **Scale Efficiently**: With KV cache optimization for generation
- **Learn Interpretable Patterns**: Different heads specialize in different linguistic phenomena

The combination of mathematical elegance, computational efficiency, and representational power makes this attention mechanism the foundation of modern NLP breakthroughs.

## Transformer Block

A complete transformer layer combining attention, feed-forward networks, and residual connections.

In [None]:
class TransformerBlock(Layer):
    """
    TRANSFORMER BLOCK - A complete transformer layer
    
    Combines:
    1. Multi-Head Self-Attention
    2. Feed-Forward Network
    3. Residual Connections
    4. Layer Normalization
    """
    
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate
        
        # Multi-head self-attention
        self.att = MultiHeadSelfAttention(embed_dim, num_heads)
        
        # Feed-forward network
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim, activation="relu"),
            Dense(embed_dim),
        ])
        
        # Layer normalization
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        
        # Dropout for regularization
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, kv_cache=None, use_cache=False, training=False):
        """
        Forward pass with residual connections and layer normalization
        """
        # Multi-head self-attention with residual connection
        attn_output, new_cache = self.att(
            inputs, 
            kv_cache=kv_cache, 
            use_cache=use_cache, 
            training=training
        )
        
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        
        # Feed-forward network with residual connection
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        output = self.layernorm2(out1 + ffn_output)
        
        return output, new_cache

### 🏗️ **Deep Dive: Transformer Block - The Complete Processing Unit**

The Transformer Block is where the magic happens - it's a sophisticated processing unit that combines multiple architectural innovations to create a powerful, trainable, and stable deep learning component. Let's dissect every aspect of this engineering marvel.

#### **🎯 The Transformer Block Architecture: A Symphony of Components**

**The Complete Processing Pipeline:**
```
Input → Layer Norm → Multi-Head Attention → Dropout → Residual Add
  ↓
Layer Norm → Feed-Forward Network → Dropout → Residual Add → Output
```

This architecture represents years of research refinements, with each component serving a critical purpose in the overall system.

#### **🔗 Residual Connections: The Highway to Deep Learning**

**The Vanishing Gradient Problem:**
In deep networks without residual connections:
- Gradients diminish exponentially as they backpropagate
- Deep layers receive virtually no learning signal
- Training becomes ineffective beyond ~6-10 layers
- Model performance degrades with increased depth

**Residual Connection Solution:**
```python
# Instead of: output = F(input)
# We use: output = F(input) + input

attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)  # Residual connection!

ffn_output = self.dropout2(ffn_output, training=training)
output = self.layernorm2(out1 + ffn_output)   # Another residual connection!
```

**Mathematical Foundation:**
```
Traditional: y = F(x)
Residual: y = F(x) + x

Gradient flow:
∂y/∂x = ∂F(x)/∂x + 1
```

**Why This Works:**
- **Gradient Highway**: The "+1" term ensures gradients always have a direct path
- **Identity Mapping**: If F(x) = 0, the layer becomes an identity function
- **Easier Optimization**: Model can learn when to use vs skip transformations
- **Stable Training**: Prevents gradient vanishing even in very deep networks

**Practical Benefits:**
- Enables training of 100+ layer networks
- Faster convergence during training
- Better gradient flow to early layers
- More stable optimization dynamics

#### **🧪 Layer Normalization: The Stability Engine**

**The Normalization Formula:**
```
LayerNorm(x) = γ × (x - μ) / σ + β

Where:
μ = mean(x)     # Mean across features
σ = std(x)      # Standard deviation across features
γ, β = learnable parameters
```

**Implementation Details:**
```python
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
```

**Why epsilon=1e-6?**
- Prevents division by zero when variance is very small
- Ensures numerical stability in edge cases
- Small enough to not affect normal operation

**Layer Norm vs Batch Norm:**

| Aspect | Batch Norm | Layer Norm |
|--------|------------|------------|
| **Normalization Axis** | Across batch dimension | Across feature dimension |
| **Sequence Handling** | Poor (variable lengths) | Excellent |
| **Training/Inference** | Different behavior | Consistent behavior |
| **Batch Size Dependency** | Yes (needs large batches) | No (works with batch=1) |
| **NLP Suitability** | Poor | Excellent |

**Why Layer Norm is Perfect for Transformers:**
1. **Variable Sequence Lengths**: Each sequence normalized independently
2. **Consistent Behavior**: Same computation during training and inference
3. **Batch Independence**: Works with any batch size, including 1
4. **Feature Stability**: Normalizes across embedding dimensions

#### **🏗️ Pre-Norm vs Post-Norm Architecture**

**Our Implementation (Pre-Norm):**
```python
# Pre-normalization: Norm → Function → Residual
normalized_input = self.layernorm1(inputs)
attn_output, new_cache = self.att(normalized_input, ...)
out1 = inputs + attn_output  # Residual connection
```

**Alternative (Post-Norm):**
```python
# Post-normalization: Function → Residual → Norm
attn_output, new_cache = self.att(inputs, ...)
out1 = self.layernorm1(inputs + attn_output)
```

**Pre-Norm Advantages (Why We Use It):**
- **Better Gradient Flow**: Direct path for gradients through residual connections
- **Training Stability**: Less prone to training instabilities
- **Faster Convergence**: Often converges faster than post-norm
- **Deeper Networks**: Enables training of very deep transformer models

#### **🧠 Feed-Forward Network: The Computational Powerhouse**

**Architecture:**
```python
self.ffn = tf.keras.Sequential([
    Dense(ff_dim, activation="relu"),  # Expansion layer
    Dense(embed_dim),                  # Projection layer
])
```

**The Two-Layer MLP Design:**

**Layer 1: Expansion (embed_dim → ff_dim)**
- **Purpose**: Expand representation to higher-dimensional space
- **Typical Ratio**: ff_dim = 4 × embed_dim (e.g., 256 → 1024)
- **Activation**: ReLU for non-linearity and computational efficiency
- **Mathematical Operation**: ReLU(xW₁ + b₁)

**Layer 2: Projection (ff_dim → embed_dim)**
- **Purpose**: Project back to original embedding dimension
- **No Activation**: Linear transformation for maximum expressiveness
- **Mathematical Operation**: (ReLU_output)W₂ + b₂

**Why This Architecture?**

**The 4x Expansion Principle:**
```
Input: [batch, seq_len, 256]     # Original embedding space
  ↓
Expand: [batch, seq_len, 1024]   # 4x larger intermediate space
  ↓
Project: [batch, seq_len, 256]   # Back to embedding space
```

**Benefits of Expansion:**
1. **Increased Capacity**: More parameters for complex transformations
2. **Non-Linear Mixing**: ReLU enables complex feature interactions
3. **Representation Learning**: Learns rich intermediate representations
4. **Computational Efficiency**: Balanced between capacity and speed

**ReLU Activation Choice:**
```python
ReLU(x) = max(0, x)
```

**Why ReLU over other activations?**
- **Computational Efficiency**: Simple max(0, x) operation
- **Gradient Properties**: No vanishing gradient for positive values
- **Sparsity**: Creates sparse representations (many zeros)
- **Empirical Success**: Proven effective in transformer architectures

#### **🎪 Dropout: The Regularization Maestro**

**Dropout Implementation:**
```python
self.dropout1 = Dropout(rate)  # After attention
self.dropout2 = Dropout(rate)  # After feed-forward
```

**Dropout Mechanism:**
- **Training**: Randomly set neurons to 0 with probability `rate`
- **Inference**: Scale outputs by (1 - rate) to maintain expected values
- **Purpose**: Prevent overfitting and improve generalization

**Strategic Placement:**
```python
# After attention computation
attn_output = self.dropout1(attn_output, training=training)

# After feed-forward computation
ffn_output = self.dropout2(ffn_output, training=training)
```

**Why These Specific Locations?**
1. **After Complex Computations**: Applied after the most complex operations
2. **Before Residual Addition**: Prevents dropout from affecting residual paths
3. **Balanced Regularization**: Regularizes both attention and feed-forward paths

#### **🔄 The Complete Forward Pass: Step-by-Step Analysis**

**Step 1: Attention Sub-Layer**
```python
# 1. Apply layer normalization to input
normalized_input = self.layernorm1(inputs)

# 2. Multi-head self-attention
attn_output, new_cache = self.att(
    normalized_input, 
    kv_cache=kv_cache, 
    use_cache=use_cache, 
    training=training
)

# 3. Apply dropout for regularization
attn_output = self.dropout1(attn_output, training=training)

# 4. Residual connection
out1 = inputs + attn_output
```

**Step 2: Feed-Forward Sub-Layer**
```python
# 1. Apply layer normalization
normalized_out1 = self.layernorm2(out1)

# 2. Feed-forward network
ffn_output = self.ffn(normalized_out1)

# 3. Apply dropout
ffn_output = self.dropout2(ffn_output, training=training)

# 4. Final residual connection
output = out1 + ffn_output
```

#### **📊 Information Flow and Representation Evolution**

**Representation Evolution Through the Block:**
```
Input Representation:
- Raw token embeddings + positional encoding
- Contains basic semantic and positional information

After Attention:
- Context-aware representations
- Each token knows about relevant other tokens
- Relationship information encoded

After Feed-Forward:
- Non-linearly transformed representations
- Complex feature combinations
- Task-specific patterns learned

Output Representation:
- Rich, context-aware, non-linearly processed features
- Ready for next transformer block or final prediction
```

#### **🎯 Design Principles and Trade-offs**

**Key Design Decisions:**

1. **Pre-Norm Architecture**: Better training stability vs slightly lower final performance
2. **4x FF Expansion**: Balance between capacity and computational cost
3. **ReLU Activation**: Simplicity and efficiency vs potentially richer activations
4. **Dropout Placement**: Effective regularization vs potential information loss

**Computational Complexity:**
```
Attention: O(n² × d) where n = sequence length, d = embedding dimension
Feed-Forward: O(n × d × ff_dim) = O(n × d²) typically
Total per block: O(n² × d + n × d²)
```

**Scaling Considerations:**
- **Short sequences**: Feed-forward dominates computation
- **Long sequences**: Attention dominates computation
- **Large models**: Both components scale significantly

This Transformer Block design represents the culmination of deep learning research, combining:
- **Attention mechanisms** for relationship modeling
- **Residual connections** for gradient flow
- **Layer normalization** for training stability
- **Feed-forward networks** for non-linear transformation
- **Dropout regularization** for generalization

The result is a powerful, trainable, and scalable building block that can be stacked to create models of arbitrary depth and complexity.

## Complete Transformer Model

The full transformer model with embedding, positional encoding, and multiple transformer blocks.

In [None]:
class TransformerModel(Model):
    """
    COMPLETE TRANSFORMER MODEL FOR TEXT GENERATION
    
    This implements a decoder-only transformer suitable for autoregressive text generation.
    Features include KV cache optimization and proper causal masking.
    """
    
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, seq_length):
        super(TransformerModel, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.num_layers = num_layers
        self.seq_length = seq_length
        
        # Token embedding layer
        self.embedding = Embedding(vocab_size, embed_dim)
        
        # Positional encoding
        self.pos_encoding = self.positional_encoding(seq_length, embed_dim)
        
        # Stack of transformer blocks
        self.transformer_blocks = [
            TransformerBlock(embed_dim, num_heads, ff_dim) 
            for _ in range(num_layers)
        ]
        
        # Output projection layer
        self.dense = Dense(vocab_size)

    def positional_encoding(self, seq_length, embed_dim):
        """
        Generate sinusoidal positional encoding
        """
        angle_rads = self.get_angles(
            np.arange(seq_length)[:, np.newaxis],
            np.arange(embed_dim)[np.newaxis, :],
            embed_dim
        )
        
        # Apply sine to even indices
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        
        # Apply cosine to odd indices
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
        
        pos_encoding = angle_rads[np.newaxis, ...]
        
        return tf.cast(pos_encoding, dtype=tf.float32)

    def get_angles(self, pos, i, embed_dim):
        """
        Calculate angles for positional encoding
        """
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(embed_dim))
        return pos * angle_rates

    def call(self, inputs, kv_cache=None, use_cache=False, start_pos=0, training=False):
        """
        Forward pass with KV cache support for efficient generation
        """
        seq_len = tf.shape(inputs)[1]
        
        # Handle positional encoding based on generation phase
        if start_pos > 0 and start_pos < self.seq_length:
            # Decode phase: specific positions
            pos_encoding = self.pos_encoding[:, start_pos:start_pos + seq_len, :]
        else:
            # Training or prefill phase
            if start_pos >= self.seq_length:
                # Handle out-of-bounds positions
                pos_encoding = self.pos_encoding[:, -1:, :]
                pos_encoding = tf.tile(pos_encoding, [1, seq_len, 1])
            else:
                pos_encoding = self.pos_encoding[:, :seq_len, :]
        
        # Convert tokens to embeddings and add positional encoding
        x = self.embedding(inputs)
        x += pos_encoding
        
        # Initialize KV cache if needed
        if use_cache and kv_cache is None:
            kv_cache = [None] * self.num_layers
        
        # Process through transformer layers
        new_caches = []
        for i, transformer_block in enumerate(self.transformer_blocks):
            layer_cache = kv_cache[i] if kv_cache else None
            
            x, new_cache = transformer_block(
                x, 
                kv_cache=layer_cache, 
                use_cache=use_cache, 
                training=training
            )
            
            new_caches.append(new_cache)
        
        # Project to vocabulary probabilities
        output = self.dense(x)
        
        if use_cache:
            return output, new_caches
        else:
            return output

## Data Preparation and Utility Functions

Functions for creating training sequences and managing data sources.

In [None]:
def create_sequences(text, seq_length):
    """
    Create training sequences for language modeling
    
    Creates input-target pairs where target is input shifted by one position
    for next-token prediction training.
    """
    input_seqs = []
    target_seqs = []
    
    # Create overlapping sequences
    for i in range(len(text) - seq_length):
        input_seq = text[i:i + seq_length]
        target_seq = text[i + 1:i + seq_length + 1]
        
        input_seqs.append(input_seq)
        target_seqs.append(target_seq)
    
    return np.array(input_seqs), np.array(target_seqs)


def load_corpus(corpus_source):
    """
    Load text corpus from web or local file
    """
    if corpus_source == "web":
        print("Loading Shakespeare dataset from web...")
        path_to_file = get_file(
            'shakespeare.txt',
            'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt'
        )
        text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
        print(f"Web dataset loaded. Text length: {len(text)} characters")
        
    elif corpus_source == "local":
        print("Loading corpus from local file 'corpus.txt'...")
        try:
            with open('corpus.txt', 'r', encoding='utf-8') as f:
                text = f.read()
            print(f"Local corpus loaded. Text length: {len(text)} characters")
        except FileNotFoundError:
            raise FileNotFoundError(
                "corpus.txt not found. Please make sure the file exists in the current directory."
            )
        except UnicodeDecodeError:
            print("Failed to decode with UTF-8, trying with latin-1...")
            with open('corpus.txt', 'r', encoding='latin-1') as f:
                text = f.read()
            print(f"Local corpus loaded with latin-1 encoding. Text length: {len(text)} characters")
    else:
        raise ValueError(
            f"Invalid corpus source: {corpus_source}. Must be 'web' or 'local'."
        )
    
    return text

## Text Generation with KV Cache Optimization

Advanced text generation function with KV cache support for dramatic speed improvements.

In [None]:
def generate_text_with_kv_cache(model, vectorizer, start_string, seq_length, 
                                num_generate=100, temperature=1.0, use_kv_cache=True):
    """
    OPTIMIZED TEXT GENERATION WITH KV CACHE SUPPORT
    
    This function demonstrates the significant performance improvement possible with
    KV caching during autoregressive generation.
    
    Key Features:
    - Prefill phase: Process initial prompt, build attention cache
    - Decode phase: Generate tokens one by one, reusing cached computations
    - Performance comparison between cached and standard generation
    
    Parameters:
    - model: Trained transformer model
    - vectorizer: Text preprocessing pipeline
    - start_string: Initial text to begin generation
    - seq_length: Model's expected sequence length
    - num_generate: Number of tokens to generate
    - temperature: Sampling temperature for creativity control
    - use_kv_cache: Whether to use KV cache optimization
    """
    
    # Preprocess input text
    input_eval = vectorizer([start_string]).numpy()
    
    # Handle sequence length mismatches
    if input_eval.shape[1] < seq_length:
        # Pad with zeros at the beginning
        padding = np.zeros((1, seq_length - input_eval.shape[1]))
        input_eval = np.concatenate((padding, input_eval), axis=1)
    elif input_eval.shape[1] > seq_length:
        # Truncate to keep only the last tokens
        input_eval = input_eval[:, -seq_length:]

    input_eval = tf.convert_to_tensor(input_eval)
    text_generated = []
    vocab = vectorizer.get_vocabulary()
    
    if use_kv_cache:
        print("Using KV Cache for generation...")
        
        # PREFILL PHASE: Process prompt and build cache
        predictions, kv_cache = model(
            input_eval, 
            use_cache=True,
            start_pos=0,
            training=False
        )
        
        # Sample first token
        last_predictions = predictions[0, -1, :]
        last_predictions = last_predictions / temperature
        predicted_id = tf.random.categorical(
            tf.expand_dims(last_predictions, 0), 
            num_samples=1
        )[0, 0].numpy()
        
        if predicted_id < len(vocab):
            text_generated.append(vocab[predicted_id])
        
        current_pos = input_eval.shape[1]
        
        # DECODE PHASE: Generate tokens using cache
        for i in range(num_generate - 1):
            next_token = tf.convert_to_tensor([[predicted_id]])
            
            # Generate next token using cached keys/values
            predictions, kv_cache = model(
                next_token,
                kv_cache=kv_cache,
                use_cache=True,
                start_pos=current_pos,
                training=False
            )
            
            last_predictions = predictions[0, -1, :]
            last_predictions = last_predictions / temperature
            predicted_id = tf.random.categorical(
                tf.expand_dims(last_predictions, 0), 
                num_samples=1
            )[0, 0].numpy()
            
            if predicted_id < len(vocab):
                text_generated.append(vocab[predicted_id])
            
            current_pos += 1
            
    else:
        print("Using standard generation (no KV cache)...")
        
        # STANDARD GENERATION: No cache optimization
        for i in range(num_generate):
            predictions = model(input_eval, use_cache=False, training=False)
            predictions = predictions[0, -1, :]
            predictions = predictions / temperature
            predicted_id = tf.random.categorical(
                tf.expand_dims(predictions, 0), 
                num_samples=1
            )[0, 0].numpy()

            # Update input sequence
            input_eval = np.append(input_eval.numpy(), [[predicted_id]], axis=1)
            input_eval = input_eval[:, -seq_length:]
            input_eval = tf.convert_to_tensor(input_eval)

            if predicted_id < len(vocab):
                text_generated.append(vocab[predicted_id])

    return start_string + ' ' + ' '.join(text_generated)


def generate_text(model, vectorizer, start_string, seq_length, num_generate=100, temperature=1.0):
    """
    Legacy text generation function for backward compatibility
    """
    return generate_text_with_kv_cache(
        model, vectorizer, start_string, seq_length, 
        num_generate, temperature, use_kv_cache=False
    )

## Training Pipeline

Complete training pipeline with monitoring, visualization, and model persistence.

In [None]:
def train():
    """
    COMPREHENSIVE MODEL TRAINING PIPELINE
    
    Implements complete training workflow:
    1. Data loading and preprocessing
    2. Model architecture setup
    3. Training with callbacks and logging
    4. Visualization and monitoring
    5. Model persistence
    """
    
    # Configuration
    corpus_source = "local"  # Options: "web" for Shakespeare, "local" for corpus.txt
    
    # Load and preprocess data
    text = load_corpus(corpus_source)
    print("Preview of the dataset:")
    print(text[:500])

    # Text preprocessing configuration
    vocab_size = 10000
    seq_length = 100
    
    # Create text vectorizer
    vectorizer = TextVectorization(
        max_tokens=vocab_size,
        output_mode='int'
    )
    
    # Adapt vectorizer to text data
    text_ds = tf.data.Dataset.from_tensor_slices([text]).batch(1)
    vectorizer.adapt(text_ds)

    # Convert text to numerical format
    vectorized_text = vectorizer([text])[0]
    print(f"Vectorized text shape: {vectorized_text.shape}")
    print(f"First 10 vectorized tokens: {vectorized_text.numpy()[:10]}")

    # Generate training sequences
    X, Y = create_sequences(vectorized_text.numpy(), seq_length)
    print(f"Number of sequences generated: {len(X)}")
    
    assert X.size > 0, "Input data X is empty"
    assert Y.size > 0, "Target data Y is empty"
    
    X = tf.convert_to_tensor(X)
    Y = tf.convert_to_tensor(Y)
    print(f"Shape of X: {X.shape}")
    print(f"Shape of Y: {Y.shape}")

    # Model architecture configuration
    embed_dim = 256
    num_heads = 4
    ff_dim = 512
    num_layers = 4

    # Create the transformer model
    model = TransformerModel(
        vocab_size, embed_dim, num_heads, ff_dim, num_layers, seq_length
    )

    # Build the model
    _ = model(tf.random.uniform((1, seq_length), maxval=vocab_size, dtype=tf.int32))

    # Configure training
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy'
    )
    
    model.summary()

    # Setup logging
    logdir = os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S"))
    os.makedirs(logdir, exist_ok=True)
    
    tensorboard_cb = tf.keras.callbacks.TensorBoard(
        log_dir=logdir,
        histogram_freq=1,
        write_graph=True,
        write_images=True,
        update_freq='epoch'
    )
    
    print(f"TensorBoard logs in: {os.path.abspath(logdir)}")
    print(f"Run: tensorboard --logdir {logdir}")

    # Training callbacks
    early_stopping = EarlyStopping(
        monitor='loss',
        patience=2,
        restore_best_weights=True
    )
    
    print("Starting training...")
    
    # Execute training
    history = model.fit(
        X, Y,
        epochs=20,
        batch_size=32,
        callbacks=[early_stopping, tensorboard_cb]
    )

    print("Training completed!")

    # Save model weights
    weights_save_path = "transformer_model.weights.h5"
    model.save_weights(weights_save_path)
    print(f"Model weights saved to: {weights_save_path}")

    # Save vectorizer and model configuration
    import pickle
    vectorizer_path = "text_vectorizer.pkl"
    
    model_metadata = {
        'vectorizer': vectorizer,
        'vocab_size': vocab_size,
        'seq_length': seq_length,
        'embed_dim': embed_dim,
        'num_heads': num_heads,
        'ff_dim': ff_dim,
        'num_layers': num_layers
    }
    
    with open(vectorizer_path, 'wb') as f:
        pickle.dump(model_metadata, f)
    print(f"Vectorizer and model parameters saved to: {vectorizer_path}")

    # Create training visualization
    plt.figure(figsize=(10, 6))
    plt.plot(history.history['loss'], label='Training Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training Loss Over Time', fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plot_path = os.path.join(logdir, 'training_loss.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Training loss plot saved to: {plot_path}")

    return model, vectorizer, vocab_size, seq_length, logdir

## Main Function and Execution Modes

The main function provides different execution modes to run the transformer model. This allows you to easily switch between training, generation, and performance comparison modes.

In [None]:
def main():
    """
    MAIN EXECUTION FUNCTION - CONTROLS PROGRAM FLOW
    
    This function serves as the entry point for the transformer model implementation.
    It provides four different execution modes to accommodate various use cases:
    
    1. Training Mode: Train a new model from scratch
    2. Generation Mode: Generate text using a pre-trained model
    3. Performance Comparison Mode: Compare KV cache vs standard generation
    4. Both Mode: Train a model and then run performance comparison
    
    Configuration Variables:
    - mode: Determines which operation to perform
    - use_kv_cache: Enables/disables KV cache optimization
    - weights_path: Path to saved model weights
    - vectorizer_path: Path to saved text vectorizer
    """
    
    # Configuration variables - Change these as needed
    mode = 'generate_compare'  # Options: 'train', 'generate', 'generate_compare', 'both'
    use_kv_cache = True  # Set to False to disable KV cache optimization
    weights_path = 'transformer_model.weights.h5'
    vectorizer_path = 'text_vectorizer.pkl'
    
    if mode == 'train':
        print("="*60)
        print("TRAINING MODE")
        print("="*60)
        train()
        
    elif mode == 'generate':
        print("="*60)
        print("GENERATION MODE")
        print("="*60)
        generate(use_kv_cache=use_kv_cache, 
                weights_path=weights_path, 
                vectorizer_path=vectorizer_path)
        
    elif mode == 'generate_compare':
        print("="*60)
        print("PERFORMANCE COMPARISON MODE")
        print("="*60)
        generate_compare(weights_path=weights_path, 
                        vectorizer_path=vectorizer_path)
        
    else:  # both
        print("="*60)
        print("TRAINING MODE")
        print("="*60)
        model, vectorizer, vocab_size, seq_length, logdir = train()
        
        print("\n" + "="*60)
        print("PERFORMANCE COMPARISON MODE")
        print("="*60)
        generate_compare(weights_path=weights_path, 
                        vectorizer_path=vectorizer_path)


if __name__ == "__main__":
    main()

## Understanding the Execution Modes

The main function provides four distinct execution modes, each designed for different use cases:

### 1. **Training Mode** (`mode = 'train'`)
- **Purpose**: Train a new transformer model from scratch
- **What it does**:
  - Loads text corpus (either from web or local file)
  - Preprocesses text data and creates training sequences
  - Builds and trains the transformer model
  - Saves model weights and vectorizer for later use
  - Creates training visualizations and logs
- **Use when**: You want to train a new model or retrain with different data
- **Requirements**: Text corpus (either `corpus.txt` file or web download)

### 2. **Generation Mode** (`mode = 'generate'`)
- **Purpose**: Generate text using a pre-trained model
- **What it does**:
  - Loads pre-trained model weights and vectorizer
  - Generates text with or without KV cache optimization
  - Allows you to experiment with different prompts and parameters
- **Use when**: You have a trained model and want to generate text
- **Requirements**: Pre-trained model weights and vectorizer files
- **Configuration**: Set `use_kv_cache=True` for faster generation

### 3. **Performance Comparison Mode** (`mode = 'generate_compare'`)
- **Purpose**: Compare KV cache optimization vs standard generation
- **What it does**:
  - Loads pre-trained model
  - Generates identical text using both methods
  - Measures and compares generation times
  - Demonstrates the performance benefits of KV caching
- **Use when**: You want to see the performance improvement from KV cache
- **Requirements**: Pre-trained model weights and vectorizer files
- **Output**: Side-by-side comparison with timing metrics

### 4. **Both Mode** (`mode = 'both'`)
- **Purpose**: Complete workflow from training to performance demonstration
- **What it does**:
  - First runs training mode (trains new model)
  - Then automatically runs performance comparison
  - Provides end-to-end demonstration
- **Use when**: You want to see the complete pipeline in action
- **Requirements**: Text corpus for training
- **Duration**: Longest execution time (includes full training)

## Configuration Options

### Key Parameters to Modify:

```python
mode = 'generate_compare'  # Change this to select execution mode
use_kv_cache = True        # Enable/disable KV cache in generation mode
weights_path = 'transformer_model.weights.h5'  # Path to model weights
vectorizer_path = 'text_vectorizer.pkl'        # Path to text vectorizer
```

### File Requirements by Mode:

| Mode | Required Files | Generated Files |
|------|----------------|----------------|
| `train` | `corpus.txt` (if using local corpus) | `transformer_model.weights.h5`, `text_vectorizer.pkl`, logs |
| `generate` | `transformer_model.weights.h5`, `text_vectorizer.pkl` | Generated text output |
| `generate_compare` | `transformer_model.weights.h5`, `text_vectorizer.pkl` | Performance comparison results |
| `both` | `corpus.txt` (if using local corpus) | All training files + comparison results |

## Quick Start Guide

1. **First Time Usage**: Set `mode = 'train'` to train a new model
2. **Text Generation**: Set `mode = 'generate'` to create text with your trained model
3. **Performance Testing**: Set `mode = 'generate_compare'` to see KV cache benefits
4. **Full Demo**: Set `mode = 'both'` for complete training and testing workflow

Simply change the `mode` variable in the main function and run the script to switch between different operations!