In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding
import torch.nn.functional as F
import datasets

In [2]:
# load test dataset 
ds = datasets.load_from_disk("binding_ds_200k")


In [3]:
# Custom data collator
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

esm_model_name = "facebook/esm2_t33_650M_UR50D"  # Replace with the correct ESM2 model name
chem_model_name = "seyonec/ChemBERTa-zinc-base-v1" # Replace with the correct ChemLLM model name

class CustomDataCollator:
    def __init__(self, chem_collator, esm_collator):
            self.chem_collator = chem_collator
            self.esm_collator = esm_collator

    def __call__(self, batch):
        batch_ligand = [{"input_ids": b["ligand_input_ids"], "attention_mask": b["ligand_attention_mask"]} for b in batch]
        batch_protein = [{"input_ids": b["protein_input_ids"], "attention_mask": b["protein_attention_mask"]} for b in batch]

        collated_chem = self.chem_collator(batch_ligand)
        collated_esm = self.esm_collator(batch_protein)

        return {
            "ligand_input_ids": collated_chem["input_ids"],
            "ligand_attention_mask": collated_chem["attention_mask"],
            "protein_input_ids": collated_esm["input_ids"],
            "protein_attention_mask": collated_esm["attention_mask"],
            "ic50": torch.tensor([x['ic50'] for x in batch])
        }

chem_tokenizer = AutoTokenizer.from_pretrained(chem_model_name)
esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
chem_collator = DataCollatorWithPadding(tokenizer=chem_tokenizer)
esm_collator = DataCollatorWithPadding(tokenizer=esm_tokenizer)
collator = CustomDataCollator(chem_collator=chem_collator, esm_collator=esm_collator)
bs=32
test_dataloader = DataLoader(ds['test'], batch_size=bs, collate_fn=collator)


In [4]:
WARMUP_STEPS = 5000
EPOCHS = 30
SAMPLE_SIZE = 200_000
BATCH_SIZE = 32
LR = 1e-4
NUM_LAYERS = 3

class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1, ffn_hidden_dim=2048):
        super(CrossAttentionLayer, self).__init__()
        self.protein_to_ligand_attention = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.ligand_to_protein_attention = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        ffn_hidden_dim = embed_dim * 3
        self.dropout = dropout
        self.ffn_protein = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden_dim),
            nn.GELU(),  # Non-linear activation
            nn.Dropout(dropout),
            nn.Linear(ffn_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )
        self.ffn_ligand = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden_dim),
            nn.GELU,  # Non-linear activation
            nn.Dropout(dropout),
            nn.Linear(ffn_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )
        self.protein_norm = nn.LayerNorm(embed_dim)
        self.ligand_norm = nn.LayerNorm(embed_dim)
        self.ffn_protein_norm = nn.LayerNorm(embed_dim)
        self.ffn_ligand_norm = nn.LayerNorm(embed_dim)

    def forward(
        self,
        protein_embedding,
        ligand_embedding,
        key_pad_mask_prot,
        key_pad_mask_ligand,
    ):
        # Protein attending to ligand
        attended_protein, _ = self.protein_to_ligand_attention(
            query=protein_embedding,
            key=ligand_embedding,
            value=ligand_embedding,
            key_padding_mask=key_pad_mask_ligand,
        )
        attended_protein = self.protein_norm(
            protein_embedding + nn.Dropout(attended_protein)
        )  # Residual connection
        x_prot = self.ffn_protein(attended_protein)
        x_prot = self.ffn_protein_norm(attended_protein + nn.Dropout(x_prot))

        # Ligand attending to protein
        attended_ligand, _ = self.ligand_to_protein_attention(
            query=ligand_embedding,
            key=protein_embedding,
            value=protein_embedding,
            key_padding_mask=key_pad_mask_prot,
        )
        attended_ligand = self.ligand_norm(
            ligand_embedding + attended_ligand
        )  # Residual connection
        x_ligand = self.ffn_ligand(attended_ligand)
        x_ligand = self.ffn_ligand_norm(attended_ligand + nn.Dropout(x_ligand))
        return x_prot, x_ligand


