In [21]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# import torch
# torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
# print(torch.cuda.get_device_name(0))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator

import matplotlib.pyplot as plt
import seaborn as sns

from Levenshtein import distance as Ldistance

import training_utils.dataset_utils as data_utils
import training_utils.partitioning_utils as pat_utils

import importlib
import training_utils.train_utils as train_utils
importlib.reload(train_utils)

<module 'training_utils.train_utils' from '/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/training_utils/train_utils.py'>

In [22]:
SEED = 0
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [23]:
import wandb
wandb.login()

True

In [24]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.5.1
Using device: cpu
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [25]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1152 # 1280 | 960 | 1152
number_of_recycles = 2
padding_value = -5000

batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

In [26]:
## Model Class
### MiniClip 
def gaussian_kernel(x, sigma):
    return np.exp(-x**2 / (2 * sigma**2))

def transform_vector(vector, sigma):

    interacting_indices = np.where(vector == 1)[0]   # positions where vector == 1
    transformed_vector = np.zeros_like(vector, dtype=float)
    
    for i in range(len(vector)):
        if vector[i] == 0:
            distances = np.abs(interacting_indices - i)   # distance to all "1"s
            min_distance = np.min(distances)              # closest "1"
            transformed_vector[i] = gaussian_kernel(min_distance, sigma)
        else:
            transformed_vector[i] = 1.0
    return transformed_vector

def safe_shuffle(n, device):
    shuffled = torch.randperm(n, device=device)
    while torch.any(shuffled == torch.arange(n, device=device)):
        shuffled = torch.randperm(n, device=device)
    return shuffled

def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=1152, num_recycles=1):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1) # Dot-Product for comparison
        
        return logits

    def training_step(self, batch, device):
        # Passing the sequences to the models
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        binder_label = batch[2]
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        binder_label = binder_label.to(device)

        logits = self.forward(embedding_pep, embedding_prot)
        binder_labels = binder_label.view_as(logits)  
        loss = F.binary_cross_entropy_with_logits(logits, binder_labels)
 
        torch.cuda.empty_cache()
        
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, binder_label = batch
 
        # Move to device
        embedding_pep  = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        binder_label = binder_label.to(device).float() 

        with torch.no_grad():
            logits = self.forward(embedding_pep, embedding_prot)   # shape [B]
            binder_labels = binder_label.view_as(logits)  
            loss = F.binary_cross_entropy_with_logits(logits, binder_labels)
    
        return float(loss.item()), logits, binder_labels

In [29]:
## Output path
trained_model_dir = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts"

## Embeddings paths
binders_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings"
targets_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings"

# ## Training variables
runID = uuid.uuid4()

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB

# print_mem_consumption()

#### Loading data frame

In [30]:
### Loading the dataset
interaction_df = pd.read_csv("../data/meta_analysis/interaction_df_metaanal.csv", index_col = 0).drop(columns = ["binder_id", "target_id"]).rename(columns={
    "A_seq" : "binder_seq",
    "B_seq" : "target_seq"
})

all_targets = interaction_df.target_id_mod.unique()
binder_nonbinder = interaction_df.binder.value_counts()
target_binder_nonbinder_Dict = dict(interaction_df.groupby("target_id_mod")["binder"].value_counts())
sorted_items = sorted(target_binder_nonbinder_Dict.items(), key=lambda kv: kv[1], reverse=True)

