In [83]:
import sys, os
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW
torch.manual_seed(0)
# from accelerate import Accelerator

import matplotlib.pyplot as plt
import seaborn as sns

from Levenshtein import distance as Ldistance

import training_utils.dataset_utils as data_utils
import training_utils.partitioning_utils as pat_utils

memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1152 # 1280 | 960 | 1152
number_of_recycles = 2
padding_value = -5000

batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

In [84]:
def gaussian_kernel(x, sigma):
    return np.exp(-x**2 / (2 * sigma**2))

def transform_vector(vector, sigma):

    interacting_indices = np.where(vector == 1)[0]   # positions where vector == 1
    transformed_vector = np.zeros_like(vector, dtype=float)
    
    for i in range(len(vector)):
        if vector[i] == 0:
            distances = np.abs(interacting_indices - i)   # distance to all "1"s
            min_distance = np.min(distances)              # closest "1"
            transformed_vector[i] = gaussian_kernel(min_distance, sigma)
        else:
            transformed_vector[i] = 1.0
    return transformed_vector

def safe_shuffle(n, device):
    shuffled = torch.randperm(n, device=device)
    while torch.any(shuffled == torch.arange(n, device=device)):
        shuffled = torch.randperm(n, device=device)
    return shuffled

def create_key_padding_mask(embeddings, padding_value=0, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=1152, num_recycles=1):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input,padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input,padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()

 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1) # Dot-Product for comparison
        
        return logits
    

    def training_step(self, batch, device):
        # Passing the sequences to the models
 
        embedding_pep = batch[0]
        embedding_prot = batch[1]
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)

        positive_logits = self(embedding_pep, embedding_prot)
        
        # rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        # negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)

        loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # # loss of predicting peptide using partner
        # negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        torch.cuda.empty_cache()
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)

        with torch.no_grad():

            positive_logits = self(
                    embedding_pep,
                    embedding_prot,
                    # interaction_pep,
                    # interaction_prot,
                    # int_prob = 0.0
                    )
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1) 
            

            negative_logits = self(embedding_pep[rows,:,:], 
                              embedding_prot[cols,:,:], 
                              int_prob=0.0)
                   
            
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
            # partner_topk_accuracy = torch.any((logits.topk(k, dim=1).indices - labels.reshape(-1, 1)) == 0, dim=1).sum() / logits.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        positive_logits = self(
            embedding_pep,
            embedding_prot)
        
        # Negaive indexes
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1) 
        
        negative_logits = self(embedding_pep[rows,:,:], 
                      embedding_prot[cols,:,:], 
                      int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        # Fill diagonal with positive scores
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        return logit_matrix

In [85]:
path = '../PPI_PLM/models/CLIP_no_structural_information/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_4/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_epoch_4.pth'
checkpoint = torch.load(path, weights_only=False, map_location=torch.device('cpu'))
# print(list(checkpoint["model_state_dict"]))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.train()

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=1152, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1152, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1152, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

In [86]:
## Output path
trained_model_dir = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts"

## Embeddings paths
binders_embeddings = "../data/meta_analysis/binders_embeddings"
targets_embeddings = "../data/meta_analysis/targets_embeddings"

# ## Training variables
# runID = uuid.uuid4()

# def print_mem_consumption():
#     # 1. Total memory available on the GPU (device 0)
#     t = torch.cuda.get_device_properties(0).total_memory
#     # 2. How much memory PyTorch has *reserved* from CUDA
#     r = torch.cuda.memory_reserved(0)
#     # 3. How much of that reserved memory is actually *used* by tensors
#     a = torch.cuda.memory_allocated(0)
#     # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
#     f = r - a

#     print("Total memory: ", t/1e9)      # total VRAM in GB
#     print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
#     print("Allocated memory: ", a//1e9) # actually in use (integer division)
#     print("Free memory: ", f/1e9)       # slack in the reserved pool in GB

In [87]:
### Loading the dataset
interaction_df = pd.read_csv("../data/meta_analysis/interaction_df_metaanal.csv", index_col = 0).drop(columns = ["binder_id", "target_id"]).rename(columns={
    "A_seq" : "binder_seq",
    "B_seq" : "target_seq"
})
all_targets = interaction_df.target_id_mod.unique()
binder_nonbinder = interaction_df.binder.value_counts()
target_binder_nonbinder_Dict = dict(interaction_df.groupby("target_id_mod")["binder"].value_counts())
sorted_items = sorted(target_binder_nonbinder_Dict.items(), key=lambda kv: kv[1], reverse=True)
sorted_items[:10]

[(('FGFR2', False), 1930),
 (('EGFR_2', False), 286),
 (('FGFR2', True), 193),
 (('IL7Ra', False), 133),
 (('EGFR', False), 120),
 (('TrkA', False), 119),
 (('InsulinR', False), 97),
 (('SARS_CoV2_RBD', False), 90),
 (('VirB8', False), 90),
 (('Pdl1', False), 83)]

In [88]:
# %%
# Annotating each observation with a weight corresponding to whether it is considered a binder or not
N_bins = len(interaction_df["binder"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df["binder"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df["observation_weight"] = interaction_df.binder.apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])
interaction_df

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id_mod,target_binder_ID,observation_weight
0,A,"[""B""]",False,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,VirB8,VirB8_1,0.000159
1,A,"[""B""]",False,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1,0.000159
2,A,"[""B""]",False,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2,0.000159
3,A,"[""B""]",False,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_3,0.000159
4,A,"[""B""]",False,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_4,0.000159
...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_62,0.000159
3528,A,"[""B""]",False,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_63,0.000159
3529,A,"[""B""]",False,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_64,0.000159
3530,A,"[""B""]",False,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_65,0.000159


In [89]:
sum = 0
sum_all = 0
for i in target_binder_nonbinder_Dict.keys():
    sum_all += target_binder_nonbinder_Dict[i]
    if i[0] in ["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"]:
        sum += target_binder_nonbinder_Dict[i]
print(sum_all-sum, sum)
print(sum / (sum_all-sum))

# ["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"] validation set

3029 503
0.16606140640475403


["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"] validation set

In [90]:
# Targets df
target_df = interaction_df[["target_id_mod","target_seq"]].rename(columns={"target_seq":"sequence", "target_id_mod" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["target_binder_ID","binder_seq", "binder"]].rename(columns={"binder_seq":"sequence", "target_binder_ID" : "ID", "binder" : "label"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")

# target_df

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))
# interaction_Dict

In [91]:
target_df_val = target_df.loc[["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"]]
target_df_train = target_df.loc[target_df.index.difference(["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"])]
# print(len(target_df_val), len(target_df_train))

idx = binder_df.index.astype(str)
mask = pd.Series(idx).str.startswith(tuple(["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"])).to_numpy()

binder_df_val   = binder_df[mask]    # IDs whose index starts with any of the names
binder_df_train = binder_df[~mask]   # everything else
print(len(binder_df_train), len(binder_df_val))

3029 503


In [92]:
binder_df_train[binder_df_train.index.str.startswith("IL7Ra")]

Unnamed: 0_level_0,sequence,label,seq_len
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1


In [93]:
binder_df_val[binder_df_val.index.str.startswith("IL7Ra")]

Unnamed: 0_level_0,sequence,label,seq_len
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
IL7Ra_1,SVQLEALEILIKELKKELEKAKKELEKASPEEKKKLEEKAKKLEEL...,False,84
IL7Ra_2,EEKDKYIEEAQYVAVEALEYIKDGTAEEGEKAKEEAEKKIRELLTK...,False,74
IL7Ra_3,DEEEKKKLAEEAIEAVEKGDLEKAKELLKKLAEAAKTEEEAEKWLS...,True,78
IL7Ra_4,EEKEKAKEIIDRAVKEAKKEAEKEDEETKKKTLEIIERAELVVKAD...,False,85
IL7Ra_5,VEEDLKKALKALKEGNKIEAAEHLLAARVEALLKGDEETAEKVEEA...,False,65
...,...,...,...
IL7Ra_167,GITIFINADDPTVAELAKSANPKHAHFHPAGAVWIELDDPTASKIV...,False,62
IL7Ra_168,SREDRAARIVLEALRQMIKNVEDPKDARLIYLKAEQAKKIVDDPTV...,False,63
IL7Ra_169,DESQKETLTKLIKLAVKAIMNNDPDTAKKVVDKLRKVASEANDHMA...,False,61
IL7Ra_170,DKEELKKKIHKLAQIVARHHREDDSTVNDVAVIVLKLLRQDTEEAL...,False,62


In [94]:
targets_dataset_val = data_utils.CLIP_meta_analysis_dataset(target_df_val, esm_encoding_paths="../data/meta_analysis/targets_embeddings", embedding_dim=1152)
binders_dataset_val = data_utils.CLIP_meta_analysis_dataset(binder_df_val, esm_encoding_paths="../data/meta_analysis/binders_embeddings", embedding_dim=1152)

targets_dataset_train = data_utils.CLIP_meta_analysis_dataset(target_df_train, esm_encoding_paths="../data/meta_analysis/targets_embeddings", embedding_dim=1152)
binders_dataset_train = data_utils.CLIP_meta_analysis_dataset(binder_df_train, esm_encoding_paths="../data/meta_analysis/binders_embeddings", embedding_dim=1152)

# %%
### Creating the CLIP datasets
# train_dataset = data_utils.CLIP_dataset(train_df,
#                        esm_encoding_paths=path_to_esm_embeddings,
#                       embedding_dim=embedding_dimension,
#                     padding_value=padding_value,
#                     max_len=max_sequence_length,
#                         # shuffle_embeddings=True
#                             )

# test_dataset = data_utils.CLIP_dataset(test_df,
#                        esm_encoding_paths=path_to_esm_embeddings,
#                       embedding_dim=embedding_dimension,
#                       padding_value=padding_value,
#                         max_len=max_sequence_length)
# test_dataset._get_item_from_PDB_index(unique_test_indexes_nondimer)[0].shape

class PairDataset(Dataset):
    def __init__(self, targets_dataset, binders_dataset, binder_df):
        self.targets = targets_dataset    # expects string-key access
        self.binders = binders_dataset    # expects string-key access
        self.binder_df = binder_df

    def __len__(self):
        # if you won't use DataLoader's default sampling, this can be any valid length;
        # keeping it tied to binders is reasonable if it implements __len__.
        return len(self.binders)

    def __getitem__(self, bname):
        # bname is already the binder name (string)
        parts = bname.split("_")
        if len(parts) >= 3:
            target_name = f"{parts[0]}_{parts[1]}"
        elif len(parts) >= 2:
            target_name = parts[0]
        else:
            target_name = parts[0]

        binder_emb = self.binders[bname]
        target_emb = self.targets[target_name]
        label = int(self.binder_df.label.loc[bname])

        return binder_emb, target_emb, label
    
training_dataset = PairDataset(targets_dataset_train, binders_dataset_train, binder_df_train)
validation_dataset = PairDataset(targets_dataset_val, binders_dataset_val, binder_df_val)

# Reading in ESM-embeddings from folder: 100%|██████████| 7/7 [00:00<00:00, 519.61it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.sequence_df["index_num"] = np.arange(len(self.sequence_df))
# Reading in ESM-embeddings from folder: 100%|██████████| 503/503 [00:00<00:00, 603.95it/s]
# Reading in ESM-embeddings from folder: 100%|██████████| 9/9 [00:00<00:00, 229.02it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.sequence_df["index_num"] = np.arange(len(self.sequence_df))


RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 3489408000 bytes. Error code 12 (Cannot allocate memory)

In [None]:
__, __, label = validation_dataset["IL7Ra_4"]
label

print(len(validation_dataset))
print(len(training_dataset))

503
3029


In [None]:
# def collate_pad(batch, pad_value=-5000.0):
#     """
#     Pads variable-length sequences in the batch to the max length (separately for binders and targets).
#     Expects items like (binder_emb [Lb,D], target_emb [Lt,D], bname, tname).
#     Returns:
#       binders: [B, Lb_max, D], targets: [B, Lt_max, D], names: list of (bname, tname)
#     """
#     binders, targets, bnames, tnames = [], [], [], []
#     for binder_emb, target_emb, bname, tname in batch:
#         binders.append(binder_emb)  # [L,D]
#         targets.append(target_emb)  # [L,D]
#         bnames.append(bname)
#         tnames.append(tname)

#     # pad along sequence length (dim=0), keep embed dim
#     binders = pad_sequence(binders, batch_first=True, padding_value=pad_value)  # [B, Lb_max, D]
#     targets = pad_sequence(targets, batch_first=True, padding_value=pad_value)  # [B, Lt_max, D]
#     return binders, targets, bnames, tnames

# # ---- Build datasets and loaders ----
# train_dataset = PairDataset(targets_dataset_train, binders_dataset_train, list_binders_train)
# val_dataset   = PairDataset(targets_dataset_val,   binders_dataset_val,   list_binders_val)

train_loader = DataLoader(training_dataset, batch_size=32)
val_loader = DataLoader(validation_dataset, batch_size=32)

train_loader


<torch.utils.data.dataloader.DataLoader at 0x7f875a7cf2b0>

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id_mod,target_binder_ID,observation_weight
0,A,"[""B""]",False,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,VirB8,VirB8_1,0.000159
1,A,"[""B""]",False,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1,0.000159
2,A,"[""B""]",False,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2,0.000159
3,A,"[""B""]",False,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_3,0.000159
4,A,"[""B""]",False,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_4,0.000159
...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_62,0.000159
3528,A,"[""B""]",False,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_63,0.000159
3529,A,"[""B""]",False,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_64,0.000159
3530,A,"[""B""]",False,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_65,0.000159


In [None]:
# %%
# Variables for making the dataset
# max_sequence_length = interaction_df["seq. len"].max()
# train_df = disordered_interfaces_df.loc[random.sample(train_indexes,k=int(len(train_indexes)*train_frac)),:].copy() # Making a sample of the training dataset
# # train_df = disordered_interfaces_df.loc[train_indexes,:].copy() # Using the entire dataset
# train_df = train_df[train_df.dimer == False] # Removing Dimers from the dataset (TODO: remove)

# test_df = disordered_interfaces_df.loc[random.sample(test_indexes,k=int(len(test_indexes)*test_frac)),:].copy() # Making a sample of the testing dataset
# # test_df = disordered_interfaces_df.loc[test_indexes,:].copy()
# test_df = test_df[test_df.dimer == False]

# # test_nondimer_df = test_df[(test_df.dimer == False) & (test_df["mean disorder interface"] > 0.0)] # Here i can filter for different disorder scores in the interface
# unique_test_indexes_nondimer = test_df.index.unique().tolist()

In [None]:
interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))

In [None]:
random.seed(0)
# ### Creating the DataLoaders
train_weights = train_dataset._get_observation_weights()

In [None]:
# %%
random.seed(0)
# ### Creating the DataLoaders
train_weights = train_dataset._get_observation_weights()
# train_sampler = WeightedRandomSampler(train_weights, len(train_weights), replacement=True)
train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size, 
                              # sampler=train_sampler
                             )
