# Transcription Factor Binding Prediction with OmniGenBench

This notebook provides a step-by-step guide to extend OmniGenBench to the TFB task based on the **OmniGenome-52M** model on the **DeepSEA dataset**. The goal is to perform multi-label classification to predict the binding sites of various transcription factors based on DNA sequences.



## Notebook Structure

This notebook is organized into several sections, each focusing on a specific aspect of the Transcription Factor Binding (TFB) prediction pipeline. Below is an overview of the structure:

1. **Setup & Installation**: Ensures all required libraries and dependencies are installed.
2. **Import Libraries**: Loads the necessary Python libraries for genomic data processing, model inference, and analysis.
3. **Configuration**: Defines key parameters such as file paths, model selection, and training hyperparameters.
4. **Model Definition**: Implements a custom model class that integrates the OmniGenome backbone with a classification head tailored for the DeepSEA task.
5. **Data Loading and Preprocessing**: Handles the loading and preprocessing of the DeepSEA dataset, converting DNA sequences into tokenized inputs.
6. **Initialization**: Sets up the tokenizer, model, datasets, and data loaders for training and evaluation.
7. **Training the Model**: Fine-tunes the model using the `AccelerateTrainer` for efficient training and evaluation.
8. **Evaluation**: Assesses the model's performance on the test set using metrics such as ROC AUC.
9. **Inference Example**: Demonstrates how to use the trained model to make predictions on new DNA sequences.

Each section is designed to be modular, allowing for easy customization and extension. Follow the notebook sequentially to understand and execute the TFB prediction pipeline effectively.

## 1. Setup & Installation

First, let's ensure all the required packages are installed. If you have already installed them, you can skip this cell. Otherwise, uncomment and run the cell to install the dependencies.

In [1]:
# Uncomment the following line to install the necessary packages
# !pip install torch numpy transformers omnigenbench autocuda

## 2. Import Libraries

Import all the necessary libraries for genomic data processing, model inference, and analysis.

In [None]:
import random
import os
import autocuda
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, BatchEncoding
from omnigenbench import (
    OmniDataset,
    OmniModel,
    OmniPooling,
    Trainer,
    ClassificationMetric,
    AccelerateTrainer
)

print("Libraries imported successfully.")

## 3. Configuration

Here, we define all the hyperparameters and settings for our experiment. This centralized configuration makes it easy to modify parameters and track experiments.

In [3]:
# --- Data File Paths ---
# Ensure these .npy files are in the same directory as the notebook, or provide the full path.
TRAIN_FILE = "train_tfb.npy"
TEST_FILE = "test_tfb.npy"
VALID_FILE = "valid_tfb.npy"

# --- Model Configuration ---
# --- Available Models for Testing ---
AVAILABLE_MODELS = [
    'yangheng/OmniGenome-52M',
    'yangheng/OmniGenome-186M',
    'yangheng/OmniGenome-v1.5',
    # You can add more models here as needed,
    # 'DNABERT-2-117M',
    # 'hyenadna-large-1m-seqlen-hf',
    # 'InstaDeepAI/nucleotide-transformer-500m-human-ref',
    # 'multimolecule/rnafm', # RNA-specific models
    # 'multimolecule/rnabert',
    # 'SpliceBERT-510nt', # Splice-specific model
]

MODEL_NAME_OR_PATH = AVAILABLE_MODELS[1]
USE_CONV_LAYERS = False  # Set to True to add DeepSEA-style convolutional layers on top of OmniGenome

# --- Training Hyperparameters ---
EPOCHS = 30
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-3
BATCH_SIZE = 128
PATIENCE = 3  # For early stopping
MAX_LENGTH = 200  # The length of the DNA sequence to be processed
SEED = 45
MAX_EXAMPLES = 100000  # Use a smaller number for quick testing (e.g., 1000), or None for all data
GRADIENT_ACCUMULATION_STEPS = 1
CACHE_DATASET = True  # Set to True to cache preprocessed data for faster re-runs

# --- Device Setup ---
DEVICE = autocuda.auto_cuda()
print(f"Using device: {DEVICE}")

Using device: cuda:1


## 4. Model Definition

We define the `OmniModelForMultiLabelClassification`, which wraps the OmniGenome transformer. This class adds a classification head on top of the pre-trained backbone, tailored for the DeepSEA multi-label prediction task. It also includes an option to add convolutional layers, allowing for a hybrid architecture that combines the strengths of both transformers and CNNs.

