# 🧬 mRNA Degradation Rate Prediction with OmniGenBench

Welcome to this comprehensive tutorial where we'll explore how to predict **mRNA degradation rates** from RNA sequences using **OmniGenBench**. This guide will walk you through a complete genomic deep learning project, from understanding the fundamental biological concepts to deploying a trained model for real-world applications.

### 1. The Biological Challenge: What is mRNA Degradation?

**mRNA degradation** is a crucial regulatory mechanism in gene expression that controls the stability and lifespan of mRNA molecules in cells. The rate at which mRNA degrades directly impacts:

- **Protein production levels**: Stable mRNAs produce more proteins over time
- **Gene expression dynamics**: Rapid degradation enables quick response to cellular signals  
- **Cellular homeostasis**: Proper mRNA turnover maintains balanced protein levels
- **Disease mechanisms**: Dysregulated mRNA stability contributes to various disorders

Understanding and predicting mRNA degradation rates has profound implications across multiple domains:
- **Therapeutic Design**: Engineering mRNA-based therapeutics (like COVID-19 vaccines) with optimal stability
- **Synthetic Biology**: Designing gene circuits with precise temporal control
- **Disease Research**: Understanding how mutations affect mRNA stability in genetic disorders
- **Biotechnology**: Optimizing protein production in industrial applications

However, experimentally measuring degradation rates across thousands of mRNA sequences is time-consuming and costly. This is where computational methods, particularly deep learning with Genomic Foundation Models, provide transformative solutions.

### 2. The Data: mRNA Degradation Dataset

To train our predictive model, we utilize a carefully curated dataset containing mRNA sequences with experimentally determined degradation parameters.

- **What it contains**: mRNA sequences with multiple degradation rate measurements under different conditions
- **What it labels**: Each sequence has three continuous values representing degradation rates:
  - `reactivity`: General degradation reactivity
  - `deg_Mg_pH10`: Degradation under Mg2+ and pH 10 conditions  
  - `deg_Mg_50C`: Degradation under Mg2+ and 50°C conditions
- **Our Goal**: Train a model that can accurately predict degradation rates for each nucleotide position in mRNA sequences

**Dataset Structure:**

| sequence | reactivity | deg_Mg_pH10 | deg_Mg_50C |
|---------|------------|-------------|------------|
| AUGCCAU... | [0.1, 0.2, ...] | [0.15, 0.25, ...] | [0.3, 0.4, ...] |
| AUGCUA... | [0.05, 0.1, ...] | [0.2, 0.3, ...] | [0.25, 0.35, ...] |
| ... | ... | ... | ... |

### 3. The Tool: From Language Models to Genomic Foundation Models

#### The Rise of Language Models
In recent years, **Language Models (LMs)** like BERT have revolutionized Natural Language Processing (NLP). Trained on vast amounts of text, they learn the underlying patterns of language—grammar, context, and even semantics. This allows them to be "fine-tuned" for a wide range of specific tasks.

#### A New Paradigm in Genomics: Genomic Foundation Models (GFMs)
The same principles can be applied to biology. The "language of life" is written in RNA and DNA sequences using nucleotides (A, C, G, U/T). **Genomic Foundation Models (GFMs)**, like **OmniGenome** (Yang et al., 2025), are large-scale models pre-trained on massive amounts of genomic sequences.

In this tutorial, we will follow a standard 4-step fine-tuning pipeline for **token-level regression**, where we predict continuous values for each nucleotide position in the sequence.

### 4. The Workflow: A 4-Step Guide to Fine-Tuning

We will follow a standard 4-step fine-tuning pipeline, which is a common practice in machine learning.

```mermaid
flowchart TD
    subgraph "4-Step Workflow for mRNA Degradation Prediction"
        A["📥 Step 1: Data Preparation<br/>Download and process the mRNA degradation dataset"] --> B["🔧 Step 2: Model Initialization<br/>Load the pre-trained OmniGenome model"]
        B --> C["🎓 Step 3: Model Training<br/>Fine-tune the model on the degradation dataset"]
        C --> D["🔮 Step 4: Model Inference<br/>Use the trained model to predict degradation rates"]
    end

    style A fill:#e1f5fe,stroke:#333,stroke-width:2px
    style B fill:#f3e5f5,stroke:#333,stroke-width:2px
    style C fill:#e8f5e8,stroke:#333,stroke-width:2px
    style D fill:#fff3e0,stroke:#333,stroke-width:2px
```

