# Custom Fine‑Tuning Composition for Regression Tasks

This notebook introduces the custom fine-tuning of OmniGenome models for regression tasks.

## Task Formulation

Regression tasks in OmniGenome can be:
- **Sequence Regression**: Predict a continuous value for the entire sequence
- **Token Regression**: Predict continuous values for each token (e.g., nucleotide) in the sequence

### Simple Model APIs
```
from omnigenbench import (
    OmniModelForSequenceRegression,
    OmniModelForTokenRegression,
    OmniModelForMatrixRegression,
    OmniModelForStructuralImputation,
)
```

### Simple Dataset APIs
```
from omnigenbench import (
    OmniDatasetForSequenceRegression,
    OmniDatasetForTokenRegression,
)
```

Otherwise, you can define your own model and dataset classes by inheriting from the base classes provided in OmniGenome. We will demonstrate how to do this in the following sections.


## Foundation Model Preparation

In [None]:
# Install required packages
!pip install -q omnigenbench torch datasets scikit-learn matplotlib seaborn

In [None]:
# In this notebook, we will introduce the custom finetuning of OmniGenome models for regression tasks.
model_name = "yangheng/OmniGenome-52M" # 52M parameters

# 1. Load the model and tokenizer according to model_name for later use
from transformers import AutoTokenizer, AutoModel
base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)  # trust_remote_code=True is used to load the model from the remote repository, which is necessary for OmniGenome models
base_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# OmniGenome-52M tokenizer can be initialized by AutoTokenizer, for other models, you can use can define a wrapper to initialize the tokenizer to be compatible with transformers tokenizer APIs

## Define a Model with a Regression Head for Downstream Task

In [None]:
# For a sequence regression task (predicts a single value for the entire sequence)
from omnigenbench import OmniModelForSequenceRegression
model = OmniModelForSequenceRegression(
    config_or_model=model_name,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    num_labels=1,  # Single regression value
)

# For a token regression task (predicts values for each token in the sequence)
from omnigenbench import OmniModelForTokenRegression
model = OmniModelForTokenRegression(
    config_or_model=model_name,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    num_labels=3,  # Multiple regression targets (e.g., reactivity, deg_Mg_pH10, deg_Mg_50C)
)

## (Optional) Define a Custom Model for Downstream Task

In [None]:
from omnigenbench import OmniModel, OmniPooling
import torch

class OmniModelForSequenceRegression(OmniModel):
    def __init__(self, config_or_model, tokenizer, *args, **kwargs):
        super().__init__(config_or_model, tokenizer, *args, **kwargs)
        self.metadata["model_name"] = self.__class__.__name__
        self.pooler = OmniPooling(self.config)
        self.regressor = torch.nn.Linear(
            self.config.hidden_size, self.config.num_labels
        )
        self.loss_fn = torch.nn.MSELoss()
        # self.model_info()

    def forward(self, **inputs):
        labels = inputs.pop("labels", None)
        last_hidden_state = self.last_hidden_state_forward(**inputs)
        last_hidden_state = self.dropout(last_hidden_state)
        last_hidden_state = self.activation(last_hidden_state)
        last_hidden_state = self.pooler(inputs, last_hidden_state)
        predictions = self.regressor(last_hidden_state)
        outputs = {
            "logits": predictions,  # For regression, logits are the predictions
            "last_hidden_state": last_hidden_state,
            "labels": labels,
        }
        return outputs

    def predict(self, sequence_or_inputs, **kwargs):
        raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)

        predictions = raw_outputs["logits"]
        last_hidden_state = raw_outputs["last_hidden_state"]

        outputs = {
            "predictions": predictions,
            "logits": predictions,
            "last_hidden_state": last_hidden_state,
        }

        return outputs

    def inference(self, sequence_or_inputs, **kwargs):
        raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)

        predictions = raw_outputs["logits"]
        last_hidden_state = raw_outputs["last_hidden_state"]

        if not isinstance(sequence_or_inputs, list):
            outputs = {
                "predictions": predictions[0],
                "logits": predictions[0],
                "last_hidden_state": last_hidden_state[0],
            }
        else:
            outputs = {
                "predictions": predictions,
                "logits": predictions,
                "last_hidden_state": last_hidden_state,
            }

        return outputs

    def loss_function(self, logits, labels):
        loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1, self.config.num_labels))
        return loss


