# Model Initialization Tutorial: Setting Up PlantRNA-FM for Plant Genomics 🤖

Welcome to the second part of our tutorial series! In the previous tutorial, we prepared our data with the enhanced OmniDataset framework. Now we'll focus on **Model Initialization** - setting up the pre-trained **PlantRNA-FM** (Plant RNA Foundation Model) for our rice translation efficiency prediction task.

> 💡 **Learning Objectives**: Understand PlantRNA-FM and plant-specialized foundation models, master model initialization patterns, and configure models for plant genomic tasks

---

## Understanding PlantRNA-FM: A Plant-Specialized Foundation Model 🧬

**PlantRNA-FM** represents a new generation of specialized foundation models for plant biology. Unlike general genomic models, PlantRNA-FM is specifically trained on plant RNA and genomic data, learning the unique patterns of plant gene regulation, codon usage, and RNA structure.

### 🔬 From Language to Plant Life

| Concept | Language Models | PlantRNA-FM |
|---------|----------------|---------------------------|
| **Vocabulary** | Words, punctuation | A, U, C, G (plant RNA) + regulatory elements |
| **Grammar** | Syntax, semantics | Plant codon usage, UTR structures, splicing signals |
| **Context** | Sentence meaning | Plant genomic context, expression regulation |
| **Training Data** | Books, web text | Plant transcriptomes, genomes, expression data |
| **Tasks** | Translation, summarization | TE prediction, gene regulation, RNA design |

### 🎯 The Power of Plant-Specific Pre-training

**PlantRNA-FM** (published in *Nature Machine Intelligence*) has been pre-trained on:
- 🌱 **Extensive plant RNA data**: Transcriptomes from major crop species (rice, maize, wheat, Arabidopsis)
- 📊 **Plant-specific features**: Alternative splicing patterns, UTR regulatory elements, polycistronic transcripts
- 🔬 **Expression contexts**: Tissue-specific expression, stress responses, developmental stages
- 💡 **Efficient architecture**: Only 35M parameters, yet achieves state-of-the-art performance

This specialized pre-training allows PlantRNA-FM to understand:
- **Plant codon bias**: Species-specific codon usage preferences that affect translation
- **UTR regulatory elements**: 5' and 3' UTR structures critical for plant mRNA stability and translation
- **Plant-specific splicing**: Alternative splicing patterns unique to plant gene regulation
- **GC content effects**: How GC-rich regions affect plant mRNA secondary structure and ribosome binding

## OmniGenBench Model Architecture 🏗️

The OmniGenBench framework provides specialized model classes for different genomic tasks. For plant genomics, **PlantRNA-FM** serves as the foundation, with task-specific heads for different applications:

### 🎯 Classification Models for Plant Genomics

| Model Class | Use Case | Plant Applications |
|-------------|----------|---------------------|
| **`OmniModelForSequenceClassification`** | Binary/Multi-class sequence labeling | Translation efficiency (our task), promoter detection, gene function prediction |
| **`OmniModelForMultiLabelSequenceClassification`** | Multiple labels per sequence | Transcription factor binding, regulatory element prediction |
| **`OmniModelForTokenClassification`** | Per-nucleotide labeling | Splice site detection, m6A modification sites |

### 📊 Regression Models for Plant Traits

| Model Class | Use Case | Plant Applications |
|-------------|----------|---------------------|
| **`OmniModelForSequenceRegression`** | Continuous sequence values | Expression levels, translation efficiency scores, protein abundance |
| **`OmniModelForTokenRegression`** | Per-nucleotide continuous values | RNA accessibility scores, ribosome density profiling |

### 🧬 Specialized Models for Plant RNA

| Model Class | Use Case | Plant Applications |
|-------------|----------|---------------------|
| **`OmniModelForMLM`** | Masked language modeling | Variant effect prediction, sequence completion |
| **`OmniModelForSeq2Seq`** | Sequence transformation | RNA secondary structure prediction |
| **`OmniModelForRNADesign`** | Sequence optimization | Synthetic gene design for plant transformation |

