In [None]:
from IPython.display import Image, display
display(Image(filename="4-step workflow.png"))


**Variant Effect Prediction** uses embedding comparison to assess functional impacts without training. We leverage `OmniModelForSequenceClassification` with **PlantRNA-FM** to extract plant-optimized embeddings and compare reference vs. alternative sequences.

### 4. Tutorial Structure

1. **[Data Preparation](01_vep_data_preparation.ipynb)**: Load variant datasets and reference genome
2. **[Model Setup](02_vep_model_setup.ipynb)**: Initialize PlantRNA-FM for plant genomic analysis
3. **[Embedding Extraction](03_embedding_and_scoring.ipynb)**: Compare reference and alternative sequences using PlantRNA-FM
4. **[Visualization](04_visualization_and_export.ipynb)**: Analyze and export results

Let's begin!


## 🚀 Step 1: Environment Setup and Configuration


In [None]:
%pip install omnigenbench -U


In [None]:
from omnigenbench import (
    OmniTokenizer,
    OmniModelForSequenceClassification,
    OmniDatasetForSequenceClassification
)


### Configuration

Define analysis parameters with sensible defaults:


In [None]:
# Configuration parameters
dataset_name = "yangheng/variant_effect_prediction"
model_name = "yangheng/PlantRNA-FM"  # Using PlantRNA-FM for plant variant analysis
max_length = 512
batch_size = 16
context_length = 200
max_variants = 100  # Use subset for quick testing
cache_dir = "vep_data"
output_dir = "vep_results"


## 📊 Step 2: Data Loading

Load the variant dataset using OmniGenBench's enhanced data loading:


In [None]:
# Load tokenizer
tokenizer = OmniTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Load dataset with automatic caching
datasets = OmniDatasetForSequenceClassification.from_hub(
    dataset_name_or_path=dataset_name,
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=cache_dir
)


## 🔧 Step 3: Model Initialization


In [None]:
model = OmniModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    trust_remote_code=True
)
model.eval()


## 🧬 Step 4: Variant Effect Scoring

Extract embeddings and calculate effect scores:


In [None]:
import torch
from tqdm import tqdm

test_dataset = datasets['test']
dataloader = test_dataset.get_dataloader(batch_size=batch_size, shuffle=False)

results = []
with torch.no_grad():
    for batch in tqdm(dataloader, desc="Processing variants"):
        outputs = model(**batch)
        embeddings = outputs.hidden_states[-1][:, 0, :]  # CLS token
        results.append(embeddings)

all_embeddings = torch.cat(results, dim=0)


## 📈 Results and Visualization

For detailed analysis and visualization, see the complete tutorial notebooks:
- **[01_vep_data_preparation.ipynb](01_vep_data_preparation.ipynb)**
- **[02_vep_model_setup.ipynb](02_vep_model_setup.ipynb)**
- **[03_embedding_and_scoring.ipynb](03_embedding_and_scoring.ipynb)**
- **[04_visualization_and_export.ipynb](04_visualization_and_export.ipynb)**