class BindingAffinityModelWithMultiHeadCrossAttention(nn.Module):
    def __init__(
        self, esm_model_name, chem_model_name, num_layers=NUM_LAYERS, hidden_dim=1024, dropout=0.1
    ):
        super().__init__()
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dropout = dropout
        # Load pretrained ESM2 model for proteins
        self.esm_model = AutoModel.from_pretrained(esm_model_name)
        self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
        self.esm_model.eval()

        # Load pretrained ChemLLM for SMILES (ligands)
        self.ligand_model = AutoModel.from_pretrained(chem_model_name)
        self.ligand_tokenizer = AutoTokenizer.from_pretrained(chem_model_name)
        self.ligand_model.eval()

        # Disable gradient computation for both base models
        for param in self.esm_model.parameters():
            param.requires_grad = False

        for param in self.ligand_model.parameters():
            param.requires_grad = False

        # Protein and SMILES embedding dimensions
        self.protein_embedding_dim = self.esm_model.config.hidden_size
        self.ligand_embedding_dim = self.ligand_model.config.hidden_size

        self.project_to_common = nn.Linear(
            self.protein_embedding_dim, self.ligand_embedding_dim
        )
        # self.project_to_common = nn.ModuleList([
        #     nn.Linear(self.protein_embedding_dim, self.ligand_embedding_dim)
        #     for _ in range(num_layers)
        # ])

        self.layers = nn.ModuleList(
            [
                CrossAttentionLayer(embed_dim=self.ligand_embedding_dim)
                for _ in range(num_layers)
            ]
        )
        self.final_norm = nn.LayerNorm(self.protein_embedding_dim + self.ligand_embedding_dim)

        self.ffn_ic50 = nn.Sequential(
            nn.Linear(2 * self.ligand_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(
        self,
        ligand_input_ids,
        ligand_attention_mask,
        protein_input_ids,
        protein_attention_mask,
    ):
        # Protein embedding
        # protein_inputs = self.esm_tokenizer(protein_sequence, return_tensors="pt")
        protein_inputs = {
            "input_ids": protein_input_ids,
            "attention_mask": protein_attention_mask,
        }
        # perform in FP16 for lower memory usage (matmuls)
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                protein_outputs = self.esm_model(**protein_inputs)
        special_tokens_mask_prot = (
            (protein_inputs["input_ids"] == self.esm_tokenizer.cls_token_id)
            | (protein_inputs["input_ids"] == self.esm_tokenizer.eos_token_id)
            | (protein_inputs["input_ids"] == self.esm_tokenizer.pad_token_id)
        )
        protein_embedding = protein_outputs.last_hidden_state

        # SMILES embedding
        ligand_inputs = {
            "input_ids": ligand_input_ids,
            "attention_mask": ligand_attention_mask,
        }
        special_tokens_mask_ligand = (
            (ligand_inputs["input_ids"] == self.ligand_tokenizer.bos_token_id)
            | (ligand_inputs["input_ids"] == self.ligand_tokenizer.eos_token_id)
            | (ligand_inputs["input_ids"] == self.ligand_tokenizer.pad_token_id)
        )

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                ligand_outputs = self.ligand_model(**ligand_inputs)
        ligand_embedding = ligand_outputs.last_hidden_state

        # project embeddings to same dimension
        protein_embedding = self.project_to_common(protein_embedding)

        for i, layer in enumerate(self.layers): # Adding skip-connections
            protein_embedding, ligand_embedding = layer(
                protein_embedding,
                ligand_embedding,
                special_tokens_mask_prot,
                special_tokens_mask_ligand,
            )
            if i % 2 == 1: # Every second layer add skip-connection
                protein_embedding = protein_embedding + prev_protein_embedding
                ligand_embedding = ligand_embedding + prev_ligand_embedding
            prev_protein_embedding, prev_ligand_embedding = protein_embedding, ligand_embedding


        # Perform mean pooling
        ligand_embedding = (
            ligand_embedding * ~special_tokens_mask_ligand.unsqueeze(dim=-1)
        ).mean(dim=1)
        protein_embedding = (
            protein_embedding * ~special_tokens_mask_prot.unsqueeze(dim=-1)
        ).mean(dim=1)
        # Combine embeddings
        combined = torch.cat([protein_embedding, ligand_embedding], dim=1)
        combined = self.final_norm(combined)
        ic50_prediction = self.ffn_ic50(combined)
        return ic50_prediction

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BindingAffinityModelWithMultiHeadCrossAttention(esm_model_name, chem_model_name).to(device)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tqdm import tqdm

def evaluate_model(model, test_loader):
        model.eval()
        all_predictions = []
        all_targets = []
        test_progress = tqdm(test_loader, desc="Test set")
        with torch.no_grad():
            for batch in test_progress:
                ligand_input_ids = batch["ligand_input_ids"].to(device)
                ligand_attention_mask = batch["ligand_attention_mask"].to(device)
                protein_input_ids = batch["protein_input_ids"].to(device)
                protein_attention_mask = batch["protein_attention_mask"].to(device)
                targets = batch["ic50"].unsqueeze(dim=-1).to(device)
                preds = model(
                    ligand_input_ids,
                    ligand_attention_mask,
                    protein_input_ids,
                    protein_attention_mask,
                )
                all_targets.append(targets)
                all_predictions.append(preds)

        all_predictions = torch.cat(all_predictions)
        all_targets = torch.cat(all_targets)

        loss = torch.nn.MSELoss()
        print(
            f"Test set mean squared error (MSE): {loss(all_predictions, all_targets)}"
        )

        mse = mean_squared_error(all_targets.cpu(), all_predictions.cpu())
        mae = mean_absolute_error(all_targets.cpu(), all_predictions.cpu())
        r2 = r2_score(all_targets.cpu(), all_predictions.cpu())
        print(f"MSE: {mse}, MAE: {mae}, R²: {r2}")

In [6]:
checkpoint = torch.load("affinity_03_02_2025_16_28_20_24.pt")
# checkpoint = torch.load("/root/finetuning/affinity_21_01_2025_23_00_56_26.pt")

In [7]:
model.load_state_dict(checkpoint['model_state_dict'])
# model.load_state_dict(checkpoint)

<All keys matched successfully>

In [8]:
evaluate_model(model, test_dataloader)

Test set: 100%|██████████| 260/260 [09:14<00:00,  2.13s/it]

Test set mean squared error (MSE): 6.241326808929443
MSE: 6.241326808929443, MAE: 1.8755537271499634, R²: 0.4396736217242718



