# Fine-Tuning OmniGenome for Transcription Factor Binding Prediction

This tutorial provides a comprehensive guide to fine-tuning a pre-trained genomic foundation model, [OmniGenome-52M](https://huggingface.co/yangheng/OmniGenome-52M), for a Transcription Factor Binding (TFB) prediction task. We will use the [DeepSEA](https://www.nature.com/articles/nmeth.3547) dataset and the OmniGenBench library to build a multi-label classifier that predicts the binding sites of 919 different transcription factors from a raw DNA sequence.

## Task Overview
The goal is to perform multi-label binary classification. Given a DNA sequence, the model must predict for each of the 919 possible chromatin features whether the sequence is a binding site (label 1) or not (label 0).

##  Dataset Description
The data for this tutorial is a preprocessed version of the DeepSEA dataset, which was originally designed for studying the effects of non-coding genetic variants. It consists of:

- Inputs: DNA sequences of 1000 base pairs (bp).

- Labels: 919 binary labels corresponding to various chromatin features, including transcription factor binding, DNase I sensitivity, histone marks, etc.

We will use a version of this dataset hosted on Hugging Face at [yangheng/tfb_prediction](https://huggingface.co/datasets/yangheng/tfb_prediction).

## Estimated Runtime

**The total runtime depends on your hardware.**

- **Full Run**: On a single NVIDIA RTX 4090 GPU, training with the default settings (MAX_EXAMPLES=100000, EPOCHS=10) takes approximately 1-2 hours.

- **Quick Test**: For a quick test run with MAX_EXAMPLES=1000, it should take about 5-10 minutes.

## Notebook Structure
This notebook is organized into the following sections:

1. **Setup & Installation**: Prepares the environment by installing necessary libraries.

2. **Import Libraries**: Loads all required Python packages.

3. **Configuration**: Centralizes all hyperparameters and paths for easy modification.

4. **Model Definition**: Implements a custom PyTorch model that integrates the OmniGenome backbone with a classification head.

5. **Data Loading and Preprocessing**: Defines the dataset class and logic for loading and tokenizing the DNA sequences.

6. **Initialization**: Instantiates the tokenizer, model, and datasets, preparing them for training.

7. **Training the Model**: Fine-tunes the model using the efficient AccelerateTrainer.

8. **Evaluation**: Assesses the final model's performance on the test set using the area under the ROC curve (AUC-ROC) metric.

9. **Inference Example**: Shows how to use the trained model for predictions on a new DNA sequence.

10. **Conclusion**: Summarizes the tutorial and suggests next steps.

# 1. Setup & Installation

Before we begin, you need to prepare your environment. This involves installing Git LFS to download the large dataset files from Hugging Face and then installing the required Python libraries.

## Step 1.1: Install and Initialize Git LFS
If you don't have Git LFS, you must install it to download the data. Run the appropriate command for your system in a **terminal**, not in this notebook.

- **On Debian/Ubuntu**: `sudo apt-get install git-lfs`

- **On macOS (with Homebrew)**:`brew install git-lfs`

- **On Windows**: Download and run the installer from the [official Git LFS website](https://git-lfs.com/).

## Step 1.2: Install Python Libraries

Now, run the following cell to install the necessary Python packages for this tutorial.

In [None]:
# Uncomment the following line to install the necessary packages
# !pip install torch numpy transformers omnigenbench autocuda findfile accelerate scikit-learn

# 2. Import Libraries
Next, we import the necessary libraries for data manipulation, model building, and training.

In [2]:
import os
import random
import zipfile

import autocuda
import findfile
import numpy as np
import torch
from omnigenbench import (
    AccelerateTrainer,
    ClassificationMetric,
    OmniDataset,
    OmniLoraModel,
    OmniModel,
    OmniPooling,
)
from transformers import AutoModel, AutoTokenizer, BatchEncoding

print("Libraries imported successfully.")

Libraries imported successfully.


# 3. Configuration
In this section, we will set up all the necessary parameters for our experiment. This includes downloading the data, selecting a model, defining training hyperparameters, and configuring the hardware.

## 3.1 Download and Verify the Dataset
First, we define the location of our dataset. The following code will download the `tfb_prediction` dataset from Hugging Face using `git`, extract it if necessary, and then verify that the required data files (`.npy`) are present.

In [19]:
# --- Define Data Source and Local Directory ---
DATASET_URL = "https://huggingface.co/datasets/yangheng/tfb_prediction"
LOCAL_DIR = "tfb_prediction_dataset"

# --- Download and Extract ---
if not os.path.isdir(LOCAL_DIR):
    print(f"Cloning dataset from {DATASET_URL} into {LOCAL_DIR}...")
    os.system(f"git clone {DATASET_URL} {LOCAL_DIR}")
    
    zip_path = os.path.join(LOCAL_DIR, "tfb_dataset.zip")
    if os.path.exists(zip_path):
        print(f"Extracting {zip_path}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(LOCAL_DIR)
        os.remove(zip_path) # Clean up the zip file
else:
    print(f"Dataset already found in {LOCAL_DIR}.")

# --- Define and Verify File Paths ---
TRAIN_FILE = os.path.join(LOCAL_DIR, "train_tfb.npy")
TEST_FILE = os.path.join(LOCAL_DIR, "test_tfb.npy")
VALID_FILE = os.path.join(LOCAL_DIR, "valid_tfb.npy")

# Verify that the files exist to prevent errors later
if not os.path.exists(TRAIN_FILE):
    raise FileNotFoundError(f"Training file not found at {TRAIN_FILE}. Please check the download step.")
if not os.path.exists(TEST_FILE):
    raise FileNotFoundError(f"Test file not found at {TEST_FILE}.")
if not os.path.exists(VALID_FILE):
    print(f"Warning: Validation file not found at {VALID_FILE}. Skipping validation.")
else:
    print("All data files found successfully.")

Dataset already found in tfb_prediction_dataset.
All data files found successfully.


## 3.2: Select the Foundation Model
Here, you can choose which pre-trained genomic foundation model to fine-tune. We have provided a list of compatible models. To switch models, simply change the index for `AVAILABLE_MODELS`.

In [9]:
# --- Model Configuration ---
AVAILABLE_MODELS = [
    'yangheng/OmniGenome-52M',      # Default model for this tutorial
    'yangheng/OmniGenome-186M',     # A larger OmniGenome model
    'DNABERT-2-117M',               # DNABERT-2 by ailab
    'InstaDeepAI/nucleotide-transformer-500m-human-ref', # Nucleotide Transformer
    # 'DNABERT-2-117M',  # You can add more models here as needed,
    # 'LongSafari/hyenadna-large-1m-seqlen-hf',
    # 'InstaDeepAI/nucleotide-transformer-500m-human-ref',
    # 'multimolecule/rnafm', # RNA-specific models
    # 'multimolecule/rnamsm',
    # 'multimolecule/rnabert',
    # 'SpliceBERT-510nt', # Splice-specific model
]

# Select the model by its index in the list (0 = OmniGenome-52M)
MODEL_NAME_OR_PATH = AVAILABLE_MODELS[0]

print(f"Selected model: {MODEL_NAME_OR_PATH}")

Selected model: yangheng/OmniGenome-52M


## 3.3: Set Training Hyperparameters

This step defines the key parameters that control the training process, such as learning rate, batch size, and the number of epochs. You can also set a limit on the number of training examples for a quicker test run.

In [18]:
# --- Training Hyperparameters ---
EPOCHS = 10                  # Number of training epochs
LEARNING_RATE = 2e-5         # Optimizer learning rate
WEIGHT_DECAY = 1e-5          # Weight decay for regularization
BATCH_SIZE = 16              # Number of samples per batch
PATIENCE = 3                 # Number of epochs with no improvement to wait before stopping
MAX_LENGTH = 200             # Sequence length to process (the central 200bp of the 1000bp sequence)
SEED = 42                    # Random seed for reproducibility
MAX_EXAMPLES = 100000        # Max examples for training/testing. Use a small number (e.g., 1000) for a quick test run.
GRADIENT_ACCUMULATION_STEPS = 1 # Accumulates gradients over multiple steps for a larger effective batch size

print(f"Training with {MAX_EXAMPLES} examples for {EPOCHS} epochs.")

Training with 100000 examples for 10 epochs.


## 3.4: Configure Hardware Device
Finally, we'll set the device for training. The `autocuda` library will automatically select an available NVIDIA GPU if possible, otherwise it will fall back to the CPU.

In [8]:
# --- Device Setup ---
DEVICE = autocuda.auto_cuda()

print(f"Using device: {DEVICE}")

Using device: cuda:1


# 4. Model Definition
Here, we define our custom model, `OmniModelForMultiLabelClassification`. This class inherits from `OmniModel` and wraps a pre-trained Transformer (like OmniGenome) with a classification head suitable for our 919-label task.

In [17]:
class OmniModelForMultiLabelClassification(OmniModel):
    """
    Multi-label sequence classification model based on OmniGenome-52M.

    This model replaces the original DeepSEA CNN architecture with a pretrained
    Transformer encoder from the OmniGenome family. Optionally, convolutional
    layers can be stacked on top of the Transformer outputs for additional
    feature extraction.

    Parameters:
        config_or_model (PretrainedConfig or nn.Module):
            Configuration or instance of the pretrained Transformer model,
            typically obtained via AutoModel.from_pretrained().
        tokenizer (PreTrainedTokenizer):
            Tokenizer compatible with the Transformer encoder, used to convert
            DNA sequences into model inputs.
        threshold (float, optional):
            Probability threshold for binary decisions in predict(), defaults to 0.5.
        use_conv (bool, optional):
            If True, apply convolutional layers after the Transformer encoder
            for enhanced feature extraction, defaults to False.
        *args, **kwargs: Additional arguments passed to the base OmniModel class.

    Attributes:
        threshold (float):
            Probability cutoff for generating binary predictions.
        deepsea_classifier (nn.Sequential):
            Classification head consisting of a Tanh activation followed by
            a linear layer mapping to the number of labels.
        loss_fn (nn.BCEWithLogitsLoss):
            Binary cross-entropy loss with logits, using pos_weight to balance
            positive and negative samples.
        pooler (OmniPooling):
            Utility for pooling token-level outputs into a sequence-level vector.
        sigmoid (nn.Sigmoid):
            Activation used to convert logits to probabilities during inference.
    """

    def __init__(self, config_or_model, tokenizer, *args, **kwargs):
        self.threshold = kwargs.pop("threshold", 0.5)
        super().__init__(config_or_model, tokenizer, *args, **kwargs)
        self.metadata["model_name"] = "OmniModelForMultiLabelClassification"

        # Classification head based directly on Transformer outputs
        conv_output_dim = self.config.hidden_size
        self.deepsea_classifier = torch.nn.Sequential(
            torch.nn.Tanh(),
            torch.nn.Linear(conv_output_dim, self.config.num_labels),
        )

        # Use pos_weight to address class imbalance in BCEWithLogitsLoss
        self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0]))
        self.pooler = OmniPooling(self.config)
        self.sigmoid = torch.nn.Sigmoid()
        self.model_info()

    def forward(self, **inputs):
        """
        Forward pass through the model: encode, pool, classify, and optionally compute loss.

        Args:
            inputs (dict):
                Must contain 'input_ids', 'attention_mask', etc., and optionally 'labels'.

        Returns:
            dict: {
                'logits': Tensor of shape (batch_size, num_labels), raw scores before
                          sigmoid activation.
                'last_hidden_state': Tensor of shape (batch_size, seq_len, hidden_size),
                                     the last layer hidden states from the Transformer.
                'loss' (optional): Computed BCEWithLogitsLoss if 'labels' provided.
            }

        Raises:
            ValueError: If all provided labels are zero or if logits and labels shapes mismatch.
        """
        labels = inputs.pop("labels", None)

        # Encode inputs
        last_hidden_state = self.last_hidden_state_forward(**inputs)
        last_hidden_state = self.dropout(last_hidden_state)
        last_hidden_state = self.activation(last_hidden_state)

        # Pooling strategy
        if self.pooler._is_causal_lm():
            pad_token_id = getattr(self.config, "pad_token_id", -100)
            sequence_lengths = inputs["input_ids"].ne(pad_token_id).sum(dim=1) - 1
            pooled_output = last_hidden_state[
                torch.arange(inputs["input_ids"].size(0), device=last_hidden_state.device),
                sequence_lengths,
            ]
        else:
            pooled_output = self.pooler(inputs, last_hidden_state)

        logits = self.deepsea_classifier(pooled_output)
        outputs = {"logits": logits, "last_hidden_state": last_hidden_state}

        if labels is not None:
            if torch.sum(labels[labels != -100]) == 0:
                raise ValueError("Labels cannot be all zeros.")
            labels = labels[labels != -100]
            loss = self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
            outputs["loss"] = loss

        return outputs

    def predict(self, sequence_or_inputs, **kwargs):
        """
        Perform inference on raw sequences or tokenized inputs, returning probabilities and predictions.

        Args:
            sequence_or_inputs (str, BatchEncoding, or dict):
                Raw DNA string or pre-tokenized inputs.
            padding (str, optional): Padding strategy for tokenizer, defaults to 'max_length'.
            max_length (int, optional): Maximum sequence length, defaults to 1024.
            **kwargs: Additional tokenizer arguments.

        Returns:
            dict: {
                'predictions': Tensor of binary labels,
                'probabilities': Tensor of positive class probabilities,
                'logits': Tensor of raw scores,
                'last_hidden_state': Transformer outputs.
            }
        """
        if not isinstance(sequence_or_inputs, BatchEncoding) and not isinstance(
                sequence_or_inputs, dict
        ):
            inputs = self.tokenizer(
                sequence_or_inputs,
                padding=kwargs.pop("padding", "max_length"),
                max_length=kwargs.pop("max_length", 1024),
                truncation=True,
                return_tensors="pt",
                **kwargs,
            )
        else:
            inputs = sequence_or_inputs
        inputs = inputs.to(self.model.device)

        with torch.no_grad():
            outputs = self(**inputs)
        logits = outputs["logits"]
        last_hidden_state = outputs["last_hidden_state"]

        probabilities = self.sigmoid(logits)
        predictions = (probabilities >= self.threshold).to(torch.int)

        return {
            "predictions": predictions,
            "probabilities": probabilities,
            "logits": logits,
            "last_hidden_state": last_hidden_state,
        }

    def loss_function(self, logits, labels):
        """
        Compute BCEWithLogitsLoss for multi-label classification.

        Args:
            logits (Tensor): Raw output scores, shape (batch_size, num_labels).
            labels (Tensor): Ground-truth labels as floats, same shape as logits.

        Returns:
            Tensor: Loss value.

        Raises:
            ValueError: If logits and labels shapes do not match.
        """
        valid_labels = labels.to(torch.float32)
        if logits.shape != valid_labels.shape:
            raise ValueError(f"Shape mismatch between logits {logits.shape} and labels {valid_labels.shape}")
        return self.loss_fn(logits, valid_labels)


print("OmniModelForMultiLabelClassification defined.")

OmniModelForMultiLabelClassification defined.


# 5. Data Loading and Preprocessing
This section defines our `DeepSEADataset` class, which handles loading the `.npy` files and preparing the DNA sequences for the model.

**Important Note on Tokenization**: Most genomic foundation models (like DNABERT and OmniGenome) are trained on sequences where each nucleotide is separated by a space (e.g., `"A T C G"` instead of `"ATCG"`). Our dataset class must perform this conversion before tokenization.

In [15]:
class DeepSEADataset(OmniDataset):
    """
    Dataset designed for the DeepSEA task, handling the conversion from DNA sequences to tokens
    """

    def __init__(self, data_source, tokenizer, max_length=None, **kwargs):
        super().__init__(data_source, tokenizer, max_length, **kwargs)
        for key, value in kwargs.items():
            self.metadata[key] = value

    def prepare_input(self, instance, **kwargs):
        """
        Prepare input data for DeepSEA

        Expected instance format:
        {
            'sequence': DNA sequence string (e.g., "ATCGATCG...")
            'labels': binary labels as numpy array of shape (919,)
        }
        """
        labels = None
        if isinstance(instance, str):
            sequence = instance
        elif isinstance(instance, dict):
            sequence = (
                instance.get("seq", None)
                if "seq" in instance
                else instance.get("sequence", None)
            )
            label = instance.get("label", None)
            labels = instance.get("labels", None)
            labels = labels if labels is not None else label
        else:
            raise Exception("Unknown instance format.")

        if sequence is None:
            raise ValueError("Sequence is required")

        # Convert DNA sequence to space-separated format for tokenizer
        # e.g., "ATCG" -> "A T C G"
        if isinstance(sequence, str):
            spaced_sequence = ' '.join(list(sequence))
        else:
            # If sequence is one-hot encoded, convert to string first
            if isinstance(sequence, np.ndarray) and sequence.shape[1] == 4:
                # one-hot to sequence string
                base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'}
                sequence_str = ''.join([base_map[np.argmax(sequence[i])] for i in range(len(sequence))])
                spaced_sequence = ' '.join(list(sequence_str))
            else:
                raise ValueError(f"Unsupported sequence format: {type(sequence)}")

        # Use tokenizer to process the sequence
        tokenized_inputs = self.tokenizer(
            spaced_sequence[500-self.max_length//2:500+self.max_length//2],  # DeepSEA usually processes 200bp sequences
            # spaced_sequence,
            padding="do_not_pad",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            **kwargs,
        )

        # Squeeze dimensions
        for col in tokenized_inputs:
            tokenized_inputs[col] = tokenized_inputs[col].squeeze()

        if labels is not None:
            # For sequence classification, labels should be a fixed-length vector
            if isinstance(labels, np.ndarray):
                labels = torch.from_numpy(labels).float()
            elif not isinstance(labels, torch.Tensor):
                labels = torch.tensor(labels, dtype=torch.float32)

            tokenized_inputs["labels"] = labels

        return tokenized_inputs

# 6. Initialization and Data Inspection
Now, we instantiate the tokenizer, model, and datasets, bringing all the components together. We will also inspect a sample to see what our processed data looks like before feeding it to the model.

## 6.1. Initialize Tokenizer, Model, and Datasets
This step loads the pre-trained model and its corresponding tokenizer from Hugging Face, wraps it in our custom `OmniModelForMultiLabelClassification` class, and then creates the `Dataset` objects for our training, validation, and test splits.

In [None]:
# Initialize Tokenizer and Model
print("--- Initializing Tokenizer and Model ---")
if "multimolecule" in MODEL_NAME_OR_PATH.lower():
    from multimolecule import AutoModelForTokenPrediction, RnaTokenizer
    base_model = AutoModelForTokenPrediction.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True).base_model
    tokenizer = RnaTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)
else:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
    base_model = AutoModel.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)

model = OmniModelForMultiLabelClassification(
    base_model,
    tokenizer,
    num_labels=919,  # DeepSEA has 919 binary labels for different chromatin features
    threshold=0.5,
)

# If you want to use LoRA, uncomment the following lines
# lora_config = {
#     "lora_r": 8,  # Rank of the LoRA layers
#     "lora_alpha": 16,  # Scaling factor for LoRA
#     "lora_dropout": 0.1,  # Dropout rate for LoRA layers
#     "target_modules": ["deepsea_classifier"],  # Target modules to apply LoRA
# }
# model = OmniLoraModel(model, lora_config=lora_config)

model.to(DEVICE).to(torch.float32) # Move model to the selected device
# 2. Create Datasets
print("\n--- Creating Datasets ---")
train_set = DeepSEADataset(
    data_source=TRAIN_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    force_padding=False,  # DeepSEA does not require padding
)
test_set = DeepSEADataset(
    data_source=TEST_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    force_padding=False,  # DeepSEA does not require padding
)
valid_set = DeepSEADataset(
    data_source=VALID_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    force_padding=False,  # DeepSEA does not require padding
) if os.path.exists(VALID_FILE) else None

In [30]:
print("\n--- Initialization Complete ---")
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")
if valid_set:
    print(f"Validation set size: {len(valid_set)}")


