# Transcription Factor Binding Prediction with OmniGenBench

This notebook provides a step-by-step demonstration to extend OmniGenBench to the TFB task based on the **OmniGenome-52M** model on the **DeepSEA dataset**. The goal is to perform multi-label classification to predict the binding sites of various transcription factors based on DNA sequences.

**Dataset Description:**
The dataset used in this notebook is derived from the DeepSEA dataset, which is designed for studying the effects of non-coding variants. It consists of DNA sequences of 1000 base pairs, each associated with 919 binary labels corresponding to various chromatin features (transcription factor binding, DNase I sensitivity, and histone marks). For this task, we use a preprocessed version available from the [`deepsea_tfb_prediction`](https://huggingface.co/datasets/yangheng/tfb_prediction) dataset on Hugging Face.

**Estimated Runtime:**
The total runtime for this notebook depends on the hardware and the number of training examples (`MAX_EXAMPLES`). On a single NVIDIA RTX 4090 GPU, training with the default settings (`MAX_EXAMPLES=100000`, `EPOCHS=10`) takes approximately **1–2 hours**. For a quick test run with `MAX_EXAMPLES=1000`, it should take about **5–10 minutes**.


## Notebook Structure

This notebook is organized into concise sections. Most core logic is moved to [`examples/tfb_prediction/utils.py`](https://github.com/COLA-Laboratory/OmniGenBench/blob/master/examples/tfb_prediction/utils.py) and imported here:

1. **Setup & Installation**: Ensures all required libraries and dependencies are installed.
2. **Import Libraries**: Loads the necessary Python libraries for genomic data processing, model inference, and analysis.
3. **Configuration**: Defines key parameters such as file paths, model selection, and training hyperparameters.
4. **Model and Dataset Initialization**: Initializes the tokenizer, model, datasets, and data loaders.
5. **Finetuning**: Fine-tunes the model using `AccelerateTrainer` via utility functions.
6. **Inference Example**: Uses the trained model to make predictions on a new DNA sequence.

Follow the notebook sequentially to execute the TFB prediction pipeline effectively.

## 1. Setup & Installation

First, let's ensure all the required packages are installed. If you have already installed them, you can skip this cell. Otherwise, uncomment and run the cell to install the dependencies.

In [1]:
!pip install -U numpy transformers omnigenbench autocuda

Collecting numpy
  Downloading numpy-2.3.3-cp312-cp312-win_amd64.whl.metadata (60 kB)
Collecting omnigenbench
  Downloading omnigenbench-0.3.11a2-py3-none-any.whl.metadata (10 kB)
Downloading numpy-2.3.3-cp312-cp312-win_amd64.whl (12.8 MB)
   ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
   ---- ----------------------------------- 1.3/12.8 MB 8.4 MB/s eta 0:00:02
   ---------- ----------------------------- 3.4/12.8 MB 10.1 MB/s eta 0:00:01
   ------------- -------------------------- 4.5/12.8 MB 9.6 MB/s eta 0:00:01
   ------------------- -------------------- 6.3/12.8 MB 8.6 MB/s eta 0:00:01
   ---------------------- ----------------- 7.1/12.8 MB 7.5 MB/s eta 0:00:01
   -------------------------- ------------- 8.4/12.8 MB 7.0 MB/s eta 0:00:01
   ----------------------------- ---------- 9.4/12.8 MB 6.8 MB/s eta 0:00:01
   -------------------------------- ------- 10.5/12.8 MB 6.5 MB/s eta 0:00:01
   ---------------------------------- ----- 11.0/12.8 MB 6.3 MB/s eta 0:

ERROR: Could not install packages due to an OSError: [WinError 5] Access is denied: 'C:\\Users\\hengu\\miniconda3\\envs\\py312\\Lib\\site-packages\\numpy.libs\\libscipy_openblas64_-860d95b1c38e637ce4509f5fa24fbf2a.dll'
Consider using the `--user` option or check the permissions.



## 2. Import Libraries

Import all the necessary libraries for genomic data processing, model inference, and analysis.

In [2]:
import autocuda
import importlib, sys

import findfile

utils_spec = importlib.util.spec_from_file_location("utils", "utils.py")
utils = importlib.util.module_from_spec(utils_spec)
utils_spec.loader.exec_module(utils)
sys.modules["utils"] = utils

# Import reusable interfaces from local utils
from utils import (
    download_deepsea_dataset,
    load_tokenizer_and_model,
    build_datasets,
    create_dataloaders,
    run_finetuning,
    run_inference,
)

print("Libraries imported successfully.")





      **@@ +----- @@**             / _ \  _ __ ___   _ __  (_)
        **@@ = @@**               | | | || '_ ` _ \ | '_ \ | |
           **@@                   | |_| || | | | | || | | || |
        @@** = **@@                \___/ |_| |_| |_||_| |_||_|
     @@** ------+ **@@
  @@ ---------------+ @@          / ___|  ___  _ __
  @@ +--------------- @@         | |_| ||  __/| | | |
    @@** +------ **@@
       @@** = **@@
          @@**                    ____                      _
       **@@ = @@**               | __ )   ___  _ __    ___ | |__
    **@@ -----+  @@**            |  _ \  / _ \| '_ \  / __|| '_ \
  @@ --------------+ @@**        |____/  \___||_| |_| \___||_| |_|

Libraries imported successfully.


## 3. Configuration

Here, we define all the hyperparameters and settings for our experiment. This centralized configuration makes it easy to modify parameters and track experiments.

In [3]:
# --- Data File Paths ---
LOCAL_PATH = "deepsea_tfb_prediction"
download_deepsea_dataset(LOCAL_PATH)
TRAIN_FILE = findfile.find_cwd_file(['train', 'jsonl'])
TEST_FILE =  findfile.find_cwd_file(['test', 'jsonl'])
VALID_FILE =  findfile.find_cwd_file(['valid', 'jsonl'])

# --- Available Models for Testing ---
AVAILABLE_MODELS = [
    'yangheng/OmniGenome-52M',
    'yangheng/OmniGenome-186M',
    'yangheng/OmniGenome-v1.5',
    # You can add more models here as needed,
]

MODEL_NAME_OR_PATH = AVAILABLE_MODELS[0]
USE_CONV_LAYERS = False  # Set to True to add DeepSEA-style convolutional layers on top of OmniGenome (not used in this demo)

# --- Training Hyperparameters ---
EPOCHS = 50
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-3
BATCH_SIZE = 64
PATIENCE = 3  # For early stopping
MAX_LENGTH = 200  # The length of the DNA sequence to be processed
SEED = 45
# LABEL_INDICES = [0]  # Example indices for the first 10 transcription factors
LABEL_INDICES = list(range(919))
MAX_EXAMPLES = 1000  # Use a smaller number for quick testing (e.g., 1000), or None for all data

DEVICE = autocuda.auto_cuda()
print(f"Using device: {DEVICE}")


Downloading deepsea_tfb_prediction.zip from https://huggingface.co/datasets/yangheng/deepsea_tfb_prediction/resolve/main/deepsea_tfb_prediction.zip...
Downloaded deepsea_tfb_prediction\deepsea_tfb_prediction.zip
Extracted deepsea_tfb_prediction.zip into deepsea_tfb_prediction
Using device: cuda:0


## 4. Model and Dataset Initialization

Initialize tokenizer and model, then build datasets and dataloaders using utilities for a concise workflow.

In [4]:

# 1. Initialize Tokenizer and Model
print("--- Initializing Tokenizer and Model ---")

# Use utility to load tokenizer and model
label_count = len(LABEL_INDICES)
tokenizer, model = load_tokenizer_and_model(
    MODEL_NAME_OR_PATH,
    num_labels=label_count,
    threshold=0.5,
    device=DEVICE,
)

# 2. Create Datasets via utility
print("\n--- Creating Datasets ---")
train_set, valid_set, test_set = build_datasets(
    tokenizer=tokenizer,
    train_file=TRAIN_FILE,
    test_file=TEST_FILE,
    valid_file=VALID_FILE,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    label_indices=LABEL_INDICES,
)

# Create DataLoaders for batching (utils)
train_loader, valid_loader, test_loader = create_dataloaders(
    train_set=train_set,
    valid_set=valid_set,
    test_set=test_set,
    batch_size=BATCH_SIZE,
)

print("\n--- Initialization Complete ---")
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")
if valid_set:
    print(f"Validation set size: {len(valid_set)}")


--- Initializing Tokenizer and Model ---


Some weights of OmniGenomeModel were not initialized from the model checkpoint at yangheng/OmniGenome-52M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[2025-09-11 16:20:17.923] [omnigenbench 0.3.11alpha2]  Model Name: OmniModelForMultiLabelSequenceClassification
Model Metadata: {'library_name': 'omnigenbench', 'omnigenbench_version': '0.3.11alpha2', 'torch_version': '2.8.0+cu129+cu12.9+gita1cb3cc05d46d198467bebbb6e8fba50a325d4e7', 'transformers_version': '4.56.1', 'model_cls': 'OmniModelForMultiLabelSequenceClassification', 'tokenizer_cls': 'EsmTokenizer', 'model_name': 'OmniModelForMultiLabelSequenceClassification'}
Base Model Name: yangheng/OmniGenome-52M
Model Type: omnigenome
Model Architecture: None
Model Parameters: 52.453345 M
Model Config: OmniGenomeConfig {
  "OmniGenomefold_config": null,
  "attention_probs_dropout_prob": 0.0,
  "auto_map": {
    "AutoConfig": "configuration_omnigenome.OmniGenomeConfig",
    "AutoModel": "modeling_omnigenome.OmniGenomeModel",
    "AutoModelForMaskedLM": "modeling_omnigenome.OmniGenomeForMaskedLM",
    "AutoModelForSeq2SeqLM": "modeling_omnigenome.OmniGenomeForSeq2SeqLM",
    "AutoModelForSe

100%|██████████| 1000/1000 [00:01<00:00, 836.93it/s]


[2025-09-11 16:33:14.814] [omnigenbench 0.3.11alpha2]  All keys have consistent sequence lengths, skipping padding and truncation.
[2025-09-11 16:33:14.817] [omnigenbench 0.3.11alpha2]  Detected max_length=200 in the dataset, using it as the max_length.
[2025-09-11 16:33:14.821] [omnigenbench 0.3.11alpha2]  Loading data from deepsea_tfb_prediction\test.jsonl...
[2025-09-11 16:34:11.702] [omnigenbench 0.3.11alpha2]  Loaded 455024 examples from deepsea_tfb_prediction\test.jsonl
[2025-09-11 16:34:11.710] [omnigenbench 0.3.11alpha2]  Detected shuffle=True, shuffling the examples...
[2025-09-11 16:34:12.281] [omnigenbench 0.3.11alpha2]  Detected max_examples=1000, truncating the examples...


100%|██████████| 1000/1000 [00:02<00:00, 382.92it/s]


[2025-09-11 16:34:25.905] [omnigenbench 0.3.11alpha2]  All keys have consistent sequence lengths, skipping padding and truncation.
[2025-09-11 16:34:25.909] [omnigenbench 0.3.11alpha2]  Detected max_length=200 in the dataset, using it as the max_length.
[2025-09-11 16:34:25.912] [omnigenbench 0.3.11alpha2]  Loading data from valid.jsonl...
[2025-09-11 16:34:27.913] [omnigenbench 0.3.11alpha2]  Loaded 8000 examples from valid.jsonl
[2025-09-11 16:34:27.919] [omnigenbench 0.3.11alpha2]  Detected shuffle=True, shuffling the examples...
[2025-09-11 16:34:27.931] [omnigenbench 0.3.11alpha2]  Detected max_examples=1000, truncating the examples...


100%|██████████| 1000/1000 [00:02<00:00, 382.27it/s]


[2025-09-11 16:34:30.579] [omnigenbench 0.3.11alpha2]  All keys have consistent sequence lengths, skipping padding and truncation.

--- Initialization Complete ---
Training set size: 1000
Test set size: 1000
Validation set size: 1000


## 5. Finetuning

Fine-tune the model using `AccelerateTrainer` (invoked through the `run_finetuning` compatibility wrapper). Early stopping monitors validation ROC AUC when a validation set is provided.

In [5]:


# Train with utilities
print("--- Starting Training ---")
trainer, metrics_best = run_finetuning(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    patience=PATIENCE,
    device=DEVICE,
    save_dir="tfb_model",
)
print(metrics_best)
print("--- Training Finished ---")


--- Starting Training ---


Evaluating: 100%|██████████| 16/16 [00:01<00:00,  8.35it/s]


[2025-09-11 16:34:33.933] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.4770981113171884}


Epoch 1/50 Loss: 0.6287: 100%|██████████| 16/16 [00:05<00:00,  3.19it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.70it/s]


[2025-09-11 16:34:41.093] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.5124675583581603}


Epoch 2/50 Loss: 0.5206: 100%|██████████| 16/16 [00:04<00:00,  3.99it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.12it/s]


[2025-09-11 16:34:47.421] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.507466398037488}


