# 🤖 VEP Tutorial 2/4: Foundation Models for Variant Analysis

Welcome to the second tutorial in our VEP series. This guide focuses on understanding and configuring **PlantRNA-FM** (Plant RNA Foundation Model) for extracting meaningful sequence embeddings in variant effect prediction.

> 📚 **Prerequisites**: 
> - Complete [Tutorial 1: Data Preparation](01_vep_data_preparation.ipynb)
> - Understand the [Fundamental Concepts Tutorial](../00_fundamental_concepts.ipynb)

Before loading models, let's understand *why* PlantRNA-FM and similar foundation models are powerful for variant effect prediction.

## 1. Understanding Genomic Foundation Models 🧠

### 1.1 From Language Models to Genomic Models

Just as **BERT** and **GPT** learned language patterns from massive text corpora, **PlantRNA-FM** learns biological patterns from vast plant RNA and genomic databases, representing a new generation of specialized foundation models for plant genomics.

| Concept | Language Models | PlantRNA-FM |
|---------|-----------------|---------------------------|
| **Input** | Words, sentences | Plant RNA/DNA sequences |
| **Training** | Books, Wikipedia | Plant genomic databases, transcriptomes |
| **Learned patterns** | Grammar, syntax | Regulatory motifs, splicing signals, codon usage |
| **Embeddings** | Semantic meaning | Biological function and structure |
| **Applications** | Translation, QA | VEP, gene expression prediction, regulatory analysis |

### 1.2 Why PlantRNA-FM for VEP?

**Traditional approaches** (e.g., SIFT, PolyPhen) rely on:
- ❌ Hand-crafted features (conservation scores, physicochemical properties)
- ❌ Limited to known protein-coding regions
- ❌ Require multiple sequence alignments

**PlantRNA-FM and foundation model approaches** offer:
- ✅ **Learned representations**: Automatically discover plant-specific regulatory patterns
- ✅ **Contextual understanding**: Capture long-range dependencies in RNA structures
- ✅ **Transfer learning**: Apply knowledge from millions of plant sequences
- ✅ **Zero-shot capability**: Predict effects without variant-specific training
- ✅ **Plant-optimized**: Specialized for plant genomic features (e.g., alternative splicing, UTR structures)

### 1.3 Embedding-Based VEP: The Core Idea

```mermaid
graph TD
    A[Reference Sequence<br/>ATCG...] --> B[PlantRNA-FM]
    C[Alternative Sequence<br/>AGCG...] --> B
    B --> D[Embedding₁<br/>[768-dim vector]]
    B --> E[Embedding₂<br/>[768-dim vector]]
    D --> F[Compare<br/>Cosine Similarity]
    E --> F
    F --> G[Effect Score<br/>0.0 - 1.0]
    
    style A fill:#e1f5ff
    style C fill:#ffe1e1
    style D fill:#e1ffe1
    style E fill:#ffe1f5
    style G fill:#fff5e1
```

**Key insight**: If a variant significantly changes the embedding produced by PlantRNA-FM (low similarity), it likely has functional impact on the plant RNA or regulatory element.


## 2. Choosing the Right Model 🎯

### 2.1 Available Foundation Models

| Model | Parameters | Context Length | Specialization | Publication |
|-------|-----------|----------------|----------------|-------------|
| **PlantRNA-FM** | **35M** | 1024 | Plant RNA & genomics | *Nature Machine Intelligence* |
| **OmniGenome-52M** | 52M | 1024 | General genomics | - |
| **OmniGenome-418M** | 418M | 4096 | Deep genomics | - |

### 2.2 Model Selection Guidelines

**For this tutorial, we use PlantRNA-FM** because:
- 🌱 **Plant-specialized**: Trained on extensive plant RNA and genomic data
- 🎯 **Optimized for plant variants**: Better captures plant-specific regulatory elements
- 🚀 **Efficient & powerful**: Only 35M parameters, published in *Nature Machine Intelligence*
- 💾 **Low resource requirements**: Fast inference, small memory footprint (~140MB)
- 📚 **Well-validated**: Peer-reviewed and extensively tested on plant genomic tasks
- ✅ **Suitable for most plant VEP tasks**: Effective for gene expression, splicing, and regulatory variants

**Consider alternatives when:**
- **OmniGenome-52M**: Quick prototyping or non-plant organisms
- **OmniGenome-418M**: Require maximum accuracy or very long sequences (>2000bp)
- Working with mixed datasets (plant and non-plant sequences)


---