--- Initialization Complete ---
Training set size: 100000
Test set size: 100000
Validation set size: 8000


## Step 6.2. Inspect a Data Sample
Let's examine a single sample from our `train_set` to understand its structure. Our `DeepSEADataset` class processes the raw data into tokenized tensors (`input_ids`) and a label tensor.

In [21]:
# Get the first sample from the training set
sample = train_set[0]

print(f"Sample keys: {sample.keys()}")
print("-" * 30)

# --- Inspect Input ---
input_ids = sample['input_ids']
print(f"Input IDs Tensor Shape: {input_ids.shape}")

# Decode the input_ids back to a human-readable sequence
decoded_sequence = tokenizer.decode(input_ids, skip_special_tokens=True)
print(f"Decoded Sequence (first 60 chars): '{decoded_sequence[:60]}...'")
print("-" * 30)

# --- Inspect Labels ---
labels = sample['labels']
print(f"Labels Tensor Shape: {labels.shape}")
print(f"Labels Tensor Dtype: {labels.dtype}")

# Show the first 20 labels for this sequence
print(f"First 20 labels: {labels[:20].int().tolist()}")


Sample keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
------------------------------
Input IDs Tensor Shape: torch.Size([102])
Decoded Sequence (first 60 chars): 'G T T C A A G A A T G C A T A A A T T G T A T C T T C A G A ...'
------------------------------
Labels Tensor Shape: torch.Size([919])
Labels Tensor Dtype: torch.float32
First 20 labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