# %%
# Annotating each observation with a weight corresponding to whether it is considered a binder or not
N_bins = len(interaction_df["binder"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df["binder"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df["observation_weight"] = interaction_df.binder.apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])
weights_Dict = dict(zip(interaction_df["target_binder_ID"], interaction_df["observation_weight"]))
interaction_df

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id_mod,target_binder_ID,observation_weight
0,A,"[""B""]",False,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,VirB8,VirB8_1,0.000159
1,A,"[""B""]",False,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1,0.000159
2,A,"[""B""]",False,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2,0.000159
3,A,"[""B""]",False,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_3,0.000159
4,A,"[""B""]",False,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_4,0.000159
...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_62,0.000159
3528,A,"[""B""]",False,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_63,0.000159
3529,A,"[""B""]",False,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_64,0.000159
3530,A,"[""B""]",False,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_65,0.000159


In [31]:
binders_per_target = dict(interaction_df['target_id_mod'].value_counts())
targets = interaction_df['target_id_mod'].unique().tolist()

# 5-fold CV training + weighting of samples

In [34]:
clusters = [
    ["VirB8"], ["FGFR2"], ["IL7Ra"], ["InsulinR"],
    ["EGFR", "EGFR_2", "EGFR_3"],     # keep together
    ["SARS_CoV2_RBD"], ["Pdl1"], ["TrkA"], ["IL10Ra"],
    ["LTK"], ["Mdm2"],
    ["sntx", "sntx_2"],               # keep together
    ["IL2Ra"],
]

random.Random(0).shuffle(clusters)
folds = np.array_split(np.array(clusters, dtype=object), 5)   # list of np arrays
targets_folds = []
for f in folds:
    flat = []
    for group in f:
        flat.extend(group)
    targets_folds.append(flat)

def build_cv_splits(targets_folds):
    val_folds, train_folds = [], []
    K = len(targets_folds)
    for i in range(K):
        val_targets = list(targets_folds[i])  # copy
        train_targets = [t for j, fold in enumerate(targets_folds) if j != i for t in fold]
        val_folds.append(val_targets)
        train_folds.append(train_targets)
    return val_folds, train_folds

val_folds, train_folds = build_cv_splits(targets_folds)
cv_splits = list(zip(val_folds, train_folds))

In [35]:
for idx in range(len(cv_splits)):
    val_targets = cv_splits[idx][0]
    vals = len(interaction_df[interaction_df.target_id_mod.isin(val_targets)])
    trains = len(interaction_df) - vals
    print(f"Fold {idx+1} : training instances : {trains}, validation instances : {vals}")

Fold 1 : training instances : 1280, validation instances : 2252
Fold 2 : training instances : 3213, validation instances : 319
Fold 3 : training instances : 3265, validation instances : 267
Fold 4 : training instances : 2999, validation instances : 533
Fold 5 : training instances : 3371, validation instances : 161


#### Creating separate targets/ binder dataframes

In [38]:
# Targets df
target_df = interaction_df[["target_id_mod","target_seq"]].rename(columns={"target_seq":"sequence", "target_id_mod" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["target_binder_ID","binder_seq", "binder"]].rename(columns={"binder_seq":"sequence", "target_binder_ID" : "ID", "binder" : "label"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")
binder_df["observation_weight"] = binder_df.index.map(weights_Dict)

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))

In [37]:
target_df

Unnamed: 0_level_0,sequence,seq_len
ID,Unnamed: 1_level_1,Unnamed: 2_level_1
VirB8,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,138
FGFR2,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,101
IL7Ra,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,193
InsulinR,EVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSF...,150
EGFR,RKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGD...,191
SARS_CoV2_RBD,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,195
Pdl1,NAFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDK...,115
EGFR_2,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,621
TrkA,VSFPASVQLHTAVEMHHWCIPFSVDGQPAPSLRWLFNGSVLNETSF...,101
IL10Ra,GTELPSPPSVWFEAEFFHHILHWTPIPQQSESTCYEVALLRYGIES...,207


In [39]:
binder_df

Unnamed: 0_level_0,sequence,label,seq_len,observation_weight
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
VirB8_1,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,False,40,0.000159
FGFR2_1,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,False,62,0.000159
FGFR2_2,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,False,61,0.000159
FGFR2_3,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,False,64,0.000159
FGFR2_4,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,False,64,0.000159
...,...,...,...,...
IL2Ra_62,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,False,55,0.000159
IL2Ra_63,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,False,56,0.000159
IL2Ra_64,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,False,56,0.000159
IL2Ra_65,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,False,57,0.000159


#### Creating separate targets/ binder dataframes (for validation/ training)

In [40]:
class CLIP_meta_analysis_dataset(Dataset):

    def __init__(self, sequence_df, esm_encoding_paths, embedding_dim=1152, padding_value=-5000):

        super(CLIP_meta_analysis_dataset, self).__init__()

        self.sequence_df = sequence_df # target/binder_df
        self.max_length = sequence_df["seq_len"].max()
        self.sequence_df["index_num"] = np.arange(len(self.sequence_df))
        # print(self.sequence_df)
        self.esm_encoding_paths = esm_encoding_paths
        num_samples = len(self.sequence_df)
        
        self.x = torch.full((num_samples, self.max_length, embedding_dim), padding_value, dtype=torch.float32)

        self.accessions = self.sequence_df.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        
        # Load embeddings into the pre-allocated tensor
        all_acc_list = self.accessions
        all_acc_loaded_list = []

        iterator = tqdm(all_acc_list, position=0, leave=True, total=num_samples, desc="# Reading in ESM-embeddings from folder")
        for i, accession in enumerate(iterator):
            npy_path = os.path.join(esm_encoding_paths, f"{accession}.npy")
            try:
                embd = np.load(npy_path)[0]
                length_to_pad = self.max_length - len(embd)
                if length_to_pad > 0:
                    zero_padding = np.ones((length_to_pad, embd.shape[1])) * padding_value
                    padded_array = np.concatenate((embd, zero_padding), axis=0)
                else:
                    padded_array = embd[:self.max_length] 
                self.x[i] = torch.tensor(padded_array, dtype=torch.float32)
                all_acc_loaded_list.append(accession)
            except FileNotFoundError as e:
                raise FileNotFoundError(f"Embedding file {accession}.npy not found.")
            
        missing = sorted(set(all_acc_list) - set(all_acc_loaded_list))
        if missing:
            raise FileNotFoundError(
                f"Missing {len(missing)} embedding files in '{esm_encoding_paths}'. "
                f"Examples: {missing}")
          
    def __len__(self):
        return int(self.x.shape[0])

    def __getitem__(self, idx):
        return self.x[idx]
    
    # add a helper:
    def get_by_name(self, name: str):
        return self.x[self.name_to_row[name]]

targets_dataset = CLIP_meta_analysis_dataset(target_df, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings", embedding_dim=1152)
binders_dataset = CLIP_meta_analysis_dataset(binder_df, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings", embedding_dim=1152)

# targets_dataset_val = CLIP_meta_analysis_dataset(target_df_val, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings", embedding_dim=1152)
# binders_dataset_val = CLIP_meta_analysis_dataset(binder_df_val, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings", embedding_dim=1152)
# targets_dataset_train = CLIP_meta_analysis_dataset(target_df_train, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings", embedding_dim=1152)
# binders_dataset_train = CLIP_meta_analysis_dataset(binder_df_train, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings", embedding_dim=1152)

# Reading in ESM-embeddings from folder:   0%|          | 0/16 [00:00<?, ?it/s]

# Reading in ESM-embeddings from folder: 100%|██████████| 16/16 [00:00<00:00, 161.21it/s]
# Reading in ESM-embeddings from folder: 100%|██████████| 3532/3532 [00:08<00:00, 421.47it/s]


In [None]:
# Cheking whether index=0 is embedding of the first element of the dataframe 
torch.equal(targets_dataset[0], targets_dataset.get_by_name("VirB8"))

True

In [42]:
torch.equal(binders_dataset[0], binders_dataset.get_by_name("VirB8_1"))

True

In [43]:
def binder_to_target_name(bname: str) -> str:
    parts = bname.split("_")
    if bname.startswith("SARS"):
        return "SARS_CoV2_RBD"
    if len(parts) == 3:
        return f"{parts[0]}_{parts[1]}"
    return parts[0]

def binder_target_label(targets_dataset, binders_dataset, binder_ids, interaction_df, stack=True):
    
    listof_bindertargetlabel = []
    
    for bname in binder_ids:
        tname = binder_to_target_name(bname)

        # get embeddings by name
        b_emb = binders_dataset.get_by_name(bname)
        t_emb = targets_dataset.get_by_name(tname)

        # get label from df
        s = interaction_df.loc[interaction_df['target_binder_ID'] == bname, 'binder']
        # if s.empty:
        #     raise ValueError(f"No label found in interaction_df for binder id '{bname}'")
        lbl = torch.tensor(float(s.iat[0]), dtype=torch.float32)

        listof_bindertargetlabel.append((b_emb, t_emb, lbl))

    return listof_bindertargetlabel

### Loading pretrained model for finetuning

In [19]:
ckpt_path = '../PPI_PLM/models/CLIP_no_structural_information/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_4/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_epoch_4.pth'
checkpoint = torch.load(ckpt_path, weights_only=False, map_location="cpu")
# print(list(checkpoint["model_state_dict"]))
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn()
model.load_state_dict(checkpoint['model_state_dict'])
torch.cuda.empty_cache()  # frees cached blocks (not live tensors)
device = torch.device("cuda:0")
model.to(device)
# model.train()

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

### Loading training and validation datasets (DataLoaders)

In [21]:
# validation_data_5clusters = []
# training_data_5clusters = []

# for i, split in enumerate(cv_splits):
#     validation, training = split[0], split[1]
    
#     validation_binders = interaction_df.loc[ interaction_df["target_id_mod"].isin(validation), "target_binder_ID"].tolist()
#     training_binders = interaction_df.loc[interaction_df["target_id_mod"].isin(training), "target_binder_ID"].tolist()

#     listof_bindertargetlabel = binder_target_label(targets_dataset, binders_dataset, validation_binders, interaction_df)
#     validation_data_5clusters.append(listof_bindertargetlabel)
    
#     listof_bindertargetlabel = binder_target_label(targets_dataset, binders_dataset, training_binders, interaction_df)
#     training_data_5clusters.append(listof_bindertargetlabel)

# train_loader = DataLoader(training_data_5clusters[0], batch_size=32, shuffle=True)
# val_loader = DataLoader(validation_data_5clusters[0], batch_size=32, shuffle=True)

# validation_dataset = validation_data_5clusters[0]
# validation_binders = interaction_df.loc[interaction_df["target_id_mod"].isin(cv_splits[0][0]), "target_binder_ID"].tolist()
# validation_row_indices = interaction_df.index[interaction_df["target_id_mod"].isin(cv_splits[0][0])].tolist()

# batch = next(iter(val_loader))
# print(f"Shape of the binders embeddings : {batch[0].shape}")
# print(f"Shape of the targets embeddings : {batch[1].shape}")
# print(f"Labels (0 - non-binder, 1 - binder) :{(batch[2].numpy())}\n")

# with torch.no_grad():
#     batch = next(iter(val_loader))
#     loss, auroc, aupr = model.validation_step(batch,device)
#     print(loss)

# len(validation_row_indices) == len(validation_binders)

In [22]:
import gc, torch
# del obj  # any large temps you created in the cell
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# print_mem_consumption()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

torch.cuda.empty_cache()
print_mem_consumption()

Total memory:  42.405855232
Reserved memory:  0.065011712
Allocated memory:  0.0
Free memory:  0.008080896


### Training loop

In [20]:
def batch(iterable, n=20):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper_MetaAnal():

    def __init__(self, model, training_loader, validation_loader, test_dataset, 
                 optimizer, EPOCHS, runID, device, test_indexes_for_auROC=None,
                 auROC_batch_size=18, model_save_steps=False, model_save_path=False, 
                 v=False, wandb_tracker=False, split_id=None):
        
        self.model = model 
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.EPOCHS = EPOCHS
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.split_id = split_id
        self.best_vloss = 1e09
        self.optimizer = optimizer
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.device = device
        self.test_indexes_for_auROC = test_indexes_for_auROC
        self.auROC_batch_size = auROC_batch_size
        self.test_dataset = test_dataset

    def train_one_epoch(self):

        self.model.train()
        running_loss = 0.0

        for batch_data in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):

            if batch_data[0].size(0) == 1:
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch_data, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

        return running_loss / max(1, len(self.training_loader))

    def calc_auroc_aupr_on_indexes(self, model, validation_dataset, batch_size=20, pad_value=-5000.0):

        model.eval()
        all_scores, all_labels = [], []
        batched_data = batch(validation_dataset, n=batch_size)

        with torch.no_grad():
            for one_batch in tqdm(batched_data, total= len(validation_dataset)/batch_size, desc="Calculating AUC"):

                # Gather items: dataset[i] -> (binder_emb [L,D], target_emb [L,D], label)
                items = [i for i in one_batch]
                binders = torch.stack([binder_emb for (binder_emb, _, _) in items]).to(self.device)
                targets = torch.stack([target_emb for (_, target_emb, _) in items]).to(self.device)
                labels = np.array([float(lbl) for *_, lbl in items], dtype=np.float32)

                # Forward: logits per pair [B]
                logits = model.forward(binders, targets).detach().cpu().numpy()

                all_scores.extend(logits.tolist())
                all_labels.extend(labels.tolist())

        all_scores = np.array(all_scores, dtype=np.float64)
        all_labels = np.array(all_labels, dtype=np.int64)

        fpr, tpr, _ = metrics.roc_curve(all_labels, all_scores)
        auroc = metrics.roc_auc_score(all_labels, all_scores)
        aupr  = metrics.average_precision_score(all_labels, all_scores)

        return auroc, aupr, fpr, tpr
    

    def validate(self, dataloader, indexes_for_auc=False, auROC_dataset=False):
        self.model.eval()
        running_loss, n_loss = 0.0, 0
        all_scores, all_labels = [], []
    
        with torch.no_grad():
            for batch_data in tqdm(dataloader, total=len(dataloader), desc="Validation"):
                loss, logits, labels = self.model.validation_step(batch_data, self.device)
                running_loss += float(loss)
                n_loss += 1
    
                # expect logits [B], labels [B] on device; move once
                all_scores.append(logits.detach().float().cpu())
                all_labels.append(labels.detach().long().cpu())
    
        val_loss = running_loss / max(1, n_loss)
    
        if all_scores:
            scores = torch.cat(all_scores).numpy()
            labs   = torch.cat(all_labels).numpy()
            val_auroc = metrics.roc_auc_score(labs, scores)
            val_aupr  = metrics.average_precision_score(labs, scores)
        else:
            val_auroc = float("nan")
            val_aupr  = float("nan")
    
        return val_loss, val_auroc, val_aupr

    def train_model(self):

        if self.verbose:
            print(f"Training model {str(self.runID)}")
        
        # Pre-training snapshot
        val_loss, val_auroc, val_aupr  = self.validate(
            dataloader=self.validation_loader, 
            indexes_for_auc=self.test_indexes_for_auROC, 
            auROC_dataset=self.test_dataset)

        if self.verbose:
            print(
                f'Before training - Val Loss {val_loss:.4f} | '
                f'Val AUROC {val_auroc if val_auroc==val_auroc else float("nan"):.4f} | '
                f'Val AUPR {val_aupr if val_aupr==val_aupr else float("nan"):.4f}'
            )

        # Optional full-set AUROC/AUPR (not per-batch average)
        # auroc_full = aupr_full = None
        # if self.test_indexes_for_auROC is not None:
        #     auroc_full, aupr_full, _, _ = self.calc_auroc_aupr_on_indexes(
        #         self.model, self.test_dataset, batch_size=self.auROC_batch_size
        #     )
            
        if self.wandb_tracker:
            log_items = {
                "Val Loss": val_loss,
                "Val AUROC": val_auroc,
                "Val AUPR": val_aupr,
            }
            self.wandb_tracker.log(log_items)
            
        # --- Epoch loop ---
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            torch.cuda.empty_cache()
            train_loss = self.train_one_epoch()
    
            val_loss, val_auroc, val_aupr = self.validate(
                dataloader=self.validation_loader,
                indexes_for_auc=self.test_indexes_for_auROC,
                auROC_dataset=self.test_dataset,
            )

            torch.cuda.empty_cache()
            # Save checkpoints periodically
            if self.model_save_steps and (epoch % self.model_save_steps == 0):
                if epoch in [5, 10, 15]:
                    check_point_folder = os.path.join(self.trained_model_dir, f"{self.split_id}_{str(self.runID)}_epochs_{epoch}")
                    if self.verbose:
                        print("Saving model to:", check_point_folder)
                    os.makedirs(check_point_folder, exist_ok=True)
                    checkpoint_path = os.path.join(
                        check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{epoch}.pth"
                    )
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(),
                            'val_loss': val_loss,
                        },
                        checkpoint_path,
                    )
    
            if self.verbose and (epoch % self.print_frequency_loss == 0):
                print(
                    f'EPOCH {epoch} - Train Loss {train_loss:.4f} | '
                    f'Val Loss {val_loss:.4f} | Val AUROC {val_auroc if val_auroc==val_auroc else float("nan"):.4f} | '
                    f'Val AUPR {val_aupr if val_aupr==val_aupr else float("nan"):.4f}'
                )
    
            if self.wandb_tracker:
                log_items = {
                    "Train Loss": train_loss,
                    "Val Loss": val_loss,
                    "Val AUROC": val_auroc,
                    "Val AUPR": val_aupr,
                    "epoch_#": epoch,
                }
                self.wandb_tracker.log(log_items)
    
        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [None]:
batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

all_binders = interaction_df["target_binder_ID"].tolist()

# all dataset: binder_enb, target_emb, label
ALL_btl = binder_target_label(targets_dataset, binders_dataset, all_binders, interaction_df)

# login once (env var preferred)
if use_wandb:
    import wandb
    wandb.login()

for i in range(len(cv_splits)):
    # NEW model per split
    model = MiniCLIP_w_transformer_crossattn()
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    val_targets = cv_splits[i][0]
    val_idx = interaction_df.loc[interaction_df.target_id_mod.isin(val_targets)].index.tolist()
    val_binders = [ALL_btl[idx] for idx in val_idx]

    train_targets = cv_splits[i][1]
    train_idx = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets)].index.tolist()
    train_binders = [ALL_btl[idx] for idx in train_idx]

    # loaders
    train_loader   = DataLoader(train_binders,   batch_size=20, shuffle=True, generator = torch.Generator().manual_seed(0))
    val_loader   = DataLoader(val_binders,   batch_size=20, shuffle=False, drop_last = False)

    # accelerator
    accelerator = Accelerator()
    device = accelerator.device
    model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader)

    # wandb
    if use_wandb:
        run = wandb.init(
            project="CLIP_PPint_metaanalysis",
            name=f"split{i+1}_{runID}",
            group="cv_splits",
            config={"learning_rate": learning_rate, "batch_size": batch_size, "epochs": EPOCHS,
                    "architecture": "MiniCLIP_w_transformer_crossattn", "dataset": "Meta analysis"},
        )
        wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
    else:
        run = None

    # train
    training_wrapper = TrainWrapper_MetaAnal(
        model=model,
        training_loader=train_loader,
        validation_loader=val_loader,
        test_dataset=val_binders,   # ok if you truly want “full val”
        optimizer=optimizer,
        EPOCHS=EPOCHS,
        runID=runID,
        device=device,
        model_save_steps=model_save_steps,
        model_save_path=trained_model_dir,
        v=True,
        wandb_tracker=run,
        split_id=i+1
    )
    training_wrapper.train_model()

    # cleanup between splits
    if use_wandb:
        wandb.finish()
    del training_wrapper, model, optimizer, train_loader, val_loader
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    try:
        accelerator.free_memory()
    except AttributeError:
        pass