# Appendix: Aligning with run_vep_mlm pipeline\n\nThis appendix shows how to reproduce the script-style variant effect scoring inside the notebook to stay aligned with `dev/benchpaper/vep/run_vep_mlm.py`:\n\n- Input expects paired sequences (`ref_seq`, `alt_seq`) and a `mutation_position` index within the context window.\n- Tokenization uses `add_special_tokens=False` to keep positions aligned.\n- Distances reported:\n  - `cls_dist`: distance at token index 0 (first-token embedding, matches script semantics)\n  - `mut_dist`: distance at the provided `mutation_position`\n- If a `label` column exists with at least two classes, ROC-AUC will be computed on `mut_dist`.\n\nNote: For human reference genome (hg38) fetching and BED-to-sequence assembly, see the script for a CLI-first workflow. Here we focus on the scoring core.
import torch, numpy as np, pandas as pd\nfrom scipy.spatial import distance\nfrom sklearn.metrics import roc_auc_score\nfrom transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM\n\n@torch.inference_mode()\ndef vep_compute_distances(\n    model, tokenizer,\n    ref_seqs, alt_seqs, mut_positions,\n    max_length: int, batch_size: int = 16, device: str = None\n):\n    \"\"\"\n    Compute first-token (index 0) and mutation-position embedding distances for ref/alt pairs.\n\n    Returns: DataFrame with columns: cls_dist, mut_dist\n    \"\"\"\n    if device is None:\n        device = model.device\n    results = []\n\n    def encode_batch(seqs):\n        toks = tokenizer(\n            [' '.join(s) for s in seqs],\n            return_tensors='pt',\n            padding='max_length', max_length=max_length, truncation=True,\n            add_special_tokens=False\n        )\n        toks = {k: v.to(device) for k, v in toks.items()}\n        out = model(**toks, output_hidden_states=True)\n        return out.hidden_states[-1].detach().cpu().numpy()  # (B, L, D)\n\n    for i in range(0, len(ref_seqs), batch_size):\n        r = encode_batch(ref_seqs[i:i+batch_size])\n        a = encode_batch(alt_seqs[i:i+batch_size])\n        for j, pos in enumerate(mut_positions[i:i+batch_size]):\n            pos = int(pos) if 0 <= int(pos) < r.shape[1] else 0\n            # L2-normalize then cosine distance (matches script style)\n            def norm(x):\n                n = np.linalg.norm(x)\n                return x / (n + 1e-8)\n            cls_dist = distance.cosine(norm(a[j, 0, :]), norm(r[j, 0, :]))\n            mut_dist = distance.cosine(norm(a[j, pos, :]), norm(r[j, pos, :]))\n            results.append((cls_dist, mut_dist))\n\n    return pd.DataFrame(results, columns=['cls_dist', 'mut_dist'])\n\n# Example (pseudocode):\n# model_name = 'yangheng/OmniGenome-52M'\n# tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n# mdl = AutoModel.from_pretrained(model_name, trust_remote_code=True).eval().to('cuda')\n# df must contain: ref_seq, alt_seq, mutation_position, and optional label\n# scores = vep_compute_distances(mdl, tok, df.ref_seq.tolist(), df.alt_seq.tolist(), df.mutation_position.tolist(), max_length=400)\n# if 'label' in df and df['label'].nunique() > 1:\n#     print('AUC:', roc_auc_score(df['label'], scores['mut_dist']))

In [None]:
import torch, numpy as np, pandas as pd
from scipy.spatial import distance
from sklearn.metrics import roc_auc_score
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM

@torch.inference_mode()
def vep_compute_distances(
    model, tokenizer,
    ref_seqs, alt_seqs, mut_positions,
    max_length: int, batch_size: int = 16, device: str = None
):
    """
    Compute first-token (index 0) and mutation-position embedding distances for ref/alt pairs.

    Returns: DataFrame with columns: cls_dist, mut_dist
    """
    if device is None:
        device = model.device
    results = []

    def encode_batch(seqs):
        toks = tokenizer(
            [' '.join(s) for s in seqs],
            return_tensors='pt',
            padding='max_length', max_length=max_length, truncation=True,
            add_special_tokens=False
        )
        toks = {k: v.to(device) for k, v in toks.items()}
        out = model(**toks, output_hidden_states=True)
        return out.hidden_states[-1].detach().cpu().numpy()  # (B, L, D)

    for i in range(0, len(ref_seqs), batch_size):
        r = encode_batch(ref_seqs[i:i+batch_size])
        a = encode_batch(alt_seqs[i:i+batch_size])
        for j, pos in enumerate(mut_positions[i:i+batch_size]):
            pos = int(pos) if 0 <= int(pos) < r.shape[1] else 0
            def norm(x):
                n = np.linalg.norm(x)
                return x / (n + 1e-8)
            cls_dist = distance.cosine(norm(a[j, 0, :]), norm(r[j, 0, :]))
            mut_dist = distance.cosine(norm(a[j, pos, :]), norm(r[j, pos, :]))
            results.append((cls_dist, mut_dist))

    return pd.DataFrame(results, columns=['cls_dist', 'mut_dist'])

# Example (pseudocode):
# model_name = 'yangheng/OmniGenome-52M'
# tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# mdl = AutoModel.from_pretrained(model_name, trust_remote_code=True).eval().to('cuda')
# # df must contain: ref_seq, alt_seq, mutation_position, and optional label
# scores = vep_compute_distances(mdl, tok, df.ref_seq.tolist(), df.alt_seq.tolist(), df.mutation_position.tolist(), max_length=400)
# if 'label' in df and df['label'].nunique() > 1:
#     print('AUC:', roc_auc_score(df['label'], scores['mut_dist']))