In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()

Tesla V100-SXM2-32GB


In [2]:
SEED = 0
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.5.1
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [5]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1280 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

# batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

In [6]:
## Output path
trained_model_dir = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts"

## Embeddings paths
binders_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_embeddings_esm2"
targets_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/targets_embeddings_esm2"

## Contact maps paths
binders_contacts = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_contacts"
targets_contacts = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/targets_contacts"

# ## Training variables
runID = uuid.uuid4()

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


### Loading PPint dataframe

In [7]:
PPint_interaactions_df = pd.read_csv("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/PPint_interactions.csv")
PPint_interaactions_df["seq_target_len"] = [len(seq) for seq in PPint_interaactions_df["seq_target"].tolist()]
PPint_interaactions_df["seq_binder_len"] = [len(seq) for seq in PPint_interaactions_df["seq_binder"].tolist()]
# PPint_interaactions_df["index"] = [i for i in PPint_interaactions_df["target_binder_id"].tolist()]
PPint_interaactions_df

Unnamed: 0,seq_target,seq_binder,target_id,binder_id,target_binder_id,seq_target_len,seq_binder_len
0,SLTKTERTIIVSMWAKISTQADTIGTETLERLFLSHPQTKTYFPHF...,VHLTDAEKAAVSGLWGKVNADEVGGEALGRLLVVYPWTQRYFDSFG...,1JEB_2,1JEB_2,1JEB_2_1JEB_2,141,146
1,TTTLAFKFQHGVIAAVDSRASAGSYISALRVNKVIEINPYLLGTMS...,DRGVNTFSPEGRLFQVEYAIEAIKLGSTAIGIQTSEGVCLAVEKRI...,7B12_23,7B12_23,7B12_23_7B12_23,203,230
2,ITHLPPEVMLSIFSYLNPQELCRCSQVSMKWSQLTKTGSLWKHLYP...,PSIKLQSSDGEIFEVDVEIAKQSVTIKTMLEDLGDPVPLPNVNAAI...,6VCD_1,6VCD_1,6VCD_1_6VCD_1,255,135
3,NAKDVLGLTLLEKTLKERLNLKDAIIVSGDSDQSPWVKKEGRAAVA...,AKDVLGLTLLEKTLKERLNLKDAIIVSGDSDQSPWVKKEGRAAVAC...,2OKG_0,2OKG_0,2OKG_0_2OKG_0,241,243
4,DIVMSQSPSSLAVSVGEKVTMSCKSSQSLLYNNNQKNYLAWYQQKP...,VTLKESGPGILQPSQTLSLTCSFSGFSLSTYGMGVGWIRQPSGKGL...,3MBX_0,3MBX_0,3MBX_0_3MBX_0,220,229
...,...,...,...,...,...,...,...
2467,REIPLKVLVKAVLFACMLMRKTMASRVRVTILFATETGKSEALAWD...,QLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAEL...,3HR4_0,3HR4_0,3HR4_0_3HR4_0,189,145
2468,NKYRFIDVQPLTGVLGAEITGVDLREPLDDSTWNEILDAFHTYQVI...,KYRFIDVQPLTGVLGAEITGVDLREPLDDSTWNEILDAFHTYQVIY...,6D3M_0,6D3M_0,6D3M_0_6D3M_0,286,285
2469,TMKIAYLGPSGSFTHNVALHAFPAADLLPFENITEVIKAYESKQVC...,TMKIAYLGPSGSFTHNVALHAFPAADLLPFENITEVIKAYESKQVC...,4LUB_0,4LUB_0,4LUB_0_4LUB_0,188,190
2470,ERDEVGARKNAVDEEIERLSQPGDQRLNALAERFGGVLLSEIYDDV...,EPVTIVLSQGWVRSAKGHDIDAPGLNYKAGDSFKAAVKGKSNQPVV...,4MN4_2,4MN4_2,4MN4_2_4MN4_2,154,236