In a nutshell, we will see the following steps:
1. **Data Preparation**: Download and preprocess the mRNA degradation dataset
2. **Model Initialization**: Load the pre-trained OmniGenome model and set it up for token regression
3. **Training Implementation**: Fine-tune the model using our dataset and validate its performance
4. **Inference: Make Predictions**: Use the trained model to predict degradation rates on new sequences

Let's get started!

## 🚀 Step 1: Data Preparation

This first step is all about getting our data ready for analysis. It involves four key parts:
1. **Environment Setup**: Installing and importing the necessary libraries
2. **Configuration**: Defining all our important parameters in one place
3. **Data Acquisition**: Loading and preparing the mRNA degradation dataset
4. **Data Pipeline**: Creating an efficient pipeline to feed data to the model

### 1.1: Environment Setup

First, let's install the required Python packages. `omnigenbench` is our core library that provides state-of-the-art genomic foundation models and training utilities.

!pip install omnigenbench -U  # Install the latest version of omnigenbench

Next, we import the libraries we just installed. This gives us the tools for data processing, deep learning, and token-level regression modeling.

A key part of this setup is determining the best available hardware for training. Our script will automatically prioritize a **CUDA-enabled GPU** if one is available, as this can accelerate training by 10-100x compared to a CPU.

### 1.2: Import Required Libraries

In [None]:
import os
import torch
import numpy as np

from omnigenbench import (
    RegressionMetric,
    AccelerateTrainer,
    ModelHub,
    OmniTokenizer,
    OmniDatasetForTokenRegression,
    OmniModelForTokenRegression,
)

### 1.3: Global Configuration

To make our tutorial easy to modify and understand, we'll centralize all important parameters in this section. This is a best practice in software development that makes experiments more reproducible.

#### Key Parameters
- **Dataset**: We define the dataset name for automatic downloading from our curated collection
- **Model**: We select which pre-trained OmniGenome model to use. For this tutorial, we'll use `OmniGenome-52M` because it's efficient and perfect for learning

This centralized approach allows you to easily experiment with different settings without hunting through the code.

In [None]:
model_name_or_path = "yangheng/OmniGenome-52M"
dataset_name = "mrna_degradation"

### 1.4: Data Acquisition and Loading

With our environment configured, it's time to load the mRNA degradation dataset. The enhanced OmniDataset framework automates this process by:
1. **Automatic downloading** from our curated dataset collection
2. **Processing sequences** with proper tokenization for token-level regression
3. **Handling multi-target labels** for the three degradation measurements
4. **Creating efficient data pipelines** ready for training

This ensures we have properly formatted data with train/validation/test splits ready for the next stage.

In [None]:
# Model and Tokenizer initialization
tokenizer = OmniTokenizer.from_pretrained(model_name_or_path)
print(f"✅ Tokenizer loaded: {model_name_or_path}")

# Load datasets using the enhanced OmniDataset framework for token regression
print("🏗️ Loading datasets with automatic download...")
datasets = OmniDatasetForTokenRegression.from_hub(
    dataset_name_or_path="mrna_degradation_rgb",
    tokenizer=tokenizer,
    max_length=128,
    target_columns=["reactivity", "deg_Mg_pH10", "deg_Mg_50C"],  # Three regression targets
)
print(f"📊 Datasets loaded: {list(datasets.keys())}")
for split, dataset in datasets.items():
    print(f"  - {split}: {len(dataset)} samples")

### 1.5: Dataset Loading with OmniGenBench

With OmniGenBench, data loading for token regression is significantly simplified! The framework automatically handles:

#### A. Automatic Data Processing
The `OmniDatasetForTokenRegression` class automatically:
1. **Downloads and processes** the dataset from our curated collection
2. **Handles sequence preprocessing** including tokenization and proper padding
3. **Manages multi-target regression formatting** for position-wise degradation prediction  
4. **Creates train/validation/test splits** ready for training

