In [None]:
# !pip install mamba_ssm

Collecting mamba_ssm
  Using cached mamba_ssm-2.2.6.post3.tar.gz (113 kB)
[31mERROR: Operation cancelled by user[0m[31m
[0m  Installing build dependencies ... [?25l[?25hcanceled

In [None]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel, AutoModelForCausalLM, BertConfig
import sys
import transformers.models.bert.modeling_bert
import builtins

class BlockImport:
    def __init__(self, *blocked):
        self.blocked = set(blocked)

    def __enter__(self):
        self._orig_import = builtins.__import__

        def fake_import(name, *args, **kwargs):
            if any(name == b or name.startswith(b + ".") for b in self.blocked):
                raise ImportError(f"Blocked import of {name}")
            return self._orig_import(name, *args, **kwargs)

        builtins.__import__ = fake_import

    def __exit__(self, exc_type, exc_value, traceback):
        builtins.__import__ = self._orig_import


class dnalm_embedding_extraction():
    def __init__(self, model_class, model_name, device):
        self.model_class = model_class
        if model_class=="DNABERT2":
            self.model_name = f"zhihan1996/{model_name}"
            # with NoModule("triton"):
            # with NoTriton():
            with BlockImport("triton"):
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_name, trust_remote_code=True
                )
                config = BertConfig.from_pretrained(self.model_name, trust_remote_code=True)
                self.model = AutoModelForMaskedLM.from_pretrained(self.model_name, config=config, trust_remote_code=True)
                self.mask_token = self.tokenizer.mask_token_id
        elif model_class=="HyenaDNA":
            self.model_name = f"LongSafari/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, padding_side="right")
            self.model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True)

        elif model_class=="Nucleotide Transformer":
            self.model_name = f"InstaDeepAI/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            self.model = AutoModelForMaskedLM.from_pretrained(self.model_name, trust_remote_code=True)
            self.mask_token = self.tokenizer.mask_token_id
        elif model_class=="Caduceus":
            self.model_name = f"kuleshov-group/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, padding_side="right")
            self.model = AutoModelForMaskedLM.from_pretrained(self.model_name, trust_remote_code=True)
            self.mask_token = self.tokenizer.mask_token_id
        elif model_class=="Mistral":
            self.model_name = f"RaphaelMourad/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True)
            self.mask_token = self.tokenizer.mask_token_id
        elif model_class=="GENA-LM":
            self.model_name = f"AIRI-Institute/{model_name}"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
            self.mask_token = self.tokenizer.mask_token_id
        else:
          print("Model not supported.")
        self.device = device
        self.model.to(self.device)
        self.model.eval()


    @property
    def start_token(self):
        if self.model_class=="HyenaDNA":
            return None
        elif self.model_class=="DNABERT2":
            return 1
        elif self.model_class=="Nucleotide Transformer":
            return 3
        elif self.model_class=="Caduceus":
            return None
        elif self.model_class=="Mistral":
            return 1
        elif self.model_class=="GENA-LM":
            return 1

    @property
    def end_token(self):
        if self.model_class=="HyenaDNA":
            return 1
        elif self.model_class=="DNABERT2":
            return 2
        elif self.model_class=="Nucleotide Transformer":
            return None
        elif self.model_class=="Caduceus":
            return 1
        elif self.model_class=="Mistral":
            return 2
        elif self.model_class=="GENA-LM":
            return 2

    def get_embedding(self, sequences, batch_size):
        embeddings = []
        for i in range(0, len(sequences), batch_size):
            # if i%50000==0:
            #     print(i)
            batch = sequences[i:min(i+batch_size, len(sequences))]

            if self.model_class=="Nucleotide Transformer":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        return_dict=True,
                    )

                clip_mask = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(clip_mask.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))
                if attention_mask is not None:
                    clip_mask = clip_mask * attention_mask

                hidden = torch_outs.hidden_states[-1]
                mask = clip_mask.unsqueeze(-1)
                summed = (hidden * mask).sum(dim=1)
                counts = mask.sum(dim=1).clamp(min=1e-9)
                mean_embeddings = summed / counts

            elif self.model_class=="Mistral":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                        output_hidden_states=True,
                        return_dict=True,
                    )

                clip_mask = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(clip_mask.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))
                if attention_mask is not None:
                    clip_mask = clip_mask * attention_mask

                hidden = torch_outs.hidden_states[-1]

                mask = clip_mask.unsqueeze(-1)
                summed = (hidden * mask).sum(dim=1)
                counts = mask.sum(dim=1).clamp(min=1e-9)
                mean_embeddings = summed / counts

            elif self.model_class=="HyenaDNA":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                        output_hidden_states=True,
                        return_dict=True,
                    )

                clip_mask = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(clip_mask.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))
                if attention_mask is not None:
                    clip_mask = clip_mask * attention_mask

                hidden = torch_outs.hidden_states[-1]

                mask = clip_mask.unsqueeze(-1)
                summed = (hidden * mask).sum(dim=1)
                counts = mask.sum(dim=1).clamp(min=1e-9)
                mean_embeddings = summed / counts

            elif self.model_class=="DNABERT2":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        return_dict=True,
                    )

                clip_mask = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(clip_mask.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))
                if attention_mask is not None:
                    clip_mask = clip_mask * attention_mask

                # !!! due to the bug in its code, DNABERT2 can only return the last hidden layer
                hidden = torch_outs.hidden_states
                mask = clip_mask.unsqueeze(-1)
                summed = (hidden * mask).sum(dim=1)
                counts = mask.sum(dim=1).clamp(min=1e-9)
                mean_embeddings = summed / counts

            elif self.model_class=="Caduceus":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        return_dict=True,
                    )

                clip_mask = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(clip_mask.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))
                if attention_mask is not None:
                    clip_mask = clip_mask * attention_mask

                hidden = torch_outs.hidden_states[-1]
                mask = clip_mask.unsqueeze(-1)
                summed = (hidden * mask).sum(dim=1)
                counts = mask.sum(dim=1).clamp(min=1e-9)
                mean_embeddings = summed / counts

            elif self.model_class=="GENA-LM":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        return_dict=True,
                    )

                clip_mask = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(clip_mask.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))
                if attention_mask is not None:
                    clip_mask = clip_mask * attention_mask

                hidden = torch_outs.hidden_states[-1]
                mask = clip_mask.unsqueeze(-1)
                summed = (hidden * mask).sum(dim=1)
                counts = mask.sum(dim=1).clamp(min=1e-9)
                mean_embeddings = summed / counts

            embeddings.append(mean_embeddings.cpu().numpy())
        return np.vstack(embeddings)


    def get_likelihood(self, sequences, batch_size):
        """
        Compute log-likelihoods of sequences.
        Returns: numpy array of log-likelihoods (one per sequence)
        """
        import torch.nn.functional as F

        likelihoods = []

        for i in range(0, len(sequences), batch_size):
            # if i % 50000 == 0:
            #     print(i)
            batch = sequences[i:min(i+batch_size, len(sequences))]

            if self.model_class == "Nucleotide Transformer":

                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)
                lls = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(tokens.shape[1]):
                    clip_mask = ((i >= starts) & (i < ends)).to(device=self.device)
                    masked_tokens = tokens.clone()
                    masked_tokens[:,i,...] = self.mask_token
                    with torch.no_grad():
                        torch_outs = self.model(
                            masked_tokens,
                            attention_mask=attention_mask,
                        )
                        logits = torch_outs.logits.swapaxes(1, 2)
                        tmp = -F.cross_entropy(logits, tokens, reduction="none")
                        lls[:,i] = tmp[:,i] * clip_mask

                seq_likelihoods = lls.sum(dim=1).numpy(force=True)

            elif self.model_class == "Mistral":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                    )
                    logits = torch_outs.logits.swapaxes(1, 2)
                    lls = torch.zeros(tokens.shape[:2], device=self.device)
                    lls[:,1:] = -F.cross_entropy(logits[:,:,:-1], tokens[:,1:], reduction="none")

                clip_mask = torch.zeros_like(lls)
                for i in range(lls.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))

                seq_likelihoods = (lls * clip_mask).sum(1).numpy(force=True)

            elif self.model_class == "HyenaDNA":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)

                with torch.no_grad():
                    torch_outs = self.model(
                        tokens,
                    )
                    logits = torch_outs.logits.swapaxes(1, 2)
                    lls = torch.zeros(tokens.shape[:2], device=self.device)
                    lls[:,1:] = -F.cross_entropy(logits[:,:,:-1], tokens[:,1:], reduction="none")

                clip_mask = torch.zeros_like(lls)
                for i in range(lls.shape[1]):
                    clip_mask[:,i] = ((i >= starts) & (i < ends))

                seq_likelihoods = (lls * clip_mask).sum(1).numpy(force=True)

            elif self.model_class == "DNABERT2":

                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)
                lls = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(tokens.shape[1]):
                    clip_mask = ((i >= starts) & (i < ends)).to(device=self.device)
                    masked_tokens = tokens.clone()
                    masked_tokens[:,i,...] = self.mask_token
                    with torch.no_grad():
                        torch_outs = self.model(
                            masked_tokens,
                            attention_mask=attention_mask,
                        )
                        logits = torch_outs.logits.swapaxes(1, 2)
                        tmp = -F.cross_entropy(logits, tokens, reduction="none")
                        lls[:,i] = tmp[:,i] * clip_mask

                seq_likelihoods = lls.sum(dim=1).numpy(force=True)

            elif self.model_class == "Caduceus":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)
                lls = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(tokens.shape[1]):
                    clip_mask = ((i >= starts) & (i < ends)).to(device=self.device)
                    masked_tokens = tokens.clone()
                    masked_tokens[:,i,...] = self.mask_token
                    with torch.no_grad():
                        torch_outs = self.model(
                            masked_tokens,
                            attention_mask=attention_mask,
                        )
                        logits = torch_outs.logits.swapaxes(1, 2)
                        tmp = -F.cross_entropy(logits, tokens, reduction="none")
                        lls[:,i] = tmp[:,i] * clip_mask

                seq_likelihoods = lls.sum(dim=1).numpy(force=True)
            elif self.model_class == "GENA-LM":
                encoded = self.tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True)
                tokens = encoded["input_ids"]
                attention_mask = encoded.get("attention_mask")
                if self.start_token is not None:
                    starts = torch.where(tokens == self.start_token)[1] + 1
                else:
                    starts = 0
                if self.end_token is not None:
                    ends = torch.where(tokens == self.end_token)[1]
                else:
                    ends = attention_mask.sum(dim=1)

                tokens = tokens.to(device=self.device)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device=self.device)
                lls = torch.zeros(tokens.shape[:2], device=self.device)
                for i in range(tokens.shape[1]):
                    clip_mask = ((i >= starts) & (i < ends)).to(device=self.device)
                    masked_tokens = tokens.clone()
                    masked_tokens[:,i,...] = self.mask_token
                    with torch.no_grad():
                        torch_outs = self.model(
                            masked_tokens,
                            attention_mask=attention_mask,
                        )
                        logits = torch_outs.logits.swapaxes(1, 2)
                        tmp = -F.cross_entropy(logits, tokens, reduction="none")
                        lls[:,i] = tmp[:,i] * clip_mask

                seq_likelihoods = lls.sum(dim=1).numpy(force=True)

            likelihoods.append(seq_likelihoods)

        return np.concatenate(likelihoods)