In [8]:
Df_val = PPint_interaactions_df.sample(n=round(len(PPint_interaactions_df) * 0.2), random_state=0)
Df_train = PPint_interaactions_df.drop(Df_val.index)
Df_val["dimer"] = Df_val["seq_target"] == Df_val["seq_binder"]
indices_non_dimers_val = Df_val[~Df_val["dimer"]].index.tolist()

In [9]:
class CLIP_PPint_analysis_dataset(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1280,
        padding_value=-5000.0,
        fixed_max_blen=None,
        fixed_max_tlen=None,
    ):
        super().__init__()

        self.dframe = dframe.copy()

        # either use provided global maxes, or fall back to per-split maxes
        if fixed_max_blen is None:
            self.max_blen = int(self.dframe["seq_binder_len"].max())
        else:
            self.max_blen = int(fixed_max_blen)

        if fixed_max_tlen is None:
            self.max_tlen = int(self.dframe["seq_target_len"].max())
        else:
            self.max_tlen = int(fixed_max_tlen)

        # unpack paths
        self.encoding_bpath = paths[0]  # binder embeddings
        self.encoding_tpath = paths[1]  # target embeddings
        self.contacts_bpath = paths[2]  # binder contact maps
        self.contacts_tpath = paths[3]  # target contact maps

        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        iterator = tqdm(
            self.accessions,
            position=0,
            total=len(self.accessions),
            desc="#Loading ESM2 embeddings and contacts"
        )

        for accession in iterator:
            parts = accession.split("_")
            if len(parts) < 4:
                raise ValueError(
                    f"Expected target_binder_id to have at least 4 underscore-separated parts, got {accession}"
                )

            target_id = parts[0] + "_" + parts[1]
            binder_id = parts[2] + "_" + parts[3]

            tname = f"t_{target_id}"
            bname = f"b_{binder_id}"

            # --- Embeddings ---
            tnpy_path = os.path.join(self.encoding_tpath, f"{tname}.npy")
            bnpy_path = os.path.join(self.encoding_bpath, f"{bname}.npy")

            if not os.path.exists(tnpy_path):
                raise FileNotFoundError(f"Missing target embedding file: {tnpy_path}")
            if not os.path.exists(bnpy_path):
                raise FileNotFoundError(f"Missing binder embedding file: {bnpy_path}")

            tembd = np.load(tnpy_path)  # [Lt, emb_dim]
            bembd = np.load(bnpy_path)  # [Lb, emb_dim]

            # pad/crop target embedding to [max_tlen, emb_dim]
            if tembd.shape[0] < self.max_tlen:
                t_pad_len = self.max_tlen - tembd.shape[0]
                t_pad = np.full(
                    (t_pad_len, tembd.shape[1]),
                    padding_value,
                    dtype=tembd.dtype
                )
                t_final = np.concatenate([tembd, t_pad], axis=0)
            else:
                t_final = tembd[: self.max_tlen]

            # pad/crop binder embedding to [max_blen, emb_dim]
            if bembd.shape[0] < self.max_blen:
                b_pad_len = self.max_blen - bembd.shape[0]
                b_pad = np.full(
                    (b_pad_len, bembd.shape[1]),
                    padding_value,
                    dtype=bembd.dtype
                )
                b_final = np.concatenate([bembd, b_pad], axis=0)
            else:
                b_final = bembd[: self.max_blen]

            # --- Contact maps ---
            tcont_path = os.path.join(self.contacts_tpath, f"{tname}.npy")
            bcont_path = os.path.join(self.contacts_bpath, f"{bname}.npy")

            if not os.path.exists(tcont_path):
                raise FileNotFoundError(f"Missing target contact map file: {tcont_path}")
            if not os.path.exists(bcont_path):
                raise FileNotFoundError(f"Missing binder contact map file: {bcont_path}")

            tcont = np.load(tcont_path)  # [Lt, Lt]
            bcont = np.load(bcont_path)  # [Lb, Lb]

            # pad/crop target contact to [max_tlen, max_tlen]
            if tcont.shape[0] < self.max_tlen:
                Lt = tcont.shape[0]
                tcont_padded = np.full(
                    (self.max_tlen, self.max_tlen),
                    padding_value,
                    dtype=tcont.dtype
                )
                tcont_padded[:Lt, :Lt] = tcont
                tcont_final = tcont_padded
            else:
                tcont_final = tcont[: self.max_tlen, : self.max_tlen]

            # pad/crop binder contact to [max_blen, max_blen]
            if bcont.shape[0] < self.max_blen:
                Lb = bcont.shape[0]
                bcont_padded = np.full(
                    (self.max_blen, self.max_blen),
                    padding_value,
                    dtype=bcont.dtype
                )
                bcont_padded[:Lb, :Lb] = bcont
                bcont_final = bcont_padded
            else:
                bcont_final = bcont[: self.max_blen, : self.max_blen]

            # store: binder_emb, target_emb, binder_contact, target_contact
            self.samples.append((b_final, t_final, bcont_final, tcont_final))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, bcont_arr, tcont_arr = self.samples[idx]

        b_tensor = torch.tensor(b_arr, dtype=torch.float32)
        t_tensor = torch.tensor(t_arr, dtype=torch.float32)
        bcont_tensor = torch.tensor(bcont_arr, dtype=torch.float32)
        tcont_tensor = torch.tensor(tcont_arr, dtype=torch.float32)

        return b_tensor, t_tensor, bcont_tensor, tcont_tensor

    def _get_by_name(self, name):
        if isinstance(name, str):
            idx = self.name_to_row[name]
            return self.__getitem__(idx)

        binder_list      = []
        target_list      = []
        binder_contact_l = []
        target_contact_l = []

        for n in name:
            idx = self.name_to_row[n]
            b_tensor, t_tensor, bcont_tensor, tcont_tensor = self.__getitem__(idx)
            binder_list.append(b_tensor)
            target_list.append(t_tensor)
            binder_contact_l.append(bcont_tensor)
            target_contact_l.append(tcont_tensor)

        binder_batch      = torch.stack(binder_list, dim=0)
        target_batch      = torch.stack(target_list, dim=0)
        binder_contact_b  = torch.stack(binder_contact_l, dim=0)
        target_contact_b  = torch.stack(target_contact_l, dim=0)

        return binder_batch, target_batch, binder_contact_b, target_contact_b