#### B. Built-in Optimizations
The framework includes several optimizations:
1. **Efficient batching** for variable-length sequences
2. **Memory management** for large genomic datasets
3. **Automatic label alignment** with tokenized sequences
4. **Proper masking** for padded positions using -100 labels

This streamlined approach eliminates the need for complex custom dataset classes while maintaining full flexibility and performance.

In [None]:
print("📝 Data loading completed! Using modern OmniDataset framework.")
print(f"📊 Loaded datasets: {list(datasets.keys())}")
for split, dataset in datasets.items():
    print(f"  - {split}: {len(dataset)} samples")
    
# Inspect a sample to understand the data structure
if len(datasets["train"]) > 0:
    sample = datasets["train"][0]
    print(f"\n🔍 Sample data structure:")
    for key, value in sample.items():
        if isinstance(value, torch.Tensor):
            print(f"  {key}: shape {value.shape}, dtype {value.dtype}")
        else:
            print(f"  {key}: {type(value)}")

## 🚀 Step 2: Model Initialization

With our data pipeline in place, it's time to set up the model. This is where the power of Genomic Foundation Models (GFMs) comes into play. Instead of building a model from scratch, we will load the pre-trained **OmniGenome** model and adapt it for token-level regression.

This process involves three key components:
1. **The Tokenizer**: We use the same tokenizer from data preparation that converts RNA sequences into numerical format
2. **The Base Model**: The core OmniGenome model that has learned fundamental genomic patterns from pretraining
3. **The Regression Head**: A neural network layer that maps sequence representations to continuous degradation values for each token position

The `OmniModelForTokenRegression` class handles this seamlessly, combining the base model with the appropriate regression head for our multi-target prediction task.

In [None]:
# === Model Initialization ===
# We support all genomic foundation models from Hugging Face Hub.

model = OmniModelForTokenRegression(
    model_name_or_path,
    tokenizer,
    num_labels=3,  # Three regression targets: reactivity, deg_Mg_pH10, deg_Mg_50C
)

print(f"✅ Model loaded: {model_name_or_path}")
print(f"📊 Model configuration:")
print(f"  - Architecture: Token-level regression")
print(f"  - Number of targets: 3 (reactivity, deg_Mg_pH10, deg_Mg_50C)")
print(f"  - Max sequence length: 128")
print(f"  - Model parameters: ~52M")

## 🚀 Step 3: Model Training

This is the most exciting part! With our data and model ready, we can now begin the **fine-tuning** process. During training, the model will learn to associate specific patterns in RNA sequences with degradation rates at each nucleotide position.

### Our Training Strategy

We use a sophisticated strategy to ensure the best possible outcome:

1. **Evaluation Metrics**: For token-level regression tasks, we use:
   - **Root Mean Squared Error (RMSE)**: Measures average prediction error magnitude
   - **R² Score**: Indicates how well the model explains variance in degradation rates
   - **OmniGenBench supports 60+ ML metrics** and customized metrics for different tasks

2. **Advanced Training Features**:
   - **Automatic mixed precision** for faster training and memory efficiency
   - **Gradient accumulation** for effective large batch training
   - **Learning rate scheduling** with warmup for stable convergence
   - **Early stopping** based on validation performance to prevent overfitting

The `AccelerateTrainer` from `omnigenbench` wraps all this logic into a simple interface, leveraging Hugging Face Accelerate for distributed training support.

In [None]:
# Define evaluation metrics for token-level regression
metric_functions = [
    RegressionMetric(ignore_y=-100).root_mean_squared_error,
    RegressionMetric(ignore_y=-100).r2_score,
]

# Initialize the modern AccelerateTrainer
trainer = AccelerateTrainer(
    model=model,
    train_dataset=datasets["train"],
    eval_dataset=datasets["valid"], 
    test_dataset=datasets["test"],
    compute_metrics=metric_functions,
)

print("🎓 Starting training...")
print("⚡ Using AccelerateTrainer with automatic optimizations:")
print("  - Mixed precision training for speed and memory efficiency")
print("  - Automatic gradient accumulation")
print("  - Learning rate scheduling with warmup")
print("  - Early stopping based on validation metrics")

# Train the model
metrics = trainer.train()
trainer.save_model("ogb_mrna_degradation_finetuned")

