## Foundation Model Preparation

In [None]:
# Install required packages
!pip install -q omnigenbench torch datasets scikit-learn matplotlib seaborn

In [None]:
# In this notebook, we will introduce the custom finetuning of OmniGenome models.
model_name = "yangheng/OmniGenome-52M" # 52M parameters

# 1. Load the model and tokenizer according to model_name for later use
from transformers import AutoTokenizer, AutoModel
base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)  # trust_remote_code=True is used to load the model from the remote repository, which is necessary for OmniGenome models
base_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 # OmniGenome-52M tokenizer can be initialized by AutoTokenizer, for other models, you can use can define a wrapper to initialize the tokenizer to be compatible with transformers tokenizer APIs

## (Optional) Wrap a tokenizer to accommodate transformers APIs
### This is an optional step, you can skip it if you are using the default tokenizer provided by OmniGenome. The core idea is to create a tokenizer that use __call__ method to process the input sequence and return the tokenized inputs in a format compatible with transformers tokenizer APIs.

```
import itertools
import warnings
import torch
from ViennaRNA import ViennaRNA
from transformers import AutoTokenizer
from omnigenbench import OmniSingleNucleotideTokenizer
from omnigenbench import RNA2StructureCache


class Tokenizer(OmniSingleNucleotideTokenizer):
    def __init__(
        self, base_tokenizer=None, **kwargs):
        super(Tokenizer, self).__init__(base_tokenizer, **kwargs)
        self.metadata["tokenizer_name"] = self.__class__.__name__
        self.rna2str = RNA2StructureCache()

    bases = [4, 5, 6, 7]
    triplet_combinations = list(itertools.product(bases, repeat=3))
    kmer_to_index = {tuple(triplet): i for i, triplet in enumerate(triplet_combinations)}

    def process_input_ids(self, input_ids, k=3):
        kmer_input_ids = [64]
        for i in range(len(input_ids) - k + 1):
            kmer = tuple(input_ids[i:i + k].tolist())
            kmer_input_ids.append(self.kmer_to_index.get(kmer, 64))
        kmer_input_ids.append(64)
        return torch.tensor(kmer_input_ids)

    def __call__(self, sequence, **kwargs):
        sequence = sequence.replace("U", "T")
        structure, mfe = self.rna2str.fold(sequence, return_mfe=True)
        structure_inputs = self.base_tokenizer(structure, **kwargs)
        tokenized_inputs = self.base_tokenizer(sequence, **kwargs)
        kmer_ids = self.process_input_ids(tokenized_inputs['input_ids'][0], k=3)
        tokenized_inputs["kmer_ids"] = kmer_ids.unsqueeze(0)
        tokenized_inputs["str_ids"] = structure_inputs["input_ids"]
        return tokenized_inputs

    @staticmethod
    def from_pretrained(model_name_or_path, **kwargs):
        self = OmniSingleNucleotideTokenizer(
            AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        )
        return self

    def tokenize(self, sequence, **kwargs):
        if isinstance(sequence, str):
            sequences = [sequence]
        else:
            sequences = sequence

        sequence_tokens = []
        for i in range(len(sequences)):
            tokens = []
            for j in range(0, len(sequences[i]), self.k - self.overlap):
                tokens.append(sequences[i][j : j + self.k])

            sequence_tokens.append(tokens)

        return sequence_tokens

    def encode(self, input_ids, **kwargs):
        return self.base_tokenizer.encode(input_ids, **kwargs)

    def decode(self, input_ids, **kwargs):
        return self.base_tokenizer.decode(input_ids, **kwargs)

    def encode_plus(self, sequence, **kwargs):
        raise NotImplementedError("The encode_plus() function is not implemented yet.")
```


## Define a Model with a Classification Head for Downstream Task

In [None]:
# For a sequence classification task,
from omnigenbench import OmniModelForSequenceClassification
model = OmniModelForSequenceClassification(
    config_or_model=model_name,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    num_labels=3,  # Binary classification
)
# For a token classification task, you can use OmniModelForTokenClassification
from omnigenbench import OmniModelForTokenClassification
model = OmniModelForTokenClassification(
    config_or_model=model_name,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    num_labels=3,  # Binary classification
)


## (Optional) Define a Custom Model for Downstream Task


In [None]:
from omnigenbench import OmniModel, OmniPooling
import torch

class OmniModelForSequenceClassification(OmniModel):
    def __init__(self, config_or_model, tokenizer, *args, **kwargs):
        super().__init__(config_or_model, tokenizer, *args, **kwargs)
        self.metadata["model_name"] = self.__class__.__name__
        self.pooler = OmniPooling(self.config)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.classifier = torch.nn.Linear(
            self.config.hidden_size, self.config.num_labels
        )
        self.loss_fn = torch.nn.CrossEntropyLoss()
        # self.model_info()

    def forward(self, **inputs):
        labels = inputs.pop("labels", None)
        last_hidden_state = self.last_hidden_state_forward(**inputs)
        last_hidden_state = self.dropout(last_hidden_state)
        last_hidden_state = self.activation(last_hidden_state)
        last_hidden_state = self.pooler(inputs, last_hidden_state)
        logits = self.classifier(last_hidden_state)
        logits = self.softmax(logits)
        outputs = {
            "logits": logits,
            "last_hidden_state": last_hidden_state,
            "labels": labels,
        }
        return outputs

    def predict(self, sequence_or_inputs, **kwargs):
        raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)

        logits = raw_outputs["logits"]
        last_hidden_state = raw_outputs["last_hidden_state"]

        predictions = []
        for i in range(logits.shape[0]):
            predictions.append(logits[i].argmax(dim=-1))

        outputs = {
            "predictions": (
                torch.vstack(predictions).to(self.model.device)
                if predictions[0].shape
                else torch.tensor(predictions).to(self.model.device)
            ),
            "logits": logits,
            "last_hidden_state": last_hidden_state,
        }

        return outputs

    def inference(self, sequence_or_inputs, **kwargs):
        raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)

        logits = raw_outputs["logits"]
        last_hidden_state = raw_outputs["last_hidden_state"]

        predictions = []
        for i in range(logits.shape[0]):
            predictions.append(
                self.config.id2label.get(logits[i].argmax(dim=-1).item(), "")
            )

        if not isinstance(sequence_or_inputs, list):
            outputs = {
                "predictions": predictions[0],
                "logits": logits[0],
                "confidence": torch.max(logits[0]),
                "last_hidden_state": last_hidden_state[0],
            }
        else:
            outputs = {
                "predictions": predictions,
                "logits": logits,
                "confidence": torch.max(logits, dim=-1)[0],
                "last_hidden_state": last_hidden_state,
            }

        return outputs

    def loss_function(self, logits, labels):
        loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
        return loss