class OmniModelForTokenRegression(OmniModel):
    def __init__(self, config_or_model, tokenizer, *args, **kwargs):
        super().__init__(config_or_model, tokenizer, *args, **kwargs)
        self.metadata["model_name"] = self.__class__.__name__
        self.regressor = torch.nn.Linear(
            self.config.hidden_size, self.config.num_labels
        )
        self.loss_fn = torch.nn.MSELoss()
        # self.model_info()

    def forward(self, **inputs):
        labels = inputs.pop("labels", None)
        last_hidden_state = self.last_hidden_state_forward(**inputs)
        last_hidden_state = self.dropout(last_hidden_state)
        last_hidden_state = self.activation(last_hidden_state)
        predictions = self.regressor(last_hidden_state)
        outputs = {
            "logits": predictions,  # For regression, logits are the predictions
            "last_hidden_state": last_hidden_state,
            "labels": labels,
        }
        return outputs

    def predict(self, sequence_or_inputs, **kwargs):
        raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)

        predictions = raw_outputs["logits"]
        last_hidden_state = raw_outputs["last_hidden_state"]

        outputs = {
            "predictions": predictions,
            "logits": predictions,
            "last_hidden_state": last_hidden_state,
        }

        return outputs

    def inference(self, sequence_or_inputs, **kwargs):
        raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)

        inputs = raw_outputs["inputs"]
        predictions = raw_outputs["logits"]
        last_hidden_state = raw_outputs["last_hidden_state"]

        # Filter out padding tokens for token-level predictions
        filtered_predictions = []
        for i in range(predictions.shape[0]):
            i_pred = predictions[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][1:-1]
            filtered_predictions.append(i_pred.detach().cpu())

        if not isinstance(sequence_or_inputs, list):
            outputs = {
                "predictions": filtered_predictions[0],
                "logits": predictions[0],
                "last_hidden_state": last_hidden_state[0],
            }
        else:
            outputs = {
                "predictions": filtered_predictions,
                "logits": predictions,
                "last_hidden_state": last_hidden_state,
            }

        return outputs

    def loss_function(self, logits, labels):
        padding_value = (
            self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
        )
        logits = logits.view(-1, self.config.num_labels)
        labels = labels.view(-1, self.config.num_labels)
        mask = torch.where(labels != padding_value)[0]

        filtered_logits = logits[mask]
        filtered_targets = labels[mask]

        loss = self.loss_fn(filtered_logits, filtered_targets)
        return loss

## Define a Dataset for Downstream Task

In [None]:
# For a sequence regression task
from omnigenbench import OmniDatasetForSequenceRegression
# For a token regression task
from omnigenbench import OmniDatasetForTokenRegression

## (Optional) Define a Custom Dataset for Downstream Task

### To define a custom dataset for a token regression task, you can inherit from `OmniDatasetForTokenRegression` and implement the `prepare_input` method to process the input data.
Make sure your dataset is in a format compatible with the tokenizer API, returning tokenized inputs and labels.
```
class Dataset(OmniDatasetForTokenRegression):
    def __init__(self, data_source, tokenizer, max_length, **kwargs):
        super().__init__(data_source, tokenizer, max_length, **kwargs)

    def prepare_input(self, instance, **kwargs):
        target_cols = ["reactivity", "deg_Mg_pH10", "deg_Mg_50C"]
        instance["sequence"] = f'{instance["sequence"]}'
        tokenized_inputs = self.tokenizer(
            instance["sequence"],
            padding=kwargs.get("padding", "do_not_pad"),
            truncation=kwargs.get("truncation", True),
            max_length=self.max_length,
            return_tensors="pt",
        )
        labels = [instance[target_col] for target_col in target_cols]
        labels = np.concatenate(
            [
                np.array(labels),
                np.array(
                    [
                        [-100]
                        * (len(tokenized_inputs["input_ids"].squeeze()) - len(labels[0])),
                        [-100]
                        * (len(tokenized_inputs["input_ids"].squeeze()) - len(labels[0])),
                        [-100]
                        * (len(tokenized_inputs["input_ids"].squeeze()) - len(labels[0])),
                    ]
                ),
            ],
            axis=1,
        ).T
        tokenized_inputs["labels"] = torch.tensor(labels, dtype=torch.float32)
        for col in tokenized_inputs:
            tokenized_inputs[col] = tokenized_inputs[col].squeeze()
        return tokenized_inputs
```

## Load the dataset according to the path

In [None]:
dataset_path = "toy_datasets/RNA-mRNA/"  # Path to your dataset files
# RNA-mRNA is an mRNA degradation regression dataset (token regression), containing train.json, test.json, and valid.json files
train_file = dataset_path + "train.json"
test_file = dataset_path + "test.json"
valid_file = dataset_path + "valid.json"

# For token regression (predicting values for each token)
train_set = OmniDatasetForTokenRegression(
    data_source=train_file,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    max_length=128,  # Set the maximum sequence length
)
test_set = OmniDatasetForTokenRegression(
    data_source=test_file,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    max_length=128,  # Set the maximum sequence length
)
valid_set = OmniDatasetForTokenRegression(
    data_source=valid_file,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    max_length=128,  # Set the maximum sequence length
)