print("✅ Training completed!")
print("📊 Final metrics:")
for metric_name, metric_value in metrics.items():
    if isinstance(metric_value, dict):
        print(f"  {metric_name}:")
        for k, v in metric_value.items():
            print(f"    {k}: {v:.4f}")
    else:
        print(f"  {metric_name}: {metric_value:.4f}")

## 🔮 Step 4: Model Inference and Interpretation

Now that we have a trained model, let's use it for its intended purpose: predicting mRNA degradation rates on new RNA sequences. This process is called **inference**.

### The Inference Pipeline

Our inference pipeline consists of several key steps:
1. **Load the Model**: We load the best-performing model saved during training using ModelHub
2. **Process Input**: We take new RNA sequences and apply the same preprocessing steps
3. **Run Prediction**: We feed the processed sequence to the model and get degradation predictions for each nucleotide
4. **Interpret Results**: We analyze the position-wise degradation rates and identify key patterns

To demonstrate, we'll test our model on sample sequences with different characteristics and analyze the predicted degradation patterns.

In [None]:
# Load the fine-tuned model for inference
inference_model = ModelHub.load("yangheng/ogb_mrna_degradation_finetuned")

# Test sequences with different characteristics
sample_sequences = {
    "Structured RNA": "AUGCCGUGCUAAUCGCGGUAGCGCUAGGCUGCAUCGCGGUAGCGCUAGGCUGCAU",
    "AU-rich sequence": "AUGUAUAUAUGUAUAUGUAUAUGUAUAUGUAUAUGUAUAUGUAUAUGUAUAU",
    "GC-rich sequence": "AUGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC",
    "Random sequence": "AUGCAGUCCGAUUCGAGCUACGUCGAUGCUAGCUCGAUGGCAUCCGAUUCGAG",
}

with torch.no_grad():
    print("🔮 Running inference on sample sequences...\n")
    
    for seq_name, sequence in sample_sequences.items():
        print(f"📊 Analysis for {seq_name}:")
        print(f"  📏 Sequence: {sequence[:50]}{'...' if len(sequence) > 50 else ''}")
        print(f"  📏 Length: {len(sequence)} nucleotides")
        
        # Get predictions
        outputs = inference_model.inference(sequence)
        predictions = outputs.get('predictions', None)
        
        if predictions is not None:
            predictions = np.array(predictions)
            print(f"  🎯 Prediction shape: {predictions.shape}")
            
            # Analyze each degradation target
            target_names = ["Reactivity", "deg_Mg_pH10", "deg_Mg_50C"]
            for i, target_name in enumerate(target_names):
                target_values = predictions[:, i] if len(predictions.shape) > 1 else predictions
                valid_predictions = target_values[target_values != -100]  # Remove padding
                
                if len(valid_predictions) > 0:
                    mean_val = np.mean(valid_predictions)
                    std_val = np.std(valid_predictions) 
                    max_val = np.max(valid_predictions)
                    min_val = np.min(valid_predictions)
                    
                    print(f"  📈 {target_name}:")
                    print(f"    Mean: {mean_val:.4f} ± {std_val:.4f}")
                    print(f"    Range: [{min_val:.4f}, {max_val:.4f}]")
                    
                    # Interpretation based on degradation levels
                    if mean_val > 0.3:
                        stability = "🔴 High degradation (unstable)"
                    elif mean_val > 0.15:
                        stability = "🟡 Moderate degradation"
                    else:
                        stability = "🟢 Low degradation (stable)"
                    print(f"    {stability}")
        
        print("─" * 50)

### Advanced Analysis: Position-wise Degradation Patterns

Let's perform a more detailed analysis to understand how degradation varies along the sequence length and identify potential structural motifs that influence stability.

In [None]:
# Advanced analysis: Position-wise degradation pattern analysis
test_sequence = "AUGCCGUGCUAAUCGCGGUAGCGCUAGGCUGCAUCGCGGUAGCGCUAGGCUGCAU"

print("🔬 Advanced Analysis: Position-wise Degradation Patterns")
print("=" * 60)
print(f"Analyzing sequence: {test_sequence}")
print(f"Length: {len(test_sequence)} nucleotides\n")