## Define a Dataset for Downstream Task



In [None]:
# For a sequence classification task
from omnigenbench import OmniDatasetForSequenceClassification
# For a token classification task
from omnigenbench import OmniDatasetForTokenClassification


## (Option) Define a Custom Dataset for Downstream Task

### To define a custom dataset for a sequence classification task, you can inherit from `OmniDataset` and implement the `prepare_input` method to process the input data.
Make sure your dataset is in a format compatible with the tokenizer API, returning tokenized inputs and labels.
```
class Dataset(OmniDataset):
    def __init__(self, data_source, tokenizer, max_length, **kwargs):
        super().__init__(data_source, tokenizer, max_length, **kwargs)

    def prepare_input(self, instance, **kwargs):
        tokenized_inputs = self.tokenizer(
            instance["sequence"],
            padding=kwargs.get("padding", "do_not_pad"),
            truncation=kwargs.get("truncation", True),
            max_length=self.max_length,
            return_tensors="pt",
        )
        tokenized_inputs["labels"] = torch.tensor(instance["label"], dtype=torch.long)
        for col in tokenized_inputs:
            tokenized_inputs[col] = tokenized_inputs[col].squeeze()
        return tokenized_inputs
```

In [None]:
## Load the dataset according to the path

In [None]:
dataset_path = "toy_datasets/Archive2/"  # Path to your dataset files
# Archive2 is RNA secondary structure prediction dataset (token classification), containing train.json, test.json, and valid.json files
train_file = dataset_path + "train.json"
test_file = dataset_path + "test.json"
valid_file = dataset_path + "valid.json"
train_set = OmniDatasetForTokenClassification(
    data_source=train_file,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    max_length=512,  # Set the maximum sequence length
)
test_set = OmniDatasetForTokenClassification(
    data_source=test_file,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    max_length=512,  # Set the maximum sequence length
)
valid_set = OmniDatasetForTokenClassification(
    data_source=valid_file,
    tokenizer=base_tokenizer,  # Use the base tokenizer or the custom tokenizer defined above
    max_length=512,  # Set the maximum sequence length
)



In [None]:

# Split dataset into train/validation sets
from sklearn.model_selection import train_test_split

train_dataset, val_dataset = train_test_split(
    dataset,
    test_size=0.2,
    random_state=42,
    stratify=dataset['label'] if 'label' in dataset.column_names else None
)
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")


## Training Implementation

In [None]:
from omnigenbench import ClassificationMetric  # contains all metrics from sklearn.metrics and some custom metrics for classification tasks
from omnigenbench import Trainer

# necessary hyperparameters
epochs = 10
learning_rate = 2e-5
weight_decay = 1e-5
batch_size = 8
max_length = 512
seeds = [45]  # Each seed will be used for one run
compute_metrics = [
    ClassificationMetric(ignore_y=-100).accuracy_score,
    ClassificationMetric(ignore_y=-100, average="macro").f1_score,
    ClassificationMetric(ignore_y=-100).matthews_corrcoef,
]
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)


for seed in seeds:
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        eval_loader=valid_loader,
        test_loader=test_loader,
        batch_size=batch_size,
        epochs=epochs,
        optimizer=optimizer,
        compute_metrics=compute_metrics,
        seeds=seed,
    )

    metrics = trainer.train()
    test_metrics = metrics["test"][-1]
    print(metrics)


## Model Loading and Inference

## Evaluation
After training, we evaluate the model on the validation set with a classification report and confusion matrix.

In [None]:

import numpy as np
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# Assuming model has a predict method returning logits
preds = model.predict(val_dataset)
y_true = np.array(val_dataset['label'])
y_pred = np.argmax(preds, axis=1)

print(classification_report(y_true, y_pred, digits=3))

ConfusionMatrixDisplay.from_predictions(y_true, y_pred)
plt.show()


In [None]:
path_to_save = "OmniGenome-52M-SSP"
model.save(path_to_save, overwrite=True)

# Load the model checkpoint
model = model.load(path_to_save)
results = model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

# We can load the model checkpoint using the ModelHub
from omnigenbench import ModelHub

ssp_model = ModelHub.load("OmniGenome-186M-SSP")
results = ssp_model.inference("CAGUGCCGAGGCCACGCGGAGAACGAUCGAGGGUACAGCACUA")
print(results["predictions"])
print("logits:", results["logits"])

## Model Prediction explanation