Epoch 3/50 Loss: 0.4018: 100%|██████████| 16/16 [00:03<00:00,  4.42it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.94it/s]


[2025-09-11 16:34:53.113] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.5075632543389192}


Epoch 4/50 Loss: 0.2979: 100%|██████████| 16/16 [00:03<00:00,  4.37it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.56it/s]


[2025-09-11 16:34:58.562] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.5163629917472332}


Epoch 5/50 Loss: 0.2263: 100%|██████████| 16/16 [00:03<00:00,  4.21it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  9.97it/s]


[2025-09-11 16:35:04.780] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.5365819806946566}


Epoch 6/50 Loss: 0.1826: 100%|██████████| 16/16 [00:04<00:00,  3.40it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.12it/s]


[2025-09-11 16:35:12.896] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.5637128430493652}


Epoch 7/50 Loss: 0.1575: 100%|██████████| 16/16 [00:05<00:00,  2.99it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.88it/s]


[2025-09-11 16:35:20.511] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.592799494799054}


Epoch 8/50 Loss: 0.1401: 100%|██████████| 16/16 [00:05<00:00,  3.19it/s]
Evaluating: 100%|██████████| 16/16 [00:02<00:00,  6.57it/s]


[2025-09-11 16:35:28.851] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.6204445193902365}