In [4]:
class OmniModelForMultiLabelClassification(OmniModel):
    """
    A custom model for multi-label classification of genomic sequences using an OmniGenome backbone.

    This model replaces the original DeepSEA CNN architecture with a pre-trained Transformer encoder.
    It can optionally add convolutional layers after the transformer embeddings before the final classification head.
    
    Args:
        config_or_model: The Hugging Face model configuration or a pre-trained model instance.
        tokenizer: The tokenizer corresponding to the model.
        use_conv (bool): If True, add convolutional layers on top of the transformer output.
    """
    def __init__(self, config_or_model, tokenizer, use_conv=False, *args, **kwargs):
        self.threshold = kwargs.pop("threshold", 0.5)
        self.use_conv = use_conv
        super().__init__(config_or_model, tokenizer, *args, **kwargs)
        self.metadata["model_name"] = "DeepSEA_OmniGenome"

        if self.use_conv:
            # Optional convolutional layers, mimicking the original DeepSEA CNN architecture
            self.conv_layers = torch.nn.Sequential(
                torch.nn.Conv1d(self.config.hidden_size, 320, kernel_size=8, padding=4),
                torch.nn.ReLU(),
                torch.nn.MaxPool1d(kernel_size=4, stride=4),
                torch.nn.Dropout(0.2),
                torch.nn.Conv1d(320, 480, kernel_size=8, padding=4),
                torch.nn.ReLU(),
                torch.nn.MaxPool1d(kernel_size=4, stride=4),
                torch.nn.Dropout(0.2),
                torch.nn.Conv1d(480, 960, kernel_size=8, padding=4),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.5),
            )
            conv_output_dim = 960
        else:
            # If not using conv layers, the input to the classifier is the transformer's hidden size
            conv_output_dim = self.config.hidden_size

        # The original DeepSEA classification head architecture
        self.deepsea_classifier = torch.nn.Sequential(
            torch.nn.Linear(conv_output_dim, 925),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(925, self.config.num_labels)  # num_labels should be 919 for DeepSEA
        )

        self.loss_fn = torch.nn.BCEWithLogitsLoss() # Suitable for multi-label classification
        self.pooler = OmniPooling(self.config)
        self.sigmoid = torch.nn.Sigmoid()
        self.model_info() # Print model summary

    def forward(self, **inputs):
        """Defines the forward pass of the model."""
        labels = inputs.pop("labels", None)
        
        # Get embeddings from the OmniGenome backbone
        last_hidden_state = self.last_hidden_state_forward(**inputs)
        last_hidden_state = self.dropout(last_hidden_state)
        last_hidden_state = self.activation(last_hidden_state)

        if self.use_conv:
            # Apply convolutional layers
            # Reshape from (batch, seq_len, hidden) to (batch, hidden, seq_len) for Conv1d
            conv_input = last_hidden_state.transpose(1, 2)
            conv_output = self.conv_layers(conv_input)
            # Pool the output of the conv layers to a fixed size vector
            pooled_output = torch.nn.functional.adaptive_avg_pool1d(conv_output, 1).squeeze(-1)
        else:
            # Use standard pooling on the transformer output
            pooled_output = self.pooler(inputs, last_hidden_state)

        # Get logits from the final classification head
        logits = self.deepsea_classifier(pooled_output)
        outputs = {"logits": logits, "last_hidden_state": last_hidden_state}

        # Calculate loss if labels are provided
        if labels is not None:
            loss = self.loss_fn(logits, labels.to(torch.float32))
            outputs["loss"] = loss

        return outputs

    def predict(self, sequence_or_inputs, **kwargs):
        """Generates predictions for a given sequence or tokenized input."""
        if not isinstance(sequence_or_inputs, (BatchEncoding, dict)):
            # If input is a raw sequence, tokenize it
            inputs = self.tokenizer(
                sequence_or_inputs,
                padding=kwargs.pop("padding", "max_length"),
                max_length=kwargs.pop("max_length", 1024),
                truncation=True,
                return_tensors="pt",
                **kwargs,
            )
        else:
            inputs = sequence_or_inputs
        
        # Move inputs to the correct device
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self(**inputs)
        
        # Convert logits to probabilities and then to binary predictions
        probabilities = self.sigmoid(outputs["logits"])
        predictions = (probabilities >= self.threshold).to(torch.int)

        return {
            "predictions": predictions,
            "probabilities": probabilities,
            "logits": outputs["logits"],
            "last_hidden_state": outputs["last_hidden_state"],
        }

print("OmniModelForMultiLabelClassification defined.")

OmniModelForMultiLabelClassification defined.


## 5. Data Loading and Preprocessing

This section handles the data loading. We define a helper function `load_deepsea_npy_data` to parse the specific format of the DeepSEA `.npy` files. Then, we create a `DeepSEADataset` class that inherits from `OmniDataset` and uses this loader. The dataset class is responsible for converting DNA sequences into a format suitable for the OmniGenome tokenizer (i.e., space-separated tokens).

In [5]:


class DeepSEADataset(OmniDataset):
    """
    为DeepSEA任务设计的数据集，处理DNA序列到token序列的转换
    """

    def __init__(self, data_source, tokenizer, max_length=None, **kwargs):
        super().__init__(data_source, tokenizer, max_length, **kwargs)
        for key, value in kwargs.items():
            self.metadata[key] = value

    def prepare_input(self, instance, **kwargs):
        """
        准备DeepSEA的输入数据

        Expected instance format:
        {
            'sequence': DNA sequence string (e.g., "ATCGATCG...")
            'labels': binary labels as numpy array of shape (919,)
        }
        """
        labels = None
        if isinstance(instance, str):
            sequence = instance
        elif isinstance(instance, dict):
            sequence = (
                instance.get("seq", None)
                if "seq" in instance
                else instance.get("sequence", None)
            )
            label = instance.get("label", None)
            labels = instance.get("labels", None)
            labels = labels if labels is not None else label
        else:
            raise Exception("Unknown instance format.")

        if sequence is None:
            raise ValueError("Sequence is required")

        if isinstance(sequence, str):
            spaced_sequence = ' '.join(list(sequence))
        else:
            if isinstance(sequence, np.ndarray) and sequence.shape[1] == 4:
                base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'}
                sequence_str = ''.join([base_map[np.argmax(sequence[i])] for i in range(len(sequence))])
                spaced_sequence = ' '.join(list(sequence_str))
            else:
                raise ValueError(f"Unsupported sequence format: {type(sequence)}")

        tokenized_inputs = self.tokenizer(
            spaced_sequence[500-self.max_length//2:500+self.max_length//2],
            # spaced_sequence,
            padding="do_not_pad",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            **kwargs,
        )

        for col in tokenized_inputs:
            tokenized_inputs[col] = tokenized_inputs[col].squeeze()

        if labels is not None:
            if isinstance(labels, np.ndarray):
                labels = torch.from_numpy(labels).float()
            elif not isinstance(labels, torch.Tensor):
                labels = torch.tensor(labels, dtype=torch.float32)

            tokenized_inputs["labels"] = labels

        return tokenized_inputs



## 6. Initialization

Now, let's initialize the tokenizer, the model, and the datasets. This step brings everything together and prepares for the training phase.

In [None]:
# 1. Initialize Tokenizer and Model
print("--- Initializing Tokenizer and Model ---")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
base_model = AutoModel.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)

model = OmniModelForMultiLabelClassification(
    base_model,
    tokenizer,
    num_labels=919,  # DeepSEA has 919 binary labels for different chromatin features
    threshold=0.5,
    use_conv=USE_CONV_LAYERS
)
model.to(DEVICE).to(torch.float32) # Move model to the selected device

# 2. Create Datasets
print("\n--- Creating Datasets ---")
train_set = DeepSEADataset(
    data_source=TRAIN_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
)
test_set = DeepSEADataset(
    data_source=TEST_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
)
valid_set = DeepSEADataset(
    data_source=VALID_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
) if os.path.exists(VALID_FILE) else None

print("\n--- Initialization Complete ---")
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")
if valid_set:
    print(f"Validation set size: {len(valid_set)}")

## 7. Training the Model

With everything set up, we can now train the model. We'll use the `AccelerateTrainer` for a streamlined and efficient training loop. The trainer handles the training loop, evaluation, early stopping, and device placement automatically.

In [None]:
# Set random seed for reproducibility across all libraries
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Create DataLoaders for batching
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=BATCH_SIZE) if valid_set else None

# Define the metric for evaluation. For DeepSEA, ROC AUC is a standard metric.
metrics = [ClassificationMetric(ignore_y=-100).roc_auc_score]

# Create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Initialize the Trainer
trainer = AccelerateTrainer(
    model=model,
    train_loader=train_loader,
    eval_loader=valid_loader, # Use validation set for early stopping and checkpointing
    test_loader=test_loader,
    optimizer=optimizer,
    epochs=EPOCHS,
    metrics=metrics,
    patience=PATIENCE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    device=DEVICE
)

# Start Training
print("--- Starting Training ---")
trainer.train()
print("--- Training Finished ---")

## 8. Evaluation

After training is complete, the `AccelerateTrainer` automatically loads the best performing model (based on the validation set). We can now evaluate this model on the held-out test set to get a final, unbiased measure of its performance.

In [None]:
print("--- Evaluating on Test Set ---")
test_results = trainer.evaluate(loader=test_loader)
print("\nTest Set Performance (based on the best model from training):")
for metric, value in test_results.items():
    print(f"  {metric}: {value:.4f}")

## 9. Inference Example

Finally, let's see how to use the fine-tuned model to make a prediction on a new, unseen DNA sequence. This demonstrates the practical application of the trained model.

In [None]:
# Create a sample DNA sequence (must be at least MAX_LENGTH base pairs long)
sample_sequence = "AGCT" * (MAX_LENGTH // 4) # Create a sequence of the required length

# Prepare the sequence for the model (add spaces between characters)
spaced_sequence = ' '.join(list(sample_sequence))
inputs = tokenizer(spaced_sequence, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)

# Set the model to evaluation mode
model.eval()

# Make a prediction
with torch.no_grad():
    outputs = model.predict(inputs)

# Get the predictions and probabilities
predictions = outputs['predictions'].cpu().numpy().flatten()
probabilities = outputs['probabilities'].cpu().numpy().flatten()

print(f"Input sequence length: {len(sample_sequence)} bp")
print(f"Number of predicted labels: {len(predictions)}")

# Display predictions for the first 10 transcription factors
print("\n--- Predictions for the first 10 TFs ---")
for i in range(10):
    pred_label = 'Binds' if predictions[i] == 1 else 'Does not bind'
    print(f"Label {i+1}: Prediction={pred_label}, Probability={probabilities[i]:.4f}")