# Tutorial 6: mRNA Stability Prediction with Encodon

This notebook demonstrates predicting mRNA stability using pretrained Encodon models.

## Overview
- **Task**: Predict mRNA stability/degradation from sequences
- **Dataset**: mRNA Stability dataset (see [1])
- **Model**: Pretrained Encodon + Random Forest regressor

[1] Li, Sizhen, et al. "CodonBERT large language model for mRNA vaccines." Genome research 34.7 (2024): 1027-1035.

## 1. Import Libraries and Setup

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# ML libraries
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split, GridSearchCV

# Visualization
import matplotlib.pyplot as plt

# Add project paths
sys.path.append('..')

# Import Encodon modules
from src.inference.encodon import EncodonInference
from src.inference.task_types import TaskTypes
from src.data.metadata import MetadataFields

# Fix random seed
torch.manual_seed(42)
np.random.seed(42)

from src.data.codon_bert_dataset import CodonBertDataset
from src.data.preprocess.codon_sequence import process_item
from torch.utils.data import DataLoader

print("‚úÖ Libraries imported successfully!")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load Pretrained Encodon Model

In [None]:
# Define checkpoint paths
ckpt_path = "./"
AVAILABLE_MODELS = [
    f"{ckpt_path}/NV-CodonFM-Encodon-80M-v1/NV-CodonFM-Encodon-80M-v1.safetensors",
    f"{ckpt_path}/NV-CodonFM-Encodon-600M-v1/NV-CodonFM-Encodon-600M-v1.safetensors",
    f"{ckpt_path}/NV-CodonFM-Encodon-1B-v1/NV-CodonFM-Encodon-1B-v1.safetensors",
    f"{ckpt_path}/NV-CodonFM-Encodon-Cdwt-1B-v1/NV-CodonFM-Encodon-Cdwt-1B-v1.safetensors"
]
checkpoint_path = AVAILABLE_MODELS[0]

model_loaded = False
if os.path.exists(checkpoint_path):
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Create EncodonInference wrapper
        encodon_model = EncodonInference(
            model_path=checkpoint_path,
            task_type=TaskTypes.EMBEDDING_PREDICTION,
        )
        
        # Configure model
        encodon_model.configure_model()
        encodon_model.to(device)
        encodon_model.eval()
        
        print(f"‚úÖ Model loaded from: {checkpoint_path}")
        print(f"Device: {device}")
        print(f"Parameters: {sum(p.numel() for p in encodon_model.model.parameters()):,}")
        
        model_loaded = True        
    except Exception as e:
        print(f"Failed to load {checkpoint_path}: {e}")

if not model_loaded:
    print("‚ùå Could not load any model. Please check checkpoint paths.")

## 3. Load Dataset

In [None]:
# Download mRNA Stability dataset if it doesn't exist
import subprocess

# NOTE: This assumes the notebook was launched from the codon-fm source directory.
# NOTE: otherwise change the path for the `subprocess` launch to correspond to the data_scripts path correctly