Epoch 9/50 Loss: 0.1290: 100%|██████████| 16/16 [00:06<00:00,  2.57it/s]
Evaluating: 100%|██████████| 16/16 [00:02<00:00,  7.82it/s]


[2025-09-11 16:35:38.176] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.644856646164891}


Epoch 10/50 Loss: 0.1223: 100%|██████████| 16/16 [00:04<00:00,  3.63it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.46it/s]


[2025-09-11 16:35:45.040] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.6646091523802289}


Epoch 11/50 Loss: 0.1162: 100%|██████████| 16/16 [00:04<00:00,  3.56it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.58it/s]


[2025-09-11 16:35:52.015] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.6836898404059145}


Epoch 12/50 Loss: 0.1130: 100%|██████████| 16/16 [00:04<00:00,  3.78it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.01it/s]


[2025-09-11 16:35:58.991] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.6977699848315684}


Epoch 13/50 Loss: 0.1094: 100%|██████████| 16/16 [00:03<00:00,  4.27it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.22it/s]


[2025-09-11 16:36:05.216] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7117147373225314}


Epoch 14/50 Loss: 0.1066: 100%|██████████| 16/16 [00:03<00:00,  4.19it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.33it/s]


[2025-09-11 16:36:11.567] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7226101104961575}


Epoch 15/50 Loss: 0.1050: 100%|██████████| 16/16 [00:03<00:00,  4.16it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.69it/s]


