# 🧬 VEP Tutorial 3/4: Embedding Extraction and Effect Scoring with PlantRNA-FM

Welcome to the third tutorial in our VEP series. This is where theory meets practice—we'll extract embeddings from thousands of variants using **PlantRNA-FM** and compute effect scores at scale.

> 📚 **Prerequisites**: 
> - Complete [Tutorial 1: Data Preparation](01_vep_data_preparation.ipynb)
> - Complete [Tutorial 2: Model Setup](02_vep_model_setup.ipynb)
> - Understand embedding concepts from [Fundamental Concepts](../../00_fundamental_concepts.ipynb)

This tutorial demonstrates the core VEP workflow: comparing sequence embeddings from **PlantRNA-FM** (35M parameters, *Nature Machine Intelligence*) to predict functional effects in plant genomic variants.

## 1. The Science of Embedding-Based Scoring 🔬

### 1.1 Why Embeddings Work for VEP

**The Central Hypothesis:**
> Functionally similar sequences have similar embeddings in the learned representation space of PlantRNA-FM.

Therefore, a variant that significantly changes the PlantRNA-FM embedding likely has functional impact on plant gene regulation or RNA structure.

```mermaid
graph TD
    A[Wild-type Sequence] --> B[PlantRNA-FM]
    C[Mutant Sequence] --> B
    B --> D[Embedding Space]
    D --> E{Large Distance?}
    E -->|Yes| F[Likely Pathogenic]
    E -->|No| G[Likely Benign]
    
    style A fill:#e1f5ff
    style C fill:#ffe1e1
    style D fill:#f5ffe1
    style F fill:#ffe1e1
    style G fill:#e1ffe1
```

### 1.2 Similarity vs. Distance Metrics

Different metrics capture different aspects of embedding relationships:

| Metric | Formula | Range | Interpretation | Best For |
|--------|---------|-------|----------------|----------|
| **Cosine Similarity** | $\frac{A \cdot B}{\\|A\\|\\|B\\|}$ | [-1, 1] | Angle between vectors | Direction-based comparison |
| **Cosine Distance** | $1 - \text{cosine similarity}$ | [0, 2] | Dissimilarity measure | Effect scores (our focus) |
| **Euclidean Distance** | $\\|A - B\\|_2$ | [0, ∞] | Straight-line distance | Magnitude-based comparison |
| **Manhattan Distance** | $\\|A - B\\|_1$ | [0, ∞] | Grid-based distance | Robust to outliers |

**For VEP with PlantRNA-FM, we primarily use cosine distance** because:
- ✅ Normalized (scale-invariant)
- ✅ Captures semantic similarity in learned plant RNA space
- ✅ Robust to sequence length variation
- ✅ Well-validated in NLP and genomics applications
- ✅ Effective for comparing plant regulatory element embeddings

### 1.3 Pooling Strategies Revisited

How we aggregate token embeddings affects results:

| Strategy | Method | Advantages | Disadvantages |
|----------|--------|------------|---------------|
| **Mean Pooling** | Average all tokens | Stable, robust | May dilute signal |
| **CLS Token** | Use [CLS] embedding | Fast, BERT-standard | Single point representation |
| **Max Pooling** | Max across tokens | Highlights peaks | Sensitive to outliers |
| **Weighted Average** | Attention-weighted | Focuses on important regions | Requires attention weights |

**Recommendation:** Start with mean pooling, compare with others if results are ambiguous.


## 2. The Embedding Extraction Workflow 🔄

Our pipeline processes variants in four stages:

```mermaid
graph LR
    A[Load<br/>Variants] --> B[Extract<br/>Embeddings]
    B --> C[Compute<br/>Distances]
    C --> D[Score<br/>Effects]
    
    style A fill:#e1f5ff
    style B fill:#ffe1f5
    style C fill:#f5ffe1
    style D fill:#ffe1e1
```

1. **Load**: Read variant data with reference and alternative sequences
2. **Extract**: Pass through model to get embeddings
3. **Compute**: Calculate pairwise distances
4. **Score**: Transform distances into interpretable effect scores


---

## 🛠️ Step-by-Step: Large-Scale Embedding Extraction

### 3.1: Environment Setup


