In [1]:
WARMUP_STEPS = 5000
EPOCHS = 5
BATCH_SIZE = 64
LR = 2e-4

import wandb
wandb.login()
wandb.init(
    project="nuclprot",
    name="testing nuclprot crossattn",
    config={
        "WARMUP_STEPS": WARMUP_STEPS,
        "EPOCHS": EPOCHS,
        "BATCH_SIZE": BATCH_SIZE,
        "LR": LR,
    },
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mnikolamilicevic[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [1]:
from datasets import load_dataset

ds = load_dataset("vladak/anthem_hla_seq")

In [2]:
ds

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
        num_rows: 539019
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
        num_rows: 172580
    })
})

In [None]:
import torch

class EmbeddingCache(torch.utils.data.Dataset):
    """This class will precompute embeddings for the data and cache
    them for future reuse."""
    def __init__(self, data, key, value, emb_model, device) -> None:
        """
        Args:
            data: Huggingface dataset that will be cached.
            key: Column with unique values for each sample (e.g. raw sequence).
            value: The value that we are computing embedding for and caching.
            emb_model: Model used for computing embeddings.
            device: Device on which embeddings will reside.
        """
        self.device = device
        self.data = self._filter_duplicates(data, key)
        self.cache = self._cache_embeddings(data, value)
        
        
        pass

    def _filter_duplicates(self, data, key):
        seen = set()
        filtered_dataset =  data['train'].filter(lambda example: not (example[key] in seen or seen.add(example[key])))
        return filtered_dataset
    
    def _cache_embeddings(self, data, value):
        # ret cache
        pass
    
    def __getitem__(self, key):
        pass

    def __len__(self):
        return len(self.data)

    

In [16]:
unique_seq = ds['train'].unique("Sequence")
seen = set()
filtered_dataset =  ds['train'].filter(lambda example: not (example["Sequence"] in seen or seen.add(example["Sequence"])))
filtered_dataset

Filter:   0%|          | 0/539019 [00:00<?, ? examples/s]

Dataset({
    features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence'],
    num_rows: 112
})

In [3]:
# # subsample
# from datasets import DatasetDict
# ds = DatasetDict({
#     split: ds[split].shuffle(seed=42).select(range(int(0.1 * len(ds[split]))))
#     for split in ds
# })

True

In [5]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding
import torch.nn.functional as F

# MultiHeadCrossAttention  
class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1, ffn_hidden_dim=2048):
        super(CrossAttentionLayer, self).__init__()
        self.protein_to_ligand_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ligand_to_protein_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        ffn_hidden_dim = embed_dim * 3
        self.ffn_protein = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden_dim),
            nn.ReLU(),  # Non-linear activation
            nn.Linear(ffn_hidden_dim, embed_dim),
        )
        self.ffn_ligand = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden_dim),
            nn.ReLU(),  # Non-linear activation
            nn.Linear(ffn_hidden_dim, embed_dim),
        )
        self.protein_norm = nn.LayerNorm(embed_dim)
        self.ligand_norm = nn.LayerNorm(embed_dim)
        self.ffn_protein_norm = nn.LayerNorm(embed_dim)
        self.ffn_ligand_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, protein_embedding, ligand_embedding, key_pad_mask_prot, key_pad_mask_ligand):
        # Protein attending to ligand
        attended_protein, _ = self.protein_to_ligand_attention(
            query=protein_embedding, 
            key=ligand_embedding,
            value=ligand_embedding,
            key_padding_mask=key_pad_mask_ligand
        )
        attended_protein = self.protein_norm(protein_embedding + attended_protein)  # Residual connection
        x_prot = self.ffn_protein(attended_protein)
        x_prot = self.ffn_protein_norm(attended_protein + self.dropout(x_prot))

        # Ligand attending to protein
        attended_ligand, _ = self.ligand_to_protein_attention(
            query=ligand_embedding, 
            key=protein_embedding, 
            value=protein_embedding,
            key_padding_mask=key_pad_mask_prot
        )
        attended_ligand = self.ligand_norm(ligand_embedding + attended_ligand)  # Residual connection
        x_ligand = self.ffn_ligand(attended_ligand)
        x_ligand = self.ffn_ligand_norm(attended_ligand + self.dropout(x_ligand))
        return x_prot, x_ligand