[2025-09-11 16:36:17.803] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7322176116886996}


Epoch 16/50 Loss: 0.1042: 100%|██████████| 16/16 [00:04<00:00,  3.55it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  9.59it/s]


[2025-09-11 16:36:25.013] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7385764215169774}


Epoch 17/50 Loss: 0.1017: 100%|██████████| 16/16 [00:04<00:00,  3.27it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  9.79it/s]


[2025-09-11 16:36:32.315] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7464615646788201}


Epoch 18/50 Loss: 0.1014: 100%|██████████| 16/16 [00:04<00:00,  3.69it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  8.86it/s]


[2025-09-11 16:36:39.252] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7515063267392464}


Epoch 19/50 Loss: 0.1006: 100%|██████████| 16/16 [00:04<00:00,  3.97it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.09it/s]


[2025-09-11 16:36:45.662] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7559365654632169}


Epoch 20/50 Loss: 0.0993: 100%|██████████| 16/16 [00:03<00:00,  4.10it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  9.77it/s]


[2025-09-11 16:36:52.212] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7602819931058815}


Epoch 21/50 Loss: 0.0994: 100%|██████████| 16/16 [00:03<00:00,  4.08it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  9.35it/s]


[2025-09-11 16:36:58.901] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7634797103491975}


Epoch 22/50 Loss: 0.0978: 100%|██████████| 16/16 [00:04<00:00,  3.88it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00,  9.83it/s]


