# Transcription Factor Binding Prediction with OmniGenBench

This notebook provides a step-by-step guide to extend OmniGenBench to the TFB task based on the **OmniGenome-52M** model on the **DeepSEA dataset**. The goal is to perform multi-label classification to predict the binding sites of various transcription factors based on DNA sequences.

**Dataset Description:**
The dataset used in this notebook is derived from the DeepSEA dataset, which is designed for studying the effects of non-coding variants. It consists of DNA sequences of 1000 base pairs, each associated with 919 binary labels corresponding to various chromatin features (transcription factor binding, DNase I sensitivity, and histone marks). For this task, we use a preprocessed version available from the `yangheng/tfb_prediction` dataset on Hugging Face.

**Estimated Runtime:**
The total runtime for this notebook depends on the hardware and the number of training examples (`MAX_EXAMPLES`). On a single NVIDIA RTX 4090 GPU, training with the default settings (`MAX_EXAMPLES=100000`, `EPOCHS=10`) takes approximately **1 - 2 hours**. For a quick test run with `MAX_EXAMPLES=1000`, it should take about **5-10 minutes**.



## Notebook Structure

This notebook is organized into several sections, each focusing on a specific aspect of the Transcription Factor Binding (TFB) prediction pipeline. Below is an overview of the structure:

1. **Setup & Installation**: Ensures all required libraries and dependencies are installed.
2. **Import Libraries**: Loads the necessary Python libraries for genomic data processing, model inference, and analysis.
3. **Configuration**: Defines key parameters such as file paths, model selection, and training hyperparameters.
4. **Model Definition**: Implements a custom model class that integrates the OmniGenome backbone with a classification head tailored for the DeepSEA task.
5. **Data Loading and Preprocessing**: Handles the loading and preprocessing of the DeepSEA dataset, converting DNA sequences into tokenized inputs.
6. **Initialization**: Sets up the tokenizer, model, datasets, and data loaders for training and evaluation.
7. **Training the Model**: Fine-tunes the model using the `AccelerateTrainer` for efficient training and evaluation.
8. **Evaluation**: Assesses the model's performance on the test set using metrics such as ROC AUC.
9. **Inference Example**: Demonstrates how to use the trained model to make predictions on new DNA sequences.

Each section is designed to be modular, allowing for easy customization and extension. Follow the notebook sequentially to understand and execute the TFB prediction pipeline effectively.

## 1. Setup & Installation

First, let's ensure all the required packages are installed. If you have already installed them, you can skip this cell. Otherwise, uncomment and run the cell to install the dependencies.

In [5]:
# Uncomment the following line to install the necessary packages
# !pip install torch numpy transformers omnigenbench autocuda

## 2. Import Libraries

Import all the necessary libraries for genomic data processing, model inference, and analysis.

In [6]:
import random
import os
import autocuda
import findfile
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, BatchEncoding
from omnigenbench import (
    OmniDataset,
    OmniModel,
    OmniPooling,
    Trainer,
    ClassificationMetric,
    AccelerateTrainer,
    OmniLoraModel
)
import zipfile

print("Libraries imported successfully.")

Libraries imported successfully.


## 3. Configuration

Here, we define all the hyperparameters and settings for our experiment. This centralized configuration makes it easy to modify parameters and track experiments.

In [7]:
# --- Data File Paths ---
# Download tfb_prediction dataset using git clone
local_dir = "tfb_prediction_dataset"

if not findfile.find_cwd_dir(local_dir):
    git_url = "https://huggingface.co/datasets/yangheng/tfb_prediction"
    os.system(f"git clone {git_url} {local_dir}") # Use subprocess.run if you prefer
    # import subprocess
    # subprocess.run(["git", "clone", git_url, local_dir])
    print(f"Cloned tfb_prediction dataset from {git_url} into {local_dir}")

    # Unzip the dataset if the zip file exists
    ZIP_DATASET = findfile.find_cwd_file("tfb_dataset.zip")
    if ZIP_DATASET:
        with zipfile.ZipFile(ZIP_DATASET, 'r') as zip_ref:
            zip_ref.extractall(local_dir)
        print(f"Extracted tfb_dataset.zip into {local_dir}")
        os.remove(ZIP_DATASET)
    else:
        print("tfb_dataset.zip not found. Skipping extraction.")