In [11]:
bemb_path = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_embeddings_esm2"
temb_path = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/targets_embeddings_esm2"

## Contact maps paths
bcont_path = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_contacts"
tcont_path = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/targets_contacts"

global_max_blen = int(PPint_interaactions_df["seq_binder_len"].max())
global_max_tlen = int(PPint_interaactions_df["seq_target_len"].max())

training_Dataset = CLIP_PPint_analysis_dataset(
    Df_train,
    paths=[bemb_path, temb_path, bcont_path, tcont_path],
    embedding_dim=1280,
    fixed_max_blen=global_max_blen,
    fixed_max_tlen=global_max_tlen,
)

validation_Dataset = CLIP_PPint_analysis_dataset(
    Df_val,
    paths=[bemb_path, temb_path, bcont_path, tcont_path],
    embedding_dim=1280,
    fixed_max_blen=global_max_blen,
    fixed_max_tlen=global_max_tlen,
)

#Loading ESM2 embeddings and contacts: 100%|████████████████████████████████████████████████████████████| 1978/1978 [02:36<00:00, 12.66it/s]
#Loading ESM2 embeddings and contacts: 100%|██████████████████████████████████████████████████████████████| 494/494 [00:43<00:00, 11.23it/s]


Here the global max lengths for binder and target are taken from the whole dataset and are used for padding contact maps, cause later contact maps are parsed though `nn.Linear(input_dim=max_len, output_dim= embed_dimension_esm2=1280)`.