In [None]:
# def MLP model
import torch
import torch.nn as nn

# Define the MLP model for Binary Classification
class MLPBinary(nn.Module):
    def __init__(self, input_size):
        super(MLPBinary, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)  # First hidden layer
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128) # Second hidden layer
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 1)   # Output layer for binary classification

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x) # No sigmoid here, as BCEWithLogitsLoss will be used
        return x

In [None]:
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_recall_curve, average_precision_score,
    precision_score, recall_score, f1_score, accuracy_score
)
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader


def performance_cCREs(embeddings, label, batch_size=256, num_epochs=50):
    """
    Train a small MLP classifier on embeddings with mini-batch SGD.

    Parameters
    ----------
    embeddings : np.ndarray, shape (N, D)
    label      : array-like, shape (N,)
    batch_size : int
        Mini-batch size for training and evaluation.
    num_epochs : int
        Number of passes over the training set.

    Returns
    -------
    prec_cls, rec_cls, f1_cls, acc_cls : float
        Precision, recall, F1 and accuracy on the held-out test set.
    """
    X = embeddings
    y = np.asarray(label).astype(int)

    # train/test split (stratified)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=40, stratify=y
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # tensors
    X_train_t = torch.from_numpy(X_train).float()
    y_train_t = torch.from_numpy(y_train).float().view(-1, 1)
    X_test_t  = torch.from_numpy(X_test).float()
    y_test_t  = torch.from_numpy(y_test).float().view(-1, 1)

    # datasets & loaders
    train_ds = TensorDataset(X_train_t, y_train_t)
    test_ds  = TensorDataset(X_test_t, y_test_t)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  drop_last=False)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False)

    # model
    input_size = X_train.shape[1]
    model = MLPBinary(input_size).to(device)  # your existing MLPBinary

    # loss (no class weighting since data are balanced)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # --- training loop (batched) ---
    model.train()
    for epoch in range(num_epochs):
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)

            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()

    # --- evaluation (batched) ---
    model.eval()
    all_probs = []
    all_y_true = []

    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.to(device)
            logits = model(xb)
            probs = torch.sigmoid(logits).cpu().numpy().ravel()

            all_probs.append(probs)
            all_y_true.append(yb.numpy().ravel())

    probs  = np.concatenate(all_probs, axis=0)
    y_true = np.concatenate(all_y_true, axis=0).astype(int)

    # metrics
    precision, recall, _ = precision_recall_curve(y_true, probs)
    ap = average_precision_score(y_true, probs)

    y_pred = (probs >= 0.5).astype(int)
    prec_cls = precision_score(y_true, y_pred, zero_division=0)
    rec_cls  = recall_score(y_true, y_pred, zero_division=0)
    f1_cls   = f1_score(y_true, y_pred, zero_division=0)
    acc_cls  = accuracy_score(y_true, y_pred)

    return prec_cls, rec_cls, f1_cls, acc_cls