TRAIN_FILE = findfile.find_cwd_file("train_tfb.npy")
if not TRAIN_FILE:
    raise FileNotFoundError("Training file not found. Please ensure the dataset is downloaded and extracted correctly.")
TEST_FILE = findfile.find_cwd_file("test_tfb.npy")
if not TEST_FILE:
    raise FileNotFoundError("Test file not found. Please ensure the dataset is downloaded and extracted correctly.")
VALID_FILE = findfile.find_cwd_file("valid_tfb.npy")
if not VALID_FILE:
    print("Validation file not found. Skipping validation set.")
# TRAIN_FILE = "tfb_prediction_dataset/train_tfb.npy"
# TEST_FILE = "tfb_prediction_dataset/test_tfb.npy"
# VALID_FILE = "tfb_prediction_dataset/valid_tfb.npy"

# --- Model Configuration ---
# --- Available Models for Testing ---
AVAILABLE_MODELS = [
    'yangheng/OmniGenome-52M',
    'yangheng/OmniGenome-186M',
    'yangheng/OmniGenome-v1.5',

    # 'DNABERT-2-117M',  # You can add more models here as needed,
    # 'LongSafari/hyenadna-large-1m-seqlen-hf',
    # 'InstaDeepAI/nucleotide-transformer-500m-human-ref',
    # 'multimolecule/rnafm', # RNA-specific models
    # 'multimolecule/rnamsm',
    # 'multimolecule/rnabert',
    # 'SpliceBERT-510nt', # Splice-specific model
]

MODEL_NAME_OR_PATH = AVAILABLE_MODELS[0]

# --- Training Hyperparameters ---
EPOCHS = 30
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 16
PATIENCE = 3  # For early stopping
MAX_LENGTH = 200  # The length of the DNA sequence to be processed
SEED = 45
MAX_EXAMPLES = 100000  # Use a smaller number for quick testing (e.g., 1000), or None for all data
GRADIENT_ACCUMULATION_STEPS = 1

# --- Device Setup ---
DEVICE = autocuda.auto_cuda()
print(f"Using device: {DEVICE}")

Cloned tfb_prediction dataset from https://huggingface.co/datasets/yangheng/tfb_prediction into tfb_prediction_dataset
Extracted tfb_dataset.zip into tfb_prediction_dataset
Using device: cuda:0


## 4. Model Definition

We define the `OmniModelForMultiLabelClassification`, which wraps the OmniGenome transformer. This class adds a classification head on top of the pre-trained backbone, tailored for the DeepSEA multi-label prediction task. It also includes an option to add convolutional layers, allowing for a hybrid architecture that combines the strengths of both transformers and CNNs.