class BindingAffinityModelWithMultiHeadCrossAttention(nn.Module):
    def __init__(self, esm_model_name, dna_model_name, num_layers=3, hidden_dim=1024):
        super().__init__()
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load pretrained ESM2 model for proteins
        self.esm_model = AutoModel.from_pretrained(esm_model_name)
        self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
        self.esm_model.eval()
        
        # Load pretrained ChemLLM for SMILES (ligands)
        self.ligand_model = AutoModel.from_pretrained(dna_model_name)
        self.ligand_tokenizer = AutoTokenizer.from_pretrained(dna_model_name)
        self.ligand_model.eval()

        # Disable gradient computation for both base models
        for param in self.esm_model.parameters():
            param.requires_grad = False

        for param in self.ligand_model.parameters():
            param.requires_grad = False
        
        # Protein and SMILES embedding dimensions
        self.protein_embedding_dim = self.esm_model.config.hidden_size
        self.ligand_embedding_dim = self.ligand_model.config.hidden_size

        self.project_to_common = nn.Linear(self.ligand_embedding_dim, self.protein_embedding_dim)

        self.layers = nn.ModuleList([
            CrossAttentionLayer(embed_dim=self.protein_embedding_dim) for _ in range(num_layers)
        ])

        self.ffn_class_head = nn.Sequential(
            nn.Linear(2 * self.protein_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        
    def forward(
            self,
            ligand_input_ids,
            ligand_attention_mask,
            protein_input_ids,
            protein_attention_mask
        ):
        # Protein embedding
        # protein_inputs = self.esm_tokenizer(protein_sequence, return_tensors="pt")
        protein_inputs = {
            "input_ids": protein_input_ids,
            "attention_mask": protein_attention_mask
        }
        # perform in FP16 for lower memory usage (matmuls)
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                protein_outputs = self.esm_model(**protein_inputs)
        special_tokens_mask_prot = (protein_inputs['input_ids'] == self.esm_tokenizer.cls_token_id)\
        | (protein_inputs['input_ids'] == self.esm_tokenizer.eos_token_id)\
        | (protein_inputs['input_ids'] == self.esm_tokenizer.pad_token_id)
        protein_embedding = protein_outputs.last_hidden_state
        
        # SMILES embedding
        ligand_inputs = {
            "input_ids": ligand_input_ids,
            "attention_mask": ligand_attention_mask
        }
        special_tokens_mask_ligand = (ligand_inputs['input_ids'] == self.ligand_tokenizer.bos_token_id)\
        | (ligand_inputs['input_ids'] == self.ligand_tokenizer.eos_token_id)\
        | (ligand_inputs['input_ids'] == self.ligand_tokenizer.pad_token_id)

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                ligand_outputs = self.ligand_model(**ligand_inputs)
        ligand_embedding = ligand_outputs.last_hidden_state

        # project embeddings to same dimension
        ligand_embedding = self.project_to_common(ligand_embedding)
        
        for layer in self.layers:
            protein_embedding, ligand_embedding = layer(protein_embedding, ligand_embedding, special_tokens_mask_prot, special_tokens_mask_ligand)

        # Perform mean pooling
        ligand_embedding = (ligand_embedding * ~special_tokens_mask_ligand.unsqueeze(dim=-1)).mean(dim=1)
        protein_embedding = (protein_embedding * ~special_tokens_mask_prot.unsqueeze(dim=-1)).mean(dim=1)
        # Combine embeddings
        combined = torch.cat([protein_embedding, ligand_embedding], dim=1)
        logits = self.ffn_class_head(combined)
        return logits


esm_model_name = "facebook/esm2_t33_650M_UR50D"  # Replace with the correct ESM2 model name
dna_model_name = "InstaDeepAI/nucleotide-transformer-2.5b-multi-species" # Replace with the correct ChemLLM model name


In [6]:
# Data split 
from datasets import DatasetDict

dataset_test = ds['test']
dataset_test_val = dataset_test.train_test_split(test_size=0.5)

dataset_dict = {
    "train": ds['train'],
    "test": dataset_test_val["train"],
    "validation": dataset_test_val['test']
}
dataset = DatasetDict(dataset_dict)
dataset

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', '__index_level_0__'],
        num_rows: 510896
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', '__index_level_0__'],
        num_rows: 81560
    })
    validation: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', '__index_level_0__'],
        num_rows: 81560
    })
})

In [7]:
# Tokenization of DNA and protein sequences
dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name)
esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
print(f"dna tokenizer is fast: {dna_tokenizer.is_fast}")
print(f"esm tokenizer is fast: {esm_tokenizer.is_fast}")

