[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vuhung16au/hf-transformer-trove/blob/main/examples/basic1.4/encoder-decoder.ipynb)
[![View on GitHub](https://img.shields.io/badge/View_on-GitHub-blue?logo=github)](https://github.com/vuhung16au/hf-transformer-trove/blob/main/examples/basic1.4/encoder-decoder.ipynb)

# Encoder-Decoder Architecture: Understanding Sequence-to-Sequence Models

## 🎯 Learning Objectives
By the end of this notebook, you will understand:
- The fundamental architecture of encoder-decoder models
- How attention mechanisms work in sequence-to-sequence tasks
- Practical implementation of T5 and BART encoder-decoder models
- Auto-regressive generation and cross-attention in practice
- Mathematical concepts behind encoder-decoder transformers

## 📋 Prerequisites
- Basic understanding of machine learning concepts
- Familiarity with Python and PyTorch
- Knowledge of NLP fundamentals (refer to [NLP Learning Journey](https://github.com/vuhung16au/nlp-learning-journey))

## 📚 What We'll Cover
1. **Setup**: Environment and device detection
2. **Architecture Overview**: Understanding the encoder-decoder mechanism
3. **Mathematical Foundations**: Key equations and concepts
4. **T5 Implementation**: Text-to-text transfer transformer
5. **BART Implementation**: Bidirectional and auto-regressive transformers
6. **Cross-Attention Visualization**: Understanding model attention patterns
7. **Auto-regressive Generation**: Step-by-step token generation
8. **Comparison**: Encoder vs Decoder vs Encoder-Decoder
9. **Summary**: Key takeaways and next steps

## What are Encoder-Decoder Models?

**Encoder-Decoder** architecture, also known as **Sequence-to-Sequence (Seq2Seq)** models, is designed for tasks where the input and output sequences are different, such as translation, summarization, and question answering.

### The Two-Phase Mechanism

1. **Encoder Phase**: Processes the input sequence and creates a numerical representation
2. **Decoder Phase**: Uses the encoder's representation to generate the output sequence

### Why Encoder-Decoders Excel

- **Separate Specialization**: Encoder understands input, decoder generates output
- **Different Modalities**: Can handle different input/output types and lengths
- **Cross-Attention**: Decoder can selectively attend to relevant parts of input
- **Auto-regressive**: Generates output token by token, using previous outputs


## 1. Setup and Environment

Let's start by importing the necessary libraries and setting up our environment.

In [None]:
# Import essential libraries
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    T5Tokenizer,
    T5ForConditionalGeneration,
    BartTokenizer,
    BartForConditionalGeneration,
    pipeline
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

print("✅ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Device detection for optimal performance
def get_device():
    """
    Get the best available device for PyTorch operations.
    
    Priority order: CUDA > MPS (Apple Silicon) > CPU
    
    Returns:
        torch.device: The optimal device for current hardware
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name()}")
        # Print GPU memory info
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB total")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("🍎 Using Apple MPS for Apple Silicon optimization")
    else:
        device = torch.device("cpu")
        print("💻 Using CPU (consider GPU for better performance)")
    
    return device

# Get the optimal device
device = get_device()
print(f"\n📱 Active device: {device}")

## 2. Mathematical Foundations

Let's understand the key mathematical concepts behind encoder-decoder models.

### Encoder Architecture
The encoder processes input sequence $X = (x_1, x_2, ..., x_n)$ to produce representations $H = (h_1, h_2, ..., h_n)$:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

### Decoder with Cross-Attention
The decoder uses three types of attention:

1. **Masked Self-Attention**: $\text{MaskedAttn}(Q, K, V) = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right)V$

2. **Cross-Attention**: $\text{CrossAttn}(Q_{dec}, K_{enc}, V_{enc}) = \text{softmax}\left(\frac{Q_{dec}K_{enc}^T}{\sqrt{d_k}}\right)V_{enc}$

3. **Auto-regressive Generation**: $P(y_t | y_{<t}, X) = \text{softmax}(\text{Decoder}(y_{<t}, X)W_{vocab})$

In [None]:
def demonstrate_attention_mechanism():
    """
    Demonstrate the mathematical concepts behind attention mechanisms.
    """
    print("🧮 MATHEMATICAL DEMONSTRATION: Attention Mechanism")
    print("=" * 55)
    
    # Create sample sequences
    seq_len = 4
    d_model = 6  # Small dimension for demonstration
    
    # Generate random query, key, value matrices
    torch.manual_seed(42)  # For reproducible results
    Q = torch.randn(1, seq_len, d_model)
    K = torch.randn(1, seq_len, d_model)
    V = torch.randn(1, seq_len, d_model)
    
    print(f"Input dimensions:")
    print(f"  Sequence length: {seq_len}")
    print(f"  Model dimension: {d_model}")
    
    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_model)
    print(f"\nStep 1: Attention scores (Q @ K^T / √d_k)")
    print(f"  Score matrix shape: {scores.shape}")
    print(f"  Score range: [{scores.min():.3f}, {scores.max():.3f}]")
    
    # Step 2: Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    print(f"\nStep 2: Attention weights (softmax of scores)")
    print(f"  Each row sums to 1.0: {attention_weights.sum(dim=-1)}")
    
    # Step 3: Apply attention to values
    attention_output = torch.matmul(attention_weights, V)
    print(f"\nStep 3: Attention output (weights @ V)")
    print(f"  Output shape: {attention_output.shape}")
    
    # Visualize attention weights
    plt.figure(figsize=(8, 6))
    
    plt.subplot(1, 2, 1)
    sns.heatmap(scores[0].detach().numpy(), annot=True, fmt='.2f', cmap='RdBu_r')
    plt.title('Attention Scores\n(Q @ K^T / √d_k)')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    
    plt.subplot(1, 2, 2)
    sns.heatmap(attention_weights[0].detach().numpy(), annot=True, fmt='.2f', cmap='Blues')
    plt.title('Attention Weights\n(Softmax of Scores)')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    
    plt.tight_layout()
    plt.show()
    
    return attention_weights, attention_output

# Run the demonstration
weights, output = demonstrate_attention_mechanism()

## 3. T5: Text-to-Text Transfer Transformer

T5 treats every NLP task as a text-to-text problem. It uses explicit task prefixes to indicate the desired operation.

In [None]:
# Load T5 model and tokenizer
print("🔄 Loading T5 model for text-to-text generation...")

try:
    # Using T5-small for demonstration (faster and less memory)
    t5_model_name = "t5-small"
    t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
    t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)
    
    # Move model to device
    t5_model = t5_model.to(device)
    t5_model.eval()  # Set to evaluation mode
    
    print(f"✅ T5 model loaded successfully")
    print(f"   Model parameters: {t5_model.num_parameters():,}")
    print(f"   Encoder layers: {t5_model.config.num_layers}")
    print(f"   Decoder layers: {t5_model.config.num_decoder_layers}")
    print(f"   Attention heads: {t5_model.config.num_heads}")
    print(f"   Model dimension: {t5_model.config.d_model}")
    
    t5_available = True
    
except Exception as e:
    print(f"❌ Error loading T5 model: {e}")
    t5_available = False

In [None]:
def demonstrate_t5_tasks(model, tokenizer, device):
    """
    Demonstrate various T5 tasks showing encoder-decoder versatility.
    """
    print("📝 T5 ENCODER-DECODER DEMONSTRATION")
    print("=" * 40)
    
    # Define various tasks with their prefixes
    tasks = [
        {
            "name": "Translation",
            "input": "translate English to German: The weather is beautiful today.",
            "description": "Cross-lingual sequence transformation"
        },
        {
            "name": "Summarization", 
            "input": "summarize: Machine learning is a subset of artificial intelligence that involves training algorithms to make predictions or decisions based on data. It has applications in many fields including healthcare, finance, and technology.",
            "description": "Text compression while preserving key information"
        },
        {
            "name": "Question Answering",
            "input": "question: What is machine learning? context: Machine learning is a method of data analysis that automates analytical model building using algorithms that iteratively learn from data.",
            "description": "Information extraction from context"
        },
        {
            "name": "Sentiment Classification",
            "input": "sentiment: I absolutely love this new smartphone! The camera quality is amazing.",
            "description": "Sequence-to-label classification"
        }
    ]
    
    for task in tasks:
        print(f"\n🎯 Task: {task['name']}")
        print(f"   Purpose: {task['description']}")
        print(f"   Input: {task['input'][:80]}..." if len(task['input']) > 80 else f"   Input: {task['input']}")
        
        try:
            # Tokenize input
            inputs = tokenizer(
                task['input'], 
                return_tensors="pt", 
                max_length=512, 
                truncation=True
            ).to(device)
            
            # Generate output
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_length=50,
                    num_beams=3,
                    early_stopping=True,
                    temperature=0.7
                )
            
            # Decode output
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"   Output: {generated_text}")
            
        except Exception as e:
            print(f"   ❌ Error: {e}")

if t5_available:
    demonstrate_t5_tasks(t5_model, t5_tokenizer, device)
else:
    print("❌ T5 model not available for demonstration")

## 4. Auto-regressive Generation Explained

Let's understand how encoder-decoder models generate text step by step in an auto-regressive manner.

In [None]:
def demonstrate_autoregressive_generation(model, tokenizer, input_text, device, max_steps=8):
    """
    Demonstrate step-by-step auto-regressive generation.
    """
    print("🔄 AUTO-REGRESSIVE GENERATION STEP-BY-STEP")
    print("=" * 50)
    print(f"Input: {input_text}\n")
    
    try:
        # Encode input
        encoder_inputs = tokenizer(input_text, return_tensors="pt").to(device)
        
        # Get encoder outputs (these remain constant during generation)
        with torch.no_grad():
            encoder_outputs = model.get_encoder()(**encoder_inputs)
        
        print(f"📥 Encoder Phase Complete:")
        print(f"   Input tokens: {len(encoder_inputs.input_ids[0])}")
        print(f"   Encoder output shape: {encoder_outputs.last_hidden_state.shape}")
        print(f"   Each position now has a {encoder_outputs.last_hidden_state.shape[-1]}-dim representation\n")
        
        # Initialize decoder with start token
        decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], device=device)
        
        generated_tokens = []
        
        print("🎯 Decoder Phase (Auto-regressive Generation):")
        
        for step in range(max_steps):
            with torch.no_grad():
                # Run decoder with current sequence and encoder outputs
                decoder_outputs = model.get_decoder()(
                    input_ids=decoder_input_ids,
                    encoder_hidden_states=encoder_outputs.last_hidden_state,
                    encoder_attention_mask=encoder_inputs.attention_mask
                )
                
                # Get logits for next token prediction
                logits = model.lm_head(decoder_outputs.last_hidden_state)
                next_token_logits = logits[0, -1, :]  # Last position logits
                
                # Get top-k predictions
                top_k = 5
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                top_k_probs = F.softmax(top_k_logits, dim=-1)
                
                # Choose next token (greedy for demonstration)
                next_token_id = top_k_indices[0]
                next_token = tokenizer.decode(next_token_id, skip_special_tokens=True)
                
                print(f"   Step {step + 1}:")
                print(f"     Current sequence: {tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)}")
                print(f"     Next token: '{next_token}' (prob: {top_k_probs[0]:.3f})")
                print(f"     Top alternatives:")
                for i in range(1, min(3, top_k)):
                    alt_token = tokenizer.decode(top_k_indices[i], skip_special_tokens=True)
                    print(f"       '{alt_token}' (prob: {top_k_probs[i]:.3f})")
                
                # Stop if we hit end token or pad token
                if next_token_id == tokenizer.eos_token_id or next_token_id == tokenizer.pad_token_id:
                    print(f"     🛑 Generation stopped (end token reached)")
                    break
                
                # Add token to sequence
                decoder_input_ids = torch.cat(
                    [decoder_input_ids, next_token_id.unsqueeze(0).unsqueeze(0)], 
                    dim=-1
                )
                generated_tokens.append(next_token)
                
                print()  # Empty line for readability
        
        # Final generated text
        final_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
        print(f"\n✅ Final Generated Text: {final_text}")
        print(f"📊 Generation Statistics:")
        print(f"   Total steps: {len(generated_tokens)}")
        print(f"   Generated tokens: {generated_tokens}")
        
    except Exception as e:
        print(f"❌ Error in auto-regressive demonstration: {e}")

# Demonstrate auto-regressive generation with T5
if t5_available:
    demo_text = "summarize: Machine learning is a powerful technology."
    demonstrate_autoregressive_generation(t5_model, t5_tokenizer, demo_text, device, max_steps=6)
else:
    print("❌ T5 model not available for auto-regressive demonstration")

## 5. Architecture Comparison: Encoder vs Decoder vs Encoder-Decoder

Let's understand the differences between the three main transformer architectures and when to use each.

In [None]:
def compare_transformer_architectures():
    """
    Compare different transformer architectures with practical examples.
    """
    print("🏗️ TRANSFORMER ARCHITECTURE COMPARISON")
    print("=" * 45)
    
    architectures = {
        "Encoder-only (e.g., BERT)": {
            "attention": "Bi-directional self-attention",
            "training": "Masked Language Modeling (MLM)",
            "strengths": [
                "Excellent context understanding",
                "Great for classification tasks", 
                "Efficient inference",
                "Perfect for hate speech detection"
            ],
            "use_cases": [
                "Sentiment analysis",
                "Hate speech detection", 
                "Named entity recognition",
                "Text classification"
            ],
            "examples": "BERT, RoBERTa, DeBERTa, DistilBERT"
        },
        "Decoder-only (e.g., GPT)": {
            "attention": "Uni-directional (causal) self-attention",
            "training": "Causal Language Modeling (CLM)",
            "strengths": [
                "Excellent for text generation",
                "Auto-regressive by design",
                "Good few-shot learning",
                "Scales well with data/parameters"
            ],
            "use_cases": [
                "Text generation",
                "Creative writing",
                "Code generation", 
                "Conversational AI"
            ],
            "examples": "GPT-2, GPT-3/4, LLaMA, Mistral"
        },
        "Encoder-Decoder (e.g., T5)": {
            "attention": "Bi-directional + Cross-attention + Causal",
            "training": "Denoising/Sequence-to-sequence",
            "strengths": [
                "Best for sequence transformation",
                "Flexible input/output lengths",
                "Excellent for conditional generation",
                "Cross-attention provides interpretability"
            ],
            "use_cases": [
                "Machine translation",
                "Text summarization",
                "Question answering",
                "Paraphrasing"
            ],
            "examples": "T5, BART, mT5, PEGASUS"
        }
    }
    
    for arch_name, details in architectures.items():
        print(f"\n🔧 {arch_name}")
        print(f"   Attention Pattern: {details['attention']}")
        print(f"   Training Objective: {details['training']}")
        print(f"   Key Strengths:")
        for strength in details['strengths']:
            print(f"     • {strength}")
        print(f"   Best Use Cases:")
        for use_case in details['use_cases']:
            print(f"     • {use_case}")
        print(f"   Popular Models: {details['examples']}")
    
    # Performance comparison
    print("\n📊 PERFORMANCE CHARACTERISTICS")
    print("=" * 35)
    
    metrics = [
        ["Architecture", "Memory Usage", "Inference Speed", "Training Complexity"],
        ["Encoder-only", "Moderate ⚡", "Fast ⚡⚡", "Simple ✅"],
        ["Decoder-only", "Efficient ⚡⚡", "Fast ⚡⚡", "Simple ✅"],
        ["Encoder-Decoder", "High ⚠️", "Slower ⚠️", "Complex ⚠️"]
    ]
    
    for row in metrics:
        print(f"   {row[0]:15} | {row[1]:15} | {row[2]:15} | {row[3]}")
        if row[0] == "Architecture":
            print("   " + "-" * 70)

# Run the comparison
compare_transformer_architectures()

## 6. Practical Model Selection Guide

Let's create a practical guide for choosing the right model architecture for different tasks.

In [None]:
def model_selection_guide():
    """
    Provide practical guidance for model selection.
    """
    print("🎯 MODEL SELECTION DECISION TREE")
    print("=" * 40)
    
    decision_tree = {
        "Do you need to generate new text?": {
            "No (Classification/Understanding)": {
                "Architecture": "Encoder-only",
                "Examples": "BERT, RoBERTa, DeBERTa",
                "Perfect for": [
                    "Hate speech detection",
                    "Sentiment analysis", 
                    "Text classification",
                    "Named entity recognition"
                ]
            },
            "Yes (Text Generation)": {
                "Is your input different from output format?": {
                    "Yes (Different formats)": {
                        "Architecture": "Encoder-Decoder",
                        "Examples": "T5, BART, mT5",
                        "Perfect for": [
                            "Machine translation",
                            "Text summarization",
                            "Question answering",
                            "Code generation from text"
                        ]
                    },
                    "No (Same format continuation)": {
                        "Architecture": "Decoder-only",
                        "Examples": "GPT-2/3/4, LLaMA",
                        "Perfect for": [
                            "Text completion",
                            "Creative writing",
                            "Conversational AI",
                            "Code completion"
                        ]
                    }
                }
            }
        }
    }
    
    # Print decision tree
    def print_decision_node(node, level=0):
        indent = "  " * level
        if isinstance(node, dict):
            for key, value in node.items():
                if key in ["Architecture", "Examples", "Perfect for"]:
                    if key == "Perfect for":
                        print(f"{indent}   {key}:")
                        for item in value:
                            print(f"{indent}     • {item}")
                    else:
                        print(f"{indent}   {key}: {value}")
                else:
                    print(f"{indent}❓ {key}")
                    print_decision_node(value, level + 1)
        elif isinstance(node, str):
            print(f"{indent}➡️ {node}")
    
    print_decision_node(decision_tree)
    
    # Practical recommendations
    print("\n💡 PRACTICAL RECOMMENDATIONS")
    print("=" * 32)
    
    recommendations = [
        {
            "scenario": "Building a hate speech detection system",
            "choice": "Encoder-only (BERT/RoBERTa)",
            "reason": "Needs bi-directional context understanding, no generation required"
        },
        {
            "scenario": "Creating a translation service", 
            "choice": "Encoder-Decoder (mT5/NLLB)",
            "reason": "Different input/output languages, needs cross-attention"
        },
        {
            "scenario": "Building a chatbot",
            "choice": "Decoder-only (GPT family)",
            "reason": "Conversational continuation, same format input/output"
        },
        {
            "scenario": "Automatic document summarization",
            "choice": "Encoder-Decoder (BART/PEGASUS)",
            "reason": "Long input → short output, requires content transformation"
        },
        {
            "scenario": "Sentiment analysis for reviews",
            "choice": "Encoder-only (DistilBERT)",
            "reason": "Classification task, needs full context understanding"
        }
    ]
    
    for rec in recommendations:
        print(f"\n🎯 Scenario: {rec['scenario']}")
        print(f"   Best Choice: {rec['choice']}")
        print(f"   Why: {rec['reason']}")

# Run the selection guide
model_selection_guide()

## Summary

Congratulations! You've learned the fundamentals of encoder-decoder architectures and how they work in practice.

### 🔑 Key Concepts Mastered
- **Encoder-Decoder Architecture**: Two-phase processing with specialized components
- **Cross-Attention**: How decoders attend to encoder representations
- **Auto-regressive Generation**: Step-by-step token generation process
- **Mathematical Foundations**: Attention mechanisms and sequence probability
- **Model Selection**: When to use encoder-only vs decoder-only vs encoder-decoder

### 📈 Best Practices Learned
- Use encoder-decoder models for sequence transformation tasks
- T5's text-to-text format provides flexibility across many tasks
- BART excels at generation tasks with its denoising pre-training
- Cross-attention provides interpretability for model decisions
- Choose architecture based on task requirements, not popularity

### 🚀 Next Steps
- **Fine-tuning**: Learn to adapt pre-trained models to your specific domains
- **Evaluation Metrics**: Study BLEU, ROUGE, and BERTScore for generation quality
- **Advanced Techniques**: Explore beam search, nucleus sampling, and length penalties
- **Deployment**: Build scalable inference systems for production use
- **Documentation**: Read the comprehensive [Encoder-Decoder Guide](../../docs/encoder-decoder.md)

### 📚 Additional Resources
- [Hugging Face Sequence-to-Sequence Guide](https://huggingface.co/docs/transformers/tasks/sequence_classification)
- [T5 Paper: "Exploring the Limits of Transfer Learning"](https://arxiv.org/abs/1910.10683)
- [BART Paper: "Denoising Sequence-to-Sequence Pre-training"](https://arxiv.org/abs/1910.13461)
- [Attention Is All You Need (Original Transformer Paper)](https://arxiv.org/abs/1706.03762)

### 🧠 Understanding the Impact

Encoder-decoder models represent a significant advancement in NLP because they:
- **Separate concerns**: Encoding (understanding) and decoding (generation)
- **Enable flexibility**: Handle different input/output sequence lengths
- **Provide interpretability**: Cross-attention shows what the model focuses on
- **Scale effectively**: Can be adapted to many different tasks with task-specific prefixes

The combination of bi-directional encoding and auto-regressive decoding with cross-attention makes encoder-decoder models exceptionally powerful for sequence-to-sequence tasks.

---

## About the Author

**Vu Hung Nguyen** - AI Engineer & Researcher

Connect with me:
- 🌐 **Website**: [vuhung16au.github.io](https://vuhung16au.github.io/)
- 💼 **LinkedIn**: [linkedin.com/in/nguyenvuhung](https://www.linkedin.com/in/nguyenvuhung/)
- 💻 **GitHub**: [github.com/vuhung16au](https://github.com/vuhung16au/)

*This notebook is part of the [HF Transformer Trove](https://github.com/vuhung16au/hf-transformer-trove) educational series.*