root_path = "/data/validation/processed"
data_path = "/data/validation/processed/mRNA_Stability.csv"
if not os.path.exists(data_path):
    print("üì• Downloading mRNA Stability dataset...")
    try:
        subprocess.run([
            "python", "data_scripts/download_preprocess_codonbert_bench.py",
            "--dataset", "mRNA_Stability.csv",
            "--output-dir", root_path
        ], check=True)
        print("‚úÖ Dataset downloaded and preprocessed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"‚ùå Error downloading dataset: {e}")
        print("Please ensure the data_scripts are available and run manually if needed.")
else:
    print("‚úÖ Dataset already exists!")


In [None]:
# Load mRNA Stability dataset
data_loaded = False
if os.path.exists(data_path):
    try:
        data = pd.read_csv(data_path)
        print(f"‚úÖ Loaded {len(data)} samples from: {data_path}")
        print(f"Columns: {list(data.columns)}")
        
        if 'split' in data.columns:
            print(f"Data splits: {data['split'].value_counts().to_dict()}")
        
        print(f"Target range: [{data['value'].min():.3f}, {data['value'].max():.3f}]")
        data_loaded = True
    except Exception as e:
        print(f"Failed to load {data_path}: {e}")

if not data_loaded:
    print("‚ùå Could not load mRNA Stability data")

## 4. Data Preprocessing

In [None]:
batch_size = 16
if data_loaded and model_loaded:
    print("=== DATA PREPROCESSING ===")    
    # Create dataset
    dataset = CodonBertDataset(
        data_path=data_path,
        tokenizer=encodon_model.tokenizer,
        process_item=lambda seq, tokenizer: process_item(
            seq, 
            context_length=encodon_model.model.hparams.max_position_embeddings,
            tokenizer=tokenizer
        )
    )
    
    print(f"Processing {len(dataset)} sequences")
    print(f"Target range: [{dataset.data['value'].min():.3f}, {dataset.data['value'].max():.3f}]")
    
    # Create data loader for batch processing
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # Extract embeddings using the dataset
    print("\nExtracting embeddings...")
    all_embeddings = []
    all_labels = []
    
    for batch in tqdm(dataloader):
        batch_input = {
            MetadataFields.INPUT_IDS: batch[MetadataFields.INPUT_IDS].to(encodon_model.device),
            MetadataFields.ATTENTION_MASK: batch[MetadataFields.ATTENTION_MASK].to(encodon_model.device),
        }
        
        # Extract embeddings
        output = encodon_model.extract_embeddings(batch_input)
        all_embeddings.append(output.embeddings)
        all_labels.append(batch[MetadataFields.LABELS].numpy())
    
    # Combine all embeddings and labels
    embeddings = np.vstack(all_embeddings)
    targets = np.concatenate(all_labels)
    
    print(f"\n‚úÖ Extracted embeddings: {embeddings.shape}")
    
else:
    print("‚ùå Skipping preprocessing")

## 5. Train Random Forest

In [None]:
if 'embeddings' in locals():
    print("=== TRAINING RANDOM FOREST ===")
    
    # Split data based on the dataset splits
    train_mask = dataset.data['split'] == 'train'
    val_mask = dataset.data['split'] == 'val'
    test_mask = dataset.data['split'] == 'test'
    
    X_train = embeddings[train_mask]
    X_val = embeddings[val_mask]
    X_test = embeddings[test_mask]
    y_train = targets[train_mask]
    y_val = targets[val_mask]
    y_test = targets[test_mask]
    
    print(f"Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}")
    
    # Combine train and validation for GridSearchCV
    X_train_val = np.vstack([X_train, X_val])
    y_train_val = np.concatenate([y_train, y_val])
    
    # Create validation indices for GridSearchCV
    # Train indices: 0 to len(X_train)-1
    # Val indices: len(X_train) to len(X_train_val)-1
    train_indices = list(range(len(X_train)))
    val_indices = list(range(len(X_train), len(X_train_val)))
    cv_splits = [(train_indices, val_indices)]
    
    # Define hyperparameter grid
    param_grid = {
        'n_estimators': [1000],
        'max_depth': [10],
        'min_samples_split': [25],
        'min_samples_leaf': [2],
    }
    
    # Create base model
    rf_base = RandomForestRegressor(random_state=42, n_jobs=-1)
    
    # Grid search with validation split
    print("Performing hyperparameter tuning...")
    grid_search = GridSearchCV(
        estimator=rf_base,
        param_grid=param_grid,
        cv=cv_splits,
        scoring='r2',
        n_jobs=-1,
        verbose=1
    )
    
    # Fit grid search
    grid_search.fit(X_train_val, y_train_val)
    
    # Get best model
    rf = grid_search.best_estimator_
    
    print(f"\n=== BEST PARAMETERS ===")
    for param, value in grid_search.best_params_.items():
        print(f"{param}: {value}")
    print(f"Best validation R¬≤: {grid_search.best_score_:.4f}")
    
    # Train final model on train set only
    rf.fit(X_train, y_train)
    
    # Predictions on all splits
    y_pred_train = rf.predict(X_train)
    y_pred_val = rf.predict(X_val)
    y_pred_test = rf.predict(X_test)
    
    # Calculate metrics for all splits
    train_r2 = r2_score(y_train, y_pred_train)
    val_r2 = r2_score(y_val, y_pred_val)
    test_r2 = r2_score(y_test, y_pred_test)
    
    train_spearmanr, _ = spearmanr(y_train, y_pred_train)
    val_spearmanr, _ = spearmanr(y_val, y_pred_val)
    test_spearmanr, _ = spearmanr(y_test, y_pred_test)
    
    print(f"\n=== FINAL RESULTS ===")
    print(f"Train R¬≤: {train_r2:.4f} | spearmanr r: {train_spearmanr:.4f}")
    print(f"Val R¬≤:   {val_r2:.4f} | spearmanr r: {val_spearmanr:.4f}")
    print(f"Test R¬≤:  {test_r2:.4f} | spearmanr r: {test_spearmanr:.4f}")
    
else:
    print("‚ùå Cannot train - missing data")


## 6. Plot Results

In [None]:
if 'y_test' in locals():
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('mRNA Stability Prediction Results', fontsize=16)
    
    # Predicted vs True for all splits
    splits = [('Train', y_train, y_pred_train, train_r2), 
              ('Validation', y_val, y_pred_val, val_r2), 
              ('Test', y_test, y_pred_test, test_r2)]
    
    for i, (split_name, y_true, y_pred, r2) in enumerate(splits):
        axes[0, i].scatter(y_true, y_pred, alpha=0.6)
        min_val = min(y_true.min(), y_pred.min())
        max_val = max(y_true.max(), y_pred.max())
        axes[0, i].plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2)
        axes[0, i].set_xlabel('True mRNA Stability')
        axes[0, i].set_ylabel('Predicted mRNA Stability')
        axes[0, i].set_title(f'{split_name}\nR¬≤ = {r2:.3f}')
        axes[0, i].grid(True, alpha=0.3)
    
    # Performance comparison
    r2_scores = [train_r2, val_r2, test_r2]
    spearmanr_scores = [train_spearmanr, val_spearmanr, test_spearmanr]
    
    x_pos = np.arange(len(splits))
    width = 0.35
    
    axes[1, 0].bar(x_pos - width/2, r2_scores, width, label='R¬≤', alpha=0.7)
    axes[1, 0].bar(x_pos + width/2, spearmanr_scores, width, label='spearmanr r', alpha=0.7)
    axes[1, 0].set_xlabel('Dataset Split')
    axes[1, 0].set_ylabel('Score')
    axes[1, 0].set_title('Performance Comparison')
    axes[1, 0].set_xticks(x_pos)
    axes[1, 0].set_xticklabels(['Train', 'Val', 'Test'])
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Target distribution across splits
    axes[1, 1].hist([y_train, y_val, y_test], bins=15, alpha=0.7, 
                   label=['Train', 'Val', 'Test'], edgecolor='black')
    axes[1, 1].set_xlabel('mRNA Stability')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_title('Target Distribution by Split')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Feature importance
    top_features = np.argsort(rf.feature_importances_)[-10:]
    axes[1, 2].barh(range(10), rf.feature_importances_[top_features])
    axes[1, 2].set_xlabel('Importance')
    axes[1, 2].set_title('Top 10 Feature Importances')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