> 🎯 **For our rice translation efficiency task**: We use `OmniModelForSequenceClassification` with **PlantRNA-FM** because we're classifying each plant mRNA sequence as "High TE" (1) or "Low TE" (0), leveraging plant-specific codon usage patterns.

## Environment Setup 🔧

Let's start by installing the required packages and importing the necessary components.

In [None]:
# Install the required packages
!pip install omnigenbench -U

In [None]:
import torch
from omnigenbench import (
    OmniTokenizer,
    OmniModelForSequenceClassification,
    OmniDatasetForSequenceClassification,
)

print("✅ Libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🎯 GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU device: {torch.cuda.get_device_name(0)}")

## Model Initialization: PlantRNA-FM for Plant Translation Efficiency 🚀

The beauty of the OmniGenBench framework lies in its simplicity. Model initialization with PlantRNA-FM follows a consistent pattern regardless of the specific plant genomics task.

In [None]:
# Configuration - Using PlantRNA-FM for rice translation efficiency
model_name_or_path = "yangheng/PlantRNA-FM"  # Plant RNA Foundation Model
label2id = {"0": 0, "1": 1}  # 0: Low TE, 1: High TE

print("⚙️ Configuration:")
print(f"   🌱 Base model: {model_name_or_path} (Plant-specialized)")
print(f"   🏷️ Label mapping: {label2id}")
print(f"   🧬 Organism: Rice (Oryza sativa)")
print(f"   📊 Task: Binary sequence classification (Translation Efficiency)")

### Step 1: Initialize the Tokenizer 📝

The tokenizer converts biological sequences into numerical tokens that the model can process. **It's crucial to use the exact same tokenizer that was used during pre-training**.

In [None]:
# Initialize tokenizer - must match the pre-trained model
tokenizer = OmniTokenizer.from_pretrained(model_name_or_path)
print(f"✅ Tokenizer loaded: {model_name_or_path}")

# Demonstrate tokenizer functionality
sample_sequence = "AUGCUGCUAUGCUA"  # Sample mRNA sequence
tokens = tokenizer(sample_sequence, return_tensors="pt")

print(f"\n🧬 Sample tokenization:")
print(f"   Input sequence: {sample_sequence}")
print(f"   Tokenized IDs: {tokens['input_ids'].squeeze().tolist()}")
print(f"   Sequence length: {len(tokens['input_ids'].squeeze())} tokens")

### Step 2: Initialize the Model 🤖

Now we'll initialize our sequence classification model. The model automatically loads the pre-trained weights and adds a classification head for our specific task.

In [None]:
# Initialize the model - exactly as in the complete tutorial
model = OmniModelForSequenceClassification(
    model_name_or_path,
    tokenizer,
    num_labels=2,  # Binary classification: Low TE vs High TE
)

print("🤖 Model initialized successfully!")
print(f"   📊 Task type: Sequence Classification")
print(f"   🏷️ Number of labels: 2 (Low TE, High TE)")
print(f"   🧬 Base architecture: OmniGenome Transformer")

# Display model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n📈 Model Statistics:")
print(f"   🔢 Total parameters: {total_params:,}")
print(f"   🎓 Trainable parameters: {trainable_params:,}")
print(f"   💾 Model size: ~{total_params * 4 / 1024**2:.1f} MB")

## Understanding Model Components 🔍

Let's explore what makes up our initialized model. Understanding these components helps you appreciate the sophistication of modern genomic AI.

In [None]:
# Explore model architecture
print("🏗️ Model Architecture Components:")
print("\n📊 Core Components:")

# Check if model has expected components
if hasattr(model, 'model'):
    print("   ✅ Base Transformer Model (pre-trained OmniGenome)")
if hasattr(model, 'classifier'):
    print("   ✅ Classification Head (newly initialized)")
if hasattr(model, 'pooler'):
    print("   ✅ Pooling Layer (sequence → single representation)")

