##  TFB Prediction Tutorial 2/4: Model Initialization - From Task to Architecture

In the previous tutorial, [01_data_preparation.ipynb](https://github.com/yangheng95/OmniGenBench/blob/master/examples/tfb_prediction/01_data_preparation.ipynb), we defined our biological task, predicting transcription factor binding sites, and prepared our data accordingly. We framed it as a **multi-label sequence classification** problem. This crucial step of defining the task and data format dictates our next decision: **choosing the right model architecture**.


Welcome to the second part of our **streamlined** tutorial series. In the previous tutorial, we prepared our data with just 3 lines of code using the enhanced OmniDataset framework. Now, we will focus on **Model Initialization** with equal simplicity.

>  Learning Objectives: Master the universal model initialization pattern, understand foundation model concepts, and leverage OmniGenBench's intelligent defaults

---

## The Power of Pre-trained Models 

The core idea behind fine-tuning is to leverage **pre-trained foundation models**. These models have already learned the fundamental \"language\" of the genome from vast amounts of unlabeled sequence data. This pre-training endows them with powerful, general-purpose understanding of genomic patterns.

Our task is to take this general knowledge and specialize it for our specific problem: translation efficiency prediction. With the enhanced OmniGenBench framework, this process is now **effortless**.

###  Why Use Foundation Models Instead of Traditional Methods?

| Traditional Methods | Foundation Model Approach |
|---------|------------|
|  Requires hand-crafted features |  Automatically learns feature representations |
|  Relies on prior biological knowledge |  Discovers patterns from data |
|  Limited generalization ability |  Strong cross-task generalization |
|  Needs large amounts of task-specific data |  Can fine-tune with small amounts of data |


## Key Components: PlantRNA-FM Model and Tokenizer 

The enhanced OmniGenBench framework has revolutionized model initialization. What once required complex configuration is now **automatic**, especially when using PlantRNA-FM for plant genomics tasks.

This tutorial will guide you through the process of selecting and initializing PlantRNA-FM from the `OmniGenBench` framework. We will cover:

1.  **The OmniGenBench Model Zoo**: An overview of available model architectures, with focus on PlantRNA-FM for plant genomics.

2.  **The Principle of Model Selection**: How to choose PlantRNA-FM and configure it for plant transcription factor binding prediction.

3.  **Model Architecture**: Understanding the "PlantRNA-FM base + task head" design.

4.  **Inputs and Outputs**: What PlantRNA-FM expects as input and produces as output for plant regulatory analysis.

5.  **Practical Implementation**: Initializing PlantRNA-FM for our plant TFB prediction task.

By the end of this tutorial, you will understand how to leverage PlantRNA-FM for plant regulatory genomics problems.

### 1. The OmniGenBench Model Zoo: Plant-Specialized Models

`OmniGenBench` provides a comprehensive framework with various model architectures. For plant genomics, **PlantRNA-FM** serves as the foundation, tailored for plant-specific tasks. These are often referred to as "task heads." When you use PlantRNA-FM, you combine a powerful, plant-specific **base model** with a smaller, task-specific **head**.

Here is a summary of the main model classes available in `OmniGenBench` and their plant genomics applications:

| Model Class                                       | Task Type                    | Plant Genomics Example                                       |
| ------------------------------------------------- | ---------------------------- | -------------------------------------------------------- |
| `OmniModelForSequenceClassification`              | Sequence Classification      | Classifying plant promoters as active/inactive, tissue-specific expression   |
| `OmniModelForMultiLabelSequenceClassification`    | Multi-Label Classification   | Predicting multiple TFB sites in plant regulatory regions (our task)       |
| `OmniModelForTokenClassification`                 | Token Classification         | Identifying splice sites in plant genes, m6A modification sites |
| `OmniModelForSequenceRegression`                  | Sequence Regression          | Predicting plant mRNA translation efficiency scores |
| `OmniModelForTokenRegression`                     | Token Regression             | Predicting per-base chromatin accessibility in plants     |
| `OmniModelForSeq2Seq`                             | Sequence-to-Sequence         | RNA secondary structure prediction in plants        |
| `OmniModelForRNADesign`                           | Sequence Generation          | Designing synthetic plant regulatory elements |
| `OmniModelForMLM` (Masked Language Model)         | Self-Supervised Pre-training | Learning representations from unlabeled plant DNA/RNA     |

**PlantRNA-FM** (published in *Nature Machine Intelligence*, 35M parameters) is particularly powerful for plant genomics because it was pre-trained on extensive plant transcriptome and genome data, making it excel at tasks involving plant regulatory elements, codon usage patterns, and RNA structures‚Äîall while being remarkably efficient.

### 2. The Principle of Model Selection: PlantRNA-FM for Plant TFB Prediction

The selection principle is straightforward: **match the model architecture to the machine learning task you defined, and choose PlantRNA-FM for plant-specific genomics**.

In our case:
-   **Biological Problem**: Predicting if any of 919 transcription factors bind to a given plant DNA sequence.
-   **Data Format**: A DNA sequence of length 1000 from plant genomes.
-   **Label Format**: A binary vector of length 919, where each element indicates binding (1) or no binding (0) for a specific TF.
-   **ML Task**: Since a single sequence can have multiple TFs binding to it, this is a **Multi-Label Sequence Classification** problem.
-   **Model Choice**: `OmniModelForMultiLabelSequenceClassification` with **PlantRNA-FM** base model.

**Why PlantRNA-FM for this task?**
- Pre-trained on plant transcriptomes and regulatory regions
- Understands plant-specific transcription factor binding motifs
- Captures plant chromatin structure patterns
- Generalizes well across plant species (Arabidopsis, rice, maize, etc.)

### 3. Model Architecture: PlantRNA-FM Base + Task Head

Let's visualize the architecture. At its core, our model consists of two parts:

1.  **PlantRNA-FM Base Model**: This is a large, pre-trained transformer model specifically trained on plant genomic and transcriptomic data. Its job is to read a plant DNA sequence (as a series of tokens) and convert it into a rich numerical representation (an embedding) that captures plant-specific regulatory patterns, motifs, and chromatin contexts.
2.  **The Multi-Label Classification Head**: This is a smaller neural network (usually one or two linear layers) that sits on top of PlantRNA-FM. It takes the plant sequence embedding and transforms it into 919 independent binary predictions for transcription factor binding.

Here is a diagram illustrating this plant-specific architecture:

```mermaid
graph TD
    subgraph "Input"
        A[DNA Sequence<br/>"GATTACA..."]
    end

    subgraph "Tokenization"
        B[Input Tokens<br/>"[CLS], G, A, T, T, A, A, [SEP]"]
    end

    subgraph "OmniModelForMultiLabelSequenceClassification"
        C(Base Model<br/>OmniGenome-186M)
        D(Classification Head<br/>Linear Layer + Sigmoid)
    end
    
    subgraph "Output"
        E[Prediction Vector<br/>"[0.9, 0.1, 0.05, ..., 0.8]"]
    end

    A --> B
    B --> C
    C -- Sequence Embedding --> D
    D --> E
```

The base model does the heavy lifting of understanding the sequence, while the head adapts that understanding to our specific predictive goal. During fine-tuning, we will update the weights of both the head and (to a lesser extent) the base model to optimize for our TFB prediction task.

### 4. Inputs and Outputs: A Look at the Data Flow

Understanding what the model expects and what it returns is critical for debugging and interpretation.

-   **Model Input**: The model expects a batch of tokenized sequences. When you use the `AccelerateTrainer` (which we will see in the next tutorial), the data loader will automatically prepare a dictionary-like object for each batch. This object contains:
    -   `input_ids`: A tensor of token IDs representing the DNA sequences.
    -   `attention_mask`: A tensor indicating which tokens are real and which are padding.
    -   `labels`: A tensor containing the ground-truth labels for each sequence.

-   **Model Output**: During training, the model returns a dictionary containing:
    -   `loss`: The calculated loss value, which the trainer uses to update the model weights.
    -   `logits`: The raw, unnormalized output scores from the final linear layer. For our task, this will be a tensor of shape `(batch_size, 919)`.

During inference (prediction), the model simply returns the `logits`. To get probabilities, we typically apply a Sigmoid function to the logits. To get binary predictions, we can then threshold these probabilities (e.g., predict 1 if probability > 0.5).

### 5. Practical Implementation: Initializing the Model

Now, let's translate this theory into code. We will perform the following steps:
1.  Import necessary libraries.
2.  Define a configuration object to hold all our parameters.
3.  Write a function to initialize the tokenizer and the `OmniModelForMultiLabelSequenceClassification` model.

First, let's set up our environment by importing the required modules.

In [None]:
import torch
from omnigenbench import (
    OmniTokenizer, 
    OmniModelForMultiLabelSequenceClassification
)

Next, we define a configuration dictionary. This is a best practice that centralizes all important parameters, making the code cleaner and easier to modify.

In [None]:
# Configuration - matches complete tutorial exactly
config = {
    "model_name_or_path": "yangheng/OmniGenome-52M",
    "num_labels": 919,
    "max_length": 512,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

print("‚öôÔ∏è Configuration:")
print(f"  üß¨ Model: {config['model_name_or_path']}")
print(f"  üè∑Ô∏è Labels: {config['num_labels']} TF binding sites")
print(f"  üìè Max length: {config['max_length']} tokens")
print(f"  üì± Device: {config['device']}")

Finally, we create the function that loads our model. This function encapsulates the logic for:
-   Loading the pre-trained tokenizer.
-   Loading the pre-trained base model.
-   Initializing our chosen task-specific model, `OmniModelForMultiLabelSequenceClassification`, which wraps the base model and adds the classification head.

In [None]:
# Load tokenizer - matches complete tutorial
print("üîÑ Loading tokenizer...")
tokenizer = OmniTokenizer.from_pretrained(config["model_name_or_path"])
print(f"‚úÖ Tokenizer loaded: {config['model_name_or_path']}")

# Load model - matches complete tutorial exactly
print("üîÑ Loading model...")
model = OmniModelForMultiLabelSequenceClassification(
    config["model_name_or_path"],
    tokenizer,
    num_labels=config["num_labels"],
)

# Move model to the specified device
model.to(config["device"])

print("‚úÖ Model and tokenizer loaded successfully.")
print(f"üéØ Model is on device: {next(model.parameters()).device}")

# Show model statistics
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"üìä Model parameters: {total_params:,} total, {trainable_params:,} trainable")
print(f"üß¨ Ready for {config['num_labels']}-label TF binding prediction")

### Summary and Next Steps

Congratulations! You have successfully initialized a powerful Genomic Foundation Model tailored for your specific biological task.

To recap, we have:
-   Surveyed the different model architectures available in `OmniGenBench`.
-   Learned the principle of matching the model architecture to the ML task.
-   Understood the "base model + task head" design.
-   Initialized the `OmniModelForMultiLabelSequenceClassification` model and its tokenizer.

We now have the two key components ready: a prepared dataset and an initialized model. The next logical step is to bring them together and train the model to make accurate predictions.

In the next tutorial, **[3/4: Model Training](./03_model_training.ipynb)**, we will take this model and the data we prepared in the first tutorial and begin the fine-tuning process.