[2025-09-11 16:37:05.695] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7663870917482092}


Epoch 23/50 Loss: 0.0983: 100%|██████████| 16/16 [00:04<00:00,  3.85it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.23it/s]


[2025-09-11 16:37:12.214] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7692631164944693}


Epoch 24/50 Loss: 0.0970: 100%|██████████| 16/16 [00:03<00:00,  4.22it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.95it/s]


[2025-09-11 16:37:18.356] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.777992042339987}


Epoch 25/50 Loss: 0.0974: 100%|██████████| 16/16 [00:03<00:00,  4.21it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 11.13it/s]


[2025-09-11 16:37:24.396] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.772284442612742}


Epoch 26/50 Loss: 0.0968: 100%|██████████| 16/16 [00:03<00:00,  4.35it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 12.29it/s]


[2025-09-11 16:37:29.773] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7736728102829498}


Epoch 27/50 Loss: 0.0974: 100%|██████████| 16/16 [00:03<00:00,  4.12it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.76it/s]


[2025-09-11 16:37:35.548] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7748346518647681}


Epoch 28/50 Loss: 0.0963: 100%|██████████| 16/16 [00:03<00:00,  4.34it/s]
Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.95it/s]


[2025-09-11 16:37:41.146] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7762962999817741}
[2025-09-11 16:37:41.154] [omnigenbench 0.3.11alpha2]  Early stopping at epoch 28.


Testing: 100%|██████████| 16/16 [00:01<00:00, 10.19it/s]


