[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vuhung16au/hf-transformer-trove/blob/main/examples/basic2.5/multiple-sequences.ipynb)
[![View on GitHub](https://img.shields.io/badge/View_on-GitHub-blue?logo=github)](https://github.com/vuhung16au/hf-transformer-trove/blob/main/examples/basic2.5/multiple-sequences.ipynb)

# Handling Multiple Sequences: Padding, Attention Masks, and Long Context

## 🎯 Learning Objectives
By the end of this notebook, you will understand:
- How to handle sequences of different lengths in batch processing
- The importance and mechanics of padding and attention masks
- Strategies for dealing with longer sequences and truncation
- Advanced techniques with Longformer for extended context understanding
- Visual analysis of attention patterns across different sequence lengths

## 📋 Prerequisites
- Basic understanding of machine learning concepts
- Familiarity with Python and PyTorch
- Knowledge of NLP fundamentals (refer to [NLP Learning Journey](https://github.com/vuhung16au/nlp-learning-journey))
- Understanding of tokenization concepts (refer to `02_tokenizers.ipynb`)

## 📚 What We'll Cover
1. Section 1: Understanding the Multiple Sequences Problem
2. Section 2: Padding and Truncation Strategies
3. Section 3: Attention Masks Deep Dive
4. Section 4: Handling Longer Sequences
5. Section 5: Longformer for Extended Context
6. Section 6: Visualizing Attention Patterns
7. Section 7: Best Practices and Performance Optimization
8. Section 8: Summary and Next Steps

In [None]:
# Import essential libraries for this comprehensive tutorial
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    LongformerTokenizer, LongformerModel, LongformerForSequenceClassification,
    AutoConfig, BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from typing import List, Dict, Optional, Union, Tuple
import warnings
warnings.filterwarnings('ignore')

# Load environment variables from .env.local for local development
import os
try:
    from dotenv import load_dotenv
    load_dotenv('.env.local', override=True)
    print("Environment variables loaded from .env.local")
except ImportError:
    print("python-dotenv not installed, skipping .env.local loading")

# For Google Colab compatibility
try:
    from google.colab import userdata
    COLAB_AVAILABLE = True
except ImportError:
    COLAB_AVAILABLE = False

def get_api_key(key_name: str, required: bool = False) -> Optional[str]:
    """
    Load API key from environment or Google Colab secrets.
    
    Args:
        key_name: Environment variable name
        required: Whether to raise error if not found
        
    Returns:
        API key string or None
    """
    # Try Colab secrets first
    if COLAB_AVAILABLE:
        try:
            return userdata.get(key_name)
        except:
            pass
    
    # Try environment variable
    api_key = os.getenv(key_name)
    
    if required and not api_key:
        raise ValueError(
            f"{key_name} not found. Set it in:\n"
            f"- Local: .env.local file\n"
            f"- Colab: Secrets manager"
        )
    
    return api_key

def get_device() -> torch.device:
    """
    Automatically detect and return the best available device.
    
    Priority: CUDA > MPS (Apple Silicon) > CPU
    
    Returns:
        torch.device: The optimal device for current hardware
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name()}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps") 
        print("🍎 Using Apple MPS (Apple Silicon)")
    else:
        device = torch.device("cpu")
        print("💻 Using CPU (consider GPU for better performance)")
    
    return device

# Setup authentication and device
hf_token = get_api_key('HF_TOKEN', required=False)
if hf_token:
    os.environ['HF_TOKEN'] = hf_token
    print("✅ Hugging Face token configured")

device = get_device()

# Set up plotting style for educational visualizations
plt.style.use('default')  # Use default style for better compatibility
sns.set_palette("husl")

print(f"\n=== Setup Information ===")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
print(f"Ready for multiple sequence processing! 🎯")

## Section 1: Understanding the Multiple Sequences Problem

When working with real-world NLP tasks, we rarely process single sequences. Instead, we need to handle batches of text sequences that vary significantly in length. This creates several challenges:

### Key Challenges:
- **Variable Length**: Text sequences have different numbers of tokens
- **Batch Processing**: Neural networks require fixed-size inputs for efficient computation
- **Memory Efficiency**: Longer sequences consume more computational resources
- **Attention Mechanics**: Models need to know which parts of input to focus on

Let's start by demonstrating this problem with real examples.

In [None]:
def demonstrate_sequence_length_problem():
    """
    Demonstrate the variable sequence length problem with real text examples.
    """
    print("🔍 DEMONSTRATING THE MULTIPLE SEQUENCES PROBLEM")
    print("=" * 55)
    
    # Example texts with varying lengths - using hate speech detection focus
    example_texts = [
        "Great post!",  # Short: 3 words
        "I really enjoyed reading this article about machine learning.",  # Medium: 10 words
        "This comprehensive tutorial on natural language processing and transformer models provides detailed insights into modern NLP techniques and their practical applications in industry.",  # Long: 25 words
        "AI",  # Very short: 1 word
        "The development of large language models has revolutionized how we approach various NLP tasks, from text classification to question answering, enabling more sophisticated and nuanced understanding of human language patterns."  # Very long: 32 words
    ]
    
    # Load a BERT tokenizer for demonstration
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    print("📊 Text Length Analysis:")
    print("-" * 40)
    
    tokenized_lengths = []
    for i, text in enumerate(example_texts):
        # Tokenize without any padding or truncation
        tokens = tokenizer.encode(text, add_special_tokens=True)
        word_count = len(text.split())
        token_count = len(tokens)
        
        tokenized_lengths.append(token_count)
        
        print(f"{i+1}. Words: {word_count:2d} | Tokens: {token_count:2d} | Text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
    
    print(f"\n📈 Token Length Statistics:")
    print(f"   Min length: {min(tokenized_lengths)} tokens")
    print(f"   Max length: {max(tokenized_lengths)} tokens")
    print(f"   Range: {max(tokenized_lengths) - min(tokenized_lengths)} tokens")
    print(f"   Average: {np.mean(tokenized_lengths):.1f} tokens")
    
    # Visualize the length distribution
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    bars = plt.bar(range(1, len(tokenized_lengths) + 1), tokenized_lengths, color='skyblue', alpha=0.8)
    plt.title('Token Count per Sequence')
    plt.xlabel('Sequence Number')
    plt.ylabel('Number of Tokens')
    plt.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, length in zip(bars, tokenized_lengths):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                str(length), ha='center', va='bottom')
    
    plt.subplot(1, 2, 2)
    # Create a visual representation of sequence lengths
    sequences_visual = np.zeros((len(tokenized_lengths), max(tokenized_lengths)))
    for i, length in enumerate(tokenized_lengths):
        sequences_visual[i, :length] = 1
    
    plt.imshow(sequences_visual, cmap='RdYlBu_r', aspect='auto')
    plt.title('Sequence Length Visualization\n(Blue = tokens, White = padding needed)')
    plt.xlabel('Token Position')
    plt.ylabel('Sequence Number')
    plt.yticks(range(len(tokenized_lengths)), [f'Seq {i+1}' for i in range(len(tokenized_lengths))])
    
    plt.tight_layout()
    plt.show()
    
    print("\n🚨 THE PROBLEM:")
    print("   • Neural networks need fixed-size inputs for batch processing")
    print("   • Variable sequence lengths prevent efficient batching")
    print("   • Need padding to make all sequences the same length")
    print("   • Need attention masks to ignore padded positions")
    
    return example_texts, tokenized_lengths

# Run the demonstration
example_texts, tokenized_lengths = demonstrate_sequence_length_problem()

## Section 2: Padding and Truncation Strategies

To solve the variable sequence length problem, we use **padding** and **truncation**:

### Padding Strategies:
- **Right Padding**: Add padding tokens to the end of sequences (most common)
- **Left Padding**: Add padding tokens to the beginning (used for some generative models)
- **Dynamic Padding**: Pad only to the length of the longest sequence in the batch
- **Fixed Padding**: Pad all sequences to a predetermined maximum length

### Truncation Strategies:
- **Right Truncation**: Remove tokens from the end
- **Left Truncation**: Remove tokens from the beginning  
- **Middle Truncation**: Remove tokens from the middle (less common)

Let's implement and compare these strategies:

In [None]:
def demonstrate_padding_strategies():
    """
    Demonstrate different padding and truncation strategies.
    """
    print("🛠️ PADDING AND TRUNCATION STRATEGIES")
    print("=" * 45)
    
    # Load preferred hate speech detection model tokenizer
    tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-hate-latest")
    
    # Use example texts from hate speech detection domain
    texts = [
        "This message promotes understanding and respect.",
        "Great work!",
        "I completely disagree with this perspective, but I respect your right to express it and welcome civil discourse on the topic.",
        "Thanks"
    ]
    
    print("📝 Original texts:")
    for i, text in enumerate(texts):
        print(f"{i+1}. '{text}'")
    
    print("\n" + "=" * 60)
    
    # Strategy 1: No padding/truncation (shows the problem)
    print("\n1️⃣ NO PADDING/TRUNCATION (The Problem):")
    try:
        unpadded = tokenizer(texts, return_tensors="pt")
        print("   ✅ Successful batching")
    except Exception as e:
        print(f"   ❌ Error: {str(e)[:100]}...")
        print("   This is why we need padding!")
    
    # Strategy 2: Right padding (most common)
    print("\n2️⃣ RIGHT PADDING:")
    right_padded = tokenizer(texts, padding=True, return_tensors="pt")
    print(f"   Shape: {right_padded['input_ids'].shape}")
    print(f"   Attention mask shape: {right_padded['attention_mask'].shape}")
    
    # Strategy 3: Fixed max length padding  
    print("\n3️⃣ FIXED LENGTH PADDING (max_length=20):")
    fixed_padded = tokenizer(texts, padding='max_length', max_length=20, return_tensors="pt")
    print(f"   Shape: {fixed_padded['input_ids'].shape}")
    
    # Strategy 4: Truncation with padding
    print("\n4️⃣ TRUNCATION + PADDING (max_length=15):")
    truncated_padded = tokenizer(texts, padding=True, truncation=True, max_length=15, return_tensors="pt")
    print(f"   Shape: {truncated_padded['input_ids'].shape}")
    
    # Visualize the different strategies
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Padding and Truncation Strategies Visualization', fontsize=16)
    
    strategies = [
        ("Right Padding", right_padded),
        ("Fixed Length (20)", fixed_padded),
        ("Truncated + Padded (15)", truncated_padded),
        ("Attention Masks", right_padded)  # Show attention masks
    ]
    
    for idx, (title, encoded) in enumerate(strategies):
        row, col = idx // 2, idx % 2
        ax = axes[row, col]
        
        if title == "Attention Masks":
            # Show attention masks
            data = encoded['attention_mask'].numpy()
            cmap = 'RdBu_r'
        else:
            # Show input IDs (replace pad tokens with -1 for visualization)
            data = encoded['input_ids'].numpy()
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
            data = np.where(data == pad_token_id, -1, 1)  # 1 for real tokens, -1 for padding
            cmap = 'RdYlBu_r'
        
        im = ax.imshow(data, cmap=cmap, aspect='auto')
        ax.set_title(title)
        ax.set_xlabel('Token Position')
        ax.set_ylabel('Sequence Number')
        ax.set_yticks(range(len(texts)))
        ax.set_yticklabels([f'Text {i+1}' for i in range(len(texts))])
        
        # Add colorbar
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()
    
    # Show detailed token analysis for first text
    print("\n🔍 DETAILED ANALYSIS - First Text:")
    print(f"   Original: '{texts[0]}'")
    
    # Show tokens for different strategies
    strategies_detail = [
        ("Right Padded", right_padded['input_ids'][0]),
        ("Fixed Length", fixed_padded['input_ids'][0]),  
        ("Truncated", truncated_padded['input_ids'][0])
    ]
    
    for name, token_ids in strategies_detail:
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        print(f"   {name:15}: {tokens}")
        print(f"   {'Length':15}: {len([t for t in tokens if t != tokenizer.pad_token])} real tokens, {len(tokens)} total")
    
    return {
        'right_padded': right_padded,
        'fixed_padded': fixed_padded,
        'truncated_padded': truncated_padded
    }

# Run the demonstration
padding_results = demonstrate_padding_strategies()

## Section 3: Attention Masks Deep Dive

Attention masks are crucial for telling the model which tokens to pay attention to and which to ignore. This is especially important when we have padded sequences.

### Understanding Attention Masks:
- **1**: Pay attention to this token (real content)
- **0**: Ignore this token (padding)
- **Purpose**: Prevents the model from learning meaningless patterns from padding tokens
- **Effect**: Improves model performance and training stability

Let's explore how attention masks work in detail:

In [None]:
def demonstrate_attention_masks():
    """
    Comprehensive demonstration of attention masks and their importance.
    """
    print("🎭 ATTENTION MASKS DEEP DIVE")
    print("=" * 35)
    
    # Use preferred hate speech detection model
    model_name = "cardiffnlp/twitter-roberta-base-hate-latest"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Example texts with very different lengths
    texts = [
        "Love this!",
        "This is a wonderful example of positive communication.",
        "AI"
    ]
    
    print("📝 Processing these texts:")
    for i, text in enumerate(texts):
        print(f"{i+1}. '{text}'")
    
    # Tokenize with padding
    encoded = tokenizer(texts, padding=True, return_tensors="pt", return_attention_mask=True)
    
    print(f"\n📊 Encoded Results:")
    print(f"   Batch shape: {encoded['input_ids'].shape}")
    print(f"   Attention mask shape: {encoded['attention_mask'].shape}")
    
    # Create detailed visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Attention Masks: Understanding Token Processing', fontsize=16)
    
    # 1. Input IDs heatmap
    ax1 = axes[0, 0]
    input_ids_viz = encoded['input_ids'].numpy()
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    
    # Create visualization data: 1 for real tokens, 0 for padding
    viz_data = np.where(input_ids_viz == pad_token_id, 0, 1)
    
    im1 = ax1.imshow(viz_data, cmap='RdYlBu_r', aspect='auto')
    ax1.set_title('Input Tokens\n(Blue = Real Token, Red = Padding)')
    ax1.set_xlabel('Token Position')
    ax1.set_ylabel('Sequence')
    ax1.set_yticks(range(len(texts)))
    ax1.set_yticklabels([f'Text {i+1}' for i in range(len(texts))])
    
    # 2. Attention masks heatmap
    ax2 = axes[0, 1]
    attention_masks = encoded['attention_mask'].numpy()
    im2 = ax2.imshow(attention_masks, cmap='RdYlBu_r', aspect='auto')
    ax2.set_title('Attention Masks\n(Blue = Attend, Red = Ignore)')
    ax2.set_xlabel('Token Position')
    ax2.set_ylabel('Sequence')
    ax2.set_yticks(range(len(texts)))
    ax2.set_yticklabels([f'Text {i+1}' for i in range(len(texts))])
    
    # 3. Token-by-token breakdown for first sequence
    ax3 = axes[1, 0]
    first_seq_tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
    first_seq_mask = encoded['attention_mask'][0].numpy()
    
    colors = ['lightcoral' if mask == 0 else 'lightblue' for mask in first_seq_mask]
    bars = ax3.bar(range(len(first_seq_tokens)), first_seq_mask, color=colors)
    ax3.set_title(f'First Sequence Token Analysis\n\"{texts[0]}\"')
    ax3.set_xlabel('Token Position')
    ax3.set_ylabel('Attention Value')
    ax3.set_xticks(range(len(first_seq_tokens)))
    ax3.set_xticklabels(first_seq_tokens, rotation=45, ha='right')
    ax3.set_ylim(-0.1, 1.1)
    
    # 4. Sequence length comparison
    ax4 = axes[1, 1]
    real_lengths = [torch.sum(mask).item() for mask in encoded['attention_mask']]
    total_length = encoded['input_ids'].shape[1]
    
    x_pos = range(len(texts))
    bars1 = ax4.bar(x_pos, real_lengths, label='Real Tokens', color='lightblue', alpha=0.8)
    bars2 = ax4.bar(x_pos, [total_length - length for length in real_lengths], 
                   bottom=real_lengths, label='Padding Tokens', color='lightcoral', alpha=0.8)
    
    ax4.set_title('Token Composition per Sequence')
    ax4.set_xlabel('Sequence Number')
    ax4.set_ylabel('Number of Tokens')
    ax4.set_xticks(x_pos)
    ax4.set_xticklabels([f'Text {i+1}' for i in range(len(texts))])
    ax4.legend()
    
    # Add value labels
    for i, (real, total) in enumerate(zip(real_lengths, [total_length] * len(texts))):
        ax4.text(i, real/2, str(real), ha='center', va='center', fontweight='bold')
        if total - real > 0:
            ax4.text(i, real + (total-real)/2, str(total-real), ha='center', va='center', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Detailed token-by-token analysis
    print("\n🔍 TOKEN-BY-TOKEN ANALYSIS:")
    print("=" * 40)
    
    for i, text in enumerate(texts):
        print(f"\n📄 Sequence {i+1}: '{text}'")
        tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][i])
        input_ids = encoded['input_ids'][i].tolist()
        attention_mask = encoded['attention_mask'][i].tolist()
        
        print(f"   {'Position':<8} {'Token':<15} {'ID':<8} {'Attention':<10} {'Status'}")
        print("   " + "-" * 55)
        
        for pos, (token, token_id, attn) in enumerate(zip(tokens, input_ids, attention_mask)):
            status = "ATTEND" if attn == 1 else "IGNORE"
            print(f"   {pos:<8} {token:<15} {token_id:<8} {attn:<10} {status}")
    
    print("\n💡 WHY ATTENTION MASKS MATTER:")
    print("   ✅ Prevent model from learning patterns in padding tokens")
    print("   ✅ Ensure attention weights sum correctly over real tokens")
    print("   ✅ Improve training stability and convergence")
    print("   ✅ Enable variable-length sequence processing in batches")
    print("   ✅ Essential for tasks like hate speech detection where text length varies greatly")
    
    return encoded

# Run the demonstration
attention_results = demonstrate_attention_masks()

## Section 4: Handling Longer Sequences

Real-world texts can be very long - social media posts, articles, documents, and conversations. Standard transformer models like BERT have limitations:

### Sequence Length Limitations:
- **BERT**: Maximum 512 tokens
- **RoBERTa**: Maximum 512 tokens  
- **Memory**: Quadratic scaling with sequence length
- **Computational Cost**: Attention complexity is O(n²)

### Strategies for Long Sequences:
1. **Truncation**: Cut sequences to fit (may lose important information)
2. **Sliding Windows**: Process text in overlapping chunks
3. **Hierarchical Processing**: Summarize chunks, then process summaries
4. **Specialized Models**: Longformer, BigBird, etc.

Let's explore these strategies:

In [None]:
def demonstrate_long_sequence_handling():
    """
    Demonstrate strategies for handling sequences longer than model limits.
    """
    print("📏 HANDLING LONGER SEQUENCES")
    print("=" * 35)
    
    # Create a long text example (simulating a long social media thread or article)
    long_text = """
    The field of artificial intelligence has undergone tremendous growth and transformation over the past decade. 
    Machine learning algorithms have become increasingly sophisticated, enabling applications that were once 
    considered science fiction. Natural language processing, in particular, has seen remarkable advances with 
    the introduction of transformer architectures and large language models. These developments have 
    revolutionized how we approach text classification, sentiment analysis, and content moderation tasks.
    
    When dealing with content moderation and hate speech detection, it's crucial to consider context and nuance. 
    Simple keyword-based approaches often fail to capture the complexity of human communication. Modern NLP 
    models can better understand context, sarcasm, and subtle forms of harmful content. However, they also 
    face challenges with longer texts that exceed typical model limitations.
    
    The challenge of processing long sequences is particularly relevant for analyzing extended conversations, 
    comment threads, or lengthy posts. Traditional transformer models have fixed input size limitations, 
    typically 512 tokens for models like BERT and RoBERTa. This constraint requires careful consideration 
    of how to handle longer content while preserving important contextual information.
    """.strip()
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-hate-latest")
    
    # Analyze the long text
    tokens = tokenizer.encode(long_text, add_special_tokens=True)
    word_count = len(long_text.split())
    token_count = len(tokens)
    
    print(f"📊 Long Text Analysis:")
    print(f"   Words: {word_count}")
    print(f"   Tokens: {token_count}")
    print(f"   Characters: {len(long_text)}")
    print(f"   Model limit (typical): 512 tokens")
    print(f"   Overflow: {max(0, token_count - 512)} tokens")
    
    # Strategy 1: Simple truncation
    print("\n1️⃣ SIMPLE TRUNCATION STRATEGY:")
    truncated = tokenizer(long_text, max_length=512, truncation=True, return_tensors="pt")
    truncated_tokens = tokenizer.convert_ids_to_tokens(truncated['input_ids'][0])
    
    print(f"   Kept tokens: {truncated['input_ids'].shape[1]}")
    print(f"   Lost tokens: {token_count - truncated['input_ids'].shape[1]}")
    print(f"   Information loss: {((token_count - truncated['input_ids'].shape[1]) / token_count * 100):.1f}%")
    
    # Strategy 2: Sliding window approach
    print("\n2️⃣ SLIDING WINDOW STRATEGY:")
    window_size = 400  # Leave room for special tokens
    stride = 200  # 50% overlap
    
    windows = []
    for start in range(0, len(tokens), stride):
        end = min(start + window_size, len(tokens))
        window_tokens = tokens[start:end]
        windows.append(window_tokens)
        if end >= len(tokens):
            break
    
    print(f"   Number of windows: {len(windows)}")
    print(f"   Window size: {window_size} tokens")
    print(f"   Stride: {stride} tokens")
    print(f"   Overlap: {window_size - stride} tokens")
    
    # Visualization of strategies
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    fig.suptitle('Long Sequence Handling Strategies', fontsize=16)
    
    # 1. Original sequence length
    ax1 = axes[0, 0]
    ax1.barh(['Original', 'Model Limit'], [token_count, 512], color=['lightcoral', 'lightblue'])
    ax1.set_title('Sequence Length vs Model Limit')
    ax1.set_xlabel('Number of Tokens')
    for i, v in enumerate([token_count, 512]):
        ax1.text(v + 10, i, str(v), va='center')
    
    # 2. Truncation loss
    ax2 = axes[0, 1]
    kept = 512
    lost = token_count - 512
    ax2.pie([kept, lost], labels=['Kept', 'Lost'], autopct='%1.1f%%', colors=['lightblue', 'lightcoral'])
    ax2.set_title('Information Loss with Truncation')
    
    # 3. Sliding window coverage
    ax3 = axes[1, 0]
    window_starts = list(range(0, len(tokens), stride))[:len(windows)]
    window_coverage = np.zeros(len(tokens))
    
    for i, start in enumerate(window_starts):
        end = min(start + window_size, len(tokens))
        window_coverage[start:end] += 1
    
    ax3.plot(window_coverage, linewidth=2, color='green')
    ax3.fill_between(range(len(window_coverage)), window_coverage, alpha=0.3, color='green')
    ax3.set_title('Sliding Window Coverage\n(Height = Number of Times Token is Processed)')
    ax3.set_xlabel('Token Position')
    ax3.set_ylabel('Coverage Count')
    ax3.grid(True, alpha=0.3)
    
    # 4. Strategy comparison table
    ax4 = axes[1, 1]
    ax4.axis('tight')
    ax4.axis('off')
    
    # Create comparison table
    strategies_data = [
        ["Strategy", "Pros", "Cons"],
        ["Truncation", "Simple, fast", "Loses information"],
        ["Sliding Window", "Preserves all info", "Redundant processing"],
        ["Specialized Models", "Built for long sequences", "Higher memory usage"]
    ]
    
    table = ax4.table(cellText=strategies_data, cellLoc='center', loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 2)
    ax4.set_title('Strategy Comparison')
    
    # Color header row
    for i in range(3):
        table[(0, i)].set_facecolor('#E6E6FA')
    
    plt.tight_layout()
    plt.show()
    
    print("\n📋 KEY TAKEAWAYS:")
    print("   📏 Most transformer models have 512-token limits")
    print("   ✂️  Truncation is simple but loses information")
    print("   🪟 Sliding windows preserve information but increase computation")
    print("   🤖 Specialized models like Longformer handle longer sequences efficiently")
    print("   ⚖️  Choose strategy based on your specific use case and resources")
    
    return {
        'original_tokens': token_count,
        'windows': windows,
        'truncated': truncated
    }

# Run the demonstration
long_sequence_results = demonstrate_long_sequence_handling()

## Section 5: Longformer for Extended Context

**Longformer** is a specialized transformer model designed to handle much longer sequences efficiently. It addresses the quadratic memory and compute complexity of standard transformers.

### Key Innovations:
- **Sparse Attention**: Combines local, global, and sliding window attention
- **Extended Length**: Can handle up to 4,096 tokens (8x more than BERT)
- **Efficient Memory**: Linear scaling with sequence length
- **Task Performance**: Maintains performance on long document tasks

Let's explore Longformer concepts and compare with standard models:

In [None]:
def demonstrate_longformer_concepts():
    """
    Demonstrate Longformer concepts and compare with standard attention.
    """
    print("🔬 LONGFORMER: EXTENDED CONTEXT PROCESSING")
    print("=" * 50)
    
    # Create conceptual visualization of attention patterns
    seq_length = 64  # Smaller for visualization
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Longformer Attention Patterns vs Standard Attention', fontsize=16)
    
    # 1. Standard full attention (BERT-style)
    ax1 = axes[0, 0]
    full_attention = np.ones((seq_length, seq_length))
    im1 = ax1.imshow(full_attention, cmap='Blues')
    ax1.set_title('Standard Full Attention (BERT)\nO(n²) complexity')
    ax1.set_xlabel('Key Position')
    ax1.set_ylabel('Query Position')
    plt.colorbar(im1, ax=ax1)
    
    # 2. Local sliding window attention
    ax2 = axes[0, 1]
    window_size = 8
    local_attention = np.zeros((seq_length, seq_length))
    
    for i in range(seq_length):
        start = max(0, i - window_size // 2)
        end = min(seq_length, i + window_size // 2 + 1)
        local_attention[i, start:end] = 1
    
    im2 = ax2.imshow(local_attention, cmap='Greens')
    ax2.set_title(f'Local Attention (window={window_size})\nO(n×w) complexity')
    ax2.set_xlabel('Key Position')
    ax2.set_ylabel('Query Position')
    plt.colorbar(im2, ax=ax2)
    
    # 3. Global + Local attention (Longformer style)
    ax3 = axes[1, 0]
    global_local_attention = local_attention.copy()
    
    # Add global attention for first few tokens (like CLS token)
    global_tokens = [0, 1, seq_length//4, seq_length//2, seq_length-1]
    for token in global_tokens:
        if token < seq_length:
            global_local_attention[token, :] = 1  # Global tokens attend to all
            global_local_attention[:, token] = 1  # All tokens attend to global
    
    im3 = ax3.imshow(global_local_attention, cmap='Oranges')
    ax3.set_title('Longformer: Global + Local Attention\nEfficient sparse patterns')
    ax3.set_xlabel('Key Position')
    ax3.set_ylabel('Query Position')
    plt.colorbar(im3, ax=ax3)
    
    # 4. Computational complexity comparison
    ax4 = axes[1, 1]
    
    # Calculate complexity for different sequence lengths
    seq_lengths = np.array([128, 256, 512, 1024, 2048, 4096])
    full_complexity = seq_lengths ** 2
    local_complexity = seq_lengths * window_size
    longformer_complexity = seq_lengths * (window_size + len(global_tokens))
    
    ax4.plot(seq_lengths, full_complexity, 'b-', label='Full Attention O(n²)', linewidth=2)
    ax4.plot(seq_lengths, local_complexity, 'g-', label='Local Only O(n×w)', linewidth=2)
    ax4.plot(seq_lengths, longformer_complexity, 'orange', label='Longformer O(n×(w+g))', linewidth=2)
    
    ax4.axvline(x=512, color='blue', linestyle='--', alpha=0.5, label='BERT limit')
    ax4.axvline(x=4096, color='orange', linestyle='--', alpha=0.5, label='Longformer limit')
    
    ax4.set_xlabel('Sequence Length')
    ax4.set_ylabel('Computational Operations')
    ax4.set_title('Computational Complexity Comparison')
    ax4.legend()
    ax4.set_yscale('log')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Show efficiency gains
    print("\n📊 EFFICIENCY COMPARISON:")
    print(f"   {'Sequence Length':<15} {'Full Attention':<15} {'Longformer':<15} {'Speedup'}")
    print("   " + "-" * 65)
    
    for seq_len in [512, 1024, 2048, 4096]:
        full_ops = seq_len ** 2
        longformer_ops = seq_len * (window_size + len(global_tokens))
        speedup = full_ops / longformer_ops if longformer_ops > 0 else 0
        
        print(f"   {seq_len:<15} {full_ops:<15,} {longformer_ops:<15,} {speedup:.1f}x")
    
    print("\n💡 LONGFORMER KEY ADVANTAGES:")
    print("   ✅ 8x longer sequences than BERT (4096 vs 512 tokens)")
    print("   ✅ Sparse attention reduces memory complexity from O(n²) to O(n)")
    print("   ✅ Maintains performance on long document tasks")
    print("   ✅ Ideal for analyzing long social media threads or articles")
    print("   ✅ Better context understanding for hate speech detection in long texts")
    
    print("\n🎯 USE CASES FOR LONGFORMER:")
    print("   📄 Long document classification and analysis")
    print("   💬 Extended conversation thread moderation")
    print("   📰 Article sentiment analysis and summarization")
    print("   🧵 Social media thread context understanding")
    print("   📚 Research paper and academic document processing")
    
    # Try to demonstrate with actual model if available
    print("\n🔬 TRYING TO LOAD LONGFORMER MODEL:")
    try:
        from transformers import LongformerTokenizer, LongformerModel
        
        print("📥 Loading Longformer model...")
        longformer_tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
        print("✅ Longformer tokenizer loaded successfully!")
        print(f"   Max position embeddings: 4096")
        print(f"   Vocab size: {longformer_tokenizer.vocab_size}")
        
        # Compare tokenization
        sample_text = "This is a sample text for comparing tokenization approaches."
        bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        bert_tokens = bert_tokenizer.encode(sample_text)
        longformer_tokens = longformer_tokenizer.encode(sample_text)
        
        print(f"\n🔍 Tokenization Comparison:")
        print(f"   Sample: '{sample_text}'")
        print(f"   BERT tokens ({len(bert_tokens)}): {bert_tokens}")
        print(f"   Longformer tokens ({len(longformer_tokens)}): {longformer_tokens}")
        
    except Exception as e:
        print(f"⚠️  Could not load Longformer model: {e}")
        print("💡 This is expected in some environments. The conceptual demonstration above shows the key ideas.")
    
    return {'demonstration': 'completed'}

# Run the Longformer demonstration
longformer_results = demonstrate_longformer_concepts()

## Section 6: Best Practices and Performance Optimization

Based on our exploration of multiple sequence handling, let's consolidate the best practices for real-world applications.

### Performance Optimization Strategies:
- **Dynamic Padding**: Only pad to the maximum length in each batch
- **Gradient Accumulation**: Handle larger effective batch sizes with limited memory
- **Mixed Precision**: Use FP16 for faster training and inference
- **Sequence Length Distribution**: Understand your data to optimize processing

Let's implement these strategies:

In [None]:
def demonstrate_best_practices():
    """
    Demonstrate best practices for handling multiple sequences efficiently.
    """
    print("🎯 BEST PRACTICES FOR MULTIPLE SEQUENCES")
    print("=" * 50)
    
    # Load preferred model for hate speech detection
    model_name = "cardiffnlp/twitter-roberta-base-hate-latest"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Create a realistic dataset with varying lengths
    sample_texts = [
        "Thanks!",
        "This is a great example of positive communication.",
        "I appreciate the constructive discussion in this thread.",
        "AI ethics is an important topic that deserves careful consideration.",
        "Excellent work on this project! Very helpful.",
        "Good job.",
        "The development of responsible AI systems requires collaboration.",
        "Thanks for sharing this resource.",
        "This analysis provides valuable insights.",
        "Perfect!"
    ] * 5  # Repeat to create a larger dataset
    
    print(f"📊 Dataset Analysis ({len(sample_texts)} texts):")
    
    # Analyze length distribution
    lengths = [len(tokenizer.encode(text)) for text in sample_texts]
    
    print(f"   Token lengths - Min: {min(lengths)}, Max: {max(lengths)}, Mean: {np.mean(lengths):.1f}")
    print(f"   Std deviation: {np.std(lengths):.1f}")
    
    # Strategy comparison
    import time
    
    print("\n🏁 STRATEGY PERFORMANCE COMPARISON:")
    
    # Fixed padding (worst case)
    start_time = time.time()
    fixed_padded = tokenizer(sample_texts, padding='max_length', max_length=128, 
                            truncation=True, return_tensors="pt")
    fixed_time = time.time() - start_time
    
    # Dynamic padding (optimal)
    start_time = time.time()
    dynamic_padded = tokenizer(sample_texts, padding=True, truncation=True, 
                              return_tensors="pt")
    dynamic_time = time.time() - start_time
    
    print(f"   Fixed padding (128):   Shape {fixed_padded['input_ids'].shape}, Time: {fixed_time:.3f}s")
    print(f"   Dynamic padding:       Shape {dynamic_padded['input_ids'].shape}, Time: {dynamic_time:.3f}s")
    
    # Calculate memory savings
    fixed_memory = fixed_padded['input_ids'].numel() * 4  # 4 bytes per int32
    dynamic_memory = dynamic_padded['input_ids'].numel() * 4
    memory_savings = (fixed_memory - dynamic_memory) / fixed_memory * 100
    
    print(f"   Memory usage - Fixed: {fixed_memory:,} bytes, Dynamic: {dynamic_memory:,} bytes")
    print(f"   Memory savings: {memory_savings:.1f}%")
    
    # Visualize optimization benefits
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Multiple Sequence Processing: Optimization Strategies', fontsize=16)
    
    # 1. Length distribution
    ax1 = axes[0, 0]
    ax1.hist(lengths, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.axvline(np.mean(lengths), color='red', linestyle='--', label=f'Mean: {np.mean(lengths):.1f}')
    ax1.axvline(np.median(lengths), color='orange', linestyle='--', label=f'Median: {np.median(lengths):.1f}')
    ax1.set_title('Token Length Distribution')
    ax1.set_xlabel('Number of Tokens')
    ax1.set_ylabel('Frequency')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Memory usage comparison
    ax2 = axes[0, 1]
    strategies = ['Fixed\nPadding', 'Dynamic\nPadding']
    memory_usage = [fixed_memory/1024, dynamic_memory/1024]  # Convert to KB
    colors = ['lightcoral', 'lightgreen']
    
    bars = ax2.bar(strategies, memory_usage, color=colors, alpha=0.8)
    ax2.set_title('Memory Usage Comparison')
    ax2.set_ylabel('Memory Usage (KB)')
    
    # Add value labels
    for bar, value in zip(bars, memory_usage):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                f'{value:.1f} KB', ha='center', va='bottom')
    
    # 3. Padding efficiency visualization
    ax3 = axes[1, 0]
    
    # Create visual representation of padding efficiency
    sample_batch = sample_texts[:8]
    sample_lengths = [len(tokenizer.encode(text)) for text in sample_batch]
    max_length = max(sample_lengths)
    
    # Dynamic padding visualization
    dynamic_padding_matrix = np.zeros((len(sample_batch), max_length))
    for i, length in enumerate(sample_lengths):
        dynamic_padding_matrix[i, :length] = 1
    
    im3 = ax3.imshow(dynamic_padding_matrix, cmap='RdYlBu_r', aspect='auto')
    ax3.set_title('Dynamic Padding Efficiency\n(Blue = Content, Red = Padding)')
    ax3.set_xlabel('Token Position')
    ax3.set_ylabel('Sequence Number')
    ax3.set_yticks(range(len(sample_batch)))
    ax3.set_yticklabels([f'Seq {i+1} ({l} tokens)' for i, l in enumerate(sample_lengths)])
    
    # 4. Performance summary
    ax4 = axes[1, 1]
    
    # Create performance metrics comparison
    metrics = ['Memory\nEfficiency', 'Processing\nSpeed', 'Batch\nUtilization']
    fixed_scores = [60, 70, 40]  # Example scores out of 100
    dynamic_scores = [90, 85, 95]
    
    x = np.arange(len(metrics))
    width = 0.35
    
    bars1 = ax4.bar(x - width/2, fixed_scores, width, label='Fixed Padding', 
                   color='lightcoral', alpha=0.8)
    bars2 = ax4.bar(x + width/2, dynamic_scores, width, label='Dynamic Padding', 
                   color='lightgreen', alpha=0.8)
    
    ax4.set_title('Performance Metrics Comparison')
    ax4.set_ylabel('Score (0-100)')
    ax4.set_xticks(x)
    ax4.set_xticklabels(metrics)
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # Add value labels
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax4.text(bar.get_x() + bar.get_width()/2, height + 1,
                    f'{height}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # Summary of best practices
    print("\n🏆 BEST PRACTICES SUMMARY:")
    print("=" * 40)
    
    recommendations = [
        ("✅ Dynamic Padding", "Always use dynamic padding instead of fixed max_length"),
        ("✅ Length Grouping", "Group similar-length sequences for efficient batching"),
        ("✅ Attention Masks", "Always use attention masks to handle padding properly"),
        ("✅ Model Choice", "Choose Longformer for consistently long sequences"),
        ("✅ Memory Monitoring", "Monitor GPU memory usage and adjust accordingly"),
        ("✅ Data Profiling", "Understand your data's length distribution")
    ]
    
    for practice, description in recommendations:
        print(f"   {practice:<20}: {description}")
    
    print("\n🚨 COMMON PITFALLS TO AVOID:")
    pitfalls = [
        ("❌ Fixed Max Length", "Wastes memory and computation on short sequences"),
        ("❌ No Attention Masks", "Model learns from padding tokens (bad!)"),
        ("❌ Ignoring Length Distribution", "Inefficient batching strategy"),
        ("❌ Too Large Batches", "Can cause out-of-memory errors"),
        ("❌ No Truncation Strategy", "Crashes on unexpectedly long sequences")
    ]
    
    for pitfall, description in pitfalls:
        print(f"   {pitfall:<25}: {description}")
    
    return {
        'memory_savings': memory_savings,
        'fixed_memory': fixed_memory,
        'dynamic_memory': dynamic_memory
    }

# Run the best practices demonstration
optimization_results = demonstrate_best_practices()

---

## 📋 Summary

### 🔑 Key Concepts Mastered
- **Multiple Sequence Challenge**: Understanding why variable-length sequences create processing difficulties
- **Padding and Truncation**: Strategic approaches to normalize sequence lengths for batch processing
- **Attention Masks**: Critical mechanism for distinguishing real content from padding tokens
- **Long Sequence Handling**: Techniques for processing sequences beyond standard model limits
- **Longformer Architecture**: Specialized model design for extended context understanding
- **Attention Visualization**: Methods to understand model behavior across different sequence types
- **Performance Optimization**: Best practices for efficient multiple sequence processing

### 📈 Best Practices Learned
- Use dynamic padding instead of fixed max_length for memory efficiency
- Always implement proper attention masking to prevent padding interference
- Group sequences by similar lengths for optimal batching efficiency
- Choose appropriate models (Longformer) for consistently long sequences
- Profile your data's length distribution to inform processing strategies
- Monitor memory usage and adjust batch sizes accordingly
- Visualize attention patterns to understand model behavior and debug issues

### 🚀 Next Steps
- **Advanced Training**: Explore gradient accumulation and mixed precision training
- **Custom Models**: Implement custom attention mechanisms for specific use cases
- **Production Deployment**: Learn about model serving and inference optimization
- **Evaluation Metrics**: Understanding how sequence length affects model performance
- **Data Preprocessing**: Advanced techniques for handling diverse text sources

---

## About the Author

**Vu Hung Nguyen** - AI Engineer & Researcher

Connect with me:
- 🌐 **Website**: [vuhung16au.github.io](https://vuhung16au.github.io/)
- 💼 **LinkedIn**: [linkedin.com/in/nguyenvuhung](https://www.linkedin.com/in/nguyenvuhung/)
- 💻 **GitHub**: [github.com/vuhung16au](https://github.com/vuhung16au/)

*This notebook is part of the [HF Transformer Trove](https://github.com/vuhung16au/hf-transformer-trove) educational series.*