In [8]:
class OmniModelForMultiLabelClassification(OmniModel):
    """
    Multi-label sequence classification model based on OmniGenome-52M.

    This model replaces the original DeepSEA CNN architecture with a pretrained
    Transformer encoder from the OmniGenome family. Optionally, convolutional
    layers can be stacked on top of the Transformer outputs for additional
    feature extraction.

    Parameters:
        config_or_model (PretrainedConfig or nn.Module):
            Configuration or instance of the pretrained Transformer model,
            typically obtained via AutoModel.from_pretrained().
        tokenizer (PreTrainedTokenizer):
            Tokenizer compatible with the Transformer encoder, used to convert
            DNA sequences into model inputs.
        threshold (float, optional):
            Probability threshold for binary decisions in predict(), defaults to 0.5.
        use_conv (bool, optional):
            If True, apply convolutional layers after the Transformer encoder
            for enhanced feature extraction, defaults to False.
        *args, **kwargs: Additional arguments passed to the base OmniModel class.

    Attributes:
        threshold (float):
            Probability cutoff for generating binary predictions.
        deepsea_classifier (nn.Sequential):
            Classification head consisting of a Tanh activation followed by
            a linear layer mapping to the number of labels.
        loss_fn (nn.BCEWithLogitsLoss):
            Binary cross-entropy loss with logits, using pos_weight to balance
            positive and negative samples.
        pooler (OmniPooling):
            Utility for pooling token-level outputs into a sequence-level vector.
        sigmoid (nn.Sigmoid):
            Activation used to convert logits to probabilities during inference.
    """

    def __init__(self, config_or_model, tokenizer, *args, **kwargs):
        self.threshold = kwargs.pop("threshold", 0.5)
        super().__init__(config_or_model, tokenizer, *args, **kwargs)
        self.metadata["model_name"] = "OmniModelForMultiLabelClassification"

        # Classification head based directly on Transformer outputs
        conv_output_dim = self.config.hidden_size
        self.deepsea_classifier = torch.nn.Sequential(
            torch.nn.Tanh(),
            torch.nn.Linear(conv_output_dim, self.config.num_labels),
        )

        # Use pos_weight to address class imbalance in BCEWithLogitsLoss
        self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0]))
        self.pooler = OmniPooling(self.config)
        self.sigmoid = torch.nn.Sigmoid()
        self.model_info()

    def forward(self, **inputs):
        """
        Forward pass through the model: encode, pool, classify, and optionally compute loss.

        Args:
            inputs (dict):
                Must contain 'input_ids', 'attention_mask', etc., and optionally 'labels'.

        Returns:
            dict: {
                'logits': Tensor of shape (batch_size, num_labels), raw scores before
                          sigmoid activation.
                'last_hidden_state': Tensor of shape (batch_size, seq_len, hidden_size),
                                     the last layer hidden states from the Transformer.
                'loss' (optional): Computed BCEWithLogitsLoss if 'labels' provided.
            }

        Raises:
            ValueError: If all provided labels are zero or if logits and labels shapes mismatch.
        """
        labels = inputs.pop("labels", None)

        # Encode inputs
        last_hidden_state = self.last_hidden_state_forward(**inputs)
        last_hidden_state = self.dropout(last_hidden_state)
        last_hidden_state = self.activation(last_hidden_state)

        # Pooling strategy
        if self.pooler._is_causal_lm():
            pad_token_id = getattr(self.config, "pad_token_id", -100)
            sequence_lengths = inputs["input_ids"].ne(pad_token_id).sum(dim=1) - 1
            pooled_output = last_hidden_state[
                torch.arange(inputs["input_ids"].size(0), device=last_hidden_state.device),
                sequence_lengths,
            ]
        else:
            pooled_output = self.pooler(inputs, last_hidden_state)

        logits = self.deepsea_classifier(pooled_output)
        outputs = {"logits": logits, "last_hidden_state": last_hidden_state}

        if labels is not None:
            if torch.sum(labels[labels != -100]) == 0:
                raise ValueError("Labels cannot be all zeros.")
            labels = labels[labels != -100]
            loss = self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
            outputs["loss"] = loss

        return outputs

    def predict(self, sequence_or_inputs, **kwargs):
        """
        Perform inference on raw sequences or tokenized inputs, returning probabilities and predictions.

        Args:
            sequence_or_inputs (str, BatchEncoding, or dict):
                Raw DNA string or pre-tokenized inputs.
            padding (str, optional): Padding strategy for tokenizer, defaults to 'max_length'.
            max_length (int, optional): Maximum sequence length, defaults to 1024.
            **kwargs: Additional tokenizer arguments.

        Returns:
            dict: {
                'predictions': Tensor of binary labels,
                'probabilities': Tensor of positive class probabilities,
                'logits': Tensor of raw scores,
                'last_hidden_state': Transformer outputs.
            }
        """
        if not isinstance(sequence_or_inputs, BatchEncoding) and not isinstance(
                sequence_or_inputs, dict
        ):
            inputs = self.tokenizer(
                sequence_or_inputs,
                padding=kwargs.pop("padding", "max_length"),
                max_length=kwargs.pop("max_length", 1024),
                truncation=True,
                return_tensors="pt",
                **kwargs,
            )
        else:
            inputs = sequence_or_inputs
        inputs = inputs.to(self.model.device)

        with torch.no_grad():
            outputs = self(**inputs)
        logits = outputs["logits"]
        last_hidden_state = outputs["last_hidden_state"]

        probabilities = self.sigmoid(logits)
        predictions = (probabilities >= self.threshold).to(torch.int)

        return {
            "predictions": predictions,
            "probabilities": probabilities,
            "logits": logits,
            "last_hidden_state": last_hidden_state,
        }

    def loss_function(self, logits, labels):
        """
        Compute BCEWithLogitsLoss for multi-label classification.

        Args:
            logits (Tensor): Raw output scores, shape (batch_size, num_labels).
            labels (Tensor): Ground-truth labels as floats, same shape as logits.

        Returns:
            Tensor: Loss value.

        Raises:
            ValueError: If logits and labels shapes do not match.
        """
        valid_labels = labels.to(torch.float32)
        if logits.shape != valid_labels.shape:
            raise ValueError(f"Shape mismatch between logits {logits.shape} and labels {valid_labels.shape}")
        return self.loss_fn(logits, valid_labels)