Training model 0a93e157-a6d6-4c99-8d95-4ba29724bf42


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 113/113 [00:11<00:00, 10.15it/s]


Before training - Val Loss 10.2305 | Val AUROC 0.4424 | Val AUPR 0.1063


Epochs:   0%|                                                                                                        | 0/15 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:19,  3.15it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:14,  4.21it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:12,  4.69it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:12,  4.96it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:01<00:11,  5.11it/s

EPOCH 1 - Train Loss 1.0058 | Val Loss 0.3225 | Val AUROC 0.6703 | Val AUPR 0.2729



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  5.03it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.31it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.32it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.38it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.42it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.46

EPOCH 2 - Train Loss 0.3858 | Val Loss 0.4388 | Val AUROC 0.6524 | Val AUPR 0.2674



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.90it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.21it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.35it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.41it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.44it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.46

EPOCH 3 - Train Loss 0.3292 | Val Loss 0.3436 | Val AUROC 0.6506 | Val AUPR 0.2422



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.95it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.23it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.34it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.36it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.40it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.43

EPOCH 4 - Train Loss 0.3204 | Val Loss 0.3452 | Val AUROC 0.6393 | Val AUPR 0.2376



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:13,  4.82it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:12,  5.16it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.22it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.30it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:11,  5.31it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.33

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/1_0a93e157-a6d6-4c99-8d95-4ba29724bf42_epochs_5


