In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random
import matplotlib.pyplot as plt

from sklearn import metrics
from scipy import stats
from collections import Counter
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay, classification_report, roc_curve

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()
import training_utils.partitioning_utils as pat_utils
from tqdm import trange

Tesla V100-SXM2-32GB


  warn(
  _torch_pytree._register_pytree_node(


### Loading PPint and meta-analysis data

In [None]:
Df_train = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_test = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test_w_pbd_lens.csv",index_col=0).reset_index(drop=True)

Df_train["target_chain"] = [str(row.ID1[:5]+row.ID1[-1]) for __, row in Df_train.iterrows()]
Df_train["binder_chain"] = [str(row.ID2[:5]+row.ID2[-1]) for __, row in Df_train.iterrows()]

Df_test["target_chain"] = [str(row.ID1[:5]+row.ID1[-1]) for __, row in Df_test.iterrows()]
Df_test["binder_chain"] = [str(row.ID2[:5]+row.ID2[-1]) for __, row in Df_test.iterrows()]

Df_train["target_binder_id"] = [str(row.ID1)+"_"+str(row.ID2) for __, row in Df_train.iterrows()]
Df_test["target_binder_id"] = [str(row.ID1)+"_"+str(row.ID2) for __, row in Df_test.iterrows()]

Df_train_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
Df_test_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)

Df_train = pd.merge(Df_train, Df_train_small[["target_binder_id", "dimer"]], on="target_binder_id", how="inner")
Df_test = pd.merge(Df_test, Df_test_small[["target_binder_id", "dimer"]], on="target_binder_id", how="inner")

Df_train

## ESM-IF

In [None]:
class CLIP_PPint_class(Dataset):
    def __init__(
        self,
        dframe,
        path,
        embedding_dim=512,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen = self.dframe["pdb_binder_len"].max()+2
        self.max_tlen = self.dframe["pdb_target_len"].max()+2

        # paths
        self.encoding_path  = path

        # index & storage
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = parts[0]+"_"+parts[2]
            bnd_id = parts[3]+"_"+parts[5]

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_path, f"{bnd_id}.npy"))     # [Lb, D]

            assert (b_emb.shape[0] == self.dframe.loc[accession].pdb_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].pdb_target_len+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb))

        # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return binder_emb, target_emb, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

emb_path = "/work3/s232958/data/PPint_DB/esmif_embeddings_noncanonical"

testing_Dataset = CLIP_PPint_class(
    Df_test,
    path=emb_path,
    embedding_dim=512
)

### Getting indeces of non-dimers
indices_non_dimers_Df = Df_test[~Df_test["dimer"]]
indices_non_dimers = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers[:5]

non_dimers_Dataset = CLIP_PPint_class(
    indices_non_dimers_Df,
    path=emb_path,
    embedding_dim=512
)

In [None]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal_w_pbd_lens.csv").drop(columns = ["binder_id", "target_id"]).rename(columns = {
    "target_id_mod" : "target_id",
    "target_binder_ID" : "binder_id",
})

# Interaction Dict
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=0).reset_index(drop=True)
interaction_df_shuffled

In [None]:
class CLIP_Meta_class(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=512,
        embedding_pad_value=-5000.0
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.max_blen = self.dframe["pdb_len_binder"].max()+2
        self.max_tlen = self.dframe["pdb_len_target"].max()+2

        # paths
        self.encoding_bpath, self.encoding_tpath = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))
            parts = accession.split("_")
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_tpath, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_bpath, f"{bnd_id}.npy"))     # [Lb, D]

            assert (b_emb.shape[0] == self.dframe.loc[accession].pdb_len_binder+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].pdb_len_target+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

bemb_path = "/work3/s232958/data/meta_analysis/esmif_embeddings_binders"
temb_path = "/work3/s232958/data/meta_analysis/esmif_embeddings_targets"

validation_Dataset = CLIP_Meta_class(
    # interaction_df_shuffled[:len(Df_test)],
    interaction_df_shuffled,
    paths=[bemb_path, temb_path],
    embedding_dim=512
)

### Pre-trained model

In [None]:
embedding_dimension = 512

def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=embedding_dimension, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = 512

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension*2
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding), dim=-1)
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding), dim=-1)
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        embedding_binder, embedding_target, labels = batch
        embedding_binder = embedding_binder.to(device)
        embedding_target = embedding_target.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_binder, embedding_target)
            logits = logits.float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [None]:
model = MiniCLIP_w_transformer_crossattn().to("cuda")
path = "/work3/s232958/data/trained/original_architecture/cb12a130-9881-423e-88ba-9e18969fdb5f/cb12a130-9881-423e-88ba-9e18969fdb5f_checkpoint_6/cb12a130-9881-423e-88ba-9e18969fdb5f_checkpoint_epoch_6.pth"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(path, map_location=device)
# print(list(checkpoint["model_state_dict"]))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

model.eval()

In [None]:
def collate_varlen(batch):
    b_emb = torch.stack([x[0] for x in batch], dim=0)
    t_emb = torch.stack([x[1] for x in batch], dim=0)
    lbls = torch.tensor([x[2].float() for x in batch])
    return b_emb, t_emb, lbls

test_dataloader = DataLoader(testing_Dataset, batch_size=10, collate_fn=collate_varlen)
non_dimers_dataloader = DataLoader(non_dimers_Dataset, batch_size=10, collate_fn=collate_varlen)
validation_dataloader = DataLoader(validation_Dataset, batch_size=15, shuffle=False, drop_last = False)