print("OmniModelForMultiLabelClassification defined.")

OmniModelForMultiLabelClassification defined.


## 5. Data Loading and Preprocessing

This section handles the data loading. We define a helper function `load_deepsea_npy_data` to parse the specific format of the DeepSEA `.npy` files. Then, we create a `DeepSEADataset` class that inherits from `OmniDataset` and uses this loader. The dataset class is responsible for converting DNA sequences into a format suitable for the OmniGenome tokenizer (i.e., space-separated tokens).

In [9]:


class DeepSEADataset(OmniDataset):
    """
    Dataset designed for the DeepSEA task, handling the conversion from DNA sequences to tokens
    """

    def __init__(self, data_source, tokenizer, max_length=None, **kwargs):
        super().__init__(data_source, tokenizer, max_length, **kwargs)
        for key, value in kwargs.items():
            self.metadata[key] = value

    def prepare_input(self, instance, **kwargs):
        """
        Prepare input data for DeepSEA

        Expected instance format:
        {
            'sequence': DNA sequence string (e.g., "ATCGATCG...")
            'labels': binary labels as numpy array of shape (919,)
        }
        """
        labels = None
        if isinstance(instance, str):
            sequence = instance
        elif isinstance(instance, dict):
            sequence = (
                instance.get("seq", None)
                if "seq" in instance
                else instance.get("sequence", None)
            )
            label = instance.get("label", None)
            labels = instance.get("labels", None)
            labels = labels if labels is not None else label
        else:
            raise Exception("Unknown instance format.")

        if sequence is None:
            raise ValueError("Sequence is required")

        # Convert DNA sequence to space-separated format for tokenizer
        # e.g., "ATCG" -> "A T C G"
        if isinstance(sequence, str):
            spaced_sequence = ' '.join(list(sequence))
        else:
            # If sequence is one-hot encoded, convert to string first
            if isinstance(sequence, np.ndarray) and sequence.shape[1] == 4:
                # one-hot to sequence string
                base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'}
                sequence_str = ''.join([base_map[np.argmax(sequence[i])] for i in range(len(sequence))])
                spaced_sequence = ' '.join(list(sequence_str))
            else:
                raise ValueError(f"Unsupported sequence format: {type(sequence)}")

        # Use tokenizer to process the sequence
        tokenized_inputs = self.tokenizer(
            spaced_sequence[500-self.max_length//2:500+self.max_length//2],  # DeepSEA usually processes 200bp sequences
            # spaced_sequence,
            padding="do_not_pad",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            **kwargs,
        )

        # Squeeze dimensions
        for col in tokenized_inputs:
            tokenized_inputs[col] = tokenized_inputs[col].squeeze()

        if labels is not None:
            # For sequence classification, labels should be a fixed-length vector
            if isinstance(labels, np.ndarray):
                labels = torch.from_numpy(labels).float()
            elif not isinstance(labels, torch.Tensor):
                labels = torch.tensor(labels, dtype=torch.float32)

            tokenized_inputs["labels"] = labels

        return tokenized_inputs



## 6. Initialization

Now, let's initialize the tokenizer, the model, and the datasets. This step brings everything together and prepares for the training phase.

In [10]:
# Initialize Tokenizer and Model
print("--- Initializing Tokenizer and Model ---")
if "multimolecule" in MODEL_NAME_OR_PATH.lower():
    from multimolecule import AutoModelForTokenPrediction, RnaTokenizer
    base_model = AutoModelForTokenPrediction.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True).base_model
    tokenizer = RnaTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)
else:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
    base_model = AutoModel.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)

model = OmniModelForMultiLabelClassification(
    base_model,
    tokenizer,
    num_labels=919,  # DeepSEA has 919 binary labels for different chromatin features
    threshold=0.5,
)

# If you want to use LoRA, uncomment the following lines
# lora_config = {
#     "lora_r": 8,  # Rank of the LoRA layers
#     "lora_alpha": 16,  # Scaling factor for LoRA
#     "lora_dropout": 0.1,  # Dropout rate for LoRA layers
#     "target_modules": ["deepsea_classifier"],  # Target modules to apply LoRA
# }
# model = OmniLoraModel(model, lora_config=lora_config)

model.to(DEVICE).to(torch.float32) # Move model to the selected device
# 2. Create Datasets
print("\n--- Creating Datasets ---")
train_set = DeepSEADataset(
    data_source=TRAIN_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    force_padding=False,  # DeepSEA does not require padding
)
test_set = DeepSEADataset(
    data_source=TEST_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    force_padding=False,  # DeepSEA does not require padding
)
valid_set = DeepSEADataset(
    data_source=VALID_FILE,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_examples=MAX_EXAMPLES,
    force_padding=False,  # DeepSEA does not require padding
) if os.path.exists(VALID_FILE) else None

print("\n--- Initialization Complete ---")
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")
if valid_set:
    print(f"Validation set size: {len(valid_set)}")

--- Initializing Tokenizer and Model ---