Epochs:  33%|████████████████████████████████                                                                | 5/15 [01:55<03:54, 23.41s/it]

EPOCH 5 - Train Loss 0.3191 | Val Loss 0.3531 | Val AUROC 0.6589 | Val AUPR 0.2641



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.85it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.30it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.35it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.41it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.40it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.39

EPOCH 6 - Train Loss 0.3015 | Val Loss 0.3348 | Val AUROC 0.6240 | Val AUPR 0.2102



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.91it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.24it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.34it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.38it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.42it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.44

EPOCH 7 - Train Loss 0.3045 | Val Loss 0.3733 | Val AUROC 0.6370 | Val AUPR 0.2124



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.87it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.18it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.28it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.36it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.39it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.42

EPOCH 8 - Train Loss 0.2885 | Val Loss 0.3914 | Val AUROC 0.6578 | Val AUPR 0.2113



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.86it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:12,  5.16it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.26it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.33it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.37it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.39

EPOCH 9 - Train Loss 0.2807 | Val Loss 0.3473 | Val AUROC 0.6178 | Val AUPR 0.2049



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.95it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.25it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.34it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.35it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.39it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.39

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/1_0a93e157-a6d6-4c99-8d95-4ba29724bf42_epochs_10


Epochs:  67%|███████████████████████████████████████████████████████████████▎                               | 10/15 [03:51<01:56, 23.39s/it]