print("len(validation_Dataset):", len(validation_Dataset))
print("len(validation_dataloader.dataset):", len(validation_dataloader.dataset))

In [None]:
ckpt = torch.load(path, map_location="cpu")
missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=False)
print("epoch in ckpt:", ckpt.get("epoch"))
print("missing:", missing)
print("unexpected:", unexpected)
print("logit_scale:", model.logit_scale.item())

### test-data

In [None]:
interaction_scores_pos = []
interaction_scores_neg = []    

for batch in tqdm(test_dataloader, total=round(len(Df_test)/10), desc="#Iterating through batched data"):
    b_emb, t_emb, lbls = batch
    embedding_pep = b_emb.to("cuda")
    embedding_prot = t_emb.to("cuda")

    with torch.no_grad():

        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = model(embedding_pep, embedding_prot)
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = model(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        # print(logit_matrix)
        interaction_scores_pos.append(positive_logits)
        interaction_scores_neg.append(negative_logits)

# Convert list of tensors to single 1D tensors
pos_logits = torch.cat(interaction_scores_pos).detach().cpu().numpy()
neg_logits = torch.cat(interaction_scores_neg).detach().cpu().numpy()
print("Positives:", pos_logits.shape)
print("Negatives:", neg_logits.shape)

In [None]:
plt.figure(figsize=(6, 4))
plt.hist(pos_logits, bins=50, alpha=0.6, label="Positive pairs", density=True)
plt.hist(neg_logits, bins=50, alpha=0.6, label="Negative pairs", density=True)

plt.xlabel("Logit score")
plt.ylabel("Density")
plt.title("Distribution of Positive vs Negative Interaction Scores")
plt.legend()
plt.show()

#### non-dimers

In [None]:
interaction_scores = []
for batch in tqdm(non_dimers_dataloader, total=round(len(non_dimers_Dataset)/10), desc="#Iterating through batched data"):
    b_emb, t_emb, lbls = batch
    embedding_pep = b_emb.to("cuda")
    embedding_prot = t_emb.to("cuda")

    with torch.no_grad():

        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = model(embedding_pep, embedding_prot)
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = model(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        # print(logit_matrix)
        interaction_scores_pos.append(positive_logits)
        interaction_scores_neg.append(negative_logits)

# Convert list of tensors to single 1D tensors
pos_logits = torch.cat(interaction_scores_pos).detach().cpu().numpy()
neg_logits = torch.cat(interaction_scores_neg).detach().cpu().numpy()
print("Positives:", pos_logits.shape)
print("Negatives:", neg_logits.shape)

In [None]:
plt.figure(figsize=(6, 4))
plt.hist(pos_logits, bins=50, alpha=0.6, label="Positive pairs", density=True)
plt.hist(neg_logits, bins=50, alpha=0.6, label="Negative pairs", density=True)

plt.xlabel("Logit score")
plt.ylabel("Density")
plt.title("Distribution of Positive vs Negative Interaction Scores")
plt.legend()
plt.show()

#### meta-binders

In [None]:
from sklearn import metrics

model.eval()
all_logits, all_lbls = [], []

with torch.no_grad():
    for batch in validation_dataloader:
        embedding_binder, embedding_target, labels = batch
        logits, _ = model.validation_step_MetaDataset(batch, device="cuda")
        all_logits.append(logits.detach().view(-1).cpu())
        all_lbls.append(labels.detach().view(-1).cpu())
        
all_logits = torch.cat(all_logits).numpy()
all_lbls   = torch.cat(all_lbls).numpy()
fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
meta_auroc = metrics.roc_auc_score(all_lbls, all_logits)
meta_aupr  = metrics.average_precision_score(all_lbls, all_logits)
# from sklearn.metrics import roc_auc_score
print("AUROC:", meta_auroc)

In [None]:
all_lbls

In [None]:
# Loading batches
interaction_scores = []

for batch in tqdm(validation_dataloader, total = round(len(interaction_df_shuffled)/10),  desc= "#Iterating through batched data"):
    b_emb, t_emb, lbls = batch
    embedding_pep = b_emb.to("cuda")
    embedding_prot = t_emb.to("cuda")

    with torch.no_grad():
        positive_logits = model(embedding_pep, embedding_prot)
        interaction_scores.append(positive_logits.unsqueeze(0))

predicted_interaction_scores = np.concatenate([batch_score.cpu().detach().numpy().reshape(-1,) for batch_score in interaction_scores])
interaction_probabilities = np.concatenate([torch.sigmoid(batch_score[0]).cpu().numpy() for batch_score in interaction_scores])

pos_logits, neg_logits = [], []
for i, row in interaction_df_shuffled.iterrows():
    logit = predicted_interaction_scores[i]
    if row.binder == False:
        neg_logits.append(logit)
    elif row.binder == True:
        pos_logits.append(logit)

plt.figure(figsize=(5, 4))

plt.hist(neg_logits, bins=50, alpha=0.6, label="Negative pairs", density=True)
plt.hist(pos_logits, bins=50, alpha=0.6, label="Positive pairs", density=True)

plt.xlabel("Logit score")
plt.ylabel("Density")
plt.title("Distribution of Positive vs Negative Interaction Scores")

# --- Simple grid behind ---
plt.grid(True, linestyle="--", alpha=0.5)

plt.legend()
plt.tight_layout()
plt.show()

In [None]:
interaction_df_shuffled["inter_prob"] = interaction_probabilities
interaction_df_shuffled["pred_binder"] = interaction_df_shuffled["inter_prob"] >= 0.5
interaction_df_shuffled["intr_scores"] = predicted_interaction_scores

pred_labels = interaction_probabilities >= 0.5
true_labels = np.array(interaction_df_shuffled["binder"])

true_positives = ((pred_labels == 1) & (true_labels == 1)).sum().item()
true_negatives = ((pred_labels == 0) & (true_labels == 0)).sum().item()
false_positives = ((pred_labels == 1) & (true_labels == 0)).sum().item()
false_negatives = ((pred_labels == 0) & (true_labels == 1)).sum().item()

predicted_positives = true_positives + false_positives
all_real_positives = true_positives + false_negatives

print(classification_report(true_labels, pred_labels, digits = 4))

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(true_labels, pred_labels))
disp.plot(ax=axes[0])
axes[0].set_title("Confusion Matrix")

TPR = true_positives / (true_positives + true_negatives) # how good the model is at predicting the positive class when the actual outcome is positive.
# sensitivity = true_positives / (true_positives + false_negatives) # the same as TPR
FPR = false_positives / (false_positives + true_negatives) # how often a positive class is predicted when the actual outcome is negative.
# specificity = true_negatives / (true_negatives + false_positives) # FPR = 1 - specificity

fpr, tpr, thresholds = roc_curve(true_labels, interaction_probabilities)
auc = roc_auc_score(true_labels, interaction_probabilities)
print('AUC: %.3f' % auc)

axes[1].plot(fpr, tpr, linewidth=2)
axes[1].plot([0, 1], [0, 1], linestyle="--", linewidth=1)  # diagonal reference
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1)
axes[1].set_title('ROC Curve')

