In [None]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()

Tesla V100-SXM2-32GB


In [2]:
SEED = 0
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.5.1
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [5]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1280 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

# batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

In [6]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=embedding_dimension, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [7]:
## Output path
trained_model_dir = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts"

## Embeddings paths
binders_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_embeddings_esm2"
targets_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/targets_embeddings_esm2"

# ## Training variables
runID = uuid.uuid4()

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


### Loading PPint dataframe

In [8]:
PPint_interaactions_df = pd.read_csv("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/PPint_interactions.csv")
PPint_interaactions_df["seq_target_len"] = [len(seq) for seq in PPint_interaactions_df["seq_target"].tolist()]
PPint_interaactions_df["seq_binder_len"] = [len(seq) for seq in PPint_interaactions_df["seq_binder"].tolist()]
# PPint_interaactions_df["index"] = [i for i in PPint_interaactions_df["target_binder_id"].tolist()]
PPint_interaactions_df

Unnamed: 0,seq_target,seq_binder,target_id,binder_id,target_binder_id,seq_target_len,seq_binder_len
0,SLTKTERTIIVSMWAKISTQADTIGTETLERLFLSHPQTKTYFPHF...,VHLTDAEKAAVSGLWGKVNADEVGGEALGRLLVVYPWTQRYFDSFG...,1JEB_2,1JEB_2,1JEB_2_1JEB_2,141,146
1,TTTLAFKFQHGVIAAVDSRASAGSYISALRVNKVIEINPYLLGTMS...,DRGVNTFSPEGRLFQVEYAIEAIKLGSTAIGIQTSEGVCLAVEKRI...,7B12_23,7B12_23,7B12_23_7B12_23,203,230
2,ITHLPPEVMLSIFSYLNPQELCRCSQVSMKWSQLTKTGSLWKHLYP...,PSIKLQSSDGEIFEVDVEIAKQSVTIKTMLEDLGDPVPLPNVNAAI...,6VCD_1,6VCD_1,6VCD_1_6VCD_1,255,135
3,NAKDVLGLTLLEKTLKERLNLKDAIIVSGDSDQSPWVKKEGRAAVA...,AKDVLGLTLLEKTLKERLNLKDAIIVSGDSDQSPWVKKEGRAAVAC...,2OKG_0,2OKG_0,2OKG_0_2OKG_0,241,243
4,DIVMSQSPSSLAVSVGEKVTMSCKSSQSLLYNNNQKNYLAWYQQKP...,VTLKESGPGILQPSQTLSLTCSFSGFSLSTYGMGVGWIRQPSGKGL...,3MBX_0,3MBX_0,3MBX_0_3MBX_0,220,229
...,...,...,...,...,...,...,...
2467,REIPLKVLVKAVLFACMLMRKTMASRVRVTILFATETGKSEALAWD...,QLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAEL...,3HR4_0,3HR4_0,3HR4_0_3HR4_0,189,145
2468,NKYRFIDVQPLTGVLGAEITGVDLREPLDDSTWNEILDAFHTYQVI...,KYRFIDVQPLTGVLGAEITGVDLREPLDDSTWNEILDAFHTYQVIY...,6D3M_0,6D3M_0,6D3M_0_6D3M_0,286,285
2469,TMKIAYLGPSGSFTHNVALHAFPAADLLPFENITEVIKAYESKQVC...,TMKIAYLGPSGSFTHNVALHAFPAADLLPFENITEVIKAYESKQVC...,4LUB_0,4LUB_0,4LUB_0_4LUB_0,188,190
2470,ERDEVGARKNAVDEEIERLSQPGDQRLNALAERFGGVLLSEIYDDV...,EPVTIVLSQGWVRSAKGHDIDAPGLNYKAGDSFKAAVKGKSNQPVV...,4MN4_2,4MN4_2,4MN4_2_4MN4_2,154,236


In [9]:
Df_val = PPint_interaactions_df.sample(n=round(len(PPint_interaactions_df) * 0.2), random_state=0)
Df_train = PPint_interaactions_df.drop(Df_val.index)
Df_val