EPOCH 10 - Train Loss 0.2988 | Val Loss 0.3232 | Val AUROC 0.6619 | Val AUPR 0.2555



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.87it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.21it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.30it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.33it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.38it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.42

EPOCH 11 - Train Loss 0.2719 | Val Loss 0.3899 | Val AUROC 0.6655 | Val AUPR 0.2245



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  5.16it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.34it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.41it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.44it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.43it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.44

EPOCH 12 - Train Loss 0.2439 | Val Loss 0.4286 | Val AUROC 0.6448 | Val AUPR 0.1817



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.88it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.22it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.35it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.42it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.40it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.43

EPOCH 13 - Train Loss 0.2539 | Val Loss 0.3486 | Val AUROC 0.6441 | Val AUPR 0.2188



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.92it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.23it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.33it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.39it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:10,  5.44it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.47

EPOCH 14 - Train Loss 0.2117 | Val Loss 0.3472 | Val AUROC 0.6509 | Val AUPR 0.1689



Running through epoch:   0%|                                                                                         | 0/64 [00:00<?, ?it/s][A
Running through epoch:   2%|█▎                                                                               | 1/64 [00:00<00:12,  4.91it/s][A
Running through epoch:   3%|██▌                                                                              | 2/64 [00:00<00:11,  5.25it/s][A
Running through epoch:   5%|███▊                                                                             | 3/64 [00:00<00:11,  5.36it/s][A
Running through epoch:   6%|█████                                                                            | 4/64 [00:00<00:11,  5.25it/s][A
Running through epoch:   8%|██████▎                                                                          | 5/64 [00:00<00:11,  5.31it/s][A
Running through epoch:   9%|███████▌                                                                         | 6/64 [00:01<00:10,  5.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/1_0a93e157-a6d6-4c99-8d95-4ba29724bf42_epochs_15


Epochs: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [05:46<00:00, 23.07s/it]

EPOCH 15 - Train Loss 0.1948 | Val Loss 0.3760 | Val AUROC 0.6341 | Val AUPR 0.2026





0,1
Train Loss,█▃▂▂▂▂▂▂▂▂▂▁▂▁▁
Val AUPR,▁██▇▇█▅▅▅▅▇▆▄▆▄▅
Val AUROC,▁█▇▇▇█▇▇█▆██▇▇▇▇
Val Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_#,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█

0,1
Train Loss,0.19479
Val AUPR,0.20261
Val AUROC,0.63413
Val Loss,0.37598
epoch_#,15.0


Training model 0a93e157-a6d6-4c99-8d95-4ba29724bf42


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:01<00:00,  9.79it/s]


Before training - Val Loss 8.8336 | Val AUROC 0.4660 | Val AUPR 0.1474


Epochs:   0%|                                                                                                        | 0/15 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.91it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.22it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.31it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.38it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:29,  5.38it/s

EPOCH 1 - Train Loss 0.5003 | Val Loss 2.7625 | Val AUROC 0.4511 | Val AUPR 0.1437



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.92it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.22it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.37it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:28,  5.42it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.44it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.46

EPOCH 2 - Train Loss 0.3109 | Val Loss 3.0376 | Val AUROC 0.3881 | Val AUPR 0.1309



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.94it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.24it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.35it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.41it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.42it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.45

EPOCH 3 - Train Loss 0.3031 | Val Loss 2.8789 | Val AUROC 0.3854 | Val AUPR 0.1304



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.94it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.23it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.39it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:28,  5.42it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.43it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.48

EPOCH 4 - Train Loss 0.3018 | Val Loss 2.9851 | Val AUROC 0.3850 | Val AUPR 0.1302



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.89it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.21it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.39it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:28,  5.45it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.47it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.45

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/2_0a93e157-a6d6-4c99-8d95-4ba29724bf42_epochs_5


Epochs:  33%|████████████████████████████████                                                                | 5/15 [02:36<05:16, 31.63s/it]

EPOCH 5 - Train Loss 0.2882 | Val Loss 2.7379 | Val AUROC 0.3885 | Val AUPR 0.1307



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.90it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.26it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.39it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:28,  5.47it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.43it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.49

EPOCH 6 - Train Loss 0.2874 | Val Loss 2.9520 | Val AUROC 0.3823 | Val AUPR 0.1296



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.88it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.23it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.32it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.37it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.40it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.42

EPOCH 7 - Train Loss 0.2753 | Val Loss 3.1113 | Val AUROC 0.3749 | Val AUPR 0.1283



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:31,  5.00it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.24it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.35it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.41it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.43it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.46

EPOCH 8 - Train Loss 0.2646 | Val Loss 3.2358 | Val AUROC 0.3658 | Val AUPR 0.1269



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.91it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.22it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.32it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.39it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.42it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.45

EPOCH 9 - Train Loss 0.2486 | Val Loss 3.0444 | Val AUROC 0.3623 | Val AUPR 0.1261



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.93it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.24it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.31it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.37it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.39it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.42

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/2_0a93e157-a6d6-4c99-8d95-4ba29724bf42_epochs_10


Epochs:  67%|███████████████████████████████████████████████████████████████▎                               | 10/15 [05:15<02:40, 32.15s/it]

EPOCH 10 - Train Loss 0.2307 | Val Loss 2.5160 | Val AUROC 0.3844 | Val AUPR 0.1305



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.99it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.27it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.38it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:28,  5.42it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.43it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.44

EPOCH 11 - Train Loss 0.2081 | Val Loss 2.6395 | Val AUROC 0.3357 | Val AUPR 0.1219



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.93it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.23it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.40it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.39it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.49it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.42

EPOCH 12 - Train Loss 0.1832 | Val Loss 2.4420 | Val AUROC 0.3718 | Val AUPR 0.1278



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.92it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.25it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.34it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.39it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:28,  5.44it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.49

EPOCH 13 - Train Loss 0.1466 | Val Loss 2.2302 | Val AUROC 0.3731 | Val AUPR 0.1280



Running through epoch:   0%|                                                                                        | 0/161 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/161 [00:00<00:32,  4.97it/s][A
Running through epoch:   1%|▉                                                                               | 2/161 [00:00<00:30,  5.24it/s][A
Running through epoch:   2%|█▍                                                                              | 3/161 [00:00<00:29,  5.31it/s][A
Running through epoch:   2%|█▉                                                                              | 4/161 [00:00<00:29,  5.35it/s][A
Running through epoch:   3%|██▍                                                                             | 5/161 [00:00<00:29,  5.36it/s][A
Running through epoch:   4%|██▉                                                                             | 6/161 [00:01<00:28,  5.39

`validation_dataset` / `training_dataset` store embedding of binder, embedding of target and label (binder - 1, non-binder - 0)