# For sequence regression (predicting a single value for the entire sequence)
# train_set = OmniDatasetForSequenceRegression(
#     data_source=train_file,
#     tokenizer=base_tokenizer,
#     max_length=128,
# )
# test_set = OmniDatasetForSequenceRegression(
#     data_source=test_file,
#     tokenizer=base_tokenizer,
#     max_length=128,
# )
# valid_set = OmniDatasetForSequenceRegression(
#     data_source=valid_file,
#     tokenizer=base_tokenizer,
#     max_length=128,
# )

## Training Implementation

In [None]:
from omnigenbench import RegressionMetric  # contains all metrics from sklearn.metrics and some custom metrics for regression tasks
from omnigenbench import Trainer
import torch

# necessary hyperparameters
epochs = 10
learning_rate = 2e-5
weight_decay = 1e-5
batch_size = 8
max_length = 128
seeds = [45]  # Each seed will be used for one run

# Regression metrics
compute_metrics = [
    RegressionMetric(ignore_y=-100).mean_squared_error,  # MSE
    RegressionMetric(ignore_y=-100).mean_absolute_error,  # MAE
    RegressionMetric(ignore_y=-100).r2_score,  # R² score
    RegressionMetric(ignore_y=-100).mcrmse,  # Mean Columnwise Root Mean Squared Error (custom metric)
]

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

for seed in seeds:
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        eval_loader=valid_loader,
        test_loader=test_loader,
        batch_size=batch_size,
        epochs=epochs,
        optimizer=optimizer,
        compute_metrics=compute_metrics,
        seeds=seed,
    )

    metrics = trainer.train()
    test_metrics = metrics["test"][-1]
    print(metrics)

## Model Loading and Inference

## Evaluation
After training, we evaluate the model on the validation set with regression metrics.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Get predictions
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in valid_loader:
        outputs = model(**batch)
        predictions = outputs['logits']
        labels = batch['labels']
        
        # Filter out padding tokens
        mask = labels != -100
        filtered_preds = predictions[mask]
        filtered_labels = labels[mask]
        
        all_predictions.append(filtered_preds.cpu().numpy())
        all_labels.append(filtered_labels.cpu().numpy())

# Concatenate all predictions and labels
all_predictions = np.concatenate(all_predictions, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

# Calculate metrics
mse = mean_squared_error(all_labels, all_predictions)
mae = mean_absolute_error(all_labels, all_predictions)
r2 = r2_score(all_labels, all_predictions)

print(f"Mean Squared Error: {mse:.4f}")
print(f"Mean Absolute Error: {mae:.4f}")
print(f"R² Score: {r2:.4f}")

# Plot predictions vs actual values
plt.figure(figsize=(10, 6))
plt.scatter(all_labels.flatten(), all_predictions.flatten(), alpha=0.5)
plt.plot([all_labels.min(), all_labels.max()], [all_labels.min(), all_labels.max()], 'r--', lw=2)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Predicted vs Actual Values')
plt.grid(True, alpha=0.3)
plt.show()

# Plot residuals
residuals = all_predictions.flatten() - all_labels.flatten()
plt.figure(figsize=(10, 6))
plt.scatter(all_predictions.flatten(), residuals, alpha=0.5)
plt.axhline(y=0, color='r', linestyle='--')
plt.xlabel('Predicted Values')
plt.ylabel('Residuals')
plt.title('Residual Plot')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
path_to_save = "OmniGenome-52M-Regression"
model.save(path_to_save, overwrite=True)

# Load the model checkpoint
model = model.load(path_to_save)
results = model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print("Predictions:", results["predictions"])
print("Logits:", results["logits"])

# We can load the model checkpoint using the ModelHub
from omnigenbench import ModelHub

# Example: Load a pre-trained regression model
# regression_model = ModelHub.load("OmniGenome-186M-Regression")
# results = regression_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
# print("Predictions:", results["predictions"])
# print("Logits:", results["logits"])

## Model Prediction Explanation

For regression tasks, the model outputs continuous values:

- **Sequence Regression**: Returns a single continuous value for the entire sequence
- **Token Regression**: Returns continuous values for each token in the sequence

The predictions can be interpreted based on your specific task:
- mRNA degradation rates
- Protein expression levels
- Binding affinities
- Structural properties

### Example Use Cases:
1. **mRNA Degradation Prediction**: Predict degradation rates for each nucleotide
2. **Protein Expression Prediction**: Predict expression levels from promoter sequences
3. **Binding Affinity Prediction**: Predict binding strengths between molecules
4. **Structural Property Prediction**: Predict structural characteristics of biomolecules