In [16]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel, AutoModelForCausalLM, BertConfig
import sys

# class NoModule:
#     def __init__(self, *module_names):
#         self.module_names = module_names
#         self.original_modules = {}

#     def __enter__(self):
#         for module_name in self.module_names:
#             if module_name in sys.modules:
#                 # Save the original module
#                 self.original_modules[module_name] = sys.modules[module_name]
#                 # Remove it so imports behave as if it's not installed
#                 del sys.modules[module_name]

#     def __exit__(self, exc_type, exc_value, traceback):
#         # Restore any original modules
#         for module_name in self.module_names:
#             if module_name in self.original_modules:
#                 sys.modules[module_name] = self.original_modules[module_name]
#             else:
#                 # If we created no original entry and something
#                 # imported it in the meantime, leave it alone
#                 pass

import builtins

class BlockImport:
    def __init__(self, *blocked):
        self.blocked = set(blocked)

    def __enter__(self):
        self._orig_import = builtins.__import__

        def fake_import(name, *args, **kwargs):
            if any(name == b or name.startswith(b + ".") for b in self.blocked):
                raise ImportError(f"Blocked import of {name}")
            return self._orig_import(name, *args, **kwargs)

        builtins.__import__ = fake_import

    def __exit__(self, exc_type, exc_value, traceback):
        builtins.__import__ = self._orig_import


class dnalm_embedding_extraction():
    def __init__(self, model_class, model_name, device):
        self.model_class = model_class
        if model_class=="DNABERT2":
            self.model_name = f"zhihan1996/{model_name}"
            with BlockImport("triton"):
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_name, trust_remote_code=True
                )
                self.model = AutoModel.from_pretrained(
                    self.model_name, trust_remote_code=True
                )
        elif model_class=="HyenaDNA":
            self.model_name = f"LongSafari/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, padding_side="right")
            self.model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True)
        elif model_class=="Nucleotide Transformer":
            self.model_name = f"InstaDeepAI/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            self.model = AutoModelForMaskedLM.from_pretrained(self.model_name, trust_remote_code=True)
        else:
          print("Model not supported.")
        self.device = device
        self.model.to(self.device)
        self.model.eval()

    def get_embedding(self, sequences, batch_size):
        embeddings = []
        for i in range(0, len(sequences), batch_size):
            batch = sequences[i:min(i+batch_size, len(sequences))]

            if self.model_class=="Nucleotide Transformer":
                enc = self.tokenizer.batch_encode_plus(
                    batch,
                    return_tensors="pt",
                    padding=True,
                    max_length=self.tokenizer.model_max_length,
                )
                input_ids = enc["input_ids"].to(self.device)
                attention_mask = enc["attention_mask"].to(self.device)
                with torch.no_grad():
                    outputs = self.model(
                        input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=True
                    )
                    hidden = outputs.hidden_states[-1]
                    mask_expanded = attention_mask.unsqueeze(-1)
                    summed = (hidden * mask_expanded).sum(dim=1)
                    counts = mask_expanded.sum(dim=1)
                    mean_embeddings = summed / counts
            elif self.model_class=="HyenaDNA":
                enc = self.tokenizer.batch_encode_plus(
                    batch,
                    return_tensors="pt",
                    padding="longest",
                    truncation=True,
                    max_length=self.tokenizer.model_max_length,
                )
                input_ids = enc["input_ids"].to(self.device)
                pad_id = self.tokenizer.pad_token_id
                if pad_id is None:
                    attention_mask = torch.ones_like(input_ids, device=self.device)
                else:
                    attention_mask = (input_ids != pad_id).long().to(self.device)
                with torch.no_grad():
                    outputs = self.model(
                        input_ids=input_ids,
                        output_hidden_states=True,
                    )
                    hidden = outputs.hidden_states[-1]
                    mask_expanded = attention_mask.unsqueeze(-1)

                    summed = (hidden * mask_expanded).sum(dim=1)
                    counts = mask_expanded.sum(dim=1).clamp(min=1e-9)
                    mean_embeddings = summed / counts

            elif self.model_class=="DNABERT2":
                enc = self.tokenizer(
                    batch,
                    return_tensors="pt",
                    padding=True,
                    truncation=True
                )
                input_ids = enc["input_ids"].to(self.device)
                attention_mask = enc["attention_mask"].to(self.device)

                with torch.no_grad():
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                    hidden = outputs[0]
                    mask_expanded = attention_mask.unsqueeze(-1).float()
                    summed = (hidden * mask_expanded).sum(dim=1)
                    counts = mask_expanded.sum(dim=1).clamp(min=1e-9)
                    mean_embeddings = summed / counts

            embeddings.append(mean_embeddings.cpu().numpy())
        return np.vstack(embeddings)