# 7. Training the Model
We are now ready to train. We will use the `AccelerateTrainer` from `omnigenbench`, which simplifies the training and evaluation loop, handles device placement, and integrates with `torch.utils.data.DataLoader` and `accelerate`.

In [22]:
# Set random seed for reproducibility across all libraries
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Create DataLoaders for batching
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=BATCH_SIZE) if valid_set else None

# Define the metric for evaluation. For DeepSEA, ROC AUC is a standard metric.
metrics = [ClassificationMetric(ignore_y=-100).roc_auc_score]

# Create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Initialize the Trainer
trainer = AccelerateTrainer(
    model=model,
    train_loader=train_loader,
    eval_loader=valid_loader, # Use validation set for early stopping and checkpointing
    test_loader=test_loader,
    optimizer=optimizer,
    epochs=EPOCHS,
    compute_metrics=metrics,
    patience=PATIENCE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    device=DEVICE
)

# Start Training
print("--- Starting Training ---")
metrics = trainer.train()
print("--- Training Finished ---")

--- Starting Training ---


Evaluating: 100%|██████████| 500/500 [00:08<00:00, 62.00it/s]


{'roc_auc_score': 0.49679106328943734}


Epoch 1/10 Loss: 0.6642: 100%|██████████| 6250/6250 [06:12<00:00, 16.76it/s]
Evaluating: 100%|██████████| 500/500 [00:07<00:00, 65.13it/s]