In [None]:
import pandas as pd

# Path to your file
file_path = "/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks/EP Pair Evaluation/accessary_files/cCRE_pos_neg_dataset.csv"

# Read the compressed TSV file
df = pd.read_csv(file_path)

In [None]:
## Note:
## Mistral-DNA-v1-1.6B-hg38 not available on HuggingFace anymore

embedding_extractor_hyenaDNA = dnalm_embedding_extraction(model_class="HyenaDNA", model_name="hyenadna-large-1m-seqlen-hf", device=torch.device("cuda"))
embedding_extractor_mistral = dnalm_embedding_extraction(model_class="Mistral", model_name="Mistral-DNA-v1-422M-hg38", device=torch.device("cuda"))

embedding_extractor_dnabert2 = dnalm_embedding_extraction(model_class="DNABERT2", model_name="DNABERT-2-117M", device=torch.device("cuda"))
embedding_extractor_nt = dnalm_embedding_extraction(model_class="Nucleotide Transformer", model_name="nucleotide-transformer-v2-500m-multi-species", device=torch.device("cuda"))
# embedding_extractor_caduceus = dnalm_embedding_extraction(model_class="Caduceus", model_name="caduceus-ps_seqlen-131k_d_model-256_n_layer-16", device=torch.device("cuda"))
embedding_extractor_genalm = dnalm_embedding_extraction(model_class="GENA-LM", model_name="gena-lm-bert-large-t2t", device=torch.device("cuda"))

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


