In [112]:
import sys, os
print("Kernel Python:", sys.executable)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

import torch
import torchvision.models as models

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
from torch.utils.data import Dataset
from torchvision import datasets

print("Current location:", os.getcwd())

import pytorch_lightning as pl
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F

from training_utils.model_architectures import MiniCLIP, MiniCLIP_cross_attn, MiniCLIP_w_transformer, MiniCLIP_w_transformer_crossattn

Kernel Python: /zhome/c9/0/203261/miniconda3/envs/esm_gpu/bin/python
Using cpu device
Current location: /zhome/c9/0/203261/DBLXXX_osaul/DBLXXX_osaul/tmp/ona_drafts


In [113]:
def gaussian_kernel(x, sigma):
    return np.exp(-x**2 / (2 * sigma**2))

def transform_vector(vector, sigma):

    interacting_indices = np.where(vector == 1)[0]   # positions where vector == 1
    transformed_vector = np.zeros_like(vector, dtype=float)
    
    for i in range(len(vector)):
        if vector[i] == 0:
            distances = np.abs(interacting_indices - i)   # distance to all "1"s
            min_distance = np.min(distances)              # closest "1"
            transformed_vector[i] = gaussian_kernel(min_distance, sigma)
        else:
            transformed_vector[i] = 1.0
    return transformed_vector

def safe_shuffle(n, device):
    shuffled = torch.randperm(n, device=device)
    while torch.any(shuffled == torch.arange(n, device=device)):
        shuffled = torch.randperm(n, device=device)
    return shuffled