[2025-09-11 16:37:43.554] [omnigenbench 0.3.11alpha2]  {'roc_auc_score': 0.7819343077104175}
[2025-09-11 16:37:46.271] [omnigenbench 0.3.11alpha2]  The model is saved to tfb_model.
{'valid': [{'roc_auc_score': 0.4770981113171884}, {'roc_auc_score': 0.5124675583581603}, {'roc_auc_score': 0.507466398037488}, {'roc_auc_score': 0.5075632543389192}, {'roc_auc_score': 0.5163629917472332}, {'roc_auc_score': 0.5365819806946566}, {'roc_auc_score': 0.5637128430493652}, {'roc_auc_score': 0.592799494799054}, {'roc_auc_score': 0.6204445193902365}, {'roc_auc_score': 0.644856646164891}, {'roc_auc_score': 0.6646091523802289}, {'roc_auc_score': 0.6836898404059145}, {'roc_auc_score': 0.6977699848315684}, {'roc_auc_score': 0.7117147373225314}, {'roc_auc_score': 0.7226101104961575}, {'roc_auc_score': 0.7322176116886996}, {'roc_auc_score': 0.7385764215169774}, {'roc_auc_score': 0.7464615646788201}, {'roc_auc_score': 0.7515063267392464}, {'roc_auc_score': 0.7559365654632169}, {'roc_auc_score': 0.76028199310

## 6. Inference Example

Run a single-sequence prediction using the persisted fine-tuned model. The same preprocessing pathway (`encode_tokens`) ensures parity with training.

In [6]:

sample_sequence = "AGCT" * (MAX_LENGTH // 4)  # Construct sequence of required length

outputs = run_inference(
    model_dir="tfb_model",
    tokenizer=tokenizer,
    sample_sequence=sample_sequence,
    max_length=MAX_LENGTH,
    device=DEVICE,
)

predictions = outputs.get('predictions', None)
probabilities = outputs.get('probabilities', None)

print(f"Input sequence length: {len(sample_sequence)} bp")
if predictions is not None:
    print(f"Number of predicted labels: {len(predictions)}")
    print("\n--- Predictions for the first 10 TFs ---")
    for i in range(min(10, len(predictions))):
        pred_label = 'Binds' if int(predictions[i]) == 1 else 'Does not bind'
        if probabilities is not None:
            try:
                p = float(probabilities[i])
                print(f"Label {i+1}: Prediction={pred_label}, Prob={p:.4f}")
            except Exception:
                print(f"Label {i+1}: Prediction={pred_label}")
        else:
            print(f"Label {i+1}: Prediction={pred_label}")
else:
    print("No 'predictions' returned by model.inference; verify the saved model and inference API.")


[2025-09-11 16:37:50.733] [omnigenbench 0.3.11alpha2]  Model Name: OmniModelForMultiLabelSequenceClassification
Model Metadata: {'library_name': 'omnigenbench', 'omnigenbench_version': '0.3.11alpha2', 'torch_version': '2.8.0+cu129+cu12.9+gita1cb3cc05d46d198467bebbb6e8fba50a325d4e7', 'transformers_version': '4.56.1', 'model_cls': 'OmniModelForMultiLabelSequenceClassification', 'tokenizer_cls': 'EsmTokenizer', 'model_name': 'OmniModelForMultiLabelSequenceClassification'}
Base Model Name: tfb_model
Model Type: omnigenome
Model Architecture: ['OmniGenomeModel']
Model Parameters: 52.453345 M
Model Config: OmniGenomeConfig {
  "OmniGenomefold_config": null,
  "architectures": [
    "OmniGenomeModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "auto_map": {
    "AutoConfig": "configuration_omnigenome.OmniGenomeConfig",
    "AutoModel": "modeling_omnigenome.OmniGenomeModel",
    "AutoModelForMaskedLM": "modeling_omnigenome.OmniGenomeForMaskedLM",
    "AutoModelForSeq2SeqLM": "modeling_omnige

Attempting to cast a BatchEncoding to type torch.float32. This is not supported.


Input sequence length: 200 bp
Number of predicted labels: 919

--- Predictions for the first 10 TFs ---
Label 1: Prediction=Does not bind
Label 2: Prediction=Does not bind
Label 3: Prediction=Does not bind
Label 4: Prediction=Does not bind
Label 5: Prediction=Does not bind
Label 6: Prediction=Does not bind
Label 7: Prediction=Does not bind
Label 8: Prediction=Does not bind
Label 9: Prediction=Does not bind
Label 10: Prediction=Does not bind
