# üß¨ RNA SSP Tutorial 2/4: Model Initialization - From Task to Architecture

In the previous tutorial, [01_data_preparation.ipynb](01_data_preparation.ipynb), we defined our biological task‚Äîpredicting RNA secondary structure‚Äîand prepared our data accordingly. We framed it as a **token classification** problem where each nucleotide receives a structural label.

This crucial step of defining the task and data format dictates our next decision: **choosing the right model architecture**.

> üìö **Learning Objectives**: Master model initialization patterns, understand foundation model concepts, and leverage OmniGenBench's intelligent defaults

---

## The Power of Pre-trained Models üöÄ

The core idea behind fine-tuning is to leverage **pre-trained foundation models**. These models have already learned the fundamental "language" of genomes from vast amounts of unlabeled sequence data. This pre-training endows them with powerful, general-purpose understanding of genomic patterns.

Our task is to take this general knowledge and specialize it for our specific problem: RNA secondary structure prediction.

### üéØ Why Use Foundation Models?

| Traditional Methods | Foundation Model Approach |
|---------|------------|
| üîß Requires hand-crafted features | ü§ñ Automatically learns representations |
| üìö Relies on prior biological knowledge | üî¨ Discovers patterns from data |
| üéØ Limited generalization ability | üåê Strong cross-task generalization |
| üìä Needs large amounts of task-specific data | üí° Can fine-tune with small datasets |

## Key Components: Model and Tokenizer

This tutorial will guide you through selecting and initializing OmniGenome for RNA structure prediction. We will cover:

1. **The OmniGenBench Model Zoo**: Available model architectures
2. **The Principle of Model Selection**: Matching models to tasks
3. **Model Architecture**: Understanding the "base + task head" design
4. **Inputs and Outputs**: What the model expects and produces
5. **Practical Implementation**: Initializing the model for our task

By the end of this tutorial, you will understand how to configure OmniGenome for token-level predictions.

### 1. The OmniGenBench Model Zoo

`OmniGenBench` provides a comprehensive framework with various model architectures. These are often referred to as "task heads." When you use a pre-trained model, you combine a powerful **base model** with a smaller, task-specific **head**.

Here is a summary of the main model classes available:

| Model Class | Task Type | RNA SSP Relevance |
| --- | --- | --- |
| `OmniModelForSequenceClassification` | Sequence Classification | Classifying entire RNA molecules |
| `OmniModelForMultiLabelSequenceClassification` | Multi-Label Classification | Multiple properties per sequence |
| **`OmniModelForTokenClassification`** | **Token Classification** | **Per-nucleotide structure prediction (our task)** |
| `OmniModelForSequenceRegression` | Sequence Regression | Predicting continuous values |
| `OmniModelForTokenRegression` | Token Regression | Per-position continuous predictions |
| `OmniModelForSeq2Seq` | Sequence-to-Sequence | Structure generation |

**OmniGenome** models (52M, 186M, 418M parameters) are particularly powerful for genomic tasks because they were pre-trained on diverse DNA/RNA sequences, making them excel at pattern recognition across different sequence contexts.

### 2. The Principle of Model Selection

The selection principle is straightforward: **match the model architecture to the machine learning task you defined**.

In our case:
- **Biological Problem**: Predicting if each nucleotide in an RNA sequence is paired or unpaired, and if paired, which side
- **Data Format**: RNA sequence of variable length
- **Label Format**: A label for each position: `(`, `)`, or `.`
- **ML Task**: Since each position needs a label, this is **Token Classification**
- **Model Choice**: `OmniModelForTokenClassification` with **OmniGenome** base model

**Why OmniGenome for this task?**
- Pre-trained on genomic sequences (DNA/RNA)
- Understands sequence patterns and motifs
- Captures contextual relationships between positions
- Generalizes well across different RNA types

### 3. Model Architecture: Base Model + Task Head

Let's visualize the architecture. At its core, our model consists of two parts:

1. **OmniGenome Base Model**: A large, pre-trained transformer that reads RNA sequences (as tokens) and converts them into rich numerical representations (embeddings) that capture sequence patterns and contexts.

2. **The Token Classification Head**: A smaller neural network (usually one or two linear layers) that takes the per-token embeddings and transforms them into 3-class predictions (one for each structural label).

Here is a diagram illustrating this architecture:

```mermaid
graph TD
    subgraph "Input"
        A["RNA Sequence<br/>AUGCCGUGC"]
    end

    subgraph "Tokenization"
        B["Input Tokens<br/>[CLS], A, U, G, C, C, G, U, G, C, [SEP]"]
    end

    subgraph "OmniModelForTokenClassification"
        C("Base Model<br/>OmniGenome-52M")
        D("Classification Head<br/>Linear Layer + Softmax")
    end
    
    subgraph "Output"
        E["Per-Token Predictions<br/>[., (, (, (, ., ., ., ), ), .]"]
    end

    A --> B
    B --> C
    C -- "Token Embeddings" --> D
    D --> E
```

The base model does the heavy lifting of understanding the sequence, while the head adapts that understanding to our specific predictive goal. During fine-tuning, we update the weights of both the head and (to a lesser extent) the base model to optimize for structure prediction.

### 4. Inputs and Outputs: A Look at the Data Flow

Understanding what the model expects and returns is critical for debugging and interpretation.

#### Model Inputs
The model expects a **dictionary** with these keys:
- `input_ids`: Tokenized sequence (tensor of integers)
- `attention_mask`: Mask indicating real vs. padded tokens
- `labels` (during training): Ground-truth labels for each position