Unnamed: 0,seq_target,seq_binder,target_id,binder_id,target_binder_id,seq_target_len,seq_binder_len
805,SLRFALTPGEPAGIGPDLCLLLARSAQPHPLIAIASRTLLQERAGQ...,SLRFALTPGEPAGIGPDLCLLLARSAQPHPLIAIASRTLLQERAGQ...,1YXO_0,1YXO_0,1YXO_0_1YXO_0,322,322
1675,SIKIECVLPENCRCGESPVWEEVSNSLLFVDIPAKKVCRWDSFTKQ...,SIKIECVLPENCRCGESPVWEEVSNSLLFVDIPAKKVCRWDSFTKQ...,3G4H_0,3G4H_0,3G4H_0_3G4H_0,297,297
1735,SLEESGGDLVKPGASLTLTCTASGFSFGWNDYMSWVRQAPGKGLEW...,IKMTQTPSSVSAAVGGTVTVNCRASEDIESYLAWYQQKPGQPPKLL...,6PEH_0,6PEH_0,6PEH_0_6PEH_0,219,213
1597,ALLQKTRIINSLQAAAGKPVNFKEAETLRDVIDSNIFVVSRRGKLL...,ALLQKTRIINSLQAAAGKPVNFKEAETLRDVIDSNIFVVSRRGKLL...,5LNH_0,5LNH_0,5LNH_0_5LNH_0,248,248
1188,NSVSVDLPGSMKVLVSKSSNADGKYDLIATVDALELSGTSDKNNGS...,KNSVSVDLPGSMKVLVSKSSNADGKYDLIATVDALELSGTSDKNNG...,3CKA_0,3CKA_0,3CKA_0_3CKA_0,315,316
...,...,...,...,...,...,...,...
2368,TANREAIDMARVAAGAAAAKLADDVVVIDVSGQLVITDCFVIASGS...,TANREAIDMARVAAGAAAAKLADDVVVIDVSGQLVITDCFVIASGS...,4WCW_0,4WCW_0,4WCW_0_4WCW_0,116,112
634,GNRGVVYLGPGKVEVQNIPYPKMQDPQGRQIDHGVILRVVSTNICG...,GNRGVVYLGPGKVEVQNIPYPKMQDPQGRQIDHGVILRVVSTNICG...,4JLW_2,4JLW_2,4JLW_2_4JLW_2,395,395
1585,SKPMEVYVSAVASPTKFWVQLIGPQSKKLASMVQEMTSYYSSAENR...,DQGRGRRPLN,5YGD_0,5YGD_0,5YGD_0_5YGD_0,210,10
926,QVQLVQSGAEVKKPGSSLKVSCKVSGGNLRSYGISWVRQAPGQGLE...,QSALTQPRSVSGSPGQSVTISCTGTSNDVGYYDHVSWYQQHPGKAP...,6P9I_0,6P9I_0,6P9I_0_6P9I_0,226,212


In [10]:
Df_val.loc[1735].target_binder_id

'6PEH_0_6PEH_0'

In [11]:
Df_val["dimer"] = Df_val["seq_target"] == Df_val["seq_binder"]
indices_non_dimers_val = Df_val[~Df_val["dimer"]].index.tolist()
indices_non_dimers_val[:5]

[1735, 1188, 789, 1261, 259]

In [12]:
### checking shape (size) of the encodings for later
fname = os.listdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_embeddings_esm2")[153]
np.load(os.path.join("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_embeddings_esm2", fname)).shape[1]

1280