## Probed Absolute Accuracy

In [None]:
df_sub = (
    df.groupby("label", group_keys=False)
      .apply(lambda x: x.sample(n =5000, random_state=42))
)
# df_sub = df

  .apply(lambda x: x.sample(n =5000, random_state=42))


In [None]:
for embedding_extractor in [embedding_extractor_mistral, embedding_extractor_genalm]: #, embedding_extractor_caduceus
    df_sub['embedding_mean'] = list(embedding_extractor.get_embedding(sequences=df_sub["sequence"].tolist(), batch_size=100))
    prec_cls, rec_cls, f1_cls, acc_cls = performance_cCREs(np.vstack(df_sub['embedding_mean']).astype(np.float32), df_sub["label"].astype(int).values, num_epochs = 100)
    print(embedding_extractor.model_class)
    print(f"  Precision: {prec_cls:.3f}")
    print(f"  Recall   : {rec_cls:.3f}")
    print(f"  F1 score : {f1_cls:.3f}")
    print(f"  Accuracy : {acc_cls:.3f}\n")

Mistral
  Precision: 0.730
  Recall   : 0.806
  F1 score : 0.766
  Accuracy : 0.754





GENA-LM
  Precision: 0.871
  Recall   : 0.822
  F1 score : 0.846
  Accuracy : 0.850