def tokenize_dna(examples):
    toks = dna_tokenizer(examples["Sequence"], truncation=True)
    return {
        "dna_input_ids": toks["input_ids"],
        "dna_attention_mask": toks["attention_mask"]
    }

def tokenize_proteins(examples):
    toks =  esm_tokenizer(examples["peptide"], truncation=True)
    return {
        "protein_input_ids": toks["input_ids"],
        "protein_attention_mask": toks["attention_mask"]
    }

tokenized_dataset = dataset.map(tokenize_proteins, batched=True)
tokenized_dataset = tokenized_dataset.map(tokenize_dna, batched=True)
tokenized_dataset

dna tokenizer is fast: False
esm tokenizer is fast: False


Map:   0%|          | 0/81560 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/81560 [00:00<?, ? examples/s]

Map:   0%|          | 0/81560 [00:00<?, ? examples/s]

Map:   0%|          | 0/81560 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', '__index_level_0__', 'protein_input_ids', 'protein_attention_mask', 'dna_input_ids', 'dna_attention_mask'],
        num_rows: 510896
    })
    test: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', '__index_level_0__', 'protein_input_ids', 'protein_attention_mask', 'dna_input_ids', 'dna_attention_mask'],
        num_rows: 81560
    })
    validation: Dataset({
        features: ['HLA', 'peptide', 'Label', 'Length', 'Sequence', '__index_level_0__', 'protein_input_ids', 'protein_attention_mask', 'dna_input_ids', 'dna_attention_mask'],
        num_rows: 81560
    })
})

In [8]:
# Custom data collator
from transformers import DataCollatorWithPadding

class CustomDataCollator:
    def __init__(self, dna_collator, esm_collator):
            self.dna_collator = dna_collator
            self.esm_collator = esm_collator

    def __call__(self, batch):
        batch_dna = [{"input_ids": b["dna_input_ids"], "attention_mask": b["dna_attention_mask", "sequence": b["Sequence"]]} for b in batch]
        batch_protein = [{"input_ids": b["protein_input_ids"], "attention_mask": b["protein_attention_mask"], "peptide": b["peptide"]} for b in batch]

        collated_dna = self.dna_collator(batch_dna)
        collated_esm = self.esm_collator(batch_protein)

        return {
            "id"
            "dna_input_ids": collated_dna["input_ids"],
            "dna_attention_mask": collated_dna["attention_mask"],
            "protein_input_ids": collated_esm["input_ids"],
            "protein_attention_mask": collated_esm["attention_mask"],
            "label": torch.tensor([x['Label'] for x in batch])
        }

In [9]:
from torch.utils.data import DataLoader
dna_collator = DataCollatorWithPadding(tokenizer=dna_tokenizer)
esm_collator = DataCollatorWithPadding(tokenizer=esm_tokenizer)
collator = CustomDataCollator(dna_collator=dna_collator, esm_collator=esm_collator)
train_dataloader = DataLoader(tokenized_dataset["train"], batch_size=BATCH_SIZE, collate_fn=collator)
test_dataloader = DataLoader(tokenized_dataset["test"], batch_size=BATCH_SIZE, collate_fn=collator)
val_dataloader = DataLoader(tokenized_dataset["validation"], batch_size=BATCH_SIZE, collate_fn=collator)

In [10]:
# Training loop 
import torch.nn as nn
from tqdm import tqdm

    
def lr_lambda(step):
        if step < WARMUP_STEPS:
            # Linear warmup
            return step / WARMUP_STEPS
        else:
            remaining_steps = total_steps - WARMUP_STEPS
            decay_step = step - WARMUP_STEPS
            return max(0.5 * LR, 1.0 - 0.5 * (decay_step / remaining_steps))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
model = BindingAffinityModelWithMultiHeadCrossAttention(esm_model_name, dna_model_name).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
total_steps = EPOCHS * len(train_dataloader)