In [13]:
class CLIP_PPint_analysis_dataset(Dataset):
    def __init__(self, dframe, tpath, bpath, embedding_dim=1280, padding_value=-5000.0):
        super().__init__()

        self.dframe = dframe.copy()
        self.max_tlen = int(self.dframe["seq_target_len"].max())
        self.max_blen = int(self.dframe["seq_binder_len"].max())
        self.encoding_tpath = tpath
        self.encoding_bpath = bpath
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)} # 1YXO_0_1YXO_0 : 1, ...
        self.samples = []

        iterator = tqdm(self.accessions, position=0, total=len(self.accessions), desc="#Loading ESM2 embeddings")

        for accession in iterator:
            parts = accession.split("_")
            
            if len(parts) < 4:
                raise ValueError(
                    f"Expected target_binder_id to have at least 4 underscore-separated parts, got {accession}")

            target_id = parts[0] + "_" + parts[1]
            binder_id = parts[2] + "_" + parts[3]

            tname = f"t_{target_id}"
            bname = f"b_{binder_id}"

            tnpy_path = os.path.join(self.encoding_tpath, f"{tname}.npy")
            bnpy_path = os.path.join(self.encoding_bpath, f"{bname}.npy")

            if not os.path.exists(tnpy_path):
                raise FileNotFoundError(f"Missing target embedding file: {tnpy_path}")
            if not os.path.exists(bnpy_path):
                raise FileNotFoundError(f"Missing binder embedding file: {bnpy_path}")

            tembd = np.load(tnpy_path)
            if tembd.shape[0] < self.max_tlen:
                t_pad_len = self.max_tlen - tembd.shape[0]
                t_pad = np.full((t_pad_len, tembd.shape[1]), padding_value, dtype=tembd.dtype)
                t_final = np.concatenate([tembd, t_pad], axis=0)
            else:
                t_final = tembd[: self.max_tlen]

            bembd = np.load(bnpy_path)
            if bembd.shape[0] < self.max_blen:
                b_pad_len = self.max_blen - bembd.shape[0]
                b_pad = np.full((b_pad_len, bembd.shape[1]), padding_value, dtype=bembd.dtype)
                b_final = np.concatenate([bembd, b_pad], axis=0)
            else:
                b_final = bembd[: self.max_blen]

            # label = 1
            # self.samples.append((t_final, b_final, label))
            self.samples.append((t_final, b_final))


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        t_arr, b_arr = self.samples[idx]
        # t_arr, b_arr, y = self.samples[idx]
        t_tensor = torch.tensor(t_arr, dtype=torch.float32)
        b_tensor = torch.tensor(b_arr, dtype=torch.float32)
        # y_tensor = torch.tensor(y, dtype=torch.float32)
        return  b_tensor, t_tensor #, y_tensor

    def _get_by_name(self, name):
        
        # only one index parsed
        if isinstance(name, str):
            idx = self.name_to_row[name]
            binder_tensor, target_tensor = self.__getitem__(idx)
            
            return binder_tensor, target_tensor

        # parsing list of indeces
        binder_list = []
        target_list = []
        for n in name:
            idx = self.name_to_row[n]
            b_tensor, t_tensor = self.__getitem__(idx)
            binder_list.append(b_tensor)
            target_list.append(t_tensor)

        # stack along batch dim = 0
        binder_batch = torch.stack(binder_list, dim=0)  # [B, max_blen, embedding_dim]
        target_batch = torch.stack(target_list, dim=0)  # [B, max_tlen, embedding_dim]

        return binder_batch, target_batch

targets_path ="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/targets_embeddings_esm2"
binders_path ="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/PPint_DB/binders_embeddings_esm2"

training_Dataset = CLIP_PPint_analysis_dataset(Df_train, tpath=targets_path, bpath=binders_path, embedding_dim=1280)
validation_Dataset = CLIP_PPint_analysis_dataset(Df_val, tpath=targets_path, bpath=binders_path, embedding_dim=1280)

#Loading ESM2 embeddings: 100%|████████████████████████████████████████████████████████████████████████| 1978/1978 [00:09<00:00, 200.42it/s]
#Loading ESM2 embeddings: 100%|██████████████████████████████████████████████████████████████████████████| 494/494 [00:02<00:00, 192.89it/s]


In [14]:
accessions = [Df_val.loc[index].target_binder_id for index in indices_non_dimers_val]
validation_Dataset._get_by_name(accessions[:5])