Instead of `embed_dimension_esm2 = 1280`, different dimmension coud be used, but in essence all the contact maps eventually should have the same size `dim=1` otherwise, the same transformer cannot be used for all the maps.

In [12]:
accessions = [Df_val.loc[index].target_binder_id for index in indices_non_dimers_val]
validation_Dataset._get_by_name(accessions[16])[0].shape

torch.Size([461, 1280])

In [13]:
def create_key_padding_mask(x, padding_value=-5000, offset=10):
    """
    x can be:
      - embeddings: [B, L, D]
      - contact maps: [B, L, L]

    Returns:
      mask: [B, L] with True at padded positions.
    """
    return (x < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask, input_type="encodings"):
    """
    embeddings:
      - if input_type == "encodings": [B, L, D]
      - if input_type == "contact map": [B, L, L]

    padding_mask: [B, L] boolean
        True  = padded position
        False = real residue

    returns:
      - if encodings: [B, D]
      - if contact map: [B, 1]  (scalar per sequence for now)
    """
    seq_embeddings = []

    for i in range(embeddings.shape[0]):  # loop over batch
        valid_pos = ~padding_mask[i]  # [L], True for real residues

        if input_type == "encodings":
            # embeddings[i]: [L, D]
            non_masked_embeddings = embeddings[i][valid_pos]  # [L_real, D]

            if non_masked_embeddings.shape[0] == 0:
                print("You are masking all positions when creating sequence representation (encodings)")
                sys.exit(1)

            mean_embedding = non_masked_embeddings.mean(dim=0)  # [D]
            seq_embeddings.append(mean_embedding)

        elif input_type == "contact map":
            # embeddings[i]: [L, L] contact map
            contact_map = embeddings[i]  # [L, L]

            # build pairwise valid mask for this protein
            # valid pairs are (p,q) where both residues are real
            pair_mask = valid_pos[:, None] & valid_pos[None, :]  # [L, L]

            # we only want upper triangle to avoid double-counting
            # get indices for i < j
            L = contact_map.shape[0]
            rows, cols = torch.triu_indices(L, L, offset=1)

            # select only pairs that are valid
            valid_pair_mask_flat = pair_mask[rows, cols]  # [num_pairs]
            if valid_pair_mask_flat.sum() == 0:
                print("All pairs masked when summarizing contact map")
                sys.exit(1)

            contact_vals = contact_map[rows, cols][valid_pair_mask_flat]  # [num_valid_pairs]

            # average contact strength among real residues
            mean_contact = contact_vals.mean()  # scalar
            seq_embeddings.append(mean_contact.unsqueeze(0))  # make it [1] so we can stack later

        else:
            raise ValueError(f"Unknown input_type {input_type}")

    return torch.stack(seq_embeddings)