# Display configuration
print(f"\n⚙️ Model Configuration:")
if hasattr(model, 'config'):
    config = model.config
    print(f"   🔢 Hidden size: {getattr(config, 'hidden_size', 'N/A')}")
    print(f"   🏷️ Number of labels: {getattr(config, 'num_labels', 'N/A')}")
    print(f"   📏 Max sequence length: {getattr(config, 'max_position_embeddings', 'N/A')}")

print(f"\n🎯 Ready for: Translation Efficiency Prediction")

## Testing Model Functionality 🧪

Before training, let's test that our model can process sequences correctly. This helps us verify everything is set up properly.

In [None]:
# Test model with sample sequences
print("🧪 Testing Model Functionality")
print("=" * 40)

# Sample mRNA sequences for testing
test_sequences = [
    "AUGCUGCUAUGCUAGCUAGC",  # Short test sequence
    "AUGAAACCAACAAAATGCAGTAGAAGTACTCTCGAGCTATAGTCGCGACGTGCTGCCCCGC",  # Longer sequence
]

model.eval()  # Set to evaluation mode
with torch.no_grad():
    for i, sequence in enumerate(test_sequences, 1):
        print(f"\n🧬 Test Sequence {i}:")
        print(f"   Sequence: {sequence[:50]}{'...' if len(sequence) > 50 else ''}")
        print(f"   Length: {len(sequence)} nucleotides")
        
        # Tokenize
        inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=512)
        
        # Forward pass
        outputs = model(**inputs)
        
        # Get predictions
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        predicted_class = torch.argmax(logits, dim=-1)
        
        print(f"   📊 Raw logits: {logits.squeeze().tolist()}")
        print(f"   📈 Probabilities: [Low TE: {probabilities[0][0]:.3f}, High TE: {probabilities[0][1]:.3f}]")
        print(f"   🎯 Predicted class: {predicted_class.item()} ({'High TE' if predicted_class.item() == 1 else 'Low TE'})")

print(f"\n✅ Model functionality test completed!")
print(f"   🎯 Model can process sequences and generate predictions")
print(f"   📊 Ready for training on the translation efficiency dataset")

## Understanding Pre-trained vs Fine-tuned Models 🎓

It's important to understand what we have at this stage and what training will accomplish.

### 🤖 Current Model State (Pre-trained + Classification Head)

**What the model knows:**
- ✅ **General genomic patterns**: Codon usage, sequence motifs, structural constraints
- ✅ **Biological relationships**: GC content effects, splice sites, regulatory elements
- ✅ **Sequence representations**: How to encode biological meaning from nucleotide sequences

**What the model doesn't know yet:**
- ❌ **Translation efficiency patterns**: Which sequence features correlate with high/low TE
- ❌ **Task-specific relationships**: How ribosome binding, codon optimization affect TE
- ❌ **Decision boundaries**: Where to classify sequences as high vs low TE

### 🎯 What Fine-tuning Will Accomplish

During training, the model will learn:
1. **Task-specific patterns**: Which genomic features predict translation efficiency
2. **Decision boundaries**: How to separate high TE from low TE sequences
3. **Biological relationships**: How sequence context affects ribosome binding and translation

The pre-trained knowledge provides a **strong foundation**, and fine-tuning **specializes** this knowledge for our specific task.

## Model Configuration Best Practices 📋

Let's discuss important configuration choices and their biological implications.

In [None]:
# Demonstrate key configuration options
print("⚙️ Model Configuration Best Practices")
print("=" * 45)

print("\n🎯 Key Configuration Choices:")
print("\n1. 📦 Base Model Selection:")
print(f"   Current: {model_name_or_path}")
print("   Options: OmniGenome-52M (fast), OmniGenome-186M (more accurate)")
print("   Trade-off: Speed vs Accuracy")

print("\n2. 🏷️ Label Configuration:")
print(f"   Labels: {list(label2id.keys())} → {list(label2id.values())}")
print("   Meaning: 0 = Low TE, 1 = High TE")
print("   Type: Binary classification")