In [None]:
import torch
import numpy as np
import pandas as pd
import warnings
from scipy.spatial.distance import cosine, euclidean, cityblock
from sklearn.metrics.pairwise import cosine_similarity
from omnigenbench import (
    OmniTokenizer,
    OmniModelForSequenceClassification,
    OmniDatasetForSequenceClassification
)
from dataclasses import dataclass
from tqdm.auto import tqdm
from typing import Literal, Optional

warnings.filterwarnings('ignore')
print("✅ Libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"💻 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")


### 3.2: Configuration

Comprehensive settings for embedding extraction and scoring:


In [None]:
@dataclass
class ScoringConfig:
    """Configuration for embedding extraction and variant scoring"""
    # Data settings
    dataset_name: str = "yangheng/variant_effect_prediction"
    cache_dir: str = "vep_data"
    
    # Model settings
    model_name: str = "yangheng/OmniGenome-52M"
    max_length: int = 512
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Embedding extraction
    pooling: Literal['mean', 'cls', 'max'] = 'mean'
    batch_size: int = 32  # Increase for faster processing
    
    # Scoring settings
    distance_metric: str = 'cosine'  # cosine, euclidean, manhattan
    normalize_scores: bool = True
    
    # Output settings
    save_embeddings: bool = False  # Set True to cache embeddings
    output_dir: str = "vep_results"

config = ScoringConfig()
print("📋 Configuration:")
print(f"   Model: {config.model_name}")
print(f"   Device: {config.device}")
print(f"   Pooling: {config.pooling}")
print(f"   Batch size: {config.batch_size}")
print(f"   Distance metric: {config.distance_metric}")


### 3.3: Load Model and Data


In [None]:
# Initialize tokenizer
print("🔤 Loading tokenizer...")
tokenizer = OmniTokenizer.from_pretrained(
    config.model_name, 
    trust_remote_code=True
)

# Load model
print("🤖 Loading model...")
model = OmniModelForSequenceClassification.from_pretrained(
    config.model_name,
    tokenizer=tokenizer,
    num_labels=2,
    trust_remote_code=True,
    output_hidden_states=True
)
model = model.to(config.device)
model.eval()

# Load datasets
print("📥 Loading datasets...")
datasets = OmniDatasetForSequenceClassification.from_hub(
    dataset_name=config.dataset_name,
    tokenizer=tokenizer,
    max_length=config.max_length,
    cache_dir=config.cache_dir
)

print("✅ All components loaded!")
print(f"   Test set size: {len(datasets['test'])} variants")


## 4. Embedding Extraction Implementation 🔬

### 4.1: Optimized Batch Extraction


In [None]:
def extract_embeddings_batch(
    model,
    dataloader,
    pooling='mean',
    device='cpu',
    show_progress=True
):
    """
    Extract embeddings from model in batches.
    
    Args:
        model: Pre-trained foundation model
        dataloader: DataLoader with tokenized sequences
        pooling: Pooling strategy ('mean', 'cls', 'max')
        device: Device for computation
        show_progress: Whether to show progress bar
        
    Returns:
        embeddings: numpy array of shape (n_samples, hidden_dim)
        
    Note:
        This function is optimized for large-scale processing with:
        - Batch processing for efficiency
        - GPU acceleration when available
        - Memory-efficient computation
    """
    all_embeddings = []
    
    iterator = tqdm(dataloader, desc=f"Extracting embeddings ({pooling})") if show_progress else dataloader
    
    with torch.no_grad():
        for batch in iterator:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            # Forward pass
            outputs = model(**batch)
            
            # Get final layer hidden states
            hidden_states = outputs.hidden_states[-1]  # (batch_size, seq_len, hidden_dim)
            
            # Apply pooling
            if pooling == 'mean':
                # Average over sequence length
                embeddings = hidden_states.mean(dim=1)
            elif pooling == 'cls':
                # Use CLS token (first position)
                embeddings = hidden_states[:, 0, :]
            elif pooling == 'max':
                # Max pooling over sequence length
                embeddings = hidden_states.max(dim=1)[0]
            else:
                raise ValueError(f"Unknown pooling strategy: {pooling}")
            
            # Move to CPU and store
            all_embeddings.append(embeddings.cpu())
    
    # Concatenate all batches
    all_embeddings = torch.cat(all_embeddings, dim=0)
    
    return all_embeddings.numpy()

# Test the function
print("🧪 Testing embedding extraction...")
test_loader = datasets['test'].get_dataloader(
    batch_size=config.batch_size,
    shuffle=False
)

# Extract embeddings (this may take a few minutes)
embeddings = extract_embeddings_batch(
    model=model,
    dataloader=test_loader,
    pooling=config.pooling,
    device=config.device,
    show_progress=True
)

print(f"\n✅ Extraction complete!")
print(f"   Embeddings shape: {embeddings.shape}")
print(f"   Memory usage: {embeddings.nbytes / 1024**2:.2f} MB")
print(f"   Embedding norm (mean): {np.linalg.norm(embeddings, axis=1).mean():.3f}")


### 4.2: Comparing Pooling Strategies

Let's compare different pooling strategies on a subset:


In [None]:
# Extract embeddings with different pooling strategies
pooling_strategies = ['mean', 'cls', 'max']
pooling_embeddings = {}

print("📊 Comparing pooling strategies...")
for strategy in pooling_strategies:
    emb = extract_embeddings_batch(
        model=model,
        dataloader=test_loader,
        pooling=strategy,
        device=config.device,
        show_progress=False
    )
    pooling_embeddings[strategy] = emb
    print(f"   {strategy:5s}: shape={emb.shape}, norm={np.linalg.norm(emb, axis=1).mean():.3f}")

# Compare similarity between pooling methods
from sklearn.metrics.pairwise import cosine_similarity

print("\n🔍 Correlation between pooling strategies:")
for i, strategy1 in enumerate(pooling_strategies):
    for strategy2 in pooling_strategies[i+1:]:
        # Calculate correlation between first embedding from each method
        corr = np.corrcoef(
            pooling_embeddings[strategy1][0],
            pooling_embeddings[strategy2][0]
        )[0, 1]
        print(f"   {strategy1} ↔ {strategy2}: {corr:.3f}")


## 5. Variant Effect Scoring 📊

### 5.1: Distance Metric Implementation


In [None]:
def calculate_pairwise_distances(
    embeddings1: np.ndarray,
    embeddings2: np.ndarray,
    metric: str = 'cosine'
) -> np.ndarray:
    """
    Calculate pairwise distances between two sets of embeddings.
    
    Args:
        embeddings1: First set of embeddings (n_samples, dim)
        embeddings2: Second set of embeddings (n_samples, dim)
        metric: Distance metric ('cosine', 'euclidean', 'manhattan')
        
    Returns:
        distances: Array of distances (n_samples,)
    """
    assert embeddings1.shape == embeddings2.shape, "Embedding shapes must match"
    
    distances = []
    for emb1, emb2 in zip(embeddings1, embeddings2):
        if metric == 'cosine':
            dist = cosine(emb1, emb2)
        elif metric == 'euclidean':
            dist = euclidean(emb1, emb2)
        elif metric == 'manhattan':
            dist = cityblock(emb1, emb2)
        else:
            raise ValueError(f"Unknown metric: {metric}")
        distances.append(dist)
    
    return np.array(distances)


def calculate_effect_scores(
    ref_embeddings: np.ndarray,
    alt_embeddings: np.ndarray,
    metric: str = 'cosine',
    normalize: bool = True
) -> np.ndarray:
    """
    Calculate variant effect scores from embeddings.
    
    Effect score interpretation:
    - Higher score = larger embedding change = likely functional impact
    - Lower score = smaller embedding change = likely benign
    
    Args:
        ref_embeddings: Reference sequence embeddings
        alt_embeddings: Alternative sequence embeddings
        metric: Distance metric to use
        normalize: Whether to normalize scores to [0, 1]
        
    Returns:
        effect_scores: Array of effect scores
    """
    # Calculate distances
    distances = calculate_pairwise_distances(
        ref_embeddings,
        alt_embeddings,
        metric=metric
    )
    
    # For cosine distance, values are already in [0, 2]
    # For euclidean/manhattan, we normalize
    if normalize:
        if metric in ['euclidean', 'manhattan']:
            distances = (distances - distances.min()) / (distances.max() - distances.min())
    
    return distances

# Example: Calculate effect scores
# Note: In practice, you'd have paired ref/alt sequences in your dataset
# Here we demonstrate with a mock scenario
print("📊 Calculating effect scores...")

# Assuming embeddings alternate between ref and alt (for demonstration)
# In real data, you'd extract these separately
n_variants = len(embeddings) // 2
ref_emb = embeddings[:n_variants]
alt_emb = embeddings[n_variants:2*n_variants]

effect_scores = calculate_effect_scores(
    ref_emb,
    alt_emb,
    metric=config.distance_metric,
    normalize=config.normalize_scores
)

print(f"\n✅ Effect scores calculated!")
print(f"   Number of variants: {len(effect_scores)}")
print(f"   Mean effect score: {effect_scores.mean():.4f}")
print(f"   Std dev: {effect_scores.std():.4f}")
print(f"   Score range: [{effect_scores.min():.4f}, {effect_scores.max():.4f}]")


### 5.2: Comparing Distance Metrics

Different metrics may give different insights:


In [None]:
# Compare different distance metrics
metrics = ['cosine', 'euclidean', 'manhattan']
metric_scores = {}

print("📊 Comparing distance metrics...")
for metric in metrics:
    scores = calculate_effect_scores(
        ref_emb,
        alt_emb,
        metric=metric,
        normalize=True
    )
    metric_scores[metric] = scores
    print(f"   {metric:10s}: mean={scores.mean():.4f}, std={scores.std():.4f}")

# Correlation between metrics
print("\n🔍 Correlation between metrics:")
for i, metric1 in enumerate(metrics):
    for metric2 in metrics[i+1:]:
        corr = np.corrcoef(metric_scores[metric1], metric_scores[metric2])[0, 1]
        print(f"   {metric1} ↔ {metric2}: {corr:.3f}")


## 6. Statistical Analysis & Interpretation 📈

### 6.1: Score Distribution Analysis


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 4)

