##  TFB Prediction Tutorial 3/4: Model Training - Teaching the Model to Predict

Welcome to the third tutorial in our series. So far, we have:
1.  Prepared our dataset and dataloaders ([01_data_preparation.ipynb](https://github.com/yangheng95/OmniGenBench/blob/master/examples/tfb_prediction/01_data_preparation.ipynb)).
2.  Selected and initialized the correct model architecture ([02_model_initialization.ipynb](https://github.com/yangheng95/OmniGenBench/blob/master/examples/tfb_prediction/02_model_initialization.ipynb)).

Now, we have all the ingredients ready. In this tutorial, we will combine them to **train** the model. Training is the process of teaching the model to make accurate predictions by showing it examples from our dataset and adjusting its internal parameters.

This tutorial will cover:
1.  **The Concept of Supervised Training**: What it means to train a model and what components are required.
2.  **Trainers in `OmniGenBench`**: An overview of the different training engines available in the framework.
3.  **The Training Process in Action**: A step-by-step guide to configuring and launching the training job.
4.  **Evaluation and Results**: How we measure the model's performance and what artifacts are produced.

By the end of this tutorial, you will have a fine-tuned model saved on your disk, ready for inference.

### 1. The Concept of Supervised Training

At its heart, **supervised training** is like teaching a student with a textbook and an answer key. We show the model an input (a DNA sequence), let it make a prediction, and then compare its prediction to the correct answer (the labels). The difference between the prediction and the answer is quantified by a **loss function**. The goal is to minimize this loss.

This process is iterative and happens in a **training loop**, which consists of four main steps:
1.  **Forward Pass**: The input data is passed through the model to get a prediction.
2.  **Loss Calculation**: The model's prediction is compared to the ground-truth labels using a loss function (e.g., Binary Cross-Entropy for our task).
3.  **Backward Pass (Backpropagation)**: The loss is used to calculate the gradient for each of the model's parameters. The gradient tells us how to adjust each parameter to reduce the loss.
4.  **Optimizer Step**: An **optimizer** (like AdamW) updates the model's parameters based on the calculated gradients.

To manage this entire process, we need a few key components:
-   A **Model** to be trained.
-   **DataLoaders** to supply batches of data for training, validation, and testing.
-   A **Loss Function** to measure prediction error.
-   An **Optimizer** to update the model's weights.
-   An **Evaluation Metric** (e.g., ROC-AUC) to assess performance on a validation set.

A **Trainer** is an object that encapsulates this entire training loop, abstracting away the boilerplate code and providing a clean interface to manage training, evaluation, and logging.

### 2. Trainers in `OmniGenBench`

`OmniGenBench` provides a flexible system for training by offering several "Trainer" classes, each suited for different needs. All trainers handle the core training loop, but they differ in their features and complexity.

| Trainer Class         | Key Feature                               | When to Use                                                                                             |
| --------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------- |
| `Trainer`             | Native PyTorch Implementation             | For simple, single-GPU training. It's lightweight and gives you a clear, direct PyTorch experience.      |
| `AccelerateTrainer`   | Distributed Training via `accelerate`     | **Recommended for most users.** Easily scales from single-GPU to multi-GPU or multi-node setups with minimal code changes. |
| `HFTrainer`           | Integration with Hugging Face `Trainer`   | If you are already heavily invested in the Hugging Face ecosystem and prefer its `TrainingArguments` setup. |

For this tutorial, we will use the **`AccelerateTrainer`**. It represents the best balance of power and ease of use, making it simple to run our experiment on a single GPU today and scale it up to a powerful cluster tomorrow if needed. The `accelerate` library by Hugging Face handles all the complexities of distributed training behind the scenes.

### 3. The Training Process in Action

Let's put everything together and train our model. The process involves three main code blocks:
1.  **Re-establishing the context**: We'll quickly re-run the code from the previous tutorials to get our `dataloaders`, `model`, and `tokenizer`.
2.  **Defining the evaluation metric**: We'll set up the function to compute ROC-AUC during validation.
3.  **Configuring and running the `AccelerateTrainer`**: This is where we define all training parameters and launch the job.

#### 3.1. Setup: Data and Model Initialization

First, let's import everything we need and run the setup code from the previous tutorials. This ensures we have all the necessary objects in our environment.

In [None]:
# Import libraries - matches complete tutorial
import torch
from omnigenbench import (
    OmniTokenizer,
    OmniModelForMultiLabelSequenceClassification,
    OmniDatasetForMultiLabelClassification,
    ClassificationMetric,
    AccelerateTrainer
)

print("✅ Libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🎯 CUDA available: {torch.cuda.is_available()}")

In [None]:
# Configuration - matches complete tutorial exactly  
model_name_or_path = "yangheng/OmniGenome-52M"
dataset_name = "deepsea_tfb_prediction"

# Basic training parameters
num_labels = 919
max_length = 512
batch_size = 64
learning_rate = 2e-5
epochs = 3
output_dir = "./ogb_tfb_finetuned"

In [None]:
# Load tokenizer and prepare datasets - matches complete tutorial
print("🔄 Loading tokenizer...")
tokenizer = OmniTokenizer.from_pretrained(model_name_or_path)

print("📊 Loading DeepSEA TFB dataset...")
datasets = OmniDatasetForMultiLabelClassification.from_hub(
    dataset_name_or_path=dataset_name,
    tokenizer=tokenizer,
    max_length=max_length,
    max_examples=1000,  # For quick testing; set to None for full dataset
    force_padding=False
)

print(f"📋 Dataset prepared:")
print(f"  📈 Training samples: {len(datasets['train'])}")
print(f"  🔍 Validation samples: {len(datasets['valid'])}")
print(f"  🧪 Test samples: {len(datasets['test'])}")

# Initialize model - matches complete tutorial
print("🔄 Loading model...")
model = OmniModelForMultiLabelSequenceClassification(
    model_name_or_path,
    tokenizer,
    num_labels=num_labels,
)

total_params = sum(p.numel() for p in model.parameters())
print(f"✅ Model loaded with {num_labels} labels!")
print(f"📊 Parameters: {total_params / 1e6:.1f}M")
print("✅ Training setup complete!")

In [None]:
# Remove this cell as it duplicates the setup above and has incorrect variable references

#### 3.2. Defining the Evaluation Metric

During training, we need to monitor how well our model is performing on data it hasn't seen before (the validation set). This helps us avoid overfitting and tells us when to stop training. For multi-label classification tasks, **Area Under the Receiver Operating Characteristic Curve (ROC-AUC)** is a standard and robust metric.

We will define a `compute_metrics` function that takes the model's predictions and the true labels and returns the average ROC-AUC score across all labels. The `AccelerateTrainer` will automatically call this function at the end of each epoch.

In [None]:
# Setup evaluation metrics - matches complete tutorial
print("📊 Setting up evaluation metrics...")

# Use the ClassificationMetric from OmniGenBench
metric_functions = [ClassificationMetric(ignore_y=-100).roc_auc_score]

print("✅ Metrics configured:")
print("   📈 ROC-AUC score for multi-label classification")
print("   🎯 Measures model's ability to distinguish between classes")
print("   📊 Averages across all 919 TF labels")

#### 3.3. Configuring and Running the `AccelerateTrainer`

This is the final and most important step. We will instantiate the `AccelerateTrainer` and provide it with all the necessary components and hyperparameters.

Key parameters for the trainer include:
-   `model`, `train_dataloader`, `valid_dataloader`: The core components for training.
-   `epochs`: The total number of times to iterate over the entire training dataset.
-   `optimizer_class`: The optimizer to use (we'll use `torch.optim.AdamW`, a standard choice for transformer models).
-   `lr`: The learning rate, which controls the step size of the optimizer.
-   `compute_metrics`: Our evaluation function.
-   `output_dir`: The directory where checkpoints and results will be saved.
-   `early_stopping_patience`: A crucial parameter for preventing overfitting. The training will stop if the validation metric (`roc_auc`) does not improve for this many epochs.
-   `monitor_metric`: The metric to watch for early stopping and for saving the best model checkpoint.

Let's create the trainer and start the training process by calling the `.train()` method.

In [None]:
# Initialize the AccelerateTrainer - matches complete tutorial exactly
print("🚀 Setting up AccelerateTrainer...")

trainer = AccelerateTrainer(
    model=model,
    train_dataset=datasets["train"],
    eval_dataset=datasets["valid"],
    test_dataset=datasets["test"],
    compute_metrics=metric_functions,
    learning_rate=learning_rate,
    batch_size=batch_size,
    
)

print("🎓 Starting training...")
metrics = trainer.train()
trainer.save_model("ogb_tfb_finetuned")

print('Metrics:', metrics)
print("✅ Training finished!")
print(f"💾 Model saved as: ogb_tfb_finetuned")

### Summary and Next Steps

In this tutorial, we have successfully fine-tuned our Genomic Foundation Model for the task of TFB prediction.

We have learned about:
-   The core concepts of supervised training.
-   The different trainers available in `OmniGenBench` and why `AccelerateTrainer` is a good choice.
-   How to configure and launch a training job, including setting up an evaluation metric and early stopping.
-   The outputs of the training process: a trained model checkpoint (`best_model.pth`) and a results summary (`all_results.json`).

We now have a model that has learned to identify transcription factor binding sites from raw DNA sequences. But a trained model is only useful if we can use it to make predictions on new, unseen data.

In the final tutorial of this series, **[4/4: Model Inference](./04_model_inference.ipynb)**, we will learn how to load our saved model and use it to perform inference, evaluate its performance on the test set, and interpret the results.