# Get detailed predictions
outputs = inference_model.inference(test_sequence)
predictions = outputs.get('predictions', None)

if predictions is not None:
    predictions = np.array(predictions)
    target_names = ["Reactivity", "deg_Mg_pH10", "deg_Mg_50C"]
    
    print("📊 Position-wise Analysis:")
    print("Pos\tNuc\tReactivity\tdeg_Mg_pH10\tdeg_Mg_50C\tStability")
    print("-" * 65)
    
    for pos in range(min(20, len(test_sequence))):  # Show first 20 positions
        nucleotide = test_sequence[pos]
        
        if len(predictions.shape) > 1 and pos < predictions.shape[0]:
            reactivity = predictions[pos, 0]
            deg_ph10 = predictions[pos, 1] 
            deg_50c = predictions[pos, 2]
            
            # Skip padded positions
            if reactivity == -100:
                continue
                
            # Determine stability based on average degradation
            avg_deg = (reactivity + deg_ph10 + deg_50c) / 3
            if avg_deg > 0.25:
                stability = "Unstable"
            elif avg_deg > 0.15:
                stability = "Moderate"
            else:
                stability = "Stable"
                
            print(f"{pos+1:2d}\t{nucleotide}\t{reactivity:.4f}\t\t{deg_ph10:.4f}\t\t{deg_50c:.4f}\t\t{stability}")
    
    print("\n🎯 Summary Insights:")
    print("• Positions with high degradation may indicate structural vulnerability")
    print("• GC-rich regions often show different degradation patterns than AU-rich regions")
    print("• The model captures position-specific degradation propensities")
    print("• These predictions can guide RNA engineering for stability optimization")

print("\n🎉 Tutorial completed successfully!")
print("🚀 Your model is ready for:")
print("  - Predicting mRNA stability in therapeutic design")
print("  - Optimizing RNA sequences for biotechnology applications")
print("  - Understanding sequence-structure-stability relationships")
print("  - Advancing synthetic biology and gene therapy research")

## 🎉 Tutorial Summary and Next Steps

Congratulations! You have successfully completed this comprehensive tutorial on mRNA degradation rate prediction with OmniGenBench.

### What You've Learned

You've walked through a complete, end-to-end MLOps workflow for token-level regression, a critical skill in computational biology. Specifically, you have:

1. **Understood the "Why"**: Gained appreciation for the biological problem of mRNA stability and how Genomic Foundation Models provide powerful solutions for therapeutic and biotechnology applications.

2. **Mastered the 4-Step Workflow**:
   - **Step 1: Data Preparation**: You learned how to acquire, process, and efficiently load genomic datasets using the enhanced OmniDataset framework for token regression tasks.
   - **Step 2: Model Initialization**: You saw how to leverage pre-trained models and adapt them for multi-target token-level regression.
   - **Step 3: Model Training**: You implemented robust training strategies using AccelerateTrainer with proper evaluation metrics and modern optimizations.
   - **Step 4: Model Inference**: You used your fine-tuned model to make position-wise predictions and interpreted the biological significance of degradation patterns.

3. **Advanced Capabilities**: You explored:
   - Token-level regression for position-specific predictions
   - Multi-target modeling for multiple degradation conditions
   - Pattern analysis for understanding sequence-stability relationships
   - Real-world applications in RNA engineering and therapeutic design

### Next Steps and Applications

Your trained model can now be applied to:
- **mRNA Therapeutics**: Design stable mRNA vaccines and therapeutics
- **Synthetic Biology**: Engineer RNA circuits with predictable degradation kinetics  
- **Biotechnology**: Optimize protein expression systems
- **Research**: Study sequence-structure-function relationships in RNA biology

### Further Learning

Explore our other tutorials to expand your genomic AI toolkit:
- **[Translation Efficiency Prediction](../translation_efficiency_prediction/)**: Predict protein production rates
- **[RNA Secondary Structure Prediction](../rna_secondary_structure_prediction/)**: Model RNA folding patterns
- **[Transcription Factor Binding](../tfb_prediction/)**: Understand gene regulation

Thank you for following along. We hope this tutorial has provided you with the knowledge and confidence to apply deep learning to your own genomics research. The future of computational biology is in your hands!

**Happy coding and discovering! 🧬✨**