(tensor([[[ 8.9161e-02,  1.1563e-02,  7.2288e-03,  ..., -2.2233e-01,
            2.1028e-01,  4.1612e-03],
          [ 1.1533e-01, -9.3344e-02, -8.4391e-02,  ...,  2.6825e-01,
            4.0479e-01, -3.6455e-01],
          [ 1.4608e-01, -1.0084e-01, -1.7621e-01,  ...,  1.7608e-01,
           -5.9677e-02, -7.8247e-02],
          ...,
          [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
           -5.0000e+03, -5.0000e+03],
          [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
           -5.0000e+03, -5.0000e+03],
          [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
           -5.0000e+03, -5.0000e+03]],
 
         [[ 2.2493e-02,  3.1844e-03, -3.4938e-02,  ..., -2.1519e-01,
            1.8567e-01, -6.6455e-02],
          [-4.8609e-02, -1.8671e-02, -2.4903e-01,  ...,  6.6610e-02,
           -1.0414e-01, -6.7875e-02],
          [ 1.6492e-01, -1.3839e-02,  7.5602e-02,  ...,  5.7252e-03,
            3.0333e-02, -3.5512e-02],
          ...,
    

### Train model from scratch with 10% of PPint dataset using old architecture (encodings only)

In [15]:
model = MiniCLIP_w_transformer_crossattn(embed_dimension=embedding_dimension, num_recycles=number_of_recycles)
model.to("cuda")

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
    )
    (linear1): Linear(in_features=1280, out_features=1280, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1280, out_features=1280, bias=True)
    (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1280, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

### Trianing loop

In [16]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 training_loader, 
                 validation_loader,
                 test_df,
                 test_dataset, 
                 optimizer, 
                 EPOCHS, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.EPOCHS = EPOCHS
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps # if truthy (e.g., 1, 5), save a checkpoint every N epochs.
        self.verbose = v
        self.best_vloss = 1_000_000
        self.optimizer = optimizer
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.device = device
        self.test_indexes_for_auROC = test_indexes_for_auROC
        self.auROC_batch_size = auROC_batch_size # which observation indices to use when computing auROC/auPR periodically.
        self.test_dataset = test_dataset
        self.test_df = test_df

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0 

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad() 
            loss = self.model.training_step(batch, self.device) 
            loss.backward() 
            self.optimizer.step() 

            torch.cuda.empty_cache()
            running_loss += loss.item()
            
        return running_loss / len(self.training_loader)
    
    def calc_auroc_aupr_on_indexes(self, model, dataset, test_df, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [test_df.loc[index].target_binder_id for index in nondimer_indexes]
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                # Loading the data into the GPU based on the index batch
                binder_emb, target_emb = dataset._get_by_name(index_batch)
                binder_emb, target_emb = binder_emb.to(self.device), target_emb.to(self.device)

                logit_matrix = self.model.calculate_logit_matrix(binder_emb, target_emb)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        # Calculate scores and labels
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        # Calculate ROC curve metrics
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores


    def validate (self, dataloader, indexes_for_auc=False, auROC_dataset=False, auROC_df = False):
        self.model.eval()
        running_loss = 0.0
        running_accuracy = 0.0
        running_topk_accuracy = 0.0

        with torch.no_grad():
            for batch in tqdm(dataloader, total=len(dataloader), desc="First Validation run"):
                if batch[0].size(0) == 1: # We can't make negatives on a batch of 1
                    continue
                loss, partner_accuracy, peptide_topk_accuracy = self.model.validation_step(batch, self.device)
                running_loss += loss.item()
                running_accuracy += partner_accuracy.item()
                running_topk_accuracy += peptide_topk_accuracy.item()
                
            val_loss = running_loss / len(dataloader)
            val_accuracy = running_accuracy / len(dataloader)
            val_topk_accuracy = running_topk_accuracy / len(dataloader)

            if indexes_for_auc:
                # Calculating auc-scores
                non_dimer_auc, non_dimer_aupr, ___, ___  = self.calc_auroc_aupr_on_indexes(model=self.model, 
                                                                                           dataset=auROC_dataset,
                                                                                           test_df=auROC_df,
                                                                                           nondimer_indexes=indexes_for_auc,
                                                                                           batch_size=self.auROC_batch_size)
                
                return val_loss, val_accuracy, val_topk_accuracy, non_dimer_auc, non_dimer_aupr
            
            else:
                return val_loss, val_accuracy, val_topk_accuracy

    def train_model(self):

        if self.verbose:
            print(f"Training model {str(self.runID)}")
        
        # --- initial validation before training
        if self.test_indexes_for_auROC:
            val_loss_before, val_accuracy_before, val_topk_accuracy_before, val_nondimer_auc_before, val_nondimer_aupr_before = self.validate(
                dataloader=self.validation_loader,
                indexes_for_auc=self.test_indexes_for_auROC,
                auROC_dataset=self.test_dataset,
                auROC_df=self.test_df
            )
        else:
            val_loss_before, val_accuracy_before, val_topk_accuracy_before = self.validate(
                dataloader=self.validation_loader,
                indexes_for_auc=self.test_indexes_for_auROC,
                auROC_dataset=self.test_dataset
            )
            val_nondimer_auc_before = None
            val_nondimer_aupr_before = None
        
        if self.verbose: 
            if val_nondimer_auc_before is not None:
                print(f'Before training - Val CLIP-loss {round(val_loss_before,4)}',
                      f'Accuracy: {round(val_accuracy_before,4)}',
                      f'Top-K accuracy : {round(val_topk_accuracy_before,4)}',
                      f'auc: {round(val_nondimer_auc_before,3)}')
            else:
                print(f'Before training - Val CLIP-loss {round(val_loss_before,4)}',
                      f'Accuracy: {round(val_accuracy_before,4)}',
                      f'Top-K accuracy : {round(val_topk_accuracy_before,4)}')

        if self.wandb_tracker:
            metrics_to_log = {
                "Val-loss": val_loss_before,
                "Val-acc": val_accuracy_before,
                "Val-TOPK-acc": val_topk_accuracy_before,
            }
            if val_nondimer_auc_before is not None:
                metrics_to_log["Val non-dimer auc"] = val_nondimer_auc_before
                metrics_to_log["Val non-dimer auPR"] = val_nondimer_aupr_before
            self.wandb_tracker.log(metrics_to_log)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC:
                val_loss, val_accuracy, val_topk_accuracy, non_dimer_auc, non_dimer_aupr = self.validate(
                    dataloader=self.validation_loader,
                    indexes_for_auc=self.test_indexes_for_auROC,
                    auROC_dataset=self.test_dataset,
                    auROC_df=self.test_df
                )
            else:
                val_loss, val_accuracy, val_topk_accuracy = self.validate(
                    dataloader=self.validation_loader,
                    indexes_for_auc=self.test_indexes_for_auROC,
                    auROC_dataset=self.test_dataset,
                    auROC_df=self.test_df
                )
                non_dimer_auc = None
                non_dimer_aupr = None
            
            # checkpoint save
            if self.model_save_steps:
                if epoch % self.model_save_steps == 0:
                    check_point_folder = os.path.join(
                        self.trained_model_dir,
                        f"{str(self.runID)}_checkpoint_{str(epoch)}"
                    )
                    if self.verbose:
                        print("Saving model to:", check_point_folder)
                    
                    if not os.path.exists(check_point_folder):
                        os.makedirs(check_point_folder)

                    checkpoint_path = os.path.join(
                        check_point_folder,
                        f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth"
                    )
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(),
                            'val_loss': val_loss
                        }, 
                        checkpoint_path
                    )
            
            # console logging
            if self.verbose:
                if epoch % self.print_frequency_loss == 0:
                    if non_dimer_auc is not None:
                        print(f'EPOCH {epoch} -  Val loss {round(val_loss,4)}',
                              f'Accuracy: {round(val_accuracy,4)}',
                              f'Top-K accuracy: {round(val_topk_accuracy,4)}',
                              f'Val-Auc:{round(non_dimer_auc,3)}', 
                              f'Val-auPR:{round(non_dimer_aupr,3)}')
                    else:
                        print(f'EPOCH {epoch} -  Val loss {round(val_loss,4)}',
                              f'Accuracy: {round(val_accuracy,4)}',
                              f'Top-K accuracy: {round(val_topk_accuracy,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "Epoch": epoch,
                    "Train-loss": train_loss,
                    "Val-loss": val_loss,
                    "Val-acc": val_accuracy,
                    "Val-TOPK-acc": val_topk_accuracy,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch["Val non-dimer auc"] = non_dimer_auc
                    metrics_to_log_epoch["Val non-dimer auPR"] = non_dimer_aupr
                self.wandb_tracker.log(metrics_to_log_epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [17]:
batch_size = 10
learning_rate = 2e-5
EPOCHS = 15
g = torch.Generator().manual_seed(SEED)

# login once (env var preferred)
if use_wandb:
    import wandb
    wandb.login()

optimizer = AdamW(model.parameters(), lr=learning_rate)
train_dataloader = DataLoader(training_Dataset, batch_size=10)
val_dataloader = DataLoader(validation_Dataset, batch_size=20, shuffle=False, drop_last = False)

# accelerator
accelerator = Accelerator()
device = accelerator.device
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, val_dataloader)

# wandb
if use_wandb:
    run = wandb.init(
        project="PPint_retrain_w_10percent_ofdata",
        name=f"PPint_retrain",
        config={"learning_rate": learning_rate, "batch_size": batch_size, "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", "dataset": "Meta analysis"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(model=model, 
                                training_loader = train_dataloader, 
                                validation_loader = val_dataloader, 
                                test_dataset = validation_Dataset,
                                test_df = Df_val,
                                optimizer = optimizer, 
                                EPOCHS = EPOCHS,
                                runID = runID, 
                                device = device, 
                                test_indexes_for_auROC = indices_non_dimers_val, 
                                model_save_steps = model_save_steps,
                                model_save_path = trained_model_dir, 
                                v = True, 
                                wandb_tracker = wandb
                                )

training_wrapper.train_model()

Training model 27e37d7a-3fa7-4627-bb90-a5629dac4806


First Validation run: 100%|█████████████████████████████████████████████████████████████████████████████████| 25/25 [00:38<00:00,  1.55s/it]
Calculating AUC: 38it [00:15,  2.47it/s]                                                                                                    


Before training - Val CLIP-loss 6.2014 Accuracy: 0.7817 Top-K accuracy : 0.8274 auc: 0.847


Epochs:   0%|                                                                                                        | 0/15 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:51,  1.48s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:39,  1.42s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:31,  1.39s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:27,  1.38s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:25,  1.37s/it

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_1


Epochs:   7%|██████▏                                                                                      | 1/15 [05:27<1:16:21, 327.24s/it]

EPOCH 1 -  Val loss 0.2862 Accuracy: 0.7954 Top-K accuracy: 0.8671 Val-Auc:0.946 Val-auPR:0.874



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:31,  1.38s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:31,  1.39s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:28,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:26,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_2


Epochs:  13%|████████████▍                                                                                | 2/15 [10:55<1:10:59, 327.63s/it]

EPOCH 2 -  Val loss 0.2392 Accuracy: 0.7966 Top-K accuracy: 0.8831 Val-Auc:0.953 Val-auPR:0.888



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:29,  1.37s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:27,  1.36s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:26,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:27,  1.38s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:25,  1.38s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:23,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_3


Epochs:  20%|██████████████████▌                                                                          | 3/15 [16:22<1:05:31, 327.62s/it]

EPOCH 3 -  Val loss 0.2294 Accuracy: 0.8174 Top-K accuracy: 0.8971 Val-Auc:0.957 Val-auPR:0.894



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:31,  1.38s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:03<05:01,  1.54s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:44,  1.46s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:36,  1.43s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:07<04:31,  1.41s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:29,  1.40

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_4


Epochs:  27%|████████████████████████▊                                                                    | 4/15 [21:50<1:00:04, 327.73s/it]

EPOCH 4 -  Val loss 0.224 Accuracy: 0.8194 Top-K accuracy: 0.8991 Val-Auc:0.96 Val-auPR:0.901



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:30,  1.37s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:29,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:26,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:25,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_5


Epochs:  33%|███████████████████████████████▋                                                               | 5/15 [27:17<54:34, 327.47s/it]

EPOCH 5 -  Val loss 0.2314 Accuracy: 0.8166 Top-K accuracy: 0.8971 Val-Auc:0.96 Val-auPR:0.903



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:29,  1.37s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:28,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:26,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:25,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_6


Epochs:  40%|██████████████████████████████████████                                                         | 6/15 [32:45<49:08, 327.57s/it]

EPOCH 6 -  Val loss 0.2617 Accuracy: 0.8114 Top-K accuracy: 0.8971 Val-Auc:0.956 Val-auPR:0.895



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:28,  1.36s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:28,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:27,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:26,  1.38s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:25,  1.38s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:23,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_7


Epochs:  47%|████████████████████████████████████████████▎                                                  | 7/15 [38:12<43:39, 327.49s/it]

EPOCH 7 -  Val loss 0.313 Accuracy: 0.8166 Top-K accuracy: 0.9011 Val-Auc:0.957 Val-auPR:0.905



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:28,  1.36s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:26,  1.36s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:25,  1.36s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:24,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:41,  1.46

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_8


Epochs:  53%|██████████████████████████████████████████████████▋                                            | 8/15 [43:40<38:12, 327.46s/it]

EPOCH 8 -  Val loss 0.2799 Accuracy: 0.8194 Top-K accuracy: 0.9103 Val-Auc:0.96 Val-auPR:0.907



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:30,  1.37s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:28,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:27,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:25,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_9


Epochs:  60%|█████████████████████████████████████████████████████████                                      | 9/15 [49:07<32:44, 327.44s/it]

EPOCH 9 -  Val loss 0.2648 Accuracy: 0.8286 Top-K accuracy: 0.9071 Val-Auc:0.962 Val-auPR:0.912



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:36,  1.40s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:32,  1.39s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:32,  1.40s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:30,  1.39s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:27,  1.39s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:24,  1.38

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_10


Epochs:  67%|██████████████████████████████████████████████████████████████▋                               | 10/15 [54:34<27:16, 327.34s/it]

EPOCH 10 -  Val loss 0.3137 Accuracy: 0.8323 Top-K accuracy: 0.906 Val-Auc:0.959 Val-auPR:0.91



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:33,  1.39s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:33,  1.39s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:30,  1.39s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:27,  1.38s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:26,  1.38s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:23,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_11


Epochs:  73%|███████████████████████████████████████████████████████████████████▍                        | 11/15 [1:00:01<21:49, 327.34s/it]

EPOCH 11 -  Val loss 0.3158 Accuracy: 0.8363 Top-K accuracy: 0.9011 Val-Auc:0.958 Val-auPR:0.909



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:31,  1.38s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:28,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:27,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:25,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:25,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_12


Epochs:  80%|█████████████████████████████████████████████████████████████████████████▌                  | 12/15 [1:05:29<16:22, 327.44s/it]

EPOCH 12 -  Val loss 0.3575 Accuracy: 0.8254 Top-K accuracy: 0.9 Val-Auc:0.958 Val-auPR:0.904



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:29,  1.37s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:27,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:25,  1.36s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:25,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_13


Epochs:  87%|███████████████████████████████████████████████████████████████████████████████▋            | 13/15 [1:10:56<10:54, 327.33s/it]

EPOCH 13 -  Val loss 0.3325 Accuracy: 0.8403 Top-K accuracy: 0.908 Val-Auc:0.958 Val-auPR:0.911



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:32,  1.38s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:30,  1.38s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:27,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:26,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:24,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_14


Epochs:  93%|█████████████████████████████████████████████████████████████████████████████████████▊      | 14/15 [1:16:24<05:27, 327.36s/it]

EPOCH 14 -  Val loss 0.3343 Accuracy: 0.8494 Top-K accuracy: 0.9051 Val-Auc:0.953 Val-auPR:0.903



Running through epoch:   0%|                                                                                        | 0/198 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                               | 1/198 [00:01<04:30,  1.37s/it][A
Running through epoch:   1%|▊                                                                               | 2/198 [00:02<04:29,  1.37s/it][A
Running through epoch:   2%|█▏                                                                              | 3/198 [00:04<04:27,  1.37s/it][A
Running through epoch:   2%|█▌                                                                              | 4/198 [00:05<04:26,  1.37s/it][A
Running through epoch:   3%|██                                                                              | 5/198 [00:06<04:25,  1.37s/it][A
Running through epoch:   3%|██▍                                                                             | 6/198 [00:08<04:22,  1.37

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/27e37d7a-3fa7-4627-bb90-a5629dac4806_checkpoint_15


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [1:21:51<00:00, 327.43s/it]

EPOCH 15 -  Val loss 0.3127 Accuracy: 0.8343 Top-K accuracy: 0.908 Val-Auc:0.959 Val-auPR:0.91





0,1
Epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
Train-loss,█▄▃▃▂▂▂▁▁▁▁▁▁▁▁
Val non-dimer auPR,▁▆▇▇▇▇▇█████▇█▇█
Val non-dimer auc,▁▇▇███████████▇█
Val-TOPK-acc,▁▄▆▇▇▇▇▇███▇▇███
Val-acc,▁▂▃▅▅▅▄▅▅▆▆▇▆▇█▆
Val-loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Epoch,15.0
Train-loss,0.01892
Val non-dimer auPR,0.90974
Val non-dimer auc,0.95873
Val-TOPK-acc,0.908
Val-acc,0.83429
Val-loss,0.31272
