# Model Training Tutorial: Fine-tuning Your Genomic AI Model 🏋️

Welcome to the most exciting part of our tutorial series! In the previous two tutorials, we prepared our data and initialized our model. Now it's time for **Model Training**.

> 💡 **Learning Objectives**: Understand supervised learning principles, master different trainer usage, complete end-to-end model fine-tuning

---

## What is Supervised Fine-tuning? 🎓

**Supervised fine-tuning** is the process of adapting a pre-trained model to a specific labeled dataset. The model "learns" through the following cycle:

```
1. 📊 Make predictions on sequences in the training set
2. 🎯 Compare predictions with true labels  
3. 📉 Calculate "error" or "loss"
4. 🔧 Adjust internal weights to reduce future errors
```

This cycle is repeated for all sequences in the training data across multiple "epochs".

### 🧠 Why is Fine-tuning So Effective?

| Training Stage | Knowledge Learned | Analogy |
|---------|------------|------|
| **Pre-training** | General patterns of genomic language | 📚 Learned to "read" genomes |
| **Fine-tuning** | Task-specific specialized knowledge | 🎯 Learned to "understand" translation efficiency |

## Training Components in OmniGenBench 🔧

To start training, we need to assemble several key components:

### 1. 🗂️ **Datasets and DataLoaders**
Wrap our training, validation, and test data into PyTorch `DataLoader`s that efficiently provide batched data to the model.

### 2. 📊 **Evaluation Metrics**
Define how we measure success. For classification tasks, we'll use the **F1 score** which balances precision and recall.

### 3. ⚙️ **Optimizer**
The algorithm that updates model weights. We'll use the popular **AdamW** optimizer.

### 4. 🎯 **Trainer**
OmniGenBench provides a powerful `Trainer` class that orchestrates the entire training process.

## 🚀 OmniGenBench Trainer Selection Guide

OmniGenBench provides multiple trainer backends to meet different needs:

| Trainer Type | Base Technology | Main Advantages | Best Use Cases |
|-----------|----------|----------|------------|
| **`Trainer`** (PyTorch Native) | Pure PyTorch | 🟢 Transparent and understandable<br>🟢 Full control | 🎯 Learning and understanding<br>🎯 Single GPU training |
| **`AccelerateTrainer`** | 🤗 Accelerate | 🟡 Seamless scaling<br>🟡 Distributed-friendly | 🎯 Multi-GPU/TPU<br>🎯 Large-scale training |
| **`HFTrainer`** | 🤗 Trainer | 🔴 Feature-rich<br>🔴 Complete ecosystem | 🎯 HF users<br>🎯 Standardized workflows |

**In this tutorial, we use `AccelerateTrainer`** - it matches the complete tutorial exactly and provides excellent performance.

> 💭 **Selection Principles**: For this tutorial, we use `AccelerateTrainer` to match the complete tutorial workflow exactly.

### 🛠️ Environment Setup and Data Preparation

First, let's set up our environment and rebuild the necessary components from previous tutorials.

In [None]:
# Install required packages
!pip install omnigenbench -U

In [None]:
import warnings
from omnigenbench import (
    ClassificationMetric,
    AccelerateTrainer,
    OmniTokenizer,
    OmniModelForSequenceClassification,
    OmniDatasetForSequenceClassification,
)

### 📊 Configure Training Parameters

Let's define all training hyperparameters. This centralized configuration matches our complete tutorial exactly.

In [None]:
# Training Configuration - matches complete tutorial exactly
model_name_or_path = "yangheng/OmniGenome-52M"
label2id = {"0": 0, "1": 1}  # 0: Low TE, 1: High TE

print("📋 Training configuration initialized!")
print(f"  🤖 Model: {config_or_model}")
print(f"  🏷️ Labels: {label2id}")
print(f"  🎯 Task: Translation Efficiency Prediction")

### 🏗️ Assemble Training Components

Now, let's create all the objects needed for training, exactly as in the complete tutorial.

In [None]:
# 1. Load tokenizer - matches complete tutorial
print("🔄 Loading tokenizer...")
tokenizer = OmniTokenizer.from_pretrained(config_or_model)
print(f"✅ Tokenizer loaded: {config_or_model}")