def train_model(model, train_dataloader, val_dataloader):
    step = 0
    ACCUMULATION_STEPS = 2

    for epoch in range(EPOCHS):
        print(f"Epoch: {epoch + 1}/{EPOCHS}")
        model.train()
        train_loss = 0.0
        train_progress = tqdm(train_dataloader, desc="Training")

        for batch in train_progress:
            dna_input_ids = batch["dna_input_ids"].to(device)
            dna_attention_mask = batch["dna_attention_mask"].to(device)
            protein_input_ids = batch["protein_input_ids"].to(device)
            protein_attention_mask = batch["protein_attention_mask"].to(device)
            targets = batch["label"].unsqueeze(dim=-1).to(device)
            preds = model(dna_input_ids, dna_attention_mask, protein_input_ids, protein_attention_mask)
            loss = criterion(preds, targets.float())
            loss.backward()
            train_loss += loss.item()
            step += 1
            if step % ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
            scheduler.step()

            if step % 100 == 0:
                wandb.log({"train_loss": loss.item()})
                wandb.log({"lr": optimizer.param_groups[0]["lr"]})

        train_loss /= len(train_dataloader)
        print(f"Epoch: {epoch} Train loss: {train_loss}")

        model.eval()
        val_loss = 0.0
        val_progress = tqdm(val_dataloader, desc="Validation")
        with torch.no_grad():
            for batch in val_progress:
                dna_input_ids = batch["dna_input_ids"].to(device)
                dna_attention_mask = batch["dna_attention_mask"].to(device)
                protein_input_ids = batch["protein_input_ids"].to(device)
                protein_attention_mask = batch["protein_attention_mask"].to(device)
                targets = batch["label"].unsqueeze(dim=-1).to(device)
                preds = model(dna_input_ids, dna_attention_mask, protein_input_ids, protein_attention_mask)
                loss = criterion(preds, targets.float())
                val_loss += loss.item()
                
        val_loss /= len(val_dataloader)
        scheduler.step(val_loss)
        print(f"Epoch: {epoch} Val loss: {val_loss}")

Using device cuda


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmModel were not initialized from the model checkpoint at InstaDeepAI/nucleotide-transformer-2.5b-multi-species and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
train_model(model, train_dataloader, val_dataloader) 

Epoch: 1/5


Training:   0%|          | 0/7983 [00:00<?, ?it/s]

Training: 100%|██████████| 7983/7983 [3:28:54<00:00,  1.57s/it]  


Epoch: 0 Train loss: 0.33527016314545716


Validation: 100%|██████████| 1275/1275 [26:39<00:00,  1.25s/it]


Epoch: 0 Val loss: 0.27273947266387005
Epoch: 2/5


Training:  79%|███████▉  | 6293/7983 [2:44:17<44:07,  1.57s/it]  


KeyboardInterrupt: 

In [23]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total, trainable = count_parameters(model)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")

Total parameters: 3,293,934,871
Trainable parameters: 104,297,729


In [22]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, classification_report


# def evaluate_model(model, test_loader):
model.eval()
all_predictions = []
all_targets = []
test_progress = tqdm(test_dataloader, desc="Test set")
with torch.no_grad():
    for batch in test_progress:
        ligand_input_ids = batch["dna_input_ids"].to(device)
        ligand_attention_mask = batch["dna_attention_mask"].to(device)
        protein_input_ids = batch["protein_input_ids"].to(device)
        protein_attention_mask = batch["protein_attention_mask"].to(device)
        targets = batch["label"].unsqueeze(dim=-1).to(device)
        preds = model(
            ligand_input_ids,
            ligand_attention_mask,
            protein_input_ids,
            protein_attention_mask,
        )
        # transform preds to 0 - 1 
        # do sigmoid or sth
        probs = torch.sigmoid(preds)
        preds = (probs > 0.5).float()
        all_targets.append(targets)
        all_predictions.append(preds)

all_predictions = torch.cat(all_predictions).cpu()
all_targets = torch.cat(all_targets).cpu()

accuracy = accuracy_score(all_targets, all_predictions)
precision = precision_score(all_targets, all_predictions)
recall = recall_score(all_targets, all_predictions)
f1 = f1_score(all_targets, all_predictions)
auc = roc_auc_score(all_targets, all_predictions)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")
print(classification_report(all_targets, all_predictions))

# evaluate_model(model, test_loader=test_dataloader)

Test set: 100%|██████████| 255/255 [02:42<00:00,  1.57it/s]


Accuracy: 0.8590
Precision: 0.8560
Recall: 0.8577
F1-score: 0.8568
              precision    recall  f1-score   support

           0       0.86      0.86      0.86      4144
           1       0.86      0.86      0.86      4012

    accuracy                           0.86      8156
   macro avg       0.86      0.86      0.86      8156
weighted avg       0.86      0.86      0.86      8156



In [13]:
dataset['train'].features

{'HLA': Value(dtype='string', id=None),
 'peptide': Value(dtype='string', id=None),
 'Label': Value(dtype='int64', id=None),
 'Length': Value(dtype='int64', id=None),
 'Sequence': Value(dtype='string', id=None),
 '__index_level_0__': Value(dtype='int64', id=None)}