# Create comprehensive statistics
results_df = pd.DataFrame({
    'variant_id': range(len(effect_scores)),
    'effect_score': effect_scores,
    'cosine_score': metric_scores['cosine'],
    'euclidean_score': metric_scores['euclidean'],
    'manhattan_score': metric_scores['manhattan']
})

print("📊 Effect Score Statistics:")
print("=" * 60)
print(results_df['effect_score'].describe())

# Percentile-based thresholds
percentiles = [50, 75, 90, 95, 99]
print("\n📏 Percentile Thresholds:")
for p in percentiles:
    threshold = np.percentile(effect_scores, p)
    print(f"   {p}th percentile: {threshold:.4f}")


### 6.2: Visualization


In [None]:
# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Distribution histogram
axes[0].hist(effect_scores, bins=50, alpha=0.7, edgecolor='black')
axes[0].axvline(effect_scores.mean(), color='red', linestyle='--', label=f'Mean: {effect_scores.mean():.3f}')
axes[0].axvline(np.percentile(effect_scores, 95), color='orange', linestyle='--', label='95th percentile')
axes[0].set_xlabel('Effect Score')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Effect Score Distribution')
axes[0].legend()

# 2. Metric comparison
metric_data = [metric_scores[m] for m in metrics]
bp = axes[1].boxplot(metric_data, labels=metrics, patch_artist=True)
for patch in bp['boxes']:
    patch.set_facecolor('#e1f5ff')