{'roc_auc_score': 0.6626522016257755}


Epoch 2/10 Loss: 0.6421: 100%|██████████| 6250/6250 [06:11<00:00, 16.82it/s]
Evaluating: 100%|██████████| 500/500 [00:08<00:00, 59.37it/s]


{'roc_auc_score': 0.6548992354311663}


Epoch 3/10 Loss: 0.6299: 100%|██████████| 6250/6250 [06:15<00:00, 16.66it/s]
Evaluating: 100%|██████████| 500/500 [00:08<00:00, 61.20it/s]


{'roc_auc_score': 0.655689085400226}


Epoch 4/10 Loss: 0.6026: 100%|██████████| 6250/6250 [06:10<00:00, 16.89it/s]
Evaluating: 100%|██████████| 500/500 [00:08<00:00, 60.77it/s]


{'roc_auc_score': 0.6620117452714009}


Epoch 5/10 Loss: 0.5628: 100%|██████████| 6250/6250 [06:13<00:00, 16.74it/s]
Evaluating: 100%|██████████| 500/500 [00:07<00:00, 62.64it/s]


{'roc_auc_score': 0.6503828399274646}
Early stopping at epoch 5.


Testing: 100%|██████████| 6250/6250 [01:42<00:00, 61.03it/s]


{'roc_auc_score': 0.6840973996364564}
--- Training Finished ---