In [None]:
class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(
        self,
        padding_value=-5000,
        embed_dimension=embedding_dimension,
        num_recycles=2,
        pep_max_len=None,
        prot_max_len=None,
    ):
        super().__init__()

        if pep_max_len is None or prot_max_len is None:
            raise ValueError("pep_max_len and prot_max_len must be provided")
            
        self.num_recycles = num_recycles # 2
        self.padding_value = padding_value # -5000
        self.embed_dimension = embed_dimension # 1280

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        
        #---------------Using differtent transformer for embeddings and structure---------------#
        
        ##### Embedding transformer #####

        # self attention
        self.seq_encoder = nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
        )

        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        # cross attention
        self.seq_cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        ##### Structure/ contact maps transformer #####

        # projecting dim=1 to size = self.embed_dimension
        self.pep_contact_projector = nn.Linear(in_features=pep_max_len, out_features=self.embed_dimension)
        self.prot_contact_projector = nn.Linear(in_features=prot_max_len, out_features=self.embed_dimension)

        # self attention
        self.struct_encoder = nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )

        # cross attention
        self.struct_cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_projection_head = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )

        self.strcut_projection_head = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, pep_cm, prot_cm, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        # masking embeddings
        pep_mask_emb = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask_emb = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)

        # masking contact maps
        pep_mask_cm = create_key_padding_mask(embeddings=pep_cm, padding_value=self.padding_value)
        prot_mask_cm = create_key_padding_mask(embeddings=prot_cm, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
        pep_contacts = pep_cm.clone()
        prot_contacts = prot_cm.clone()
 
        for _ in range(self.num_recycles):

            # Transformers
            # Embedding transformer
            pep_trans = self.seq_encoder(self.norm(pep_emb), src_key_padding_mask=pep_mask) # normalization for num stability + self attention
            prot_trans = self.seq_encoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)
            
            # Structure transformer
            # binders
            pep_contact_proj = self.pep_contact_projector(pep_contacts) # projection so that dim=1 of size 1280
            pep_struct_encoded = self.struct_encoder(self.norm_struct(pep_contact_proj), src_key_padding_mask=pep_mask) # normalization for num stability + self attention
            # targets
            prot_contact_proj = self.prot_contact_projector(prot_contacts)
            prot_struct_encoded = self.struct_encoder(self.norm_struct(prot_contact_proj), src_key_padding_mask=prot_mask)

            # Cross-attention
            pep_cross, _ = self.seq_cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.seq_cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            pep_struct_cross, _ = self.struct_cross_attn(query=self.norm(pep_struct_encoded), key=self.norm(prot_struct_encoded), value=self.norm(prot_struct_encoded), key_padding_mask=prot_mask)
            prot_struct_cross, _ = self.struct_cross_attn(query=self.norm(prot_struct_encoded), key=self.norm(pep_struct_encoded), value=self.norm(pep_struct_encoded), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross
            pep_contacts = pep_contacts + pep_struct_cross
            prot_contacts = prot_contacts + prot_struct_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        pep_contacts_coding = create_mean_of_non_masked(pep_contacts, pep_mask_cm)
        prot_contacts_coding = create_mean_of_non_masked(prot_contacts, prot_mask_cm)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_projection_head(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_projection_head(prot_seq_coding))
        pep_contacts_coding = F.normalize(self.strcut_projection_head(pep_contacts_coding))
        prot_contacts_coding = F.normalize(self.strcut_projection_head(prot_contacts_coding))

        pep_full  = torch.cat([pep_seq_coding,  pep_contacts_coding],  dim=-1)
        prot_full = torch.cat([prot_seq_coding, prot_contacts_coding], dim=-1)
         
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

### Train model from scratch with 10% of PPint dataset using old architecture (encodings only)

In [None]:
model = MiniCLIP_w_transformer_crossattn(
    embed_dimension=embedding_dimension,
    num_recycles=number_of_recycles,
    pep_max_len=global_max_blen,
    prot_max_len=global_max_tlen,
).to("cuda")

model

### Trianing loop

In [None]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 training_loader, 
                 validation_loader,
                 test_df,
                 test_dataset, 
                 optimizer, 
                 EPOCHS, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.EPOCHS = EPOCHS
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps # if truthy (e.g., 1, 5), save a checkpoint every N epochs.
        self.verbose = v
        self.best_vloss = 1_000_000
        self.optimizer = optimizer
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.device = device
        self.test_indexes_for_auROC = test_indexes_for_auROC
        self.auROC_batch_size = auROC_batch_size # which observation indices to use when computing auROC/auPR periodically.
        self.test_dataset = test_dataset
        self.test_df = test_df

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0 

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad() 
            loss = self.model.training_step(batch, self.device) 
            loss.backward() 
            self.optimizer.step() 

            torch.cuda.empty_cache()
            running_loss += loss.item()
            
        return running_loss / len(self.training_loader)
    
    def calc_auroc_aupr_on_indexes(self, model, dataset, test_df, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [test_df.loc[index].target_binder_id for index in nondimer_indexes]
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                # Loading the data into the GPU based on the index batch
                binder_emb, target_emb = dataset._get_by_name(index_batch)
                binder_emb, target_emb = binder_emb.to(self.device), target_emb.to(self.device)

                logit_matrix = self.model.calculate_logit_matrix(binder_emb, target_emb)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        # Calculate scores and labels
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        # Calculate ROC curve metrics
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores


    def validate (self, dataloader, indexes_for_auc=False, auROC_dataset=False, auROC_df = False):
        self.model.eval()
        running_loss = 0.0
        running_accuracy = 0.0
        running_topk_accuracy = 0.0

        with torch.no_grad():
            for batch in tqdm(dataloader, total=len(dataloader), desc="First Validation run"):
                if batch[0].size(0) == 1: # We can't make negatives on a batch of 1
                    continue
                loss, partner_accuracy, peptide_topk_accuracy = self.model.validation_step(batch, self.device)
                running_loss += loss.item()
                running_accuracy += partner_accuracy.item()
                running_topk_accuracy += peptide_topk_accuracy.item()
                
            val_loss = running_loss / len(dataloader)
            val_accuracy = running_accuracy / len(dataloader)
            val_topk_accuracy = running_topk_accuracy / len(dataloader)

            if indexes_for_auc:
                # Calculating auc-scores
                non_dimer_auc, non_dimer_aupr, ___, ___  = self.calc_auroc_aupr_on_indexes(model=self.model, 
                                                                                           dataset=auROC_dataset,
                                                                                           test_df=auROC_df,
                                                                                           nondimer_indexes=indexes_for_auc,
                                                                                           batch_size=self.auROC_batch_size)
                
                return val_loss, val_accuracy, val_topk_accuracy, non_dimer_auc, non_dimer_aupr
            
            else:
                return val_loss, val_accuracy, val_topk_accuracy

    def train_model(self):

        if self.verbose:
            print(f"Training model {str(self.runID)}")
        
        # --- initial validation before training
        if self.test_indexes_for_auROC:
            val_loss_before, val_accuracy_before, val_topk_accuracy_before, val_nondimer_auc_before, val_nondimer_aupr_before = self.validate(
                dataloader=self.validation_loader,
                indexes_for_auc=self.test_indexes_for_auROC,
                auROC_dataset=self.test_dataset,
                auROC_df=self.test_df
            )
        else:
            val_loss_before, val_accuracy_before, val_topk_accuracy_before = self.validate(
                dataloader=self.validation_loader,
                indexes_for_auc=self.test_indexes_for_auROC,
                auROC_dataset=self.test_dataset
            )
            val_nondimer_auc_before = None
            val_nondimer_aupr_before = None
        
        if self.verbose: 
            if val_nondimer_auc_before is not None:
                print(f'Before training - Val CLIP-loss {round(val_loss_before,4)}',
                      f'Accuracy: {round(val_accuracy_before,4)}',
                      f'Top-K accuracy : {round(val_topk_accuracy_before,4)}',
                      f'auc: {round(val_nondimer_auc_before,3)}')
            else:
                print(f'Before training - Val CLIP-loss {round(val_loss_before,4)}',
                      f'Accuracy: {round(val_accuracy_before,4)}',
                      f'Top-K accuracy : {round(val_topk_accuracy_before,4)}')

        if self.wandb_tracker:
            metrics_to_log = {
                "Val-loss": val_loss_before,
                "Val-acc": val_accuracy_before,
                "Val-TOPK-acc": val_topk_accuracy_before,
            }
            if val_nondimer_auc_before is not None:
                metrics_to_log["Val non-dimer auc"] = val_nondimer_auc_before
                metrics_to_log["Val non-dimer auPR"] = val_nondimer_aupr_before
            self.wandb_tracker.log(metrics_to_log)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC:
                val_loss, val_accuracy, val_topk_accuracy, non_dimer_auc, non_dimer_aupr = self.validate(
                    dataloader=self.validation_loader,
                    indexes_for_auc=self.test_indexes_for_auROC,
                    auROC_dataset=self.test_dataset,
                    auROC_df=self.test_df
                )
            else:
                val_loss, val_accuracy, val_topk_accuracy = self.validate(
                    dataloader=self.validation_loader,
                    indexes_for_auc=self.test_indexes_for_auROC,
                    auROC_dataset=self.test_dataset,
                    auROC_df=self.test_df
                )
                non_dimer_auc = None
                non_dimer_aupr = None
            
            # checkpoint save
            if self.model_save_steps:
                if epoch % self.model_save_steps == 0:
                    check_point_folder = os.path.join(
                        self.trained_model_dir,
                        f"{str(self.runID)}_checkpoint_{str(epoch)}"
                    )
                    if self.verbose:
                        print("Saving model to:", check_point_folder)
                    
                    if not os.path.exists(check_point_folder):
                        os.makedirs(check_point_folder)

                    checkpoint_path = os.path.join(
                        check_point_folder,
                        f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth"
                    )
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(),
                            'val_loss': val_loss
                        }, 
                        checkpoint_path
                    )
            
            # console logging
            if self.verbose:
                if epoch % self.print_frequency_loss == 0:
                    if non_dimer_auc is not None:
                        print(f'EPOCH {epoch} -  Val loss {round(val_loss,4)}',
                              f'Accuracy: {round(val_accuracy,4)}',
                              f'Top-K accuracy: {round(val_topk_accuracy,4)}',
                              f'Val-Auc:{round(non_dimer_auc,3)}', 
                              f'Val-auPR:{round(non_dimer_aupr,3)}')
                    else:
                        print(f'EPOCH {epoch} -  Val loss {round(val_loss,4)}',
                              f'Accuracy: {round(val_accuracy,4)}',
                              f'Top-K accuracy: {round(val_topk_accuracy,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "Epoch": epoch,
                    "Train-loss": train_loss,
                    "Val-loss": val_loss,
                    "Val-acc": val_accuracy,
                    "Val-TOPK-acc": val_topk_accuracy,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch["Val non-dimer auc"] = non_dimer_auc
                    metrics_to_log_epoch["Val non-dimer auPR"] = non_dimer_aupr
                self.wandb_tracker.log(metrics_to_log_epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [None]:
batch_size = 10
learning_rate = 2e-5
EPOCHS = 15
g = torch.Generator().manual_seed(SEED)

# login once (env var preferred)
if use_wandb:
    import wandb
    wandb.login()

optimizer = AdamW(model.parameters(), lr=learning_rate)
train_dataloader = DataLoader(training_Dataset, batch_size=10)
val_dataloader = DataLoader(validation_Dataset, batch_size=20, shuffle=False, drop_last = False)

# accelerator
accelerator = Accelerator()
device = accelerator.device
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader)

# wandb
if use_wandb:
    run = wandb.init(
        project="PPint_retrain_w_10percent_ofdata",
        name=f"PPint_retrain",
        config={"learning_rate": learning_rate, "batch_size": batch_size, "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", "dataset": "Meta analysis"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(model=model, 
                                training_loader = train_dataloader, 
                                validation_loader = val_dataloader, 
                                test_dataset = validation_Dataset,
                                test_df = Df_val,
                                optimizer = optimizer, 
                                EPOCHS = EPOCHS,
                                runID = runID, 
                                device = device, 
                                test_indexes_for_auROC = indices_non_dimers_val, 
                                model_save_steps = model_save_steps,
                                model_save_path = trained_model_dir, 
                                v = True, 
                                wandb_tracker = wandb
                                )

training_wrapper.train_model()