axes[1].set_ylabel('Normalized Score')
axes[1].set_title('Distance Metric Comparison')
axes[1].grid(axis='y', alpha=0.3)

# 3. Score correlation
axes[2].scatter(metric_scores['cosine'], metric_scores['euclidean'], alpha=0.5, s=10)
axes[2].set_xlabel('Cosine Distance')
axes[2].set_ylabel('Euclidean Distance')
axes[2].set_title('Metric Correlation')
corr = np.corrcoef(metric_scores['cosine'], metric_scores['euclidean'])[0, 1]
axes[2].text(0.05, 0.95, f'r = {corr:.3f}', transform=axes[2].transAxes, 
             verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig(f"{config.output_dir}/embedding_analysis.png", dpi=300, bbox_inches='tight')
plt.show()

print("\n✅ Visualization saved!")


### 6.3: Variant Classification

Based on effect scores, classify variants:


In [None]:
# Define thresholds
threshold_high = np.percentile(effect_scores, 75)
threshold_low = np.percentile(effect_scores, 25)

# Classify variants
def classify_variant(score, low_thresh, high_thresh):
    if score > high_thresh:
        return 'High Impact'
    elif score > low_thresh:
        return 'Moderate Impact'
    else:
        return 'Low Impact'

results_df['impact'] = results_df['effect_score'].apply(
    lambda x: classify_variant(x, threshold_low, threshold_high)
)

print("📊 Variant Impact Classification:")
print("=" * 60)
print(results_df['impact'].value_counts())
print(f"\n🎯 Classification Thresholds:")
print(f"   Low → Moderate: {threshold_low:.4f}")
print(f"   Moderate → High: {threshold_high:.4f}")

# Sample high-impact variants
high_impact = results_df[results_df['impact'] == 'High Impact'].sort_values('effect_score', ascending=False)
print(f"\n🔝 Top 5 High-Impact Variants:")
print(high_impact.head())


## 7. Export Results 💾


In [None]:
import os
from pathlib import Path

# Create output directory
Path(config.output_dir).mkdir(exist_ok=True, parents=True)

# Export scores
output_file = f"{config.output_dir}/variant_effect_scores.csv"
results_df.to_csv(output_file, index=False)
print(f"✅ Results exported to: {output_file}")

# Optionally save embeddings
if config.save_embeddings:
    np.save(f"{config.output_dir}/ref_embeddings.npy", ref_emb)
    np.save(f"{config.output_dir}/alt_embeddings.npy", alt_emb)
    print(f"✅ Embeddings saved to: {config.output_dir}/")

# Export summary statistics
summary_stats = {
    'total_variants': len(effect_scores),
    'mean_score': float(effect_scores.mean()),
    'std_score': float(effect_scores.std()),
    'min_score': float(effect_scores.min()),
    'max_score': float(effect_scores.max()),
    'high_impact_count': int((results_df['impact'] == 'High Impact').sum()),
    'moderate_impact_count': int((results_df['impact'] == 'Moderate Impact').sum()),
    'low_impact_count': int((results_df['impact'] == 'Low Impact').sum()),
}

import json
with open(f"{config.output_dir}/summary_statistics.json", 'w') as f:
    json.dump(summary_stats, f, indent=2)

print(f"✅ Summary statistics saved!")


---

## 📝 Summary & Next Steps

### What We Accomplished

In this tutorial, we:
1. ✅ **Understood embedding distance metrics** - Learned cosine, euclidean, and Manhattan distances
2. ✅ **Implemented batch extraction** - Efficiently processed thousands of variants
3. ✅ **Compared pooling strategies** - Evaluated mean, CLS, and max pooling
4. ✅ **Calculated effect scores** - Converted embeddings to interpretable scores
5. ✅ **Analyzed distributions** - Performed statistical analysis and visualization
6. ✅ **Classified variants** - Categorized by predicted impact level

### Key Takeaways

📊 **Distance metrics matter**: Cosine distance is preferred for semantic comparison

🔄 **Batch processing is essential**: Process large datasets efficiently with GPU acceleration

📈 **Thresholds are dataset-specific**: Use percentiles rather than absolute cutoffs

🎯 **Multiple metrics provide robustness**: Compare results across different distance measures

### Technical Highlights

**Optimized Embedding Extraction:**
```python
embeddings = extract_embeddings_batch(
    model, dataloader, pooling='mean', device='cuda'
)
```

**Effect Score Calculation:**
```python
scores = calculate_effect_scores(
    ref_embeddings, alt_embeddings, 
    metric='cosine', normalize=True
)
```

**Variant Classification:**
```python
threshold = np.percentile(scores, 75)
high_impact = scores > threshold
```

### Next Steps

Continue to **[Tutorial 4: Visualization & Export](04_visualization_and_export.ipynb)** where we'll:
- Create publication-quality figures
- Perform ROC/PR curve analysis
- Interpret results biologically
- Generate clinical reports

### Quick Reference

```python
# Complete workflow in 5 lines
model = OmniModelForSequenceClassification.from_pretrained(model_name, output_hidden_states=True)
embeddings = extract_embeddings_batch(model, dataloader, pooling='mean')
ref_emb, alt_emb = embeddings[::2], embeddings[1::2]  # Split pairs
scores = calculate_effect_scores(ref_emb, alt_emb, metric='cosine')
results_df = pd.DataFrame({'variant_id': range(len(scores)), 'effect_score': scores})
```

**Next Tutorial**: [04_visualization_and_export.ipynb](04_visualization_and_export.ipynb) 📊