# 8. Evaluation
After training, the `AccelerateTrainer` automatically loads the best model checkpoint (based on validation set performance). It then runs a final evaluation on the held-out test set to provide an unbiased measure of the model's performance.

In [24]:
print("--- Evaluating on Test Set ---")
# The evaluation is automated by the trainer
print(f"All metrics:", metrics)
for metric in metrics['test']:
    print(f"Test metric:  {metric}")

--- Evaluating on Test Set ---
All metrics: {'valid': [{'roc_auc_score': 0.49679106328943734}, {'roc_auc_score': 0.6626522016257755}, {'roc_auc_score': 0.6548992354311663}, {'roc_auc_score': 0.655689085400226}, {'roc_auc_score': 0.6620117452714009}, {'roc_auc_score': 0.6503828399274646}], 'best_valid': {'roc_auc_score': 0.6626522016257755}, 'test': [{'roc_auc_score': 0.6840973996364564}]}
Test metric:  {'roc_auc_score': 0.6840973996364564}


# 9. Inference Example
Finally, let's see how to use our fine-tuned model on a real-world example. Instead of a synthetic sequence, we'll take the first sample from our held-out test set. This is a powerful way to gut-check the model's performance, as we can directly compare its predictions with the actual ground truth labels for that sequence.

In [29]:

# 1. Select the first sample from the test set
test_sequence = tokenizer.decode(test_set[0]['input_ids'], skip_special_tokens=True) # you can also use any DNA sequence you want
print(f"Test Sequence: {test_sequence}")
true_labels = test_set[0]['labels'].int() # The ground truth



# 2. Prepare the sequence for the model (add spaces between characters)
spaced_sequence = ' '.join(list(test_sequence))
inputs = tokenizer(spaced_sequence, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)
print(f"Input Sequence (from test set): '{inputs[:80]}...'")
print("-" * 65)

model.eval()

# 4. Make a prediction
with torch.no_grad():
    outputs = model.predict(inputs)

# 5. Extract predictions and probabilities
# We take the first element [0] as our batch size is 1.
predictions = outputs['predictions'][0].cpu().numpy()
probabilities = outputs['probabilities'][0].cpu().numpy()

# 6. Compare predictions with true labels for the first 20 TFs
print("--- Inference on a Test Sample (first 20 labels) ---")
print(f"{'Label #':<10} | {'Prediction':<13} | {'Ground Truth':<13} | {'Probability'}")
print("-" * 65)

for i in range(20):
    pred_label = 'Binds' if predictions[i] == 1 else 'Does not bind'
    true_label = 'Binds' if true_labels[i] == 1 else 'Does not bind'
    prob = probabilities[i]
    # Add a checkmark if the prediction is correct
    correct = "correct" if pred_label == true_label else "false"
    
    print(f"Label {i+1:<7} | {pred_label:<13} | {true_label:<13} | {prob:.4f}  {correct}")

# Optional: Calculate and print the accuracy just for this single sample
accuracy = (predictions == true_labels.cpu().numpy()).mean()
print("-" * 65)
print(f"Accuracy for this single sample: {accuracy:.2%}")

Test Sequence: G C C A T T G G C C G T C T G T G C C A C C T G C C C A C T G T G A A G G C A T G T G A C T T G G A T C C T G G T G A A G G A G G T G G C T G T G T G G C G G G G T G G G C A G G T A A A G A A G C A G
Input Sequence (from test set): '{'input_ids': tensor([[0, 6, 5, 5, 4, 7, 7, 6, 6, 5, 5, 6, 7, 5, 7, 6, 7, 6, 5, 5, 4, 5, 5, 7,
         6, 5, 5, 5, 4, 5, 7, 6, 7, 6, 4, 4, 6, 6, 5, 4, 7, 6, 7, 6, 4, 5, 7, 7,
         6, 6, 4, 7, 5, 5, 7, 6, 6, 7, 6, 4, 4, 6, 6, 4, 6, 6, 7, 6, 6, 5, 7, 6,
         7, 6, 7, 6, 6, 5, 6, 6, 6, 6, 7, 6, 6, 6, 5, 4, 6, 6, 7, 4, 4, 4, 6, 4,
         4, 6, 5, 4, 6, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1]])}...'
--------

# 10. Conclusion
Congratulations! You have successfully fine-tuned a genomic foundation model for transcription factor binding prediction.

In this tutorial, you have learned how to:

- Set up a project environment for genomic deep learning.

- Load and preprocess the DeepSEA dataset using a custom `OmniDataset` class.

- Define a multi-label classification model by adding a custom head to a pre-trained `OmniGenome` backbone.

- Train and evaluate the model efficiently using the `AccelerateTrainer` from OmniGenBench.

- Use the final model to make predictions on new DNA sequences.

From here, you could explore:

- Experimenting with other models: Try different backbones from the `AVAILABLE_MODELS` list (see section 3.2).

- Hyperparameter tuning: Adjust `LEARNING_RATE`, `BATCH_SIZE`, or `WEIGHT_DECAY` to improve performance (see section 3.3).

- Using LoRA: Uncomment the `OmniLoraModel` code in the Initialization section to try parameter-efficient fine-tuning (see section 6.1).

- Applying to other tasks: Adapt the pipeline for other genomic classification tasks, such as predicting promoter regions or splice sites.