# show the plot
plt.tight_layout()
plt.show()

## ESM-IF + ESM-2

In [None]:
class CLIP_PPint_w_esmIF(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim_struct=512,
        embedding_dim_seq=1280,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim_seq = embedding_dim_seq
        self.embedding_dim_struct = embedding_dim_struct
        self.emb_pad = embedding_pad_value

        # lengths
        self.max_blen_seq = self.dframe["seq_binder_len"].max()
        self.max_tlen_seq = self.dframe["seq_target_len"].max()
        self.max_blen_struct = self.dframe["pdb_binder_len"].max()
        self.max_tlen_struct = self.dframe["pdb_target_len"].max()

        # paths
        self.seq_encodings_path, self.struct_encodings_path = paths

        # index & storage
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = parts[0]+"_"+parts[2]
            bnd_id = parts[-3]+"_"+parts[-1]

            ### --- SEQ embeddings (pad to fixed lengths) --- ###
            # laod embeddings
            t_emb_seq = np.load(os.path.join(self.seq_encodings_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb_seq = np.load(os.path.join(self.seq_encodings_path, f"{bnd_id}.npy"))     # [Lb, D]
            t_emb_struct = np.load(os.path.join(self.struct_encodings_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb_struct = np.load(os.path.join(self.struct_encodings_path, f"{bnd_id}.npy"))     # [Lb, D]

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb_seq.shape[1] != self.embedding_dim_seq or b_emb_seq.shape[1] != self.embedding_dim_seq:
                raise ValueError("Embedding dim mismatch with 'embedding_dim_seq'.")
            if t_emb_struct.shape[1] != self.embedding_dim_struct or b_emb_struct.shape[1] != self.embedding_dim_struct:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")
                
            # add -5000 to all the padded target rows
                ### SEQ_embeddings ###
            if t_emb_seq.shape[0] < self.max_tlen_seq:
                t_emb_seq = np.concatenate([t_emb_seq, np.full((self.max_tlen_seq - t_emb_seq.shape[0], t_emb_seq.shape[1]), self.emb_pad, dtype=t_emb_seq.dtype)], axis=0)
            else:
                t_emb_seq = t_emb_seq[: self.max_tlen_seq] # no padding was used
            if b_emb_seq.shape[0] < self.max_blen_seq:
                b_emb_seq = np.concatenate([b_emb_seq, np.full((self.max_blen_seq - b_emb_seq.shape[0], b_emb_seq.shape[1]), self.emb_pad, dtype=b_emb_seq.dtype)], axis=0)
            else:
                b_emb_seq = b_emb_seq[: self.max_blen_seq] # no padding was used

                ### STRUCT_embeddings ###
            if t_emb_struct.shape[0] < self.max_tlen_struct:
                t_emb_struct = np.concatenate([t_emb_struct, np.full((self.max_tlen_struct - t_emb_struct.shape[0], t_emb_struct.shape[1]), self.emb_pad, dtype=t_emb_struct.dtype)], axis=0)
            else:
                t_emb_struct = t_emb_struct[: self.max_tlen_struct] # no padding was used
            if b_emb_struct.shape[0] < self.max_blen_struct:
                b_emb_struct = np.concatenate([b_emb_struct, np.full((self.max_blen_struct - b_emb_struct.shape[0], b_emb_struct.shape[1]), self.emb_pad, dtype=b_emb_struct.dtype)], axis=0)
            else:
                b_emb_struct = b_emb_struct[: self.max_blen_struct] # no padding was used

            self.samples.append((b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct = self.samples[idx]
        b_emb_seq, t_emb_seq = torch.from_numpy(b_emb_seq).float(), torch.from_numpy(t_emb_seq).float()
        b_emb_struct, t_emb_struct = torch.from_numpy(b_emb_struct).float(), torch.from_numpy(t_emb_struct).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_emb_seq_list, t_emb_seq_list, b_emb_struct_list, t_emb_struct_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b_emb_seq  = torch.stack([torch.as_tensor(x) for x in b_emb_seq_list],  dim=0)  # [B, ...]
        t_emb_seq  = torch.stack([torch.as_tensor(x) for x in t_emb_seq_list],  dim=0)  # [B, ...]
        
        b_emb_struct  = torch.stack([torch.as_tensor(x) for x in b_emb_struct_list],  dim=0)  # [B, ...]
        t_emb_struct  = torch.stack([torch.as_tensor(x) for x in t_emb_struct_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, labels

emb_seq_path = "/work3/s232958/data/PPint_DB/embeddings_esm2"
emb_struct_path = "/work3/s232958/data/PPint_DB/esmif_embeddings_noncanonical"

testing_Dataset = CLIP_PPint_w_esmIF(
    Df_test,
    paths=[emb_seq_path, emb_struct_path],
    embedding_dim_seq=1280,
    embedding_dim_struct=512
)

### Getting indeces of non-dimers
non_dimers_Df = Df_test[~Df_test["dimer"]]
indices_non_dimers = Df_test[~Df_test["dimer"]].index.tolist()

non_dimers_Dataset = CLIP_PPint_w_esmIF(
    non_dimers_Df,
    paths=[emb_seq_path, emb_struct_path],
    embedding_dim_seq=1280,
    embedding_dim_struct=512
)

In [None]:
class CLIP_ESM2_ESMIF(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        seq_embedding_dim=1280,
        struct_embedding_dim=512,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim_seq = int(seq_embedding_dim)
        self.embedding_dim_struct = int(struct_embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen_seq = self.dframe["seq_len_binder"].max()
        self.max_tlen_seq = self.dframe["seq_len_target"].max()
        self.max_blen_struct = self.dframe["pdb_len_binder"].max()
        self.max_tlen_struct = self.dframe["pdb_len_target"].max()

        # paths
        self.seq_bembed, self.seq_tembed, self.struct_bembed, self.struct_tembed = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- SEQ embeddings (pad to fixed lengths) --- ###
            # laod embeddings
            b_emb_seq = np.load(os.path.join(self.seq_bembed, f"{bnd_id}.npy"))     # [Lb, D]
            t_emb_seq = np.load(os.path.join(self.seq_tembed, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb_struct = np.load(os.path.join(self.struct_bembed, f"{bnd_id}.npy"))     # [Lb, D]
            t_emb_struct = np.load(os.path.join(self.struct_tembed, f"{tgt_id}.npy"))     # [Lt, D]

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb_seq.shape[1] != self.embedding_dim_seq or b_emb_seq.shape[1] != self.embedding_dim_seq:
                raise ValueError("Embedding dim mismatch with 'embedding_dim_seq'.")
            if t_emb_struct.shape[1] != self.embedding_dim_struct or b_emb_struct.shape[1] != self.embedding_dim_struct:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")
                
            # add -5000 to all the padded target rows
                ### SEQ_embeddings ###
            if t_emb_seq.shape[0] < self.max_tlen_seq:
                t_emb_seq = np.concatenate([t_emb_seq, np.full((self.max_tlen_seq - t_emb_seq.shape[0], t_emb_seq.shape[1]), self.emb_pad, dtype=t_emb_seq.dtype)], axis=0)
            else:
                t_emb_seq = t_emb_seq[: self.max_tlen_seq] # no padding was used
            if b_emb_seq.shape[0] < self.max_blen_seq:
                b_emb_seq = np.concatenate([b_emb_seq, np.full((self.max_blen_seq - b_emb_seq.shape[0], b_emb_seq.shape[1]), self.emb_pad, dtype=b_emb_seq.dtype)], axis=0)
            else:
                b_emb_seq = b_emb_seq[: self.max_blen_seq] # no padding was used

                ### STRUCT_embeddings ###
            if t_emb_struct.shape[0] < self.max_tlen_struct:
                t_emb_struct = np.concatenate([t_emb_struct, np.full((self.max_tlen_struct - t_emb_struct.shape[0], t_emb_struct.shape[1]), self.emb_pad, dtype=t_emb_struct.dtype)], axis=0)
            else:
                t_emb_struct = t_emb_struct[: self.max_tlen_struct] # no padding was used
            if b_emb_struct.shape[0] < self.max_blen_struct:
                b_emb_struct = np.concatenate([b_emb_struct, np.full((self.max_blen_struct - b_emb_struct.shape[0], b_emb_struct.shape[1]), self.emb_pad, dtype=b_emb_struct.dtype)], axis=0)
            else:
                b_emb_struct = b_emb_struct[: self.max_blen_struct] # no padding was used

            self.samples.append((b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, lbl))
        
    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, lbls = self.samples[idx]
        b_emb_seq, t_emb_seq = torch.from_numpy(b_emb_seq).float(), torch.from_numpy(t_emb_seq).float()
        b_emb_struct, t_emb_struct = torch.from_numpy(b_emb_struct).float(), torch.from_numpy(t_emb_struct).float()
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_emb_seq_list, t_emb_seq_list, b_emb_struct_list, t_emb_struct_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b_emb_seq  = torch.stack([torch.as_tensor(x) for x in b_emb_seq_list],  dim=0)  # [B, ...]
        t_emb_seq  = torch.stack([torch.as_tensor(x) for x in t_emb_seq_list],  dim=0)  # [B, ...]
        
        b_emb_struct  = torch.stack([torch.as_tensor(x) for x in b_emb_struct_list],  dim=0)  # [B, ...]
        t_emb_struct  = torch.stack([torch.as_tensor(x) for x in t_emb_struct_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, labels

esm2_path_binders = "/work3/s232958/data/meta_analysis/embeddings_esm2_binders"
esm2_path_targets = "/work3/s232958/data/meta_analysis/embeddings_esm2_targets"

## Contact maps paths
esmIF_path_binders = "/work3/s232958/data/meta_analysis/esmif_embeddings_binders"
esmIF_path_targets = "/work3/s232958/data/meta_analysis/esmif_embeddings_targets"

validation_Dataset = CLIP_ESM2_ESMIF(
    interaction_df_shuffled,
    paths=[esm2_path_binders, esm2_path_targets, esmIF_path_binders, esmIF_path_targets],
)

In [None]:
test_dataloader = DataLoader(testing_Dataset, batch_size=10)
non_dimers_dataloader = DataLoader(non_dimers_Dataset, batch_size=10)
validation_dataloader = DataLoader(validation_Dataset, batch_size=10, shuffle=False, drop_last = False)

### Pre-trained model

In [2]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_ESM2_ESMIF(pl.LightningModule):

    def __init__(self, padding_value = -5000, seq_embed_dimension=1280, struct_embed_dimension=512, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.seq_embed_dimension = seq_embed_dimension
        self.struct_embed_dimension = struct_embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        self.struct_alpha = nn.Parameter(torch.tensor(0.0))

        # --- SEQUENCE embeddings --- #
        
        self.norm_seq = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        self.seq_encoder =  nn.TransformerEncoderLayer(
            d_model=self.seq_embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.seq_embed_dimension
            )

        self.seq_cross_attn = nn.MultiheadAttention(
            embed_dim=self.seq_embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.seq_proj = nn.Sequential(
            nn.Linear(self.seq_embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )

        # --- STRUCTURE embeddings --- #

        self.norm_struct = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        self.initial_stuct_proj = nn.Linear(self.struct_embed_dimension, self.seq_embed_dimension)

        self.struct_encoder =  nn.TransformerEncoderLayer(
            d_model=self.seq_embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.seq_embed_dimension
            )

        self.struct_to_seq_attn = nn.MultiheadAttention(
            embed_dim=self.seq_embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
    def forward(self, pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True):
        
        # Key padding masks (True = pad -> to be ignored by attention)
        pep_seq_mask = create_key_padding_mask(embeddings = pep_seq_emb, padding_value = self.padding_value).to(device)   # [B, Lp]
        prot_seq_mask = create_key_padding_mask(embeddings = prot_seq_emb, padding_value = self.padding_value).to(device)    # [B, Lt]
        
        pep_struct_mask = create_key_padding_mask(embeddings = pep_struct_emb, padding_value = self.padding_value).to(device)     # [B, Lp_cm]
        prot_struct_mask = create_key_padding_mask(embeddings = prot_struct_emb, padding_value = self.padding_value).to(device)     # [B, Lt_cm]
    
        # Residual states
        pep_seq_emb = pep_seq_emb.to(device)
        prot_seq_emb = prot_seq_emb.to(device)
        pep_struct_emb = pep_struct_emb.to(device)
        prot_struct_emb = prot_struct_emb.to(device)
    
        for _ in range(self.num_recycles):
            
            # --- Self-attention encoders (sequence streams) ---
            pep_trans_seq = self.seq_encoder(self.norm_seq(pep_seq_emb), src_key_padding_mask=pep_seq_mask)   # [B, Lp, E]
            prot_trans_seq = self.seq_encoder(self.norm_seq(prot_seq_emb), src_key_padding_mask=prot_seq_mask)  # [B, Lt, E]
    
            # --- Self-attention encoders (structure streams) ---
            pep_trans_str = self.struct_encoder(self.norm_struct(self.initial_stuct_proj(pep_struct_emb)), src_key_padding_mask=pep_struct_mask)   # [B, Lp_cm, E]
            prot_trans_str = self.struct_encoder(self.norm_struct(self.initial_stuct_proj(prot_struct_emb)), src_key_padding_mask=prot_struct_mask)  # [B, Lt_cm, E]

            # --- Cross-attend to structures ---
            pep_struct_upd, _ = self.struct_to_seq_attn(query=self.norm_seq(pep_trans_seq), key=self.norm_struct(pep_trans_str), value=self.norm_struct(pep_trans_str), key_padding_mask=pep_struct_mask)
            prot_struct_upd, _ = self.struct_to_seq_attn(query=self.norm_seq(prot_trans_seq), key=self.norm_struct(prot_trans_str), value=self.norm_struct(prot_trans_str), key_padding_mask=prot_struct_mask)

            pep_trans_seq  = pep_trans_seq  + self.struct_alpha.tanh() * pep_struct_upd    # [B, Lp, E]
            prot_trans_seq = prot_trans_seq + self.struct_alpha.tanh() * prot_struct_upd    # [B, Lt, E]
    
            # --- Cross-attend binder vs target ---
            pep_cross,  _  = self.seq_cross_attn(query=self.norm_seq(pep_trans_seq), key=self.norm_seq(prot_trans_seq), value=self.norm_seq(prot_trans_seq), key_padding_mask=prot_seq_mask)
            prot_cross, _  = self.seq_cross_attn(query=self.norm_seq(prot_trans_seq), key=self.norm_seq(pep_trans_seq), value=self.norm_seq(pep_trans_seq), key_padding_mask=pep_seq_mask)
    
            # --- Residual updates ---
            pep_seq_emb = pep_seq_emb + pep_cross
            prot_seq_emb = prot_seq_emb + prot_cross
    
        # Pool (mean over non-masked positions)
        pep_seq_coding   = create_mean_of_non_masked(pep_seq_emb, pep_seq_mask)
        prot_seq_coding  = create_mean_of_non_masked(prot_seq_emb, prot_seq_mask)

        # Projections + L2-normalize
        pep_full   = F.normalize(self.seq_proj(pep_seq_coding),   dim=-1)
        prot_full  = F.normalize(self.seq_proj(prot_seq_coding),  dim=-1)
    
        if mem_save:
            torch.cuda.empty_cache()
    
        scale  = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_full * prot_full).sum(dim=-1)  # [B]
        
        return logits

    def training_step(self, batch, device):
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch

        # loss of predicting partner using peptide
        positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device)) # F.binary_cross_entropy_with_logits does sigmoid transfromation inside, excepts data, labels
        
        # Negative indexes
        rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1) # upper triangle
        
        # loss of predicting peptide using partner
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb = pep_seq_emb.to(device), prot_seq_emb.to(device), pep_struct_emb.to(device), prot_struct_emb.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1)
            
            negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2

            logit_matrix = torch.zeros((pep_seq_emb.size(0), pep_seq_emb.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(pep_seq_emb.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(pep_seq_emb.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            return loss, peptide_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch
        pep_seq_emb, prot_seq_emb = pep_seq_emb.to(device), prot_seq_emb.to(device) 
        pep_struct_emb, prot_struct_emb = pep_struct_emb.to(device), prot_struct_emb.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb).float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self, pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb):
        
        rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1)
        positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
        negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((pep_seq_emb.size(0),pep_seq_emb.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(pep_seq_emb.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [5]:
model = MiniCLIP_ESM2_ESMIF().to("cuda")
path = "/work3/s232958/data/trained/with_structure/2dca0ab0-422d-4567-8970-30ab1504f5b2/9644ac4d-47d5-4c18-a6f4-285950dbfb97_checkpoint_9/9644ac4d-47d5-4c18-a6f4-285950dbfb97_checkpoint_epoch_9.pth"
checkpoint = torch.load(path, weights_only=False, map_location=torch.device('cpu'))
# print(list(checkpoint["model_state_dict"]))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

alpha_raw = model.struct_alpha.item()
print(alpha_raw)
print(torch.sigmoid(torch.tensor(alpha_raw)))

-0.00883762538433075
tensor(0.4978)


#### Sigmoid(0.0)

In [8]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_ESM2_ESMIF(pl.LightningModule):

    def __init__(self, padding_value = -5000, seq_embed_dimension=1280, struct_embed_dimension=512, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.seq_embed_dimension = seq_embed_dimension
        self.struct_embed_dimension = struct_embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        self.struct_alpha = nn.Parameter(torch.tensor(0.0)) # Sigmoid(0) = 0.5

        # --- SEQUENCE embeddings --- #
        
        self.norm_seq = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        self.seq_encoder =  nn.TransformerEncoderLayer(
            d_model=self.seq_embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.seq_embed_dimension
            )

        self.seq_cross_attn = nn.MultiheadAttention(
            embed_dim=self.seq_embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.seq_proj = nn.Sequential(
            nn.Linear(self.seq_embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )

        # --- STRUCTURE embeddings --- #

        self.norm_struct = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        self.initial_stuct_proj = nn.Linear(self.struct_embed_dimension, self.seq_embed_dimension)

        self.struct_encoder =  nn.TransformerEncoderLayer(
            d_model=self.seq_embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.seq_embed_dimension
            )

        self.struct_to_seq_attn = nn.MultiheadAttention(
            embed_dim=self.seq_embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
    def forward(self, pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True):
        
        # Key padding masks (True = pad -> to be ignored by attention)
        pep_seq_mask = create_key_padding_mask(embeddings = pep_seq_emb, padding_value = self.padding_value).to(device)   # [B, Lp]
        prot_seq_mask = create_key_padding_mask(embeddings = prot_seq_emb, padding_value = self.padding_value).to(device)    # [B, Lt]
        
        pep_struct_mask = create_key_padding_mask(embeddings = pep_struct_emb, padding_value = self.padding_value).to(device)     # [B, Lp_cm]
        prot_struct_mask = create_key_padding_mask(embeddings = prot_struct_emb, padding_value = self.padding_value).to(device)     # [B, Lt_cm]
    
        # Residual states
        pep_seq_emb = pep_seq_emb.to(device)
        prot_seq_emb = prot_seq_emb.to(device)
        pep_struct_emb = pep_struct_emb.to(device)
        prot_struct_emb = prot_struct_emb.to(device)
    
        for _ in range(self.num_recycles):
            
            # --- Self-attention encoders (sequence streams) ---
            pep_trans_seq = self.seq_encoder(self.norm_seq(pep_seq_emb), src_key_padding_mask=pep_seq_mask)   # [B, Lp, E]
            prot_trans_seq = self.seq_encoder(self.norm_seq(prot_seq_emb), src_key_padding_mask=prot_seq_mask)  # [B, Lt, E]
    
            # --- Self-attention encoders (structure streams) ---
            pep_trans_str = self.struct_encoder(self.norm_struct(self.initial_stuct_proj(pep_struct_emb)), src_key_padding_mask=pep_struct_mask)   # [B, Lp_cm, E]
            prot_trans_str = self.struct_encoder(self.norm_struct(self.initial_stuct_proj(prot_struct_emb)), src_key_padding_mask=prot_struct_mask)  # [B, Lt_cm, E]

            # --- Cross-attend to structures ---
            pep_struct_upd, _ = self.struct_to_seq_attn(query=self.norm_seq(pep_trans_seq), key=self.norm_struct(pep_trans_str), value=self.norm_struct(pep_trans_str), key_padding_mask=pep_struct_mask)
            prot_struct_upd, _ = self.struct_to_seq_attn(query=self.norm_seq(prot_trans_seq), key=self.norm_struct(prot_trans_str), value=self.norm_struct(prot_trans_str), key_padding_mask=prot_struct_mask)

            current_alpha = torch.sigmoid(self.struct_alpha)
            pep_trans_emb  = pep_trans_seq  + current_alpha * pep_struct_upd    # [B, Lp, E]
            prot_trans_emb = prot_trans_seq + current_alpha * prot_struct_upd    # [B, Lt, E]
    
            # --- Cross-attend binder vs target ---
            pep_cross,  _  = self.seq_cross_attn(query=self.norm_seq(pep_trans_seq), key=self.norm_seq(prot_trans_seq), value=self.norm_seq(prot_trans_seq), key_padding_mask=prot_seq_mask)
            prot_cross, _  = self.seq_cross_attn(query=self.norm_seq(prot_trans_seq), key=self.norm_seq(pep_trans_seq), value=self.norm_seq(pep_trans_seq), key_padding_mask=pep_seq_mask)
    
            # --- Residual updates ---
            pep_seq_emb = pep_seq_emb + pep_cross
            prot_seq_emb = prot_seq_emb + prot_cross
    
        # Pool (mean over non-masked positions)
        pep_seq_coding   = create_mean_of_non_masked(pep_seq_emb, pep_seq_mask)
        prot_seq_coding  = create_mean_of_non_masked(prot_seq_emb, prot_seq_mask)

        # Projections + L2-normalize
        pep_full   = F.normalize(self.seq_proj(pep_seq_coding),   dim=-1)
        prot_full  = F.normalize(self.seq_proj(prot_seq_coding),  dim=-1)
    
        if mem_save:
            torch.cuda.empty_cache()
    
        scale  = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_full * prot_full).sum(dim=-1)  # [B]
        
        return logits

    def training_step(self, batch, device):
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch

        # loss of predicting partner using peptide
        positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device)) # F.binary_cross_entropy_with_logits does sigmoid transfromation inside, excepts data, labels
        
        # Negative indexes
        rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1) # upper triangle
        
        # loss of predicting peptide using partner
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb = pep_seq_emb.to(device), prot_seq_emb.to(device), pep_struct_emb.to(device), prot_struct_emb.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1)
            
            negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2

            logit_matrix = torch.zeros((pep_seq_emb.size(0), pep_seq_emb.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(pep_seq_emb.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(pep_seq_emb.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            return loss, peptide_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch
        pep_seq_emb, prot_seq_emb = pep_seq_emb.to(device), prot_seq_emb.to(device) 
        pep_struct_emb, prot_struct_emb = pep_struct_emb.to(device), prot_struct_emb.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb).float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self, pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb):
        
        rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1)
        positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
        negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((pep_seq_emb.size(0),pep_seq_emb.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(pep_seq_emb.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [9]:
model = MiniCLIP_ESM2_ESMIF().to("cuda")
path = "/work3/s232958/data/trained/with_structure/ab0444eb-1438-4513-a60f-6fef0340bb95/d26175a6-5994-4146-b8f4-041f1e8fd80a_checkpoint_9/d26175a6-5994-4146-b8f4-041f1e8fd80a_checkpoint_epoch_9.pth"
checkpoint = torch.load(path, weights_only=False, map_location=torch.device('cpu'))
# print(list(checkpoint["model_state_dict"]))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

alpha_raw = model.struct_alpha.item()
print(alpha_raw)
print(torch.sigmoid(torch.tensor(alpha_raw)))

0.0
tensor(0.5000)


#### test-dataset

In [None]:
interaction_scores_pos = []
interaction_scores_neg = []    

for batch in tqdm(test_dataloader, total=round(len(Df_test)/10), desc="#Iterating through batched data"):
    seq_embedding_pep, seq_embedding_prot, str_embedding_pep, str_embedding_prot, lbls = batch
    
    seq_embedding_pep = seq_embedding_pep.to("cuda")
    seq_embedding_prot = seq_embedding_prot.to("cuda")
    str_embedding_pep = str_embedding_pep.to("cuda")
    str_embedding_prot = str_embedding_prot.to("cuda")

    with torch.no_grad():

        rows, cols = torch.triu_indices(seq_embedding_pep.size(0), seq_embedding_pep.size(0), offset=1)
        positive_logits = model(seq_embedding_pep, seq_embedding_prot, str_embedding_pep, str_embedding_prot)
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], str_embedding_pep[rows,:,:], str_embedding_prot[cols,:,:], int_prob=0.0)
        negative_logits = model(seq_embedding_pep[rows,:,:], seq_embedding_prot[cols,:,:], str_embedding_pep[rows,:,:], str_embedding_prot[cols,:,:], int_prob=0.0)
        
        # print(logit_matrix)
        interaction_scores_pos.append(positive_logits)
        interaction_scores_neg.append(negative_logits)

# Convert list of tensors to single 1D tensors
pos_logits = torch.cat(interaction_scores_pos).detach().cpu().numpy()
neg_logits = torch.cat(interaction_scores_neg).detach().cpu().numpy()
print("Positives:", pos_logits.shape)
print("Negatives:", neg_logits.shape)

In [None]:
plt.figure(figsize=(6, 4))
plt.hist(pos_logits, bins=50, alpha=0.6, label="Positive pairs", density=True)
plt.hist(neg_logits, bins=50, alpha=0.6, label="Negative pairs", density=True)

plt.xlabel("Logit score")
plt.ylabel("Density")
plt.title("Distribution of Positive vs Negative Interaction Scores")
plt.legend()
plt.show()

#### non-dimers datset

In [None]:
interaction_scores = []
for batch in tqdm(non_dimers_dataloader, total=round(len(non_dimers_Df)/10), desc="#Iterating through batched data"):
    seq_embedding_pep, seq_embedding_prot, str_embedding_pep, str_embedding_prot, lbls = batch
    
    seq_embedding_pep = seq_embedding_pep.to("cuda")
    seq_embedding_prot = seq_embedding_prot.to("cuda")
    str_embedding_pep = str_embedding_pep.to("cuda")
    str_embedding_prot = str_embedding_prot.to("cuda")

    with torch.no_grad():

        rows, cols = torch.triu_indices(seq_embedding_pep.size(0), seq_embedding_pep.size(0), offset=1)
        positive_logits = model(seq_embedding_pep, seq_embedding_prot, str_embedding_pep, str_embedding_prot)
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], str_embedding_pep[rows,:,:], str_embedding_prot[cols,:,:], int_prob=0.0)
        negative_logits = model(seq_embedding_pep[rows,:,:], seq_embedding_prot[cols,:,:], str_embedding_pep[rows,:,:], str_embedding_prot[cols,:,:], int_prob=0.0)
        
        # print(logit_matrix)
        interaction_scores_pos.append(positive_logits)
        interaction_scores_neg.append(negative_logits)

# Convert list of tensors to single 1D tensors
pos_logits = torch.cat(interaction_scores_pos).detach().cpu().numpy()
neg_logits = torch.cat(interaction_scores_neg).detach().cpu().numpy()
print("Positives:", pos_logits.shape)
print("Negatives:", neg_logits.shape)

In [None]:
plt.figure(figsize=(6, 4))
plt.hist(pos_logits, bins=50, alpha=0.6, label="Positive pairs", density=True)
plt.hist(neg_logits, bins=50, alpha=0.6, label="Negative pairs", density=True)

plt.xlabel("Logit score")
plt.ylabel("Density")
plt.title("Distribution of Positive vs Negative Interaction Scores")
plt.legend()
plt.show()

#### meta-analysis dataset

In [None]:
# Loading batches
interaction_scores = []

for batch in tqdm(validation_dataloader, total = round(len(interaction_df_shuffled)/10),  desc= "#Iterating through batched data"):
    seq_embedding_pep, seq_embedding_prot, str_embedding_pep, str_embedding_prot, lbls = batch
    
    seq_embedding_pep = seq_embedding_pep.to("cuda")
    seq_embedding_prot = seq_embedding_prot.to("cuda")
    str_embedding_pep = str_embedding_pep.to("cuda")
    str_embedding_prot = str_embedding_prot.to("cuda")

    with torch.no_grad():
        positive_logits = model(seq_embedding_pep, seq_embedding_prot, str_embedding_pep, str_embedding_prot)
        interaction_scores.append(positive_logits.unsqueeze(0))

predicted_interaction_scores = np.concatenate([batch_score.cpu().detach().numpy().reshape(-1,) for batch_score in interaction_scores])
interaction_probabilities = np.concatenate([torch.sigmoid(batch_score[0]).cpu().numpy() for batch_score in interaction_scores])

pos_logits, neg_logits = [], []
for i, row in interaction_df_shuffled.iterrows():
    logit = predicted_interaction_scores[i]
    if row.binder == False:
        neg_logits.append(logit)
    elif row.binder == True:
        pos_logits.append(logit)

In [None]:
plt.figure(figsize=(6, 4))
plt.hist(pos_logits, bins=50, alpha=0.6, label="Positive pairs", density=True)
plt.hist(neg_logits, bins=50, alpha=0.6, label="Negative pairs", density=True)

plt.xlabel("Logit score")
plt.ylabel("Density")
plt.title("Distribution of Positive vs Negative Interaction Scores")
plt.legend()
plt.show()