# 2. Load datasets - matches complete tutorial exactly
print("📊 Loading datasets...")
datasets = OmniDatasetForSequenceClassification.from_hub(
    dataset_name_or_path="translation_efficiency_prediction",
    tokenizer=tokenizer,
    max_length=512,
    label2id=label2id
)

print(f"📊 Datasets loaded: {list(datasets.keys())}")
for split, dataset in datasets.items():
    print(f"  - {split}: {len(dataset)} samples")

In [None]:
# 3. Initialize model - matches complete tutorial exactly
print("🤖 Initializing model...")
model = OmniModelForSequenceClassification(
    config_or_model,
    tokenizer,
    num_labels=2,  # Binary classification: Low TE vs High TE
)

total_params = sum(p.numel() for p in model.parameters())
print(f"✅ Model initialized! Parameter count: {total_params / 1e6:.1f}M")

## 🚀 Start Training with AccelerateTrainer!

Now we'll use the same training approach as the complete tutorial. The `AccelerateTrainer` will automatically handle:

- ✅ Moving data and model to the correct device (GPU or CPU)
- ✅ Iterating through training data for the specified number of epochs
- ✅ Calculating loss and updating model weights
- ✅ Evaluating the model on the validation set after each epoch
- ✅ Logging performance metrics
- ✅ Saving the best-performing model checkpoint

In [None]:
# Setup training - exactly as in complete tutorial
print("🚀 Setting up training with AccelerateTrainer...")

metric_functions = [ClassificationMetric().f1_score]

trainer = AccelerateTrainer(
    model=model,
    train_dataset=datasets["train"],
    eval_dataset=datasets["valid"],
    test_dataset=datasets["test"],
    compute_metrics=metric_functions,
)

print("🎓 Starting training...")
metrics = trainer.train()
trainer.save_model("ogb_te_finetuned")  # Matches complete tutorial

print('Metrics:', metrics)
print("\n🎉 Training completed!")

## 🎯 Understanding Training Results

After training completes, you'll see metrics that help you understand your model's performance:

### 📊 Key Metrics to Watch:
- **F1 Score**: Balances precision and recall (higher is better)
- **Accuracy**: Overall classification accuracy
- **Loss**: How "wrong" the model's predictions are (lower is better)

### 🎯 What Good Results Look Like:
- **F1 Score > 0.7**: Good performance
- **F1 Score > 0.8**: Excellent performance
- **Stable validation metrics**: Model is learning generalizable patterns

In [None]:
# Display training summary
print("📈 Training Summary:")
print("=" * 40)
print(f"✅ Model successfully fine-tuned for translation efficiency prediction")
print(f"📊 Dataset: Rice mRNA sequences with experimental TE labels")
print(f"🎯 Task: Binary classification (High TE vs Low TE)")
print(f"💾 Model saved as: 'ogb_te_finetuned'")
print(f"🚀 Ready for inference on new sequences!")

## 🎯 Summary and Next Steps

🎉 Congratulations! You have successfully completed the model training tutorial. Let's review what we've accomplished:

### ✅ **Skills Mastered**

✅ **Understood supervised fine-tuning**: How pre-trained models learn specific tasks  
✅ **Mastered the AccelerateTrainer**: Professional training with minimal code  
✅ **Learned training best practices**: Proper data loading, metric selection, model saving  
✅ **Completed end-to-end training**: From raw data to trained model  
✅ **Matched complete tutorial workflow**: Consistent with the main tutorial  

**Your model is now "intelligent"!** 🧠✨

Through fine-tuning, we have transformed a general genomic foundation model into a translation efficiency prediction specialist. This trained model has been saved and is ready for making predictions on new mRNA sequences.

---

### 🚀 What's Next...

In the final tutorial, we will explore:
- 🔮 **Model inference**: Using your trained model to predict new sequences
- 📊 **Result interpretation**: Understanding and validating predictions
- 🌐 **Deployment options**: From research to production applications
- 🚀 **Real-world usage**: Applying your model to biological research

**Ready to put your trained model to work?**

> 🎊 **Milestone**: You are now a qualified genomic AI trainer!

👉 **Final Step**: Open [04_model_inference.ipynb](https://github.com/yangheng95/OmniGenBench/blob/master/examples/translation_efficiency_prediction/04_model_inference.ipynb) to complete your learning journey!