def create_key_padding_mask(embeddings, padding_value=0, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]):
        non_masked_embeddings = embeddings[i][~padding_mask[i]]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0)
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):
    def __init__(self, padding_value = -5000, embed_dimension=1152,num_recycles=1):
        super().__init__()
        self.num_recycles = num_recycles
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension
        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        self.transformerencoder =  nn.TransformerEncoderLayer(d_model=self.embed_dimension, 
                                                              nhead=8,
                                                             dropout=0.1,
                                                             batch_first=True,
                                                            dim_feedforward=self.embed_dimension,
                                                             )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens
 
        # print()
        # print("forward")
        # print("pep_input.shape: ",pep_input.shape)
        # print("prot_input.shape: ",prot_input.shape)

        pep_mask = create_key_padding_mask(embeddings=pep_input,padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input,padding_value=self.padding_value)
        # print("pep_mask.shape: ", pep_mask.shape)
        # print("prot_mask.shape: ", prot_mask.shape)

 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):
            # Transformer encoding with residual
            pep_trans = self.transformerencoder(
                self.norm(pep_emb),  # Normalize before input
                src_key_padding_mask=pep_mask
            )
            prot_trans = self.transformerencoder(
                self.norm(prot_emb),
                src_key_padding_mask=prot_mask
            )
            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(
                query=self.norm(pep_trans),
                key=self.norm(prot_trans),
                value=self.norm(prot_trans),
                key_padding_mask=prot_mask
            )
            prot_cross, _ = self.cross_attn(
                query=self.norm(prot_trans),
                key=self.norm(pep_trans),
                value=self.norm(pep_trans),
                key_padding_mask=pep_mask
            )
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
            # del pep_seq_coding, prot_seq_coding, pep_self_embedding, prot_self_embedding, pep_self_attn, prot_self_attn, pep_mask, prot_mask
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1) # Dot-Product for comparison
        
        return logits
    

    def training_step(self, batch, device):
        # Passing the sequences to the models
 
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        # interaction_pep = batch[2]
        # interaction_prot = batch[3]
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        
        # interaction_pep = interaction_pep.to(device)
        # interaction_prot = interaction_prot.to(device)

        positive_logits = self(
            embedding_pep,
            embedding_prot,
            # interaction_pep,
            # interaction_prot,
            # int_prob = 0.0
        )
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        # interaction_pep = batch[2]
        # interaction_prot = batch[3]
        # print()
        # print("validation_step")
        # print("embedding_pep.shape", embedding_pep.shape)
        # print("embedding_prot.shape", embedding_prot.shape)
        # print("interaction_pep.shape", interaction_pep.shape)
        # print("interaction_prot.shape", interaction_prot.shape)
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        
        # interaction_pep = interaction_pep.to(device)
        # interaction_prot = interaction_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(
                    embedding_pep,
                    embedding_prot,
                    # interaction_pep,
                    # interaction_prot,
                    # int_prob = 0.0
                    )
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1) 
            

            negative_logits = self(embedding_pep[rows,:,:], 
                              embedding_prot[cols,:,:], 
                              int_prob=0.0)
                   
            
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
            # partner_topk_accuracy = torch.any((logits.topk(k, dim=1).indices - labels.reshape(-1, 1)) == 0, dim=1).sum() / logits.shape[0]
    
            # del peptide_predictions, partner_predictions, logits, peptide_ranks, peptide_mrr, partner_ranks,partner_mrr, embedding_pep, embedding_prot
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        positive_logits = self(
            embedding_pep,
            embedding_prot)
        
        # Negaive indexes
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1) 
        
        negative_logits = self(embedding_pep[rows,:,:], 
                      embedding_prot[cols,:,:], 
                      int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        # Fill diagonal with positive scores
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        return logit_matrix

In [114]:
path = '../PPI_PLM/models/CLIP_no_structural_information/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_4/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_epoch_4.pth'
checkpoint = torch.load(path, weights_only=False, map_location=torch.device('cpu'))
# print(list(checkpoint["model_state_dict"]))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()  # or model.train()

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=1152, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1152, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1152, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

In [122]:
interaction_df = pd.read_csv("../data/meta_analysis/interaction_df_metaanal.csv", index_col = 0).drop(columns = ["binder_id", "target_id"]).rename(columns={
    "A_seq" : "binder_seq",
    "B_seq" : "target_seq"
})
interaction_df.target_id_mod.unique()

array(['VirB8', 'FGFR2', 'IL7Ra', 'InsulinR', 'EGFR', 'SARS_CoV2_RBD',
       'Pdl1', 'EGFR_2', 'TrkA', 'IL10Ra', 'LTK', 'Mdm2', 'EGFR_3',
       'sntx', 'sntx_2', 'IL2Ra'], dtype=object)

In [123]:
# Targets df
target_df = interaction_df[["target_id_mod","target_seq"]].rename(columns={"target_seq":"sequence", "target_id_mod" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
# target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["target_binder_ID","binder_seq"]].rename(columns={"binder_seq":"sequence", "target_binder_ID" : "ID"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
# binder_df = binder_df.set_index("ID")

target_df

Unnamed: 0,ID,sequence,seq_len
0,VirB8,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,138
1,FGFR2,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,101
18,IL7Ra,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,193
25,InsulinR,EVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSF...,150
40,EGFR,RKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGD...,191
42,SARS_CoV2_RBD,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,195
81,Pdl1,NAFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDK...,115
87,EGFR_2,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,621
88,TrkA,VSFPASVQLHTAVEMHHWCIPFSVDGQPAPSLRWLFNGSVLNETSF...,101
110,IL10Ra,GTELPSPPSVWFEAEFFHHILHWTPIPQQSESTCYEVALLRYGIES...,207


In [124]:
# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))
interaction_Dict

{1: ('VirB8', 'VirB8_1'),
 2: ('FGFR2', 'FGFR2_1'),
 3: ('FGFR2', 'FGFR2_2'),
 4: ('FGFR2', 'FGFR2_3'),
 5: ('FGFR2', 'FGFR2_4'),
 6: ('FGFR2', 'FGFR2_5'),
 7: ('FGFR2', 'FGFR2_6'),
 8: ('FGFR2', 'FGFR2_7'),
 9: ('FGFR2', 'FGFR2_8'),
 10: ('FGFR2', 'FGFR2_9'),
 11: ('FGFR2', 'FGFR2_10'),
 12: ('FGFR2', 'FGFR2_11'),
 13: ('FGFR2', 'FGFR2_12'),
 14: ('FGFR2', 'FGFR2_13'),
 15: ('FGFR2', 'FGFR2_14'),
 16: ('FGFR2', 'FGFR2_15'),
 17: ('FGFR2', 'FGFR2_16'),
 18: ('FGFR2', 'FGFR2_17'),
 19: ('IL7Ra', 'IL7Ra_1'),
 20: ('IL7Ra', 'IL7Ra_2'),
 21: ('IL7Ra', 'IL7Ra_3'),
 22: ('IL7Ra', 'IL7Ra_4'),
 23: ('IL7Ra', 'IL7Ra_5'),
 24: ('IL7Ra', 'IL7Ra_6'),
 25: ('IL7Ra', 'IL7Ra_7'),
 26: ('InsulinR', 'InsulinR_1'),
 27: ('InsulinR', 'InsulinR_2'),
 28: ('InsulinR', 'InsulinR_3'),
 29: ('InsulinR', 'InsulinR_4'),
 30: ('InsulinR', 'InsulinR_5'),
 31: ('InsulinR', 'InsulinR_6'),
 32: ('InsulinR', 'InsulinR_7'),
 33: ('FGFR2', 'FGFR2_18'),
 34: ('FGFR2', 'FGFR2_19'),
 35: ('FGFR2', 'FGFR2_20'),
 36: ('FGFR

In [131]:
class CLIP_dataset(Dataset):
    def __init__(self, sequence_df, esm_encoding_paths, embedding_dim=1152, padding_value=-5000):
        super(CLIP_dataset, self).__init__()

        self.sigma = 1
        self.sequence_df = sequence_df
        self.max_length = sequence_df["seq_len"].max()
        # self.sequence_df["index_num"] = np.arange(len(self.sequence_df))
        self.esm_encoding_paths = esm_encoding_paths
        num_samples = len(self.sequence_df)
        
        self.x = torch.full((num_samples, self.max_length, embedding_dim), padding_value, dtype=torch.float32)

        # Load embeddings into the pre-allocated tensor
        iterator = tqdm(self.sequence_df["ID"].tolist(), position=0, leave=True, total=num_samples, desc="# Reading in ESM-embeddings from folder")
        
        for i, accession in enumerate(iterator):
            
            try:
                embd = np.load(os.path.join(esm_encoding_paths, accession + ".npy"))[0]

                length_to_pad = self.max_length - len(embd)
                
                if length_to_pad > 0:
                    zero_padding = np.ones((length_to_pad, embd.shape[1])) * padding_value
                    padded_array = np.concatenate((embd, zero_padding), axis=0)
                else:
                    padded_array = embd[:self.max_length] 

                self.x[i] = torch.tensor(padded_array, dtype=torch.float32)
            
            except FileNotFoundError as e:
                raise FileNotFoundError(f"Embedding file {accession}.npy not found.")
                
    def __len__(self):
        return int(self.x.shape[0]/2)

    def __getitem__(self, index):
        index1 = self.sequence_df.loc[index,"index_num"]
        return self.x[index1]

In [132]:
target_dataset = CLIP_dataset(target_df, esm_encoding_paths="../data/meta_analysis/targets_embeddings", embedding_dim=1152)

# Reading in ESM-embeddings from folder:  44%|████▍     | 7/16 [00:00<00:00, 159.02it/s]


FileNotFoundError: Embedding file EGFR_2.npy not found.

In [None]:
mb_dataset = CLIP_dataset(binder_df, esm_encoding_paths="../data/meta_analysis/binders_embeddings", embedding_dim=1152)



In [None]:
esm_encoding_paths = "../data/meta_analysis/targets_embeddings"
ola = np.load(os.path.join(esm_encoding_paths, "VirB8" + ".npy"))
ola.shape
# arr = np.load(os.path.join(esm_encoding_paths, accession + ".npy"))[0]

(1, 140, 1152)

In [None]:
something = torch.full((3, 3), 0, dtype=torch.float32)
something

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [None]:
something = torch.zeros(3, 3, 3)
something

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])

In [None]:
# class MetaDataset(Dataset):
    
#     def __init__(self, interaction_df, esm_encoding_paths, embedding_dim = 1152, max_len=None,padding_value=-5000, shuffle_embeddings=False,v=True):


#         super(MetaDataset, self).__init__()

#         self.sigma = 1 # Defines Gaussian smoothing strength for interaction masks.
#         self.interaction_df = interaction_df
#         self.padding_value = padding_value
#         self.v = v
        
#         max_target = self.interaction_df.target_seq.max()
#         max_binder = self.interaction_df.binder_seq.max()
#         self.max_length = max_target if max_target > max_binder else max_binder
        
#         self.esm_encoding_paths = esm_encoding_paths
        
#         num_samples = len(self.interaction_df)

#         if self.v:
#             print("## Pre-allocate tensors (targets, binders)")

#         # Two big tensors: one for targets, one for binders
#         self.x_target = torch.full((num_samples, self.max_length, embedding_dim), self.padding_value, dtype=torch.float32)
#         self.x_binder = torch.full((num_samples, self.max_length, embedding_dim), self.padding_value, dtype=torch.float32)

#         # Masks
#         self.mask_target = torch.zeros((num_samples, self.max_length), dtype=torch.float32)
#         self.mask_binder = torch.zeros((num_samples, self.max_length), dtype=torch.float32)

#         # Load embeddings for targets
#         for i in tqdm(range(num_samples), total=num_samples, desc="# Reading target embeddings", leave=True, position=0):

#             target = self.interaction_df.loc[i, "target_id_mod"]
#             embed = np.load(os.path.join(esm_encoding_paths, "targets_embeddings", f"{target}.npy"))

#             # Compute how much padding is needed for this chain.
#             length_to_pad = self.max_length - embed[1]

#             if length_to_pad > 0:
#                 zero_padding = np.ones((length_to_pad, embed.shape[1])) * self.padding_value
#                 padded_array = np.concatenate((embed, zero_padding), axis=0)
#             else:
#                 padded_array = embed

#             self.x_target[i] = torch.tensor(padded_array, dtype=torch.float32)

#         # Load embeddings into the pre-allocated tensor
#         iterator = tqdm(self.interaction_df["PDB_chain_name"],
#                         position=0, 
#                         leave=True,
#                         total=num_samples, 
#                         desc="# Reading in ESM-embeddings from folder")
        
#         # Loop over each chain. Load its .npy embedding file.
#         for i, chain_seq_name in enumerate(iterator):
#             arr = np.load(os.path.join(esm_encoding_paths, chain_seq_name + ".npy"))

#             # Some .npy may have shape (1, L, D). If so, remove the extra dimension → (L, D).
#             if arr.shape[0] == 1:
#                 arr = arr[0]
                
#             # Compute how much padding is needed for this chain.
#             length_to_pad = self.max_length - len(arr)

#             if length_to_pad > 0:
#                 zero_padding = np.ones((length_to_pad, arr.shape[1])) * self.padding_value
#                 padded_array = np.concatenate((arr, zero_padding), axis=0)
#             else:
#                 padded_array = arr[:self.max_length]  # Truncate if longer

#             # Fill the pre-allocated tensor
#             self.x[i] = torch.tensor(padded_array, dtype=torch.float32)
#             # faster : torch.from_numpy(padded_array).to(dtype=torch.float32)

#         if self.v:
#             print("All embeddings loaded into memory as a tensor.")

#         # Create a per-chain binary mask with 1s at interacting residue indices.
#         self.interaction_mask = torch.full((num_samples, self.max_length), 0, dtype=torch.float32)

#          # Load embeddings into the pre-allocated tensor
#         iterator = tqdm(self.interaction_df["interface_residues"],
#                         position=0, 
#                         leave=True,
#                         total=num_samples, 
#                         desc="# Setting up masking interaction mask matrix")

#         for i, interaction_indexes in enumerate(iterator):
#             interaction_mask_entry = np.zeros((1, self.max_length))
#             interaction_mask_entry[:,interaction_indexes] = 1.
#             interaction_mask_entry = interaction_mask_entry.reshape(-1,)
#             # Transforming with gaussian kernel density estiamtor
#             transformed_vector = transform_vector(interaction_mask_entry, self.sigma)
#             transformed_vector = torch.tensor(transformed_vector,dtype=torch.float32)
#             self.interaction_mask[i] = transformed_vector
    
#     # Number of interfaces (not chains)
#     def __len__(self):
#         return len(self.interaction_df.drop_duplicates(subset=["PDB","interface_index"]).index)

#     # Return the pair of chain embeddings and masks for one interface
#     def __getitem__(self, index):
#         pdb_interface = self.obs_index_to_interface_name[index]
#         index1,index2 = self.interaction_df.loc[pdb_interface,"index_num"].values
#         return self.x[index1], self.x[index2], self.interaction_mask[index1], self.interaction_mask[index2]
    
#     # def __getitem__(self, index):
#     #     index1, index2 = self.obs_index_to_indexnum_pair[index]
#     #     return (self.x[index1], self.x[index2], self.interaction_mask[index1], self.interaction_mask[index2])

#     # def _get_observation_weights(self):
#     #     return self.interaction_df.drop_duplicates(subset=["PDB","interface_index"]).loc[:,"observation_weight"].tolist()

#     def _get_item_from_PDB_index(self, PDB_indexes):
#         # Get index pairs for all interactions in the batch
#         index_pairs = self.interaction_df.loc[PDB_indexes, "index_num"].values.reshape(-1, 2)
#         index1 = index_pairs[:, 0]  # First index of each pair (shape: [batch_size])
#         index2 = index_pairs[:, 1]  # Second index of each pair (shape: [batch_size])
        
#         return (
#             self.x[index1],          # Shape: [batch_size, seq_len, embed_dim]
#             self.x[index2],          # Shape: [batch_size, seq_len, embed_dim]
#             self.interaction_mask[index1],  # Shape: [batch_size, seq_len]
#             self.interaction_mask[index2]   # Shape: [batch_size, seq_len]
#         )