val_weights = test_dataset._get_observation_weights()
# val_sampler = WeightedRandomSampler(val_weights, len(val_weights), replacement=True)
val_dataloader = DataLoader(test_dataset, 
                            batch_size=batch_size,
                            # sampler=val_sampler
                           )

# Exanmple of getting observation from the dataset based on indexes
# test_dimer_df = test_df[test_df.dimer == False]
# unique_test_indexes_nondimer = test_dimer_df.index.unique().tolist()
# input_x, input_y, int_x, int_y = test_dataset._get_item_from_PDB_index(unique_test_indexes_nondimer)

# %%
## Model Class
### MiniClip
def create_key_padding_mask(embeddings, padding_value=0, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]):
        non_masked_embeddings = embeddings[i][~padding_mask[i]]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0)
        seq_embeddings.append(mean_embedding)
    
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):
    def __init__(self, padding_value = -5000, embed_dimension=1280,num_recycles=1):
        super().__init__()
        self.num_recycles = num_recycles
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension
        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        self.transformerencoder =  nn.TransformerEncoderLayer(d_model=self.embed_dimension, 
                                                              nhead=8,
                                                             dropout=0.1,
                                                             batch_first=True,
                                                            dim_feedforward=self.embed_dimension,
                                                             )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True)

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        

 
    def forward(self, pep_input, prot_input, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens
        
        # PAdding masks
        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states: start from the raw embeddings
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            # Encode each sequence separately with the same Transformer encoder (weight sharing)
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)
            
            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            # Residual update (add the cross information back)
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        # Mean-pool only over non-masked tokens (i.e., real residues):
        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
            # del pep_seq_coding, prot_seq_coding, pep_self_embedding, prot_self_embedding, pep_self_attn, prot_self_attn, pep_mask, prot_mask
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1) # Dot-Product for comparison
        
        return logits
    

    def training_step(self, batch, device):
        # Passing the sequences to the models
 
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        # interaction_pep = batch[2]
        # interaction_prot = batch[3]
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        
        # interaction_pep = interaction_pep.to(device)
        # interaction_prot = interaction_prot.to(device)

        positive_logits = self(
            embedding_pep,
            embedding_prot,
            # interaction_pep,
            # interaction_prot,
            # int_prob = 0.0
        )
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        # interaction_pep = batch[2]
        # interaction_prot = batch[3]
        # print()
        # print("validation_step")
        # print("embedding_pep.shape", embedding_pep.shape)
        # print("embedding_prot.shape", embedding_prot.shape)
        # print("interaction_pep.shape", interaction_pep.shape)
        # print("interaction_prot.shape", interaction_prot.shape)
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        
        # interaction_pep = interaction_pep.to(device)
        # interaction_prot = interaction_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(
                    embedding_pep,
                    embedding_prot,
                    # interaction_pep,
                    # interaction_prot,
                    # int_prob = 0.0
                    )
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1) 
            

            negative_logits = self(embedding_pep[rows,:,:], 
                              embedding_prot[cols,:,:], 
                              int_prob=0.0)
                   
            
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
            # partner_topk_accuracy = torch.any((logits.topk(k, dim=1).indices - labels.reshape(-1, 1)) == 0, dim=1).sum() / logits.shape[0]
    
            # del peptide_predictions, partner_predictions, logits, peptide_ranks, peptide_mrr, partner_ranks,partner_mrr, embedding_pep, embedding_prot
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        positive_logits = self(
            embedding_pep,
            embedding_prot)
        
        # Negaive indexes
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1) 
        
        negative_logits = self(embedding_pep[rows,:,:], 
                      embedding_prot[cols,:,:], 
                      int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        # Fill diagonal with positive scores
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        return logit_matrix
    


# Importing the modules
from model_architectures import MiniCLIP, MiniCLIP_cross_attn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn(embed_dimension=embedding_dimension,
                                        num_recycles=number_of_recycles)
model = model.to(device)
## Testing information flow from batch
# batch = next(iter(train_dataloader))
# with torch.no_grad():
#     loss = model.training_step(batch,device)
#     loss, partner_accuracy, peptide_topk_accuracy = model.validation_step(batch,device)
    

# Printing model paramaeters and clearing cache
# print_mem_consumption()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
# torch.cuda.empty_cache()
# print_mem_consumption()

# %% [markdown]
# ## Training loop


# %%

if use_wandb:
    import wandb
    wandb.init(
        # set the wandb project where this run will be logged
        project="CLIP_binder",
        name=str(runID),
        # track hyperparameters and run metadata
        config={
        "learning_rate": learning_rate,
        "architecture": str(model.__repr__),
        "Batch size":batch_size,
        "dataset": "Small-sample dataset of 5 (5000)",
        "training-procedure":"new_binary_cross",
        }
    )
    # Watch the model with wandb to log gradients, weights, and biases
    wandb.watch(model, log='all')
else:
    print("WandB Tracking not used ")
    wandb = False

# %%
# set the optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)
# initialize accelerator for training
accelerator = Accelerator()
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader)


import importlib
import training_utils.train_utils as train_utils
importlib.reload(train_utils)

training_wrapper = train_utils.TrainWrapper(
                                            model=model, 
                                            training_loader=train_dataloader, 
                                            validation_loader=val_dataloader, 
                                            test_dataset=test_dataset,
                                            optimizer=optimizer, 
                                            EPOCHS=EPOCHS,
                                            runID=runID, 
                                            device=device, 
                                            test_indexes_for_auROC=unique_test_indexes_nondimer, 
                                            model_save_steps=model_save_steps,
                                            model_save_path=trained_model_dir, 
                                            v=True, 
                                            wandb_tracker=wandb
                                            )

training_wrapper.train_model()