else:
    print("‚ùå No results to plot")

## 7. Troubleshooting & Optimization Tips

### Common Issues and Solutions:

#### 1. Model Loading Issues
- **Problem**: Checkpoint not found
- **Solution**: Update checkpoint paths in section 2
- **Check**: Verify checkpoint files exist and are accessible

#### 2. Data Loading Issues
- **Problem**: Dataset not found
- **Solution**: Update data paths in section 3
- **Check**: Ensure CSV files have required columns (id, ref_seq, value)

#### 3. Memory Issues
- **Problem**: CUDA out of memory
- **Solution**: Reduce batch_size in preprocessing section
- **Alternative**: Use CPU by setting device='cpu'

#### 4. Performance Issues
- **Problem**: Low R¬≤ scores
- **Solutions**:
  - Try larger models (600M or 1B parameters)
  - Implement fine-tuning instead of just embeddings
  - Tune Random Forest hyperparameters
  - Check data quality and preprocessing


### Optimization Strategies:

#### 1. Model Architecture
- **80M model**: Fast, good for initial experiments
- **600M model**: Better performance, moderate cost
- **1B model**: Best performance, highest computational cost

#### 3. Hyperparameter Tuning
```python
# Try these Random Forest parameters:
rf_params = {
    'n_estimators': [100, 200, 500],
    'max_depth': [10, 15, 20, None],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}
```