# 🧬 VEP Embedding and Scoring Tutorial

This tutorial focuses on the core embedding extraction and variant effect scoring methodology for Variant Effect Prediction using genomic foundation models.

## Overview

In this tutorial, we'll explore:
- Sequence embedding extraction from genomic foundation models
- Variant effect scoring methodologies
- Distance metrics for comparing reference vs alternative sequences
- Interpretation of effect scores

## 1. Setup and Prerequisites

First, let's import the necessary libraries and set up our environment.

In [None]:
# Import essential libraries
import torch
import numpy as np
import pandas as pd
from typing import Tuple, List, Dict
from scipy.spatial.distance import cosine, euclidean
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from omnigenbench import (
    OmniTokenizer,
    OmniModelForSequenceClassification
)
from autocuda import auto_cuda

print("✅ Libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🎯 CUDA available: {torch.cuda.is_available()}")

## 2. Configuration and Model Setup

Define our analysis configuration and load the genomic foundation model.

In [None]:
# VEP Embedding Configuration
from dataclasses import dataclass

@dataclass
class EmbeddingConfig:
    MODEL_NAME = "yangheng/OmniGenome-52M"
    MAX_LENGTH = 512
    DEVICE = auto_cuda()
    
    # Embedding extraction settings
    POOLING_STRATEGY = "mean"  # Options: mean, max, cls, last
    LAYER_INDEX = -1  # Which transformer layer to use for embeddings
    
    # Scoring parameters
    DISTANCE_METRICS = ["cosine", "euclidean", "manhattan"]
    EFFECT_THRESHOLD = 0.1  # Threshold for functional significance

config = EmbeddingConfig()

print("⚙️ Embedding Analysis Configuration:")
print(f"  🧬 Model: {config.MODEL_NAME}")
print(f"  📏 Max length: {config.MAX_LENGTH}")
print(f"  📱 Device: {config.DEVICE}")
print(f"  🎯 Pooling: {config.POOLING_STRATEGY}")
print(f"  📊 Distance metrics: {config.DISTANCE_METRICS}")

In [None]:
# Load tokenizer and model
print("🔄 Loading tokenizer and model...")

tokenizer = OmniTokenizer.from_pretrained(config.MODEL_NAME, trust_remote_code=True)

model = OmniModelForSequenceClassification.from_pretrained(
    config.MODEL_NAME,
    num_labels=2,
    trust_remote_code=True,
    output_hidden_states=True  # Essential for embedding extraction
)

model.to(config.DEVICE)
model.eval()