print("\n3. 📏 Sequence Length Considerations:")
max_length = 512  # Common choice for mRNA sequences
print(f"   Max length: {max_length} nucleotides")
print("   Rationale: Covers most mRNA 5' UTR + start of CDS")
print("   Biological relevance: Key regulatory regions for translation")

print("\n4. 🔧 Model Architecture:")
print("   Type: Transformer-based (attention mechanism)")
print("   Advantages: Captures long-range dependencies")
print("   Biological relevance: Secondary structures, distant motifs")

print("\n✅ Configuration optimized for translation efficiency prediction!")

## Connecting to Data Pipeline 🔗

Let's briefly connect our initialized model to the data pipeline we created in the previous tutorial, ensuring everything works together.

In [None]:
# Load the dataset to demonstrate model-data compatibility
print("🔗 Testing Model-Data Pipeline Integration")
print("=" * 45)

# Load datasets - exactly as in complete tutorial
datasets = OmniDatasetForSequenceClassification.from_hub(
    dataset_name_or_path="translation_efficiency_prediction",
    tokenizer=tokenizer,
    max_length=512,
    label2id=label2id
)

print(f"📊 Datasets loaded: {list(datasets.keys())}")
for split, dataset in datasets.items():
    print(f"  - {split}: {len(dataset)} samples")

# Test model with actual dataset sample
print("\n🧪 Testing with real dataset sample:")
sample = datasets['train'][0]
print(f"   Sample keys: {list(sample.keys())}")
print(f"   Input shape: {sample['input_ids'].shape}")
print(f"   Label: {sample['labels'].item()} ({'High TE' if sample['labels'].item() == 1 else 'Low TE'})")

# Test forward pass with real data
model.eval()
with torch.no_grad():
    # Add batch dimension
    batch_input = {k: v.unsqueeze(0) if v.dim() == 1 else v for k, v in sample.items() if k != 'labels'}
    outputs = model(**batch_input)
    
    predicted_class = torch.argmax(outputs.logits, dim=-1)
    print(f"   🎯 Model prediction: {predicted_class.item()} ({'High TE' if predicted_class.item() == 1 else 'Low TE'})")

print("\n✅ Model-data integration successful!")
print("   🎯 Model can process real dataset samples")
print("   📊 Ready for training pipeline")

## 🎉 Tutorial Summary and Next Steps

Congratulations! You have successfully completed the model initialization tutorial. Let's review what we've accomplished:

### 🎓 **Skills Mastered**

✅ **Understood Genomic Foundation Models**: The principles behind pre-trained genomic AI  
✅ **Mastered Model Architecture**: Different model types for various biological tasks  
✅ **Initialized Models Correctly**: Proper tokenizer and model setup  
✅ **Tested Model Functionality**: Verified the model can process genomic sequences  
✅ **Configured for Biology**: Optimized settings for translation efficiency prediction  
✅ **Integrated with Data**: Connected model to the data pipeline  


### 🔬 **Key Concepts Learned**

- **🧬 Pre-training vs Fine-tuning**: How models learn general then specific knowledge
- **🤖 Model Architecture**: Transformer-based models for genomic sequences
- **📊 Task Specialization**: Choosing the right model class for your biological question
- **⚙️ Configuration Management**: Best practices for reproducible model setup

### 🚀 **What's Next...**

You now have a properly initialized model ready for training! In the next tutorial, we will:

- 🎓 **Learn the training process**: How models learn from data
- ⚙️ **Configure training parameters**: Learning rates, optimizers, evaluation metrics
- 📊 **Monitor training progress**: Loss curves, validation metrics, early stopping
- 💾 **Save trained models**: Model persistence for later use

**Ready to train your genomic AI model?** 🧬✨

👉 **Next Step**: Open [03_model_training.ipynb](https://github.com/yangheng95/OmniGenBench/blob/master/examples/translation_efficiency_prediction/03_model_training.ipynb) to continue your learning journey!