<!-- 
Tutorial Information
====================
File: Attention_Analysis_Tutorial.ipynb
Last Updated: November 2, 2025
Author: YANG, HENG <hy345@exeter.ac.uk>
Version: 2.0 (Comprehensive Review & Enhancement)

Prerequisites:
- Python 3.8+
- OmniGenBench library (pip install omnigenbench)
- Basic understanding of genomic sequences and PyTorch
- Recommended: CUDA-capable GPU (CPU works but slower)

Estimated Time: 30-45 minutes
Difficulty Level: Intermediate

Learning Path:
1. Complete this tutorial first
2. Then explore: Genomic Embeddings Tutorial
3. Advanced: RNA Secondary Structure Prediction
-->

# Attention Score Extraction from Genomic Foundation Models

## Tutorial Overview

This tutorial demonstrates how to extract and analyze attention patterns from genomic foundation models (GFMs) to understand sequence representations learned by transformer architectures.

### Learning Objectives

By completing this tutorial, you will be able to:

1. **Extract** attention scores from genomic sequences using OmniGenBench models
2. **Analyze** attention patterns using statistical metrics
3. **Visualize** attention heatmaps to interpret model focus
4. **Compare** attention patterns across different sequences
5. **Apply** attention extraction to any OmniModel type (embedding, classification, regression, etc.)

### What is Attention?

Attention mechanisms in transformer models assign importance weights to different positions in a sequence, enabling the model to "focus" on relevant features. In genomic contexts, attention patterns reveal:

- **Motif recognition**: Which nucleotide positions interact
- **Structural dependencies**: Long-range relationships in sequences
- **Feature importance**: What the model considers relevant for predictions

### Why Attention Extraction Matters

- **Model interpretability**: Understand what the model "sees" in genomic sequences
- **Biological insights**: Discover sequence patterns the model associates with functions
- **Model debugging**: Identify potential biases or unexpected attention patterns
- **Transfer learning**: Use attention patterns to guide feature engineering

### Key Feature: Universal Attention Support

**ALL OmniModel types support attention extraction** through the `EmbeddingMixin` base class:

| Model Type | Primary Purpose | Attention Support |
|------------|----------------|-------------------|
| `OmniModelForEmbedding` | Embedding extraction | ✓ |
| `OmniModelForSequenceClassification` | Sequence-level classification | ✓ |
| `OmniModelForSequenceRegression` | Sequence-level regression | ✓ |
| `OmniModelForTokenClassification` | Token-level prediction | ✓ |
| `OmniModelForMLM` | Masked language modeling | ✓ |

You don't need a separate "embedding model" – any task-specific model provides attention extraction!

### Prerequisites

- **Python**: 3.8+
- **Environment**: CPU or CUDA-capable GPU (GPU recommended for large models)
- **Knowledge**: Basic understanding of Python, PyTorch tensors, and genomic sequences
- **Packages**: `omnigenbench`, `torch`, `matplotlib`, `seaborn`, `numpy`

### References

- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - Original Transformer paper
- [BERTology: What does BERT learn?](https://arxiv.org/abs/2002.12327) - Attention analysis in language models
- [OmniGenBench Documentation](https://github.com/yangheng95/OmniGenBench) - Framework documentation

## 1. Environment Setup and Configuration

### 1.1 Installation

First, ensure OmniGenBench is installed. This may take a few minutes on the first run.

In [None]:
# Install OmniGenBench if not already available
!pip install omnigenbench -U -q

# Verify installation
import omnigenbench
print(f"OmniGenBench version: {omnigenbench.__version__}")

### 1.2 Import Required Libraries

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
from pathlib import Path

# Import various model types - ALL support attention extraction through EmbeddingMixin
from omnigenbench import (
    OmniModelForEmbedding,
    OmniModelForSequenceClassification,
    OmniModelForSequenceRegression,
    OmniTokenizer,
)

# Configure plotting
plt.rcParams.update({
    'figure.figsize': (12, 8),
    'font.size': 11,
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.dpi': 100,
})
sns.set_palette('husl')

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

print("[SUCCESS] Imports successful!")
print("[INFO] All OmniModel types support attention extraction via EmbeddingMixin")

### 1.3 Configuration (Single Source of Truth)

Define all configurable parameters in one place for easy modification.

In [None]:
# ==================== Configuration ====================
# Single source of truth for all parameters - modify here as needed

# Model Configuration
MODEL_NAME = "yangheng/OmniGenome-186M"  # Pre-trained genomic foundation model
TRUST_REMOTE_CODE = True  # Required for custom model architectures

# Sequence Processing
MAX_LENGTH = 128  # Maximum sequence length for tokenization
BATCH_SIZE = 4    # Number of sequences to process in parallel

# Attention Extraction
EXTRACT_ALL_LAYERS = None  # None = all layers, or specify list like [0, 5, 11]
EXTRACT_ALL_HEADS = None   # None = all heads, or specify list like [0, 1, 2]

# Device Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Output Configuration
OUTPUT_DIR = Path("attention_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)

# Reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# ==================== Environment Info ====================
print("[INFO] Configuration Summary:")
print(f"  Model: {MODEL_NAME}")
print(f"  Device: {DEVICE}")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"  CUDA version: {torch.version.cuda}")
print(f"  PyTorch version: {torch.__version__}")
print(f"  Max sequence length: {MAX_LENGTH}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Random seed: {RANDOM_SEED}")
print(f"  Output directory: {OUTPUT_DIR}")
print("[SUCCESS] Configuration complete!")

## 2. Model Loading

### Choosing the Right Model Type

The attention extraction functionality is available in **all OmniModel types** through the `EmbeddingMixin` base class. Choose based on your use case:

- **`OmniModelForEmbedding`**: Use when you only need embeddings and attention (no task-specific head)
- **`OmniModelForSequenceClassification`**: Use when you need classification + attention analysis
- **`OmniModelForSequenceRegression`**: Use when you need regression + attention analysis
- **`OmniModelForTokenClassification`**: Use when you need token-level predictions + attention

For this tutorial, we'll use `OmniModelForEmbedding` as it's the simplest option focused on representation learning.

In [None]:
print(f"[INFO] Loading model: {MODEL_NAME}")
print(f"[INFO] Target device: {DEVICE}")

try:
    # Load the embedding model
    # OmniModelForEmbedding automatically loads both the model and tokenizer
    model = OmniModelForEmbedding(
        config_or_model=MODEL_NAME,
        trust_remote_code=TRUST_REMOTE_CODE
    )
    
    # Move to device and set to evaluation mode
    model = model.to(DEVICE)
    model.eval()
    
    # Verify model capabilities
    assert hasattr(model, 'extract_attention_scores'), \
        "Model must have extract_attention_scores method"
    assert hasattr(model, 'encode'), \
        "Model must have encode method"
    
    print(f"[SUCCESS] Model loaded: {type(model).__name__}")
    print(f"[INFO] Model supports:")
    print("  - Attention extraction (extract_attention_scores)")
    print("  - Embedding generation (encode, batch_encode)")
    print("  - Similarity computation (compute_similarity)")
    print("  - Attention visualization (visualize_attention_pattern)")
    print("  - Attention statistics (get_attention_statistics)")
    
except Exception as e:
    print(f"[ERROR] Failed to load model: {e}")
    raise

# Alternative: Load task-specific models (these also support attention extraction)
# 
# Option 2: Classification model
# tokenizer = OmniTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
# model = OmniModelForSequenceClassification(
#     config_or_model=MODEL_NAME,
#     tokenizer=tokenizer,
#     num_labels=2,
#     trust_remote_code=TRUST_REMOTE_CODE
# ).to(DEVICE).eval()
#
# Option 3: Regression model
# tokenizer = OmniTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
# model = OmniModelForSequenceRegression(
#     config_or_model=MODEL_NAME,
#     tokenizer=tokenizer,
#     num_labels=1,
#     trust_remote_code=TRUST_REMOTE_CODE
# ).to(DEVICE).eval()

## 3. Prepare Test Sequences

We'll use diverse genomic sequences to demonstrate different attention patterns:

1. **Regular sequence**: Balanced nucleotide composition
2. **GC-rich sequence**: High G+C content (common in promoters)
3. **Repeat pattern**: Tandem repeats (microsatellites)
4. **Long sequence**: Extended sequence to test attention across distances

These sequences represent common patterns in genomics and will help illustrate how attention mechanisms capture different sequence features.

In [None]:
# Define test sequences with different characteristics
test_sequences = [
    "ATCGATCGATCGTAGCTAGCTAGCT",     # Regular: Mixed nucleotides (25 bp)
    "GGCCTTAACCGGTTAACCGGTTAA",      # GC-rich: 68% GC content (24 bp)
    "TTTTAAAACCCCGGGGTTTTAAAA",      # Repeat: Simple tandem repeats (24 bp)
    "AUGCGAUCUCGAGCUACGUCGAUGCUAGCUCGAUGGCAUCCGAUUCGAGCUACGUCGAUGCUAG",  # RNA: Longer sequence (64 bp)
]

# Analyze sequence properties
print("[INFO] Test sequence properties:")
print("=" * 70)
for i, seq in enumerate(test_sequences, 1):
    length = len(seq)
    gc_content = (seq.count('G') + seq.count('C') + seq.count('g') + seq.count('c')) / length * 100
    seq_type = "RNA" if 'U' in seq.upper() else "DNA"
    
    print(f"Sequence {i} ({seq_type}):")
    print(f"  Length: {length} bp")
    print(f"  GC%: {gc_content:.1f}%")
    print(f"  Preview: {seq[:40]}{'...' if length > 40 else ''}")
    print()

print(f"[SUCCESS] Prepared {len(test_sequences)} test sequences")

## 4. Single Sequence Attention Extraction

### 4.1 Extract Attention Scores

The `extract_attention_scores()` method returns attention weights for all transformer layers and heads. The output format is:

```
attention_tensor: (num_layers, num_heads, seq_len, seq_len)
```

Where:
- **num_layers**: Number of transformer layers in the model
- **num_heads**: Number of attention heads per layer
- **seq_len**: Tokenized sequence length (including special tokens)
- **Attention values**: Range [0, 1], sum to 1 across the last dimension (softmax normalized)

In [None]:
# Select first sequence for detailed analysis
sequence = test_sequences[0]

print(f"[INFO] Analyzing sequence: {sequence}")
print(f"[INFO] Sequence length: {len(sequence)} bp")
print("[INFO] Extracting attention scores...")

try:
    attention_result = model.extract_attention_scores(
        sequence=sequence,
        max_length=MAX_LENGTH,
        layer_indices=EXTRACT_ALL_LAYERS,  # None = all layers
        head_indices=EXTRACT_ALL_HEADS,     # None = all heads
        return_on_cpu=True  # Transfer to CPU to save GPU memory
    )
    
    # Validate result structure
    assert 'attentions' in attention_result, "Missing 'attentions' key"
    assert 'tokens' in attention_result, "Missing 'tokens' key"
    assert 'attention_mask' in attention_result, "Missing 'attention_mask' key"
    
    attentions = attention_result['attentions']
    tokens = attention_result['tokens']
    attention_mask = attention_result['attention_mask']
    
    # Verify attention tensor properties
    assert attentions.ndim == 4, f"Expected 4D tensor, got {attentions.ndim}D"
    assert attentions.shape[2] == attentions.shape[3], "Attention matrix must be square"
    
    # Extract dimensions
    num_layers, num_heads, seq_len, _ = attentions.shape
    num_tokens = len(tokens)
    
    print(f"[SUCCESS] Attention extraction complete!")
    print(f"[INFO] Attention tensor shape: {attentions.shape}")
    print(f"       Format: (layers={num_layers}, heads={num_heads}, seq_len={seq_len}, seq_len={seq_len})")
    print(f"[INFO] Tokenized into {num_tokens} tokens")
    print(f"[INFO] First 10 tokens: {tokens[:10]}")
    print(f"[INFO] Attention value range: [{attentions.min():.4f}, {attentions.max():.4f}]")
    
    # Verify softmax normalization (rows sum to 1)
    row_sums = attentions[0, 0, :, :].sum(dim=-1)
    print(f"[INFO] Attention row sums (should be ~1.0): min={row_sums.min():.4f}, max={row_sums.max():.4f}")
    
except Exception as e:
    print(f"[ERROR] Attention extraction failed: {e}")
    raise

### 4.2 Compute Attention Statistics

Attention statistics help quantify attention behavior:

- **Attention matrix**: Aggregated attention across layers and heads
- **Attention entropy**: Measures how "spread out" attention is (higher = more uniform)
- **Attention concentration**: Measures how "focused" attention is (higher = more peaked)
- **Self-attention scores**: Diagonal values showing how much each token attends to itself
- **Max attention per position**: Identifies which tokens receive the most attention

In [None]:
# Compute comprehensive attention statistics
print("[INFO] Computing attention statistics...")

try:
    stats = model.get_attention_statistics(
        attention_result['attentions'],
        attention_result['attention_mask'],
        layer_aggregation="mean",  # Options: mean, max, sum, first, last
        head_aggregation="mean"    # Options: mean, max, sum
    )
    
    # Validate statistics
    assert 'attention_matrix' in stats, "Missing attention_matrix"
    assert 'attention_entropy' in stats, "Missing attention_entropy"
    assert 'attention_concentration' in stats, "Missing attention_concentration"
    assert 'self_attention_scores' in stats, "Missing self_attention_scores"
    assert 'max_attention_per_position' in stats, "Missing max_attention_per_position"
    
    print("[SUCCESS] Statistics computed!")
    print("=" * 70)
    print(f"Attention Matrix Shape: {stats['attention_matrix'].shape}")
    print(f"  Expected: (seq_len, seq_len) aggregated across layers and heads")
    print()
    print(f"Attention Entropy:")
    print(f"  Mean: {stats['attention_entropy'].mean():.4f}")
    print(f"  Std:  {stats['attention_entropy'].std():.4f}")
    print(f"  Range: [{stats['attention_entropy'].min():.4f}, {stats['attention_entropy'].max():.4f}]")
    print(f"  Interpretation: Higher entropy = attention more spread out")
    print()
    print(f"Attention Concentration:")
    print(f"  Mean: {stats['attention_concentration'].mean():.4f}")
    print(f"  Max:  {stats['attention_concentration'].max():.4f}")
    print(f"  Interpretation: Higher concentration = attention more focused")
    print()
    print(f"Self-Attention Scores:")
    print(f"  Mean: {stats['self_attention_scores'].mean():.4f}")
    print(f"  Interpretation: How much each token attends to itself")
    print()
    print(f"Max Attention Per Position (first 5):")
    for i, val in enumerate(stats['max_attention_per_position'][:5]):
        token = tokens[i] if i < len(tokens) else "N/A"
        print(f"  Position {i} ({token}): {val:.4f}")
        
except Exception as e:
    print(f"[ERROR] Statistics computation failed: {e}")
    raise

### 4.3 Visualize Attention Patterns

Attention heatmaps visualize which positions attend to which. Key observations:

- **Diagonal patterns**: Strong diagonal indicates local attention (each token attends to nearby tokens)
- **Vertical/horizontal lines**: Specific tokens receiving/giving attention to many others
- **Block patterns**: Related regions attending to each other (e.g., motif recognition)
- **Special tokens**: [CLS], [SEP], [PAD] tokens often show distinct patterns

In [None]:
# Visualize attention pattern for a specific layer and head
print("[INFO] Generating attention heatmap...")

try:
    layer_idx = 0   # First layer (change to -1 for last layer)
    head_idx = 0    # First attention head
    save_path = OUTPUT_DIR / "attention_heatmap_layer0_head0.png"
    
    fig = model.visualize_attention_pattern(
        attention_result=attention_result,
        layer_idx=layer_idx,
        head_idx=head_idx,
        save_path=str(save_path),
        figsize=(14, 12)
    )
    
    if fig is not None:
        print(f"[SUCCESS] Attention heatmap generated!")
        print(f"[INFO] Saved to: {save_path}")
        print(f"[INFO] Visualizing Layer {layer_idx}, Head {head_idx}")
        plt.tight_layout()
        plt.show()
    else:
        print("[WARNING] Visualization skipped (matplotlib not available or error occurred)")
        
except Exception as e:
    print(f"[ERROR] Visualization failed: {e}")
    # Continue execution even if visualization fails
    import traceback
    traceback.print_exc()

# Interpretation guide
print("\n[INFO] Heatmap Interpretation Guide:")
print("  - Bright colors = high attention weights (strong relationship)")
print("  - Dark colors = low attention weights (weak relationship)")
print("  - Diagonal = tokens attending to themselves or nearby positions")
print("  - Vertical lines = tokens that many others attend to (important positions)")
print("  - Horizontal lines = tokens that attend to many others (query-rich positions)")

## 5. Batch Attention Extraction

### 5.1 Efficient Batch Processing

When analyzing multiple sequences, batch processing is more efficient than individual extraction:

- **Memory efficiency**: Processes sequences in batches to avoid OOM errors
- **Computational efficiency**: Leverages GPU parallelism
- **Layer/head selection**: Can limit extraction to specific layers/heads to reduce memory

**Best Practices:**
- Use smaller batch sizes if encountering memory issues
- Extract only needed layers (e.g., `[0, -1]` for first and last)
- Enable `return_on_cpu=True` to free GPU memory immediately

In [None]:
# Extract attention from multiple sequences efficiently
print("[INFO] Extracting attention from batch of sequences...")
print(f"[INFO] Processing {len(test_sequences[:3])} sequences")
print(f"[INFO] Batch size: {BATCH_SIZE // 2}")  # Use smaller batch for demo

try:
    batch_results = model.batch_extract_attention_scores(
        sequences=test_sequences[:3],  # First 3 sequences
        batch_size=BATCH_SIZE // 2,    # Smaller batch size for memory efficiency
        max_length=MAX_LENGTH,
        layer_indices=[0, -1],  # First and last layer only (reduce memory)
        head_indices=[0, 1, 2], # First 3 heads only (reduce memory)
        return_on_cpu=True      # Free GPU memory immediately
    )
    
    # Validate batch results
    assert len(batch_results) == 3, f"Expected 3 results, got {len(batch_results)}"
    
    print(f"[SUCCESS] Batch attention extraction complete!")
    print(f"[INFO] Processed {len(batch_results)} sequences")
    print("=" * 70)
    
    for i, result in enumerate(batch_results, 1):
        assert 'attentions' in result, f"Result {i} missing attentions"
        assert 'tokens' in result, f"Result {i} missing tokens"
        
        attn_shape = result['attentions'].shape
        num_tokens = len(result['tokens'])
        
        print(f"Sequence {i}:")
        print(f"  Attention shape: {attn_shape}")
        print(f"  Format: (layers={attn_shape[0]}, heads={attn_shape[1]}, seq_len={attn_shape[2]}, seq_len={attn_shape[3]})")
        print(f"  Number of tokens: {num_tokens}")
        print(f"  First 5 tokens: {result['tokens'][:5]}")
        print()
        
except Exception as e:
    print(f"[ERROR] Batch extraction failed: {e}")
    raise

### 5.2 Compare Attention Patterns Across Sequences

Comparing statistics across sequences helps identify:

- **Sequence complexity**: More complex sequences often have higher entropy
- **Structural patterns**: Repeat sequences may show lower entropy (more predictable)
- **Model confidence**: High concentration suggests the model is "confident" about relationships

In [None]:
# Compare attention patterns between different sequences
print("[INFO] Comparing attention patterns across sequences...")
print("=" * 70)

try:
    comparison_data = []
    
    for i, result in enumerate(batch_results, 1):
        stats = model.get_attention_statistics(
            result['attentions'],
            result['attention_mask']
        )
        
        # Extract key metrics
        entropy_mean = stats['attention_entropy'].mean().item()
        concentration_mean = stats['attention_concentration'].mean().item()
        self_attn_mean = stats['self_attention_scores'].mean().item()
        
        comparison_data.append({
            'seq_idx': i,
            'entropy': entropy_mean,
            'concentration': concentration_mean,
            'self_attention': self_attn_mean
        })
        
        seq_preview = test_sequences[i-1][:30] + ("..." if len(test_sequences[i-1]) > 30 else "")
        print(f"Sequence {i}: {seq_preview}")
        print(f"  Attention entropy:       {entropy_mean:.4f}")
        print(f"  Attention concentration: {concentration_mean:.4f}")
        print(f"  Self-attention:          {self_attn_mean:.4f}")
        print()
    
    # Summary comparison
    print("[INFO] Summary:")
    entropy_vals = [d['entropy'] for d in comparison_data]
    print(f"  Entropy range: [{min(entropy_vals):.4f}, {max(entropy_vals):.4f}]")
    print(f"  Highest entropy: Sequence {entropy_vals.index(max(entropy_vals)) + 1} (more dispersed attention)")
    print(f"  Lowest entropy:  Sequence {entropy_vals.index(min(entropy_vals)) + 1} (more focused attention)")
    
except Exception as e:
    print(f"[ERROR] Comparison failed: {e}")
    raise

## 6. Combined Attention and Embedding Extraction

### Unified Interface for Multi-Modal Analysis

Since all OmniModel types inherit from `EmbeddingMixin`, you can extract **both** attention and embeddings from the same model:

- **Attention**: Shows *how* the model processes sequences (relationships between positions)
- **Embeddings**: Shows *what* the model learned (fixed-length representations)

This is useful for:
- **Visualization**: Combine attention with embedding-based clustering/UMAP
- **Transfer learning**: Use both attention patterns and embeddings as features
- **Model debugging**: Compare attention behavior with embedding similarity

In [None]:
# Extract embeddings from the same model used for attention
print("[INFO] Extracting embeddings from the same model...")

try:
    # Single sequence embedding
    single_embedding = model.encode(
        test_sequences[0],
        max_length=MAX_LENGTH,
        agg="mean"  # Options: mean, head, tail
    )
    print(f"[SUCCESS] Single sequence embedding shape: {single_embedding.shape}")
    
    # Batch embedding extraction
    batch_embeddings = model.batch_encode(
        test_sequences[:3],
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH,
        agg="mean"
    )
    print(f"[SUCCESS] Batch embeddings shape: {batch_embeddings.shape}")
    print(f"           Format: (num_sequences={batch_embeddings.shape[0]}, embedding_dim={batch_embeddings.shape[1]})")
    
    # Compute pairwise similarity
    print("\n[INFO] Computing pairwise sequence similarities:")
    print("=" * 70)
    for i in range(len(batch_embeddings)):
        for j in range(i + 1, len(batch_embeddings)):
            similarity = model.compute_similarity(
                batch_embeddings[i],
                batch_embeddings[j]
            )
            print(f"  Sequence {i+1} vs Sequence {j+1}: {similarity:.4f}")
    
    print(f"\n[SUCCESS] Both attention and embeddings extracted from the same model!")
    print("[INFO] This unified interface works with ALL OmniModel types:")
    print("       - OmniModelForEmbedding")
    print("       - OmniModelForSequenceClassification")
    print("       - OmniModelForSequenceRegression")
    print("       - OmniModelForTokenClassification")
    print("       - And all other OmniModel variants!")
    
except Exception as e:
    print(f"[ERROR] Embedding extraction failed: {e}")
    raise

## 7. Advanced Topics and Best Practices

### 7.1 Memory Management

**GPU Memory Considerations:**
- Each attention tensor can be large: `(layers × heads × seq_len × seq_len) × 4 bytes`
- Example: 12 layers, 12 heads, 512 seq_len = ~150 MB per sequence
- Use `return_on_cpu=True` to immediately transfer results to CPU
- Extract specific layers/heads to reduce memory usage
- Reduce batch size if encountering OOM errors

**CPU Memory Considerations:**
- Batch processing on CPU is slower but uses less memory
- Consider processing sequences sequentially for very large datasets

In [None]:
# Demonstrate memory-efficient extraction strategies

print("[INFO] Memory Management Strategies:")
print("=" * 70)

# Strategy 1: Extract only specific layers
print("Strategy 1: Extract specific layers only")
specific_layers = [0, -1]  # First and last layer
result_specific = model.extract_attention_scores(
    sequence=test_sequences[0],
    max_length=MAX_LENGTH,
    layer_indices=specific_layers,
    return_on_cpu=True
)
print(f"  Full extraction would be: (all_layers, all_heads, seq_len, seq_len)")
print(f"  Optimized extraction: {result_specific['attentions'].shape}")
print(f"  Memory saved: {((1 - len(specific_layers)/num_layers) * 100):.1f}%")
print()

# Strategy 2: Extract specific heads
print("Strategy 2: Extract specific heads only")
specific_heads = [0, 1]  # First 2 heads
result_heads = model.extract_attention_scores(
    sequence=test_sequences[0],
    max_length=MAX_LENGTH,
    head_indices=specific_heads,
    return_on_cpu=True
)
print(f"  Optimized extraction: {result_heads['attentions'].shape}")
print(f"  Memory saved: {((1 - len(specific_heads)/num_heads) * 100):.1f}%")
print()

# Strategy 3: Return on CPU
print("Strategy 3: Use return_on_cpu=True")
print("  [INFO] Immediately transfers results to CPU, freeing GPU memory")
print("  [INFO] Essential for processing large batches")
print()

print("[SUCCESS] Memory management strategies demonstrated!")

### 7.2 Limitations and Boundary Conditions

**Sequence Length Constraints:**
- **max_length parameter**: Must match model's maximum context window
- OmniGenome-186M supports up to 1024 tokens (check model card for specifics)
- Sequences exceeding max_length are **truncated**, not split
- Attention to truncated regions is lost

**Attention Extraction Limitations:**
- **Not all models support attention extraction**: Model must output attention weights
- **Computational cost**: Scales as O(L × H × N²) where L=layers, H=heads, N=sequence length
- **Memory requirements**: Full attention matrices can be very large
- **Interpretation challenges**: Attention ≠ importance (see [Attention is not Explanation](https://arxiv.org/abs/1902.10186))

**When Attention Analysis is Useful:**
- Debugging unexpected model predictions
- Understanding learned sequence motifs
- Comparing attention patterns across model variants
- Identifying position-specific biases

**When Attention Analysis May NOT Be Useful:**
- Directly determining feature importance (use gradient-based methods instead)
- Explaining black-box predictions (attention is one of many factors)
- Sequences with extreme length (computational constraints)

### 7.3 Aggregation Strategies

Different aggregation methods provide different insights:

In [None]:
# Compare different aggregation strategies
print("[INFO] Comparing aggregation strategies:")
print("=" * 70)

aggregation_strategies = [
    ("mean", "mean"),  # Average across layers and heads
    ("first", "mean"), # First layer, average heads
    ("last", "mean"),  # Last layer, average heads
    ("max", "max"),    # Maximum across layers and heads
]

for layer_agg, head_agg in aggregation_strategies:
    try:
        stats_agg = model.get_attention_statistics(
            attention_result['attentions'],
            attention_result['attention_mask'],
            layer_aggregation=layer_agg,
            head_aggregation=head_agg
        )
        
        entropy = stats_agg['attention_entropy'].mean().item()
        concentration = stats_agg['attention_concentration'].mean().item()
        
        print(f"Strategy: layer={layer_agg:5s}, head={head_agg:4s}")
        print(f"  Entropy: {entropy:.4f}, Concentration: {concentration:.4f}")
        
    except Exception as e:
        print(f"Strategy: layer={layer_agg}, head={head_agg} - Error: {e}")

print()
print("[INFO] Interpretation:")
print("  - 'mean' aggregation: Smoother, averaged patterns across layers/heads")
print("  - 'first' layer: Early representations (closer to input)")
print("  - 'last' layer: Late representations (task-specific features)")
print("  - 'max' aggregation: Emphasizes strongest attention signals")

## 8. Summary and Next Steps

### 8.1 What We Learned

In this tutorial, you learned how to:

**Core Skills:**
- [x] Extract attention scores from genomic sequences
- [x] Compute attention statistics (entropy, concentration, self-attention)
- [x] Visualize attention patterns as heatmaps
- [x] Process multiple sequences in batches efficiently
- [x] Compare attention patterns across different sequences
- [x] Extract both attention and embeddings from the same model

**Key Concepts:**
- [x] Attention mechanisms in transformer models for genomics
- [x] Universal attention support across all OmniModel types
- [x] Memory management strategies for large-scale analysis
- [x] Aggregation strategies (layer-wise and head-wise)
- [x] Limitations and appropriate use cases

### 8.2 API Methods Reference

All methods are available through `EmbeddingMixin` (inherited by all OmniModel types):

| Method | Purpose | Returns |
|--------|---------|---------|
| `extract_attention_scores()` | Single sequence attention | Dict with 'attentions', 'tokens', 'attention_mask' |
| `batch_extract_attention_scores()` | Multiple sequences attention | List of dicts |
| `get_attention_statistics()` | Compute attention metrics | Dict with statistical measures |
| `visualize_attention_pattern()` | Create attention heatmap | matplotlib Figure |
| `encode()` | Single sequence embedding | Tensor (embedding_dim,) |
| `batch_encode()` | Multiple sequences embeddings | Tensor (num_seqs, embedding_dim) |
| `compute_similarity()` | Embedding similarity | Float (cosine similarity) |

### 8.3 Supported Model Types

**All these models support attention extraction:**

| Model Type | Primary Purpose | Task Head |
|------------|----------------|-----------|
| `OmniModelForEmbedding` | Representation learning | None |
| `OmniModelForSequenceClassification` | Sequence-level classification | Linear classifier |
| `OmniModelForSequenceRegression` | Sequence-level regression | Linear regressor |
| `OmniModelForTokenClassification` | Token-level prediction | Token classifier |
| `OmniModelForMLM` | Masked language modeling | MLM head |

### 8.4 Troubleshooting

**Common Issues and Solutions:**

| Issue | Possible Cause | Solution |
|-------|---------------|----------|
| `CUDA out of memory` | Sequence too long / batch too large | Reduce batch_size, extract specific layers/heads, use return_on_cpu=True |
| `AttributeError: 'Model' object has no attribute 'extract_attention_scores'` | Old OmniGenBench version | Update: `pip install omnigenbench -U` |
| `Attention values not summing to 1` | Using wrong aggregation | Check layer/head aggregation, use raw attention from specific head |
| `Heatmap not showing` | matplotlib not available | Install: `pip install matplotlib seaborn` |
| `Sequence truncated` | max_length too small | Increase MAX_LENGTH (check model's max context window) |

### 8.5 Next Steps

**Explore More:**

1. **Fine-tune on your data**: Train task-specific models and analyze attention changes
2. **Compare models**: Extract attention from different GFMs and compare patterns
3. **Biological validation**: Correlate attention patterns with known biological motifs
4. **Advanced visualization**: Create attention flow diagrams across layers
5. **Integration**: Combine attention with other interpretability methods (SHAP, integrated gradients)

**Related Tutorials:**
- [Genomic Embeddings Tutorial](../genomic_embeddings/) - Deep dive into embedding extraction
- [RNA Secondary Structure Prediction](../rna_secondary_structure_prediction/) - Task-specific attention analysis
- [Variant Effect Prediction](../variant_effect_prediction/) - Attention for variant analysis

**Further Reading:**
- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - Original transformer paper
- [Attention is not Explanation](https://arxiv.org/abs/1902.10186) - Critical perspective on attention
- [Analyzing Attention in Genomics](https://www.nature.com/articles/s41592-021-01252-x) - Domain-specific analysis

### 8.6 Citation

If you use OmniGenBench in your research, please cite:

```bibtex
@software{omnigenbench2025,
  author = {Yang, Heng},
  title = {OmniGenBench: A Unified Framework for Genomic Foundation Models},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/yangheng95/OmniGenBench}
}
```

### 8.7 Feedback and Contributions

Found an issue or have suggestions? Please:
- Open an issue on [GitHub](https://github.com/yangheng95/OmniGenBench/issues)
- Submit a pull request with improvements
- Join discussions in the community forum

---

**Thank you for completing this tutorial!** You now have the tools to analyze attention patterns in genomic foundation models. Happy exploring!