In [None]:
for embedding_extractor in [embedding_extractor_hyenaDNA, embedding_extractor_dnabert2, embedding_extractor_nt]:
    df_sub['embedding_mean'] = list(embedding_extractor.get_embedding(sequences=df_sub["sequence"].tolist(), batch_size=100))
    prec_cls, rec_cls, f1_cls, acc_cls = performance_cCREs(np.vstack(df_sub['embedding_mean']).astype(np.float32), df_sub["label"].astype(int).values, num_epochs = 100)
    print(embedding_extractor.model_class)
    print(f"  Precision: {prec_cls:.3f}")
    print(f"  Recall   : {rec_cls:.3f}")
    print(f"  F1 score : {f1_cls:.3f}")
    print(f"  Accuracy : {acc_cls:.3f}\n")


HyenaDNA
  Precision: 0.622
  Recall   : 0.789
  F1 score : 0.696
  Accuracy : 0.655

DNABERT2
  Precision: 0.827
  Recall   : 0.753
  F1 score : 0.788
  Accuracy : 0.798

Nucleotide Transformer
  Precision: 0.752
  Recall   : 0.763
  F1 score : 0.758
  Accuracy : 0.756



## Zero-shot accuracy

In [None]:
midpoint = df.shape[0] // 2
df_upper = df.iloc[:midpoint].copy().reset_index(drop=True)
df_lower = df.iloc[midpoint:].copy().reset_index(drop=True)
df_lower = df_lower.rename(columns={'sequence': 'sequence_shf', 'subtype': "subtype_shf", "label": "label_shf"})
df_combined = pd.concat([df_upper, df_lower], axis=1)

df_sub = df_combined.sample(n=1000, random_state = 42)
# df_sub = df_combined

In [None]:
for embedding_extractor in [embedding_extractor_mistral, embedding_extractor_genalm]: #, embedding_extractor_caduceus
    df_sub['seq_pos_liklihood'] = list(embedding_extractor.get_likelihood(sequences=df_sub["sequence"].tolist(), batch_size=50))
    df_sub['seq_neg_liklihood'] = list(embedding_extractor.get_likelihood(sequences=df_sub["sequence_shf"].tolist(), batch_size=50))
    print(embedding_extractor.model_class)
    print((df_sub['seq_neg_liklihood']<=df_sub['seq_pos_liklihood']).sum() / df_sub.shape[0])

Mistral
0.895




GENA-LM
0.941


In [None]:
for embedding_extractor in [embedding_extractor_hyenaDNA, embedding_extractor_dnabert2, embedding_extractor_nt]:
    df_sub['seq_pos_liklihood'] = list(embedding_extractor.get_likelihood(sequences=df_sub["sequence"].tolist(), batch_size=50))
    df_sub['seq_neg_liklihood'] = list(embedding_extractor.get_likelihood(sequences=df_sub["sequence_shf"].tolist(), batch_size=50))
    print(embedding_extractor.model_class)
    print((df_sub['seq_neg_liklihood']<=df_sub['seq_pos_liklihood']).sum() / df_sub.shape[0])

HyenaDNA
0.857
DNABERT2
0.836
Nucleotide Transformer
0.709


In [None]:
from google.colab import runtime

runtime.unassign()