Some weights of RnaMsmForTokenPrediction were not initialized from the model checkpoint at multimolecule/rnamsm and are newly initialized: ['token_head.decoder.bias', 'token_head.decoder.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model Name: OmniModelForMultiLabelClassification
Model Metadata: {'library_name': 'omnigenbench', 'omnigenbench_version': '0.3.7alpha', 'torch_version': '2.7.0+cu128+cu12.8+git134179474539648ba7dee1317959529fbd0e7f89', 'transformers_version': '4.53.3', 'model_cls': 'OmniModelForMultiLabelClassification', 'tokenizer_cls': 'RnaTokenizer', 'model_name': 'OmniModelForMultiLabelClassification'}
Base Model Name: multimolecule/rnamsm
Model Type: rnamsm
Model Architecture: ['RnaMsmForPreTraining']
Model Parameters: 95.329792 M
Model Config: RnaMsmConfig {
  "architectures": [
    "RnaMsmForPreTraining"
  ],
  "attention_bias": true,
  "attention_dropout": 0.1,
  "attention_type": "standard",
  "bos_token_id": 1,
  "embed_positions_msa": true,
  "eos_token_id": 2,
  "head": null,
  "hidden_act": "gelu",
  "hidden_dropout": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "0",
    "1": "1",
    "2": "2",
    "3": "3",
    "4": "4",
    "5": "5",
    "6": "6",
    "7": "7",
    "8": "8",
    "

100%|██████████| 100000/100000 [03:07<00:00, 534.49it/s]


All keys have consistent sequence lengths, skipping padding and truncation.
Detected max_length=200 in the dataset, using it as the max_length.
Loading data from tfb_prediction_dataset\test_tfb.npy...


  right=ast.Str(s=sentinel),
  return Constant(*args, **kwargs)


FileNotFoundError: [Errno 2] No such file or directory: 'tfb_prediction_dataset\\test_tfb.npy'

## 7. Training the Model

With everything set up, we can now train the model. We'll use the `AccelerateTrainer` for a streamlined and efficient training loop. The trainer handles the training loop, evaluation, early stopping, and device placement automatically.

In [None]:
# Set random seed for reproducibility across all libraries
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Create DataLoaders for batching
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=BATCH_SIZE) if valid_set else None

# Define the metric for evaluation. For DeepSEA, ROC AUC is a standard metric.
metrics = [ClassificationMetric(ignore_y=-100).roc_auc_score]

# Create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Initialize the Trainer
trainer = AccelerateTrainer(
    model=model,
    train_loader=train_loader,
    eval_loader=valid_loader, # Use validation set for early stopping and checkpointing
    test_loader=test_loader,
    optimizer=optimizer,
    epochs=EPOCHS,
    compute_metrics=metrics,
    patience=PATIENCE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    device=DEVICE
)

# Start Training
print("--- Starting Training ---")
metrics = trainer.train()
print("--- Training Finished ---")

--- Starting Training ---


  warn("Single sequence input detected, RNA-MSM works best with MSA inputs.")
Evaluating: 100%|██████████| 500/500 [00:16<00:00, 30.17it/s]


{'roc_auc_score': 0.5128501615091636}


  warn("Single sequence input detected, RNA-MSM works best with MSA inputs.")
Epoch 1/30 Loss: 0.6747:  84%|████████▎ | 5232/6250 [08:22<01:36, 10.59it/s]

## 8. Evaluation

After training is complete, the `AccelerateTrainer` automatically loads the best performing model (based on the validation set). We can now evaluate this model on the held-out test set to get a final, unbiased measure of its performance.

In [16]:
print("--- Evaluating on Test Set ---")
# The evaluation is automated by the trainer
print(f"All metrics:", metrics)
for metric in metrics['test']:
    print(f"Test metric:  {metric}")

--- Evaluating on Test Set ---
All metrics: {'valid': [{'roc_auc_score': 0.66544658242437}, {'roc_auc_score': 0.6510533985521894}, {'roc_auc_score': 0.6531804294068921}, {'roc_auc_score': 0.6499433313733695}, {'roc_auc_score': 0.6428444822852545}], 'best_valid': {'roc_auc_score': 0.66544658242437}, 'test': [{'roc_auc_score': 0.6833502729163198}]}
Test metric:  {'roc_auc_score': 0.6833502729163198}


## 9. Inference Example

Finally, let's see how to use the fine-tuned model to make a prediction on a new, unseen DNA sequence. This demonstrates the practical application of the trained model.

In [10]:
# Create a sample DNA sequence (must be at least MAX_LENGTH base pairs long)
sample_sequence = "AGCT" * (MAX_LENGTH // 4) # Create a sequence of the required length

# Prepare the sequence for the model (add spaces between characters)
spaced_sequence = ' '.join(list(sample_sequence))
inputs = tokenizer(spaced_sequence, return_tensors="pt", max_length=MAX_LENGTH, truncation=True)

# Set the model to evaluation mode
model.eval()

# Make a prediction
with torch.no_grad():
    outputs = model.predict(inputs)

# Get the predictions and probabilities
predictions = outputs['predictions'].cpu().numpy().flatten()
probabilities = outputs['probabilities'].cpu().numpy().flatten()

print(f"Input sequence length: {len(sample_sequence)} bp")
print(f"Number of predicted labels: {len(predictions)}")

# Display predictions for the first 10 transcription factors
print("\n--- Predictions for the first 10 TFs ---")
for i in range(10):
    pred_label = 'Binds' if predictions[i] == 1 else 'Does not bind'
    print(f"Label {i+1}: Prediction={pred_label}, Probability={probabilities[i]:.4f}")


Input sequence length: 200 bp
Number of predicted labels: 919

--- Predictions for the first 10 TFs ---
Label 1: Prediction=Does not bind, Probability=0.3310
Label 2: Prediction=Does not bind, Probability=0.3695
Label 3: Prediction=Does not bind, Probability=0.3184
Label 4: Prediction=Does not bind, Probability=0.2379
Label 5: Prediction=Binds, Probability=0.5397
Label 6: Prediction=Does not bind, Probability=0.4741
Label 7: Prediction=Does not bind, Probability=0.4068
Label 8: Prediction=Does not bind, Probability=0.3112
Label 9: Prediction=Does not bind, Probability=0.3325
Label 10: Prediction=Does not bind, Probability=0.3347