## 🛠️ Step-by-Step: Model Initialization

### 3.1: Environment Setup


In [None]:
import torch
import warnings
from omnigenbench import (
    OmniTokenizer,
    OmniModelForSequenceClassification
)
from dataclasses import dataclass

warnings.filterwarnings('ignore')
print("✅ Libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"💻 CUDA available: {torch.cuda.is_available()}")


### 3.2: Configuration

Define model parameters with clear explanations:


In [None]:
# Configuration parameters
@dataclass
class ModelConfig:
    """Configuration for VEP model setup with PlantRNA-FM"""
    # Model selection - Using PlantRNA-FM for plant variant analysis
    model_name: str = "yangheng/PlantRNA-FM"  # Plant RNA Foundation Model
    
    # Task settings
    num_labels: int = 2  # Binary classification (benign/pathogenic)
    
    # Embedding extraction
    output_hidden_states: bool = True  # Required for embedding extraction
    output_attentions: bool = False  # Set True to analyze attention patterns
    
    # Device management
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
config = ModelConfig()
print("📋 Model Configuration:")
print(f"   Model: {config.model_name}")
print(f"   Device: {config.device}")
print(f"   Output embeddings: {config.output_hidden_states}")


### 3.3: Load PlantRNA-FM Model and Tokenizer

**What happens during PlantRNA-FM loading:**
1. 🔽 Download pre-trained PlantRNA-FM weights from Hugging Face (if not cached)
2. 🏗️ Initialize PlantRNA-FM architecture
3. ⚖️ Load plant-specific learned parameters
4. 🔤 Configure RNA/DNA tokenizer
5. 🎯 Set up for evaluation mode


In [None]:
# Initialize tokenizer
print("🔤 Loading tokenizer...")
tokenizer = OmniTokenizer.from_pretrained(
    config.model_name, 
    trust_remote_code=True
)
print(f"   Vocabulary size: {len(tokenizer)}")
print(f"   Special tokens: {tokenizer.all_special_tokens}")
print("✅ Tokenizer loaded!")


In [None]:
# Load pre-trained PlantRNA-FM model
print("\n🤖 Loading PlantRNA-FM...")
model = OmniModelForSequenceClassification.from_pretrained(
    config.model_name,
    tokenizer=tokenizer,
    num_labels=config.num_labels,
    trust_remote_code=True,
    output_hidden_states=config.output_hidden_states,
    output_attentions=config.output_attentions
)

# Move to appropriate device
model = model.to(config.device)

# Set to evaluation mode (disables dropout, batch norm)
model.eval()

print("✅ PlantRNA-FM loaded successfully!")
print(f"   Device: {next(model.parameters()).device}")


## 4. Model Verification & Analysis 🔍

Let's examine the model architecture and verify it's ready for VEP:


In [None]:
# Model statistics
num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("📊 Model Statistics:")
print("=" * 50)
print(f"Total parameters: {num_params:,} ({num_params/1e6:.1f}M)")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{num_params * 4 / 1024**2:.1f} MB")

# Model architecture overview
print("\n🏗️ Model Architecture:")
print("=" * 50)
print(model)


### 4.1: Understanding Model Outputs

For VEP, we need to understand what the model produces:


In [None]:
# Test with sample sequence
print("🧪 Testing model outputs...")
test_sequence = "ATCGATCGATCG" * 10  # 120bp test sequence

# Tokenize
inputs = tokenizer(
    test_sequence, 
    return_tensors="pt", 
    max_length=512, 
    truncation=True,
    padding=True
)
inputs = {k: v.to(config.device) for k, v in inputs.items()}

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)

print("\n📤 Model Outputs:")
print("=" * 50)
print(f"Logits shape: {outputs.logits.shape}")
print(f"  → (batch_size, num_labels)")
print(f"\nHidden states: {len(outputs.hidden_states)} layers")
print(f"  → Layer 0 (embeddings): {outputs.hidden_states[0].shape}")
print(f"  → Layer -1 (final): {outputs.hidden_states[-1].shape}")
print(f"  → Shape: (batch_size, sequence_length, hidden_dim)")

# Extract embedding dimension
embedding_dim = outputs.hidden_states[-1].shape[-1]
print(f"\n💡 Embedding dimension: {embedding_dim}")


### 4.2: Embedding Extraction for VEP

For variant effect prediction, we extract embeddings from the final hidden layer:


In [None]:
def extract_embedding(sequence, pooling='mean'):
    """
    Extract sequence embedding from PlantRNA-FM.
    
    Args:
        sequence: Plant DNA/RNA sequence string
        pooling: 'mean', 'cls', or 'max' pooling strategy
        
    Returns:
        Embedding vector (hidden_dim,)
    """
    # Tokenize
    inputs = tokenizer(
        sequence,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True
    )
    inputs = {k: v.to(config.device) for k, v in inputs.items()}
    
    # Extract hidden states
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get final layer hidden states
    hidden_states = outputs.hidden_states[-1]  # (1, seq_len, hidden_dim)
    
    # Apply pooling
    if pooling == 'mean':
        # Average over sequence length
        embedding = hidden_states.mean(dim=1).squeeze()
    elif pooling == 'cls':
        # Use CLS token (first position)
        embedding = hidden_states[:, 0, :].squeeze()
    elif pooling == 'max':
        # Max pooling over sequence length
        embedding = hidden_states.max(dim=1)[0].squeeze()
    else:
        raise ValueError(f"Unknown pooling: {pooling}")
    
    return embedding.cpu().numpy()

# Test embedding extraction
print("🧬 Testing embedding extraction...")
ref_seq = "ATCGATCGATCG" * 10
alt_seq = "ATCGAGCGATCG" * 10  # SNV: T→G

ref_emb = extract_embedding(ref_seq, pooling='mean')
alt_emb = extract_embedding(alt_seq, pooling='mean')

print(f"\n✅ Embedding extraction successful!")
print(f"   Reference embedding shape: {ref_emb.shape}")
print(f"   Alternative embedding shape: {alt_emb.shape}")
print(f"   Embedding L2 norm: {(ref_emb**2).sum()**0.5:.3f}")


### 4.3: Computing Variant Effect Scores


In [None]:
from scipy.spatial.distance import cosine

# Calculate similarity
cosine_sim = 1 - cosine(ref_emb, alt_emb)
effect_score = 1 - cosine_sim  # Higher score = larger effect

print("📊 Variant Effect Metrics:")
print("=" * 50)
print(f"Cosine similarity: {cosine_sim:.4f}")
print(f"Effect score (1 - similarity): {effect_score:.4f}")
print(f"\n💡 Interpretation:")
if effect_score < 0.01:
    print("   → Likely benign (minimal embedding change)")
elif effect_score < 0.05:
    print("   → Possibly benign (small embedding change)")
else:
    print("   → Potentially functional (significant embedding change)")


---

## 📝 Summary & Next Steps

### What We Accomplished

In this tutorial, we:
1. ✅ **Understood PlantRNA-FM** - Learned how plant-specialized foundation models differ from traditional methods
2. ✅ **Loaded PlantRNA-FM** - Successfully initialized the plant RNA foundation model
3. ✅ **Verified model outputs** - Confirmed hidden states and embedding extraction
4. ✅ **Implemented embedding extraction** - Created reusable function for plant VEP
5. ✅ **Tested variant scoring** - Calculated effect scores from embeddings

### Key Takeaways

🌱 **PlantRNA-FM is plant-optimized**: Trained on extensive plant genomic data to capture plant-specific patterns

🧠 **Foundation models are powerful**: They learn from millions of sequences without manual feature engineering

📊 **Embeddings capture biological meaning**: Similar sequences have similar embeddings

🎯 **Effect = embedding distance**: Large changes suggest functional impact in plant regulatory elements

### Technical Details

**Model Output Structure:**
```python
outputs.logits           # (batch, num_labels) - classification logits
outputs.hidden_states    # List of (batch, seq_len, hidden_dim) per layer
outputs.attentions       # Attention weights (if requested)
```

**Pooling Strategies:**
- `mean`: Average over sequence → robust to length variation
- `cls`: Use CLS token → follows BERT convention
- `max`: Maximum values → highlights salient features

### Next Steps

Continue to **[Tutorial 3: Embedding & Scoring](03_embedding_and_scoring.ipynb)** where we'll:
- Process large variant datasets efficiently
- Implement multiple scoring methods
- Compare reference and alternative sequences
- Analyze effect score distributions

### Quick Reference

```python
# Load PlantRNA-FM and extract embeddings
model = OmniModelForSequenceClassification.from_pretrained(
    "yangheng/PlantRNA-FM", 
    tokenizer=tokenizer,
    output_hidden_states=True
)
outputs = model(**inputs)
embedding = outputs.hidden_states[-1].mean(dim=1)
```

**Next Tutorial**: [03_embedding_and_scoring.ipynb](03_embedding_and_scoring.ipynb) 🧬