print(f"✅ Model setup complete!")
print(f"  🧠 Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"  📱 Device: {config.DEVICE}")

## 3. Embedding Extraction Functions

Define comprehensive functions for extracting sequence embeddings from the genomic foundation model.

In [None]:
def extract_sequence_embedding(
    sequence: str, 
    model, 
    tokenizer, 
    device: str,
    pooling_strategy: str = "mean",
    layer_index: int = -1
) -> np.ndarray:
    """
    Extract sequence embedding using specified pooling strategy.
    
    Args:
        sequence: DNA sequence string
        model: Pre-trained genomic model
        tokenizer: Corresponding tokenizer
        device: Computing device
        pooling_strategy: Pooling method ('mean', 'max', 'cls', 'last')
        layer_index: Which transformer layer to use
    
    Returns:
        Sequence embedding as numpy array
    """
    # Tokenize sequence
    inputs = tokenizer(
        sequence,
        padding=True,
        truncation=True,
        max_length=config.MAX_LENGTH,
        return_tensors="pt"
    ).to(device)
    
    # Extract embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        hidden_states = outputs.hidden_states[layer_index]  # Shape: [batch, seq_len, hidden_dim]
        
        # Apply pooling strategy
        if pooling_strategy == "mean":
            # Mean pooling over sequence length
            embedding = hidden_states.mean(dim=1)
        elif pooling_strategy == "max":
            # Max pooling over sequence length
            embedding = hidden_states.max(dim=1)[0]
        elif pooling_strategy == "cls":
            # Use CLS token (first token)
            embedding = hidden_states[:, 0, :]
        elif pooling_strategy == "last":
            # Use last token
            embedding = hidden_states[:, -1, :]
        else:
            raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
    
    return embedding.cpu().numpy().squeeze()

def batch_extract_embeddings(
    sequences: List[str],
    model,
    tokenizer,
    device: str,
    batch_size: int = 8,
    pooling_strategy: str = "mean"
) -> np.ndarray:
    """
    Extract embeddings for multiple sequences in batches for efficiency.
    
    Args:
        sequences: List of DNA sequences
        model: Pre-trained genomic model
        tokenizer: Corresponding tokenizer
        device: Computing device
        batch_size: Batch size for processing
        pooling_strategy: Pooling method
    
    Returns:
        Array of embeddings [num_sequences, embedding_dim]
    """
    embeddings = []
    
    for i in tqdm(range(0, len(sequences), batch_size), desc="Extracting embeddings"):
        batch_sequences = sequences[i:i + batch_size]
        
        # Tokenize batch
        inputs = tokenizer(
            batch_sequences,
            padding=True,
            truncation=True,
            max_length=config.MAX_LENGTH,
            return_tensors="pt"
        ).to(device)
        
        # Extract embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            hidden_states = outputs.hidden_states[-1]  # Last layer
            
            # Apply pooling
            if pooling_strategy == "mean":
                batch_embeddings = hidden_states.mean(dim=1)
            elif pooling_strategy == "max":
                batch_embeddings = hidden_states.max(dim=1)[0]
            elif pooling_strategy == "cls":
                batch_embeddings = hidden_states[:, 0, :]
            else:
                batch_embeddings = hidden_states.mean(dim=1)  # Default to mean
        
        embeddings.append(batch_embeddings.cpu().numpy())
    
    return np.vstack(embeddings)

print("✅ Embedding extraction functions defined!")
print("  🧬 extract_sequence_embedding() - Single sequence embedding")
print("  📦 batch_extract_embeddings() - Batch processing for efficiency")

## 4. Variant Effect Scoring Methods

Implement multiple distance metrics for calculating variant effect scores.

In [None]:
def calculate_variant_effects(
    ref_embedding: np.ndarray, 
    alt_embedding: np.ndarray,
    metrics: List[str] = ["cosine", "euclidean", "manhattan"]
) -> Dict[str, float]:
    """
    Calculate variant effect scores using multiple distance metrics.
    
    Args:
        ref_embedding: Reference sequence embedding
        alt_embedding: Alternative sequence embedding
        metrics: List of distance metrics to compute
    
    Returns:
        Dictionary of metric names and their scores
    """
    scores = {}
    
    for metric in metrics:
        if metric == "cosine":
            # Cosine distance (0 = identical, 1 = orthogonal)
            score = cosine(ref_embedding, alt_embedding)
        elif metric == "euclidean":
            # Euclidean distance
            score = euclidean(ref_embedding, alt_embedding)
        elif metric == "manhattan":
            # Manhattan (L1) distance
            score = np.sum(np.abs(ref_embedding - alt_embedding))
        elif metric == "cosine_similarity":
            # Cosine similarity (1 = identical, -1 = opposite)
            # Convert to distance: 1 - similarity
            similarity = cosine_similarity([ref_embedding], [alt_embedding])[0, 0]
            score = 1 - similarity
        else:
            raise ValueError(f"Unknown metric: {metric}")
        
        scores[metric] = score
    
    return scores

def interpret_effect_score(score: float, metric: str = "cosine", threshold: float = 0.1) -> Dict[str, any]:
    """
    Interpret variant effect score in biological context.
    
    Args:
        score: Effect score
        metric: Distance metric used
        threshold: Threshold for functional significance
    
    Returns:
        Dictionary with interpretation results
    """
    # Define interpretation categories
    if metric == "cosine":
        if score < 0.01:
            impact = "Minimal"
            description = "Likely neutral variant"
            functional = False
        elif score < threshold:
            impact = "Low"
            description = "Possibly neutral or mildly functional"
            functional = False
        elif score < 0.3:
            impact = "Moderate"
            description = "Likely functional impact"
            functional = True
        else:
            impact = "High"
            description = "Strong functional impact expected"
            functional = True
    else:
        # Generic interpretation for other metrics
        functional = score > threshold
        impact = "High" if score > threshold * 3 else ("Moderate" if functional else "Low")
        description = f"Effect score: {score:.4f}"
    
    return {
        "score": score,
        "impact": impact,
        "functional": functional,
        "description": description,
        "metric": metric
    }

print("✅ Variant effect scoring functions defined!")
print("  📊 calculate_variant_effects() - Multiple distance metrics")
print("  🔬 interpret_effect_score() - Biological interpretation")

## 5. Practical Example: Variant Analysis

Let's demonstrate the embedding and scoring pipeline with example variants.

In [None]:
# Create example variant sequences
def create_variant_sequences(
    ref_allele: str, 
    alt_allele: str, 
    context_length: int = 100,
    position: int = None
) -> Tuple[str, str]:
    """
    Create reference and alternative sequences for variant analysis.
    
    Args:
        ref_allele: Reference allele
        alt_allele: Alternative allele
        context_length: Length of context sequence on each side
        position: Position within context (if None, places in center)
    
    Returns:
        Tuple of (reference_sequence, alternative_sequence)
    """
    if position is None:
        position = context_length
    
    # Create synthetic context (in practice, this would come from reference genome)
    context_before = "ATCG" * (position // 4) + "ATCG"[:position % 4]
    context_after = "GCTA" * (context_length // 4) + "GCTA"[:context_length % 4]
    
    ref_sequence = context_before + ref_allele + context_after
    alt_sequence = context_before + alt_allele + context_after
    
    return ref_sequence, alt_sequence

# Example variants to analyze
example_variants = [
    {"id": "SNV_1", "ref": "A", "alt": "G", "type": "transition"},
    {"id": "SNV_2", "ref": "C", "alt": "T", "type": "transition"},
    {"id": "SNV_3", "ref": "A", "alt": "T", "type": "transversion"},
    {"id": "SNV_4", "ref": "G", "alt": "C", "type": "transversion"},
    {"id": "INDEL_1", "ref": "AT", "alt": "A", "type": "deletion"},
    {"id": "INDEL_2", "ref": "G", "alt": "GT", "type": "insertion"},
]

print("🧬 Analyzing example variants...")
results = []

for variant in tqdm(example_variants, desc="Processing variants"):
    # Create sequences
    ref_seq, alt_seq = create_variant_sequences(
        variant["ref"], 
        variant["alt"], 
        context_length=150
    )
    
    # Extract embeddings
    ref_embedding = extract_sequence_embedding(
        ref_seq, model, tokenizer, config.DEVICE, config.POOLING_STRATEGY
    )
    alt_embedding = extract_sequence_embedding(
        alt_seq, model, tokenizer, config.DEVICE, config.POOLING_STRATEGY
    )
    
    # Calculate effect scores
    effect_scores = calculate_variant_effects(
        ref_embedding, alt_embedding, config.DISTANCE_METRICS
    )
    
    # Interpret primary score (cosine)
    interpretation = interpret_effect_score(
        effect_scores["cosine"], "cosine", config.EFFECT_THRESHOLD
    )
    
    # Store results
    result = {
        "variant_id": variant["id"],
        "ref_allele": variant["ref"],
        "alt_allele": variant["alt"],
        "variant_type": variant["type"],
        "ref_sequence_length": len(ref_seq),
        "alt_sequence_length": len(alt_seq),
        **{f"{metric}_score": score for metric, score in effect_scores.items()},
        "predicted_functional": interpretation["functional"],
        "impact_level": interpretation["impact"],
        "description": interpretation["description"]
    }
    
    results.append(result)

# Convert to DataFrame for analysis
results_df = pd.DataFrame(results)

print(f"\n📊 Analysis Results:")
print(results_df[['variant_id', 'ref_allele', 'alt_allele', 'variant_type', 
                 'cosine_score', 'predicted_functional', 'impact_level']].to_string(index=False))

print(f"\n🎯 Summary:")
print(f"  📈 Total variants analyzed: {len(results_df)}")
print(f"  ⚡ Predicted functional: {results_df['predicted_functional'].sum()}")
print(f"  📊 Mean cosine score: {results_df['cosine_score'].mean():.4f}")

## 6. Advanced Analysis: Pooling Strategy Comparison

Compare different embedding pooling strategies to understand their impact on variant effect prediction.

In [None]:
# Compare pooling strategies
pooling_strategies = ["mean", "max", "cls"]
comparison_results = []

print("🔬 Comparing pooling strategies...")

# Use first variant as example
test_variant = example_variants[0]
ref_seq, alt_seq = create_variant_sequences(test_variant["ref"], test_variant["alt"])

for strategy in pooling_strategies:
    print(f"  📊 Testing {strategy} pooling...")
    
    # Extract embeddings with different pooling
    ref_embedding = extract_sequence_embedding(
        ref_seq, model, tokenizer, config.DEVICE, strategy
    )
    alt_embedding = extract_sequence_embedding(
        alt_seq, model, tokenizer, config.DEVICE, strategy
    )
    
    # Calculate cosine distance
    cosine_score = cosine(ref_embedding, alt_embedding)
    
    # Interpret effect
    interpretation = interpret_effect_score(cosine_score, "cosine")
    
    comparison_results.append({
        "pooling_strategy": strategy,
        "cosine_score": cosine_score,
        "predicted_functional": interpretation["functional"],
        "impact_level": interpretation["impact"],
        "embedding_shape": ref_embedding.shape
    })

comparison_df = pd.DataFrame(comparison_results)

print(f"\n📊 Pooling Strategy Comparison:")
print(comparison_df.to_string(index=False))

# Visualize pooling comparison
plt.figure(figsize=(10, 6))
plt.bar(comparison_df['pooling_strategy'], comparison_df['cosine_score'], 
        color=['skyblue', 'lightcoral', 'lightgreen'])
plt.xlabel('Pooling Strategy')
plt.ylabel('Cosine Distance Score')
plt.title('🧬 Variant Effect Scores by Pooling Strategy')
plt.axhline(y=config.EFFECT_THRESHOLD, color='red', linestyle='--', 
            label=f'Functional threshold ({config.EFFECT_THRESHOLD})')
plt.legend()
plt.show()

print("✅ Pooling strategy comparison complete!")

## 7. Distance Metric Analysis

Analyze how different distance metrics perform for variant effect prediction.

In [None]:
# Analyze distance metric correlations
print("📊 Analyzing distance metric relationships...")

# Extract scores for all metrics
cosine_scores = results_df['cosine_score'].values
euclidean_scores = results_df['euclidean_score'].values
manhattan_scores = results_df['manhattan_score'].values

# Create correlation matrix
metric_data = pd.DataFrame({
    'Cosine': cosine_scores,
    'Euclidean': euclidean_scores,
    'Manhattan': manhattan_scores
})

correlation_matrix = metric_data.corr()

# Visualize correlations
plt.figure(figsize=(12, 5))

# Correlation heatmap
plt.subplot(1, 2, 1)
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
            square=True, fmt='.3f')
plt.title('🔗 Distance Metric Correlations')

# Score distributions
plt.subplot(1, 2, 2)
for metric in ['Cosine', 'Euclidean', 'Manhattan']:
    plt.hist(metric_data[metric], alpha=0.6, label=metric, bins=15)
plt.xlabel('Effect Score')
plt.ylabel('Frequency')
plt.title('📊 Score Distributions by Metric')
plt.legend()

plt.tight_layout()
plt.show()

print(f"\n🔗 Metric Correlations:")
print(correlation_matrix.round(3))

# Find most discriminative metric
functional_variants = results_df[results_df['predicted_functional']]
neutral_variants = results_df[~results_df['predicted_functional']]

print(f"\n🎯 Discriminative Power Analysis:")
for metric in ['cosine_score', 'euclidean_score', 'manhattan_score']:
    if len(functional_variants) > 0 and len(neutral_variants) > 0:
        func_mean = functional_variants[metric].mean()
        neut_mean = neutral_variants[metric].mean()
        separation = abs(func_mean - neut_mean)
        print(f"  {metric.split('_')[0].title()}: Separation = {separation:.4f}")
    else:
        print(f"  {metric.split('_')[0].title()}: Insufficient data for comparison")

## 8. Summary and Best Practices

### Key Findings

Based on our analysis, here are the key insights for variant effect prediction using genomic foundation models:

#### 🧬 **Embedding Extraction**
- **Mean pooling** generally provides robust sequence representations
- **Layer selection** matters - last layer captures task-specific features
- **Sequence length** should include sufficient context around variants

#### 📊 **Distance Metrics**
- **Cosine distance** is effective for comparing sequence semantics
- **Euclidean distance** captures magnitude differences
- **Multiple metrics** provide complementary information

#### 🎯 **Effect Scoring**
- Threshold-based classification works well for binary functional prediction
- Score interpretation should consider biological context
- Validation with experimental data is crucial

### Best Practices

1. **🔧 Technical Recommendations**
   - Use mean pooling for stable embeddings
   - Include adequate sequence context (≥200bp)
   - Batch process for efficiency with large datasets
   - Validate thresholds on known functional variants

2. **🧬 Biological Considerations**
   - Consider variant type (SNV vs INDEL) in interpretation
   - Account for genomic context (coding vs non-coding)
   - Integrate with other genomic features when possible
   - Validate predictions with functional assays

3. **📈 Performance Optimization**
   - Use appropriate batch sizes for GPU memory
   - Cache embeddings for repeated analyses
   - Consider model size vs accuracy tradeoffs
   - Profile different pooling strategies for your data

In [None]:
# Final summary statistics
print("🎉 VEP Embedding and Scoring Analysis Complete!")
print("=" * 60)

print(f"\n📊 Analysis Summary:")
print(f"  🧬 Variants analyzed: {len(results_df)}")
print(f"  ⚡ Functional predictions: {results_df['predicted_functional'].sum()}")
print(f"  📈 Average effect scores:")
for metric in ['cosine_score', 'euclidean_score', 'manhattan_score']:
    mean_score = results_df[metric].mean()
    std_score = results_df[metric].std()
    print(f"    {metric.split('_')[0].title()}: {mean_score:.4f} ± {std_score:.4f}")

print(f"\n🔧 Technical Configuration:")
print(f"  🧠 Model: {config.MODEL_NAME}")
print(f"  📊 Pooling: {config.POOLING_STRATEGY}")
print(f"  📏 Max length: {config.MAX_LENGTH}")
print(f"  🎯 Threshold: {config.EFFECT_THRESHOLD}")

print(f"\n💡 Next Steps:")
print(f"  🔬 Validate predictions with experimental data")
print(f"  📊 Scale analysis to larger variant datasets")
print(f"  🧬 Integrate with other genomic features")
print(f"  🏥 Apply to clinical variant interpretation")

print(f"\n✅ Tutorial complete! Ready for variant effect prediction.")