#### Model Outputs
The model returns a **dictionary** with:
- `logits`: Raw prediction scores for each token and class (shape: [batch, seq_len, num_labels])
- `loss` (during training): Computed loss value
- `predictions` (during inference): Predicted class IDs

#### Example Flow
```python
# Input
sequence = "AUGCCGUGC"
# After tokenization: [CLS] A U G C C G U G C [SEP]
# Model output logits shape: [1, 11, 3]
# Predictions: [., ., (, (, (, ., ., ), ), ., .]
# (Ignore [CLS] and [SEP] labels)
```

## üõ†Ô∏è Practical Implementation

Now let's put theory into practice. We'll initialize the model in just a few lines of code.

### Step 1: Environment Setup and Imports

In [None]:
# Install if needed
# !pip install omnigenbench -U

In [None]:
import torch
from omnigenbench import (
    OmniTokenizer,
    OmniModelForTokenClassification,
)

print("‚úÖ Libraries imported successfully!")
print(f"üî• PyTorch version: {torch.__version__}")
print(f"üéØ CUDA available: {torch.cuda.is_available()}")

### Step 2: Configuration

Define our model and label configuration.

In [None]:
# Model configuration
model_name_or_path = "yangheng/OmniGenome-52M"

# Label mapping for RNA secondary structure
label2id = {
    "(": 0,  # Opening base pair
    ")": 1,  # Closing base pair
    ".": 2   # Unpaired nucleotide
}
id2label = {v: k for k, v in label2id.items()}
num_labels = len(label2id)

print(f"‚úÖ Configuration complete!")
print(f"üìä Model: {model_name_or_path}")
print(f"üìä Number of labels: {num_labels}")
print(f"üìä Label mapping: {label2id}")

### Step 3: Initialize Tokenizer

The tokenizer converts sequences into model inputs.

In [None]:
# Load tokenizer matching the model
tokenizer = OmniTokenizer.from_pretrained(model_name_or_path)

print(f"‚úÖ Tokenizer loaded: {model_name_or_path}")
print(f"üìä Vocabulary size: {tokenizer.vocab_size}")
print(f"üìä Special tokens: {tokenizer.special_tokens_map}")

### Step 4: Initialize Model

Now we create the model for token classification. This is remarkably simple with OmniGenBench!

In [None]:
# Initialize model for token classification
model = OmniModelForTokenClassification(
    model_name_or_path,
    tokenizer=tokenizer,
    label2id=label2id,
    id2label=id2label,
)

print(f"‚úÖ Model initialized: {model_name_or_path}")
print(f"\nüìä Model Configuration:")
print(f"  - Architecture: Token-level classification")
print(f"  - Base model: OmniGenome-52M")
print(f"  - Number of labels: {num_labels}")
print(f"  - Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  - Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

### Step 5: Test Model with Sample Input

Let's verify the model works by testing it on a sample sequence.

In [None]:
# Test with a sample sequence
sample_sequence = "AUGCCGUGCAUUAA"

# Tokenize
inputs = tokenizer(
    sample_sequence,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512,
)

print(f"üìù Sample sequence: {sample_sequence}")
print(f"üìä Tokenized input shape: {inputs['input_ids'].shape}")
print(f"üìä Input IDs: {inputs['input_ids']}")

# Forward pass (no gradients needed for testing)
with torch.no_grad():
    outputs = model(**inputs)

print(f"\nüìä Model Outputs:")
print(f"  - Logits shape: {outputs['logits'].shape}")
print(f"  - Logits: {outputs['logits']}")

# Get predictions
predictions = torch.argmax(outputs['logits'], dim=-1)[0]
predicted_structure = "".join([id2label[pred.item()] for pred in predictions[1:-1]])  # Skip [CLS] and [SEP]

print(f"\nüîÆ Predicted Structure:")
print(f"  Sequence:  {sample_sequence}")
print(f"  Structure: {predicted_structure}")
print(f"\nüí° Note: These are untrained predictions. After training, they will be accurate!")

### Understanding Model Components

Let's examine the model architecture:

In [None]:
# Inspect model components
print("üèóÔ∏è Model Architecture:")
print(f"\nBase Model: {type(model.model).__name__}")
print(f"Classification Head: {type(model.classifier).__name__}")

# Show classification head details
print(f"\nüìä Classification Head Structure:")
print(model.classifier)

print(f"\nüí° The classification head transforms the base model's embeddings into class predictions.")

## üìö Summary and Next Steps

In this tutorial, we:
1. ‚úÖ Understood the concept of foundation models and pre-training
2. ‚úÖ Explored the OmniGenBench model zoo
3. ‚úÖ Learned the principle of matching models to tasks
4. ‚úÖ Understood the "base model + task head" architecture
5. ‚úÖ Initialized OmniGenome for token classification
6. ‚úÖ Tested the model with a sample sequence

### What We've Accomplished
```python
# Model initialization in just 3 lines!
tokenizer = OmniTokenizer.from_pretrained(model_name_or_path)
model = OmniModelForTokenClassification(
    model_name_or_path,
    tokenizer=tokenizer,
    label2id=label2id,
)
```

### Key Takeaways
- **Foundation models** provide powerful pre-trained representations
- **Task heads** adapt base models to specific problems
- **Token classification** predicts a label for each position
- **OmniGenBench** makes model initialization effortless

### Next: Model Training
Now that our model is initialized, proceed to **[03_model_training.ipynb](03_model_training.ipynb)** to:
- Fine-tune the model on RNA structure data
- Configure training parameters
- Evaluate model performance
- Save the trained model

The model is ready to learn! üöÄ