In [9]:
# def MLP model
import torch
import torch.nn as nn

# Define the MLP model for Binary Classification
class MLPBinary(nn.Module):
    def __init__(self, input_size):
        super(MLPBinary, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)  # First hidden layer
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128) # Second hidden layer
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 1)   # Output layer for binary classification

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x) # No sigmoid here, as BCEWithLogitsLoss will be used
        return x

In [10]:
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import numpy as np
import torch
import torch.nn as nn

def performance_cCREs(embeddings, label):
    X = embeddings
    y = label

    # train/val split (stratified)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=40, stratify=y
    )

    # # scale inputs
    # scaler = StandardScaler()
    # X_train = scaler.fit_transform(X_train)
    # X_test  = scaler.transform(X_test)

    # tensors
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X_train_t = torch.from_numpy(X_train).float().to(device)
    y_train_t = torch.from_numpy(y_train).float().reshape(-1, 1).to(device)
    X_test_t  = torch.from_numpy(X_test).float().to(device)

    # model: small, regularized
    input_size = X_train.shape[1]
    model = MLPBinary(input_size).to(device)  # ensure this has Dropout/BatchNorm or is small

    # class imbalance
    neg, pos = np.bincount(y_train)
    pos_weight = torch.tensor(neg / max(pos, 1), dtype=torch.float, device=device)

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # early stopping on val AUPRC
    best_ap, patience, wait = -np.inf, 20, 0
    num_epochs = 2000
    model.train()
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        logits = model(X_train_t)
        loss = criterion(logits, y_train_t)
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        probs = torch.sigmoid(model(X_test_t)).cpu().numpy().ravel()

    precision, recall, _ = precision_recall_curve(y_test, probs)
    ap = average_precision_score(y_test, probs)

    y_pred = (probs >= 0.5).astype(int)
    prec_cls = precision_score(y_test, y_pred, zero_division=0)
    rec_cls  = recall_score(y_test, y_pred, zero_division=0)
    f1_cls   = f1_score(y_test, y_pred, zero_division=0)
    acc_cls  = accuracy_score(y_test, y_pred)

    return (prec_cls, rec_cls, f1_cls, acc_cls)


In [11]:
import pandas as pd

# Path to your file
file_path = "/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks/EP Pair Evaluation/accessary_files/cCRE_pos_neg_dataset.csv"

# Read the compressed TSV file
df = pd.read_csv(file_path)

df_sub = (
    df.groupby("label", group_keys=False)
      .apply(lambda x: x.sample(n = 5000, random_state=42))
)

  .apply(lambda x: x.sample(n = 5000, random_state=42))


In [17]:
embedding_extractor_hyenaDNA = dnalm_embedding_extraction(model_class="HyenaDNA", model_name="hyenadna-large-1m-seqlen-hf", device=torch.device("cuda"))
embedding_extractor_dnabert2 = dnalm_embedding_extraction(model_class="DNABERT2", model_name="DNABERT-2-117M", device=torch.device("cuda"))
embedding_extractor_nt = dnalm_embedding_extraction(model_class="Nucleotide Transformer", model_name="nucleotide-transformer-v2-500m-multi-species", device=torch.device("cuda"))


Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
for embedding_extractor in [embedding_extractor_hyenaDNA, embedding_extractor_dnabert2, embedding_extractor_nt]:
    df_sub['embedding_mean'] = list(embedding_extractor.get_embedding(sequences=df_sub["sequence"].tolist(), batch_size=25))
    prec_cls, rec_cls, f1_cls, acc_cls = performance_cCREs(np.vstack(df_sub['embedding_mean']).astype(np.float32), df_sub["label"].astype(int).values)
    print(embedding_extractor.model_class)
    print(f"  Precision: {prec_cls:.3f}")
    print(f"  Recall   : {rec_cls:.3f}")
    print(f"  F1 score : {f1_cls:.3f}")
    print(f"  Accuracy : {acc_cls:.3f}\n")


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


HyenaDNA
  Precision: 0.668
  Recall   : 0.783
  F1 score : 0.721
  Accuracy : 0.697

DNABERT2
  Precision: 0.797
  Recall   : 0.789
  F1 score : 0.793
  Accuracy : 0.794





Nucleotide Transformer
  Precision: 0.737
  Recall   : 0.727
  F1 score : 0.732
  Accuracy : 0.734

