# 🔍 Attention Analysis Tutorial: Exploring Genomic Attention Patterns

## 🎯 Key Feature: ALL OmniModel Types Support Attention Extraction!

This tutorial demonstrates attention extraction capabilities that are now available in **ALL OmniGenBench models** through the `EmbeddingMixin`.

### Supported Model Types
✅ `OmniModelForEmbedding` - Dedicated embedding extraction  
✅ `OmniModelForSequenceClassification` - Classification + Attention  
✅ `OmniModelForSequenceRegression` - Regression + Attention  
✅ `OmniModelForTokenClassification` - Token classification + Attention  
✅ `OmniModelForMLM` - Masked language modeling + Attention  
✅ **All other OmniModel variants** - Task-specific + Attention

### What You'll Learn
1. 🧬 Extract attention scores from genomic sequences
2. 📊 Analyze attention patterns and statistics
3. 🎨 Visualize attention heatmaps
4. 🔬 Compare attention patterns across sequences
5. 💡 Use attention extraction with any model type

## 🚀 Setup and Installation

In [None]:
!pip install omnigenbench torch transformers matplotlib seaborn -U

## 📚 Import Libraries

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Import various model types - ALL support attention extraction!
from omnigenbench import (
    OmniModelForEmbedding,
    OmniModelForSequenceClassification,
    OmniModelForSequenceRegression,
)

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette('husl')

print("✅ Imports successful!")
print("💡 All OmniModel types support attention extraction!")

## 🔧 Load Model

### Important: You Can Use ANY OmniModel Type!

The attention extraction functionality is available in all model types. Choose the one that fits your use case:
- Use `OmniModelForEmbedding` for dedicated embedding/attention extraction
- Use `OmniModelForSequenceClassification` if you also need classification
- Use `OmniModelForSequenceRegression` if you also need regression
- And so on...

In [None]:
# Configuration
model_name = "anonymous8/OmniGenome-186M"  # Change to your preferred model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"🔧 Loading model: {model_name}")
print(f"📱 Device: {device}")

# Option 1: Use dedicated embedding model
model = OmniModelForEmbedding(model_name, trust_remote_code=True)

# Option 2: Use classification model (also supports attention extraction!)
# model = OmniModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)

# Option 3: Use regression model (also supports attention extraction!)
# model = OmniModelForSequenceRegression.from_pretrained(model_name, trust_remote_code=True)

model = model.to(device)
model.eval()

print(f"✅ Model loaded: {type(model).__name__}")
print(f"💡 This model supports both task-specific operations AND attention extraction!")

## 🧬 Prepare Test Sequences

In [None]:
# Example genomic sequences with different characteristics
test_sequences = [
    "ATCGATCGATCGTAGCTAGCTAGCT",  # Regular sequence
    "GGCCTTAACCGGTTAACCGGTTAA",   # GC-rich sequence
    "TTTTAAAACCCCGGGGTTTTAAAA",   # Repeat pattern
    "AUGCGAUCUCGAGCUACGUCGAUGCUAGCUCGAUGGCAUCCGAUUCGAGCUACGUCGAUGCUAG",  # Longer sequence
]

print("🧬 Test sequences prepared:")
for i, seq in enumerate(test_sequences, 1):
    print(f"  {i}. {seq[:40]}{'...' if len(seq) > 40 else ''}")

## 1️⃣ Extract Attention Scores from Single Sequence

In [None]:
# Extract attention from the first sequence
sequence = test_sequences[0]

print(f"🔍 Analyzing sequence: {sequence}")
print("⏳ Extracting attention scores...")

attention_result = model.extract_attention_scores(
    sequence=sequence,
    max_length=128,
    layer_indices=None,  # Extract all layers (or specify [0, 5, 11] for specific layers)
    head_indices=None,   # Extract all heads (or specify [0, 1, 2] for specific heads)
    return_on_cpu=True
)

print(f"\n✅ Attention extraction successful!")
print(f"📊 Attention tensor shape: {attention_result['attentions'].shape}")
print(f"   Format: (layers, heads, seq_len, seq_len)")
print(f"🔤 Number of tokens: {len(attention_result['tokens'])}")
print(f"🎯 First 10 tokens: {attention_result['tokens'][:10]}")

## 2️⃣ Compute Attention Statistics

In [None]:
# Compute comprehensive attention statistics
stats = model.get_attention_statistics(
    attention_result['attentions'],
    attention_result['attention_mask'],
    layer_aggregation="mean",  # Options: mean, max, sum, first, last
    head_aggregation="mean"    # Options: mean, max, sum
)

print("📈 Attention Statistics:")
print(f"  Attention matrix shape: {stats['attention_matrix'].shape}")
print(f"  Average attention entropy: {stats['attention_entropy'].mean():.4f}")
print(f"  Max attention concentration: {stats['attention_concentration'].max():.4f}")
print(f"  Average self-attention score: {stats['self_attention_scores'].mean():.4f}")
print(f"  Max attention per position (top 5): {stats['max_attention_per_position'][:5]}")

## 3️⃣ Visualize Attention Patterns

In [None]:
# Visualize attention pattern for a specific layer and head
fig = model.visualize_attention_pattern(
    attention_result=attention_result,
    layer_idx=0,   # First layer
    head_idx=0,    # First attention head
    save_path="attention_heatmap.png",
    figsize=(12, 10)
)

if fig is not None:
    print("✅ Attention heatmap generated and saved!")
    plt.show()
else:
    print("⚠️  Visualization skipped (matplotlib not available)")

## 4️⃣ Batch Attention Extraction

In [None]:
# Extract attention from multiple sequences efficiently
print("⏳ Extracting attention from batch of sequences...")

batch_results = model.batch_extract_attention_scores(
    sequences=test_sequences[:3],  # First 3 sequences
    batch_size=2,
    max_length=128,
    layer_indices=[0, -1],  # First and last layer only
    head_indices=[0, 1, 2], # First 3 heads only
    return_on_cpu=True
)

print(f"✅ Batch attention extraction successful!")
print(f"📊 Processed {len(batch_results)} sequences")

for i, result in enumerate(batch_results, 1):
    print(f"  Sequence {i} attention shape: {result['attentions'].shape}")

## 5️⃣ Compare Attention Patterns Across Sequences

In [None]:
# Compare attention patterns between different sequences
print("🔬 Comparing attention patterns across sequences...\n")

for i, result in enumerate(batch_results, 1):
    stats = model.get_attention_statistics(
        result['attentions'],
        result['attention_mask']
    )
    
    seq_preview = test_sequences[i-1][:30] + "..."
    print(f"Sequence {i}: {seq_preview}")
    print(f"  Attention entropy: {stats['attention_entropy'].mean():.4f}")
    print(f"  Self-attention: {stats['self_attention_scores'].mean():.4f}")
    print(f"  Concentration: {stats['attention_concentration'].mean():.4f}")
    print()

## 6️⃣ Advanced: Embedding Extraction

### Bonus: Extract Embeddings Too!

Since all OmniModel types support both attention AND embedding extraction, you can easily get both:

In [None]:
# Extract embeddings from the same model
print("🎯 Extracting embeddings from the same model...")

# Single sequence
embedding = model.encode(test_sequences[0], agg="mean")
print(f"Single embedding shape: {embedding.shape}")

# Batch encoding
embeddings = model.batch_encode(test_sequences, batch_size=4, agg="mean")
print(f"Batch embeddings shape: {embeddings.shape}")

# Compute similarity
similarity = model.compute_similarity(embeddings[0], embeddings[1])
print(f"\nSimilarity between sequences 1 and 2: {similarity:.4f}")

print("\n✅ Both attention and embeddings extracted from the same model!")
print("💡 This works with ALL OmniModel types!")

## 🎉 Summary

### What We've Learned

1. **Universal Support**: ALL OmniModel types support attention and embedding extraction
2. **Flexible API**: Same API works across all model types
3. **Rich Features**:
   - `extract_attention_scores()` - Single sequence attention
   - `batch_extract_attention_scores()` - Batch processing
   - `get_attention_statistics()` - Attention analysis
   - `visualize_attention_pattern()` - Visualization
   - `encode()` / `batch_encode()` - Embedding extraction
   - `compute_similarity()` - Similarity computation

### Supported Model Types
✅ OmniModelForEmbedding  
✅ OmniModelForSequenceClassification  
✅ OmniModelForSequenceRegression  
✅ OmniModelForTokenClassification  
✅ OmniModelForMLM  
✅ All other OmniModel variants!

### Key Benefits
- 🎯 Use task-specific models for their intended purpose
- 🔍 Extract attention and embeddings from the same model
- 💪 No need for separate embedding models
- 🚀 Efficient and unified API

---

**Next Steps**: Try using attention extraction with your fine-tuned models!