In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 3 on the node
print(torch.cuda.get_device_name(0))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, TensorDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator

import matplotlib.pyplot as plt
import seaborn as sns

from Levenshtein import distance as Ldistance

import training_utils.dataset_utils as data_utils
import training_utils.partitioning_utils as pat_utils

import importlib
import training_utils.train_utils as train_utils
importlib.reload(train_utils)

Tesla V100-SXM2-32GB


<module 'training_utils.train_utils' from '/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/training_utils/train_utils.py'>

In [2]:
labels = torch.tensor([0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=torch.float32)
pos = (labels == 1).sum()
neg = (labels == 0).sum()
pos_weight = torch.tensor([neg / max(1, pos)])
pos_weight

tensor([5.4000])

In [3]:
SEED = 0
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [4]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.5.1
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [6]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1152 # 1280 | 960 | 1152
number_of_recycles = 2
padding_value = -5000

# batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

In [7]:
## Model Class
### MiniClip 
def gaussian_kernel(x, sigma):
    return np.exp(-x**2 / (2 * sigma**2))

def transform_vector(vector, sigma):

    interacting_indices = np.where(vector == 1)[0]   # positions where vector == 1
    transformed_vector = np.zeros_like(vector, dtype=float)
    
    for i in range(len(vector)):
        if vector[i] == 0:
            distances = np.abs(interacting_indices - i)   # distance to all "1"s
            min_distance = np.min(distances)              # closest "1"
            transformed_vector[i] = gaussian_kernel(min_distance, sigma)
        else:
            transformed_vector[i] = 1.0
    return transformed_vector

def safe_shuffle(n, device):
    shuffled = torch.randperm(n, device=device)
    while torch.any(shuffled == torch.arange(n, device=device)):
        shuffled = torch.randperm(n, device=device)
    return shuffled

def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=1152, num_recycles=1):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1) # Dot-Product for comparison
        
        return logits
    

    def training_step(self, batch, device):
        # Passing the sequences to the models
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        binder_label = batch[2]
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        binder_label = binder_label.to(device)

        logits = self.forward(embedding_pep, embedding_prot)
        binder_labels = binder_label.view_as(logits)  
        loss = F.binary_cross_entropy_with_logits(logits, binder_labels)
 
        torch.cuda.empty_cache()
        
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, binder_label = batch
 
        # Move to device
        embedding_pep  = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        binder_label = binder_label.to(device).float() 

        with torch.no_grad():
            logits = self.forward(embedding_pep, embedding_prot)   # shape [B]
            binder_labels = binder_label.view_as(logits)  
            loss = F.binary_cross_entropy_with_logits(logits, binder_labels)
            
        return float(loss.item()), logits, binder_labels

In [8]:
## Output path
trained_model_dir = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts"

## Embeddings paths
binders_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings"
targets_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings"

# ## Training variables
runID = uuid.uuid4()

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB

#### Loading data frame

In [9]:
### Loading the dataset
interaction_df = pd.read_csv("../data/meta_analysis/interaction_df_metaanal.csv", index_col = 0).drop(columns = ["binder_id", "target_id"]).rename(columns={
    "A_seq" : "binder_seq",
    "B_seq" : "target_seq"
})

all_targets = interaction_df.target_id_mod.unique()
binder_nonbinder = interaction_df.binder.value_counts()
target_binder_nonbinder_Dict = dict(interaction_df.groupby("target_id_mod")["binder"].value_counts())
sorted_items = sorted(target_binder_nonbinder_Dict.items(), key=lambda kv: kv[1], reverse=True)

# %%
# Annotating each observation with a weight corresponding to whether it is considered a binder or not

### Weights for binder/non-binders
N_bins = len(interaction_df["binder"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df["binder"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df["class_weight"] = interaction_df.binder.apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])
# binder_nonbinder_weights_Dict = dict(zip(interaction_df["target_binder_ID"], interaction_df["class_weight"]))

### Weights for target
N_bins = len(interaction_df["target_id_mod"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df["target_id_mod"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df["target_weight"] = interaction_df.target_id_mod.apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])

### Combined weights
interaction_df["combined_weight"] = interaction_df["class_weight"]*interaction_df["target_weight"] 

In [10]:
interaction_df

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id_mod,target_binder_ID,class_weight,target_weight,combined_weight
0,A,"[""B""]",False,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,VirB8,VirB8_1,0.000159,0.000631,1.004956e-07
1,A,"[""B""]",False,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1,0.000159,0.000029,4.686322e-09
2,A,"[""B""]",False,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2,0.000159,0.000029,4.686322e-09
3,A,"[""B""]",False,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_3,0.000159,0.000029,4.686322e-09
4,A,"[""B""]",False,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_4,0.000159,0.000029,4.686322e-09
...,...,...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_62,0.000159,0.000947,1.507433e-07
3528,A,"[""B""]",False,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_63,0.000159,0.000947,1.507433e-07
3529,A,"[""B""]",False,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_64,0.000159,0.000947,1.507433e-07
3530,A,"[""B""]",False,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_65,0.000159,0.000947,1.507433e-07


# 13(leave-1-target-cluster-out)-fold CV training + weighting of samples

In [11]:
clusters = [
    ["VirB8"], ["FGFR2"], ["IL7Ra"], ["InsulinR"],
    ["EGFR", "EGFR_2", "EGFR_3"],     # keep together
    ["SARS_CoV2_RBD"], ["Pdl1"], ["TrkA"], ["IL10Ra"],
    ["LTK"], ["Mdm2"],
    ["sntx", "sntx_2"],               # keep together
    ["IL2Ra"],
]

random.Random(0).shuffle(clusters)
folds = np.array_split(np.array(clusters, dtype=object), len(clusters))   # list of np arrays
targets_folds = []
for f in folds:
    flat = []
    for group in f:
        flat.extend(group)
    targets_folds.append(flat)

def build_cv_splits(targets_folds):
    val_folds, train_folds = [], []
    K = len(targets_folds)
    for i in range(K):
        val_targets = list(targets_folds[i])  # copy
        train_targets = [t for j, fold in enumerate(targets_folds) if j != i for t in fold]
        val_folds.append(val_targets)
        train_folds.append(train_targets)
    return val_folds, train_folds

val_folds, train_folds = build_cv_splits(targets_folds)
cv_splits = list(zip(val_folds, train_folds))

In [12]:
for idx in range(len(cv_splits)):
    val_targets = cv_splits[idx][0]
    vals = len(interaction_df[interaction_df.target_id_mod.isin(val_targets)])
    trains = len(interaction_df) - vals
    print(f"Fold {idx+1} : training instances : {trains}, validation instances : {vals}")

Fold 1 : training instances : 1409, validation instances : 2123
Fold 2 : training instances : 3436, validation instances : 96
Fold 3 : training instances : 3499, validation instances : 33
Fold 4 : training instances : 3433, validation instances : 99
Fold 5 : training instances : 3483, validation instances : 49
Fold 6 : training instances : 3361, validation instances : 171
Fold 7 : training instances : 3415, validation instances : 117
Fold 8 : training instances : 3404, validation instances : 128
Fold 9 : training instances : 3510, validation instances : 22
Fold 10 : training instances : 3098, validation instances : 434
Fold 11 : training instances : 3433, validation instances : 99
Fold 12 : training instances : 3466, validation instances : 66
Fold 13 : training instances : 3437, validation instances : 95


In [13]:
for idx in range(len(cv_splits)):
    val_targets = cv_splits[idx][0]
    train_targets = cv_splits[idx][1]
    print(f"Fold {idx+1} : validation targets : {val_targets}")

Fold 1 : validation targets : ['FGFR2']
Fold 2 : validation targets : ['Mdm2']
Fold 3 : validation targets : ['LTK']
Fold 4 : validation targets : ['SARS_CoV2_RBD']
Fold 5 : validation targets : ['sntx', 'sntx_2']
Fold 6 : validation targets : ['IL7Ra']
Fold 7 : validation targets : ['InsulinR']
Fold 8 : validation targets : ['TrkA']
Fold 9 : validation targets : ['IL10Ra']
Fold 10 : validation targets : ['EGFR', 'EGFR_2', 'EGFR_3']
Fold 11 : validation targets : ['VirB8']
Fold 12 : validation targets : ['IL2Ra']
Fold 13 : validation targets : ['Pdl1']


#### Creating separate targets/ binder dataframes

In [14]:
# Targets df
target_df = interaction_df[["target_id_mod","target_seq"]].rename(columns={"target_seq":"sequence", "target_id_mod" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["target_binder_ID","binder_seq", "binder", "class_weight", "target_weight", "combined_weight"]].rename(columns={"binder_seq":"sequence", "target_binder_ID" : "ID", "binder" : "label"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))

In [15]:
binder_df

Unnamed: 0_level_0,sequence,label,class_weight,target_weight,combined_weight,seq_len
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
VirB8_1,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,False,0.000159,0.000631,1.004956e-07,40
FGFR2_1,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,False,0.000159,0.000029,4.686322e-09,62
FGFR2_2,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,False,0.000159,0.000029,4.686322e-09,61
FGFR2_3,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,False,0.000159,0.000029,4.686322e-09,64
FGFR2_4,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,False,0.000159,0.000029,4.686322e-09,64
...,...,...,...,...,...,...
IL2Ra_62,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,False,0.000159,0.000947,1.507433e-07,55
IL2Ra_63,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,False,0.000159,0.000947,1.507433e-07,56
IL2Ra_64,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,False,0.000159,0.000947,1.507433e-07,56
IL2Ra_65,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,False,0.000159,0.000947,1.507433e-07,57


In [16]:
target_df

Unnamed: 0_level_0,sequence,seq_len
ID,Unnamed: 1_level_1,Unnamed: 2_level_1
VirB8,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,138
FGFR2,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,101
IL7Ra,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,193
InsulinR,EVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSF...,150
EGFR,RKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGD...,191
SARS_CoV2_RBD,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,195
Pdl1,NAFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDK...,115
EGFR_2,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,621
TrkA,VSFPASVQLHTAVEMHHWCIPFSVDGQPAPSLRWLFNGSVLNETSF...,101
IL10Ra,GTELPSPPSVWFEAEFFHHILHWTPIPQQSESTCYEVALLRYGIES...,207


#### Creating separate targets/ binder dataframes (for validation/ training)

In [17]:
class CLIP_meta_analysis_dataset(Dataset):

    def __init__(self, sequence_df, esm_encoding_paths, embedding_dim=1152, padding_value=-5000):

        super(CLIP_meta_analysis_dataset, self).__init__()

        self.sequence_df = sequence_df # target/binder_df
        self.max_length = sequence_df["seq_len"].max()
        self.sequence_df["index_num"] = np.arange(len(self.sequence_df))
        # print(self.sequence_df)
        self.esm_encoding_paths = esm_encoding_paths
        num_samples = len(self.sequence_df)
        
        self.x = torch.full((num_samples, self.max_length, embedding_dim), padding_value, dtype=torch.float32)

        self.accessions = self.sequence_df.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        
        # Load embeddings into the pre-allocated tensor
        all_acc_list = self.accessions
        all_acc_loaded_list = []

        iterator = tqdm(all_acc_list, position=0, leave=True, total=num_samples, desc="# Reading in ESM-embeddings from folder")
        for i, accession in enumerate(iterator):
            npy_path = os.path.join(esm_encoding_paths, f"{accession}.npy")
            try:
                embd = np.load(npy_path)[0]
                length_to_pad = self.max_length - len(embd)
                if length_to_pad > 0:
                    zero_padding = np.ones((length_to_pad, embd.shape[1])) * padding_value
                    padded_array = np.concatenate((embd, zero_padding), axis=0)
                else:
                    padded_array = embd[:self.max_length] 
                self.x[i] = torch.tensor(padded_array, dtype=torch.float32)
                all_acc_loaded_list.append(accession)
            except FileNotFoundError as e:
                raise FileNotFoundError(f"Embedding file {accession}.npy not found.")
            
        missing = sorted(set(all_acc_list) - set(all_acc_loaded_list))
        if missing:
            raise FileNotFoundError(
                f"Missing {len(missing)} embedding files in '{esm_encoding_paths}'. "
                f"Examples: {missing}")
          
    def __len__(self):
        return int(self.x.shape[0])

    def __getitem__(self, idx):
        return self.x[idx]
    
    # add a helper:
    def get_by_name(self, name: str):
        return self.x[self.name_to_row[name]]

targets_dataset = CLIP_meta_analysis_dataset(target_df, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings", embedding_dim=1152)
binders_dataset = CLIP_meta_analysis_dataset(binder_df, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings", embedding_dim=1152)

# Reading in ESM-embeddings from folder: 100%|█████████████████████████████████████████| 16/16 [00:00<00:00, 172.58it/s]
# Reading in ESM-embeddings from folder: 100%|█████████████████████████████████████| 3532/3532 [00:06<00:00, 561.56it/s]


In [18]:
binders_dataset[12]

tensor([[-1.3036e-02, -2.4448e-03, -6.8176e-03,  ...,  5.2219e-03,
          1.3421e-02,  1.3365e-02],
        [ 3.5381e-03, -1.6180e-04, -2.7373e-02,  ...,  3.8627e-02,
          1.8391e-02,  2.9930e-02],
        [-5.4220e-03, -4.6685e-03, -5.6514e-02,  ...,  2.7079e-02,
          3.4216e-02,  9.9046e-03],
        ...,
        [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
         -5.0000e+03, -5.0000e+03],
        [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
         -5.0000e+03, -5.0000e+03],
        [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
         -5.0000e+03, -5.0000e+03]])

In [19]:
torch.equal(targets_dataset[0], targets_dataset.get_by_name("VirB8"))

True

In [20]:
torch.equal(binders_dataset[0], binders_dataset.get_by_name("VirB8_1"))

True

In [21]:
def binder_to_target_name(bname: str) -> str:
    parts = bname.split("_")
    if bname.startswith("SARS"):
        return "SARS_CoV2_RBD"
    if len(parts) == 3:
        return f"{parts[0]}_{parts[1]}"
    return parts[0]

def binder_target_label(targets_dataset, binders_dataset, binder_ids, interaction_df, stack=True):
    
    listof_bindertargetlabel = []
    
    for bname in binder_ids:
        tname = binder_to_target_name(bname)

        # get embeddings by name
        b_emb = binders_dataset.get_by_name(bname)
        t_emb = targets_dataset.get_by_name(tname)

        # get label from df
        s = interaction_df.loc[interaction_df['target_binder_ID'] == bname, 'binder']
        # if s.empty:
        #     raise ValueError(f"No label found in interaction_df for binder id '{bname}'")
        lbl = torch.tensor(float(s.iat[0]), dtype=torch.float32)

        listof_bindertargetlabel.append((b_emb, t_emb, lbl))

    return listof_bindertargetlabel

# validation_data_5clusters = []
# training_data_5clusters = []

# for i, split in enumerate(cv_splits):
#     validation, training = split[0], split[1]
    
#     validation_binders = interaction_df.loc[interaction_df["target_id_mod"].isin(validation), "target_binder_ID"].tolist()
#     training_binders = interaction_df.loc[interaction_df["target_id_mod"].isin(training), "target_binder_ID"].tolist()

#     listof_bindertargetlabel = binder_target_label(targets_dataset, binders_dataset, validation_binders, interaction_df)
#     validation_data_5clusters.append(listof_bindertargetlabel)
    
#     listof_bindertargetlabel = binder_target_label(targets_dataset, binders_dataset, training_binders, interaction_df)
#     training_data_5clusters.append(listof_bindertargetlabel)

In [22]:
# len(validation_data_5clusters)
# >> 5 # 5 folds
# len(validation_data_5clusters[0])
# >> 249 # number of  instances per fold "1" used ofr validation
# len(validation_data_5clusters[0][0])
# >> 3 # binder_emb, target_emb, label

In [23]:
# for i in range(len(validation_data_5clusters)):
#     print(f"Run {i+1} : len(val_dataset) : {len(validation_data_5clusters[i])}, len(train_dataset) : {len(training_data_5clusters[i])}")

### Loading pretrained model for finetuning

In [24]:
ckpt_path = '../PPI_PLM/models/CLIP_no_structural_information/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_4/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_epoch_4.pth'
checkpoint = torch.load(ckpt_path, weights_only=False, map_location="cpu")
# print(list(checkpoint["model_state_dict"]))
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn()
model.load_state_dict(checkpoint['model_state_dict'])
torch.cuda.empty_cache()  # frees cached blocks (not live tensors)
device = torch.device("cuda:0")
model.to(device)
# model.train()

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=1152, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1152, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1152, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

### Loading training and validation datasets (DataLoaders)

In [25]:
all_binders = interaction_df["target_binder_ID"].tolist()
# all dataset: binder_enb, target_emb, label
ALL_btl_list = binder_target_label(targets_dataset, binders_dataset, all_binders, interaction_df)
ALL_btl_list.__len__()
ALL_btl_list.__getitem__(0)

(tensor([[-7.4324e-03,  2.6254e-04, -7.7275e-03,  ...,  5.6024e-03,
           1.3973e-02,  1.4799e-02],
         [-1.3152e-02,  4.0290e-02,  1.1565e-02,  ...,  9.0204e-03,
           2.5051e-02,  2.8246e-02],
         [ 6.4436e-03,  1.9148e-02, -2.1045e-02,  ...,  4.0891e-04,
           1.0381e-02, -4.7295e-03],
         ...,
         [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
          -5.0000e+03, -5.0000e+03],
         [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
          -5.0000e+03, -5.0000e+03],
         [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
          -5.0000e+03, -5.0000e+03]]),
 tensor([[ 3.5021e-03, -1.8118e-03, -3.0359e-03,  ...,  1.7072e-04,
           6.6654e-03,  1.2242e-02],
         [ 2.2083e-02,  1.7457e-02, -3.6554e-03,  ...,  1.6579e-02,
           3.6905e-04,  1.7242e-02],
         [-6.4675e-03,  1.0692e-02, -8.5746e-04,  ...,  2.8062e-02,
          -1.3191e-02,  7.7366e-04],
         ...,
         [-5.0000e+03, -5

In [26]:
binders = torch.stack([torch.as_tensor(b) for b,_,_ in ALL_btl_list])   # [N, L, D]
targets = torch.stack([torch.as_tensor(t) for _,t,_ in ALL_btl_list])   # [N, L, D]
labels  = torch.tensor([float(y) for *_,y in ALL_btl_list], dtype=torch.float32)  # [N]

ALL_btl = TensorDataset(binders, targets, labels)
ALL_btl[0]

(tensor([[-7.4324e-03,  2.6254e-04, -7.7275e-03,  ...,  5.6024e-03,
           1.3973e-02,  1.4799e-02],
         [-1.3152e-02,  4.0290e-02,  1.1565e-02,  ...,  9.0204e-03,
           2.5051e-02,  2.8246e-02],
         [ 6.4436e-03,  1.9148e-02, -2.1045e-02,  ...,  4.0891e-04,
           1.0381e-02, -4.7295e-03],
         ...,
         [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
          -5.0000e+03, -5.0000e+03],
         [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
          -5.0000e+03, -5.0000e+03],
         [-5.0000e+03, -5.0000e+03, -5.0000e+03,  ..., -5.0000e+03,
          -5.0000e+03, -5.0000e+03]]),
 tensor([[ 3.5021e-03, -1.8118e-03, -3.0359e-03,  ...,  1.7072e-04,
           6.6654e-03,  1.2242e-02],
         [ 2.2083e-02,  1.7457e-02, -3.6554e-03,  ...,  1.6579e-02,
           3.6905e-04,  1.7242e-02],
         [-6.4675e-03,  1.0692e-02, -8.5746e-04,  ...,  2.8062e-02,
          -1.3191e-02,  7.7366e-04],
         ...,
         [-5.0000e+03, -5

In [27]:
class PairListDataset(torch.utils.data.Dataset):
    # examples: list of (binder_emb, target_emb, label)
    # target_ids: parallel list of target_id_mod (same order)
    def __init__(self, examples, target_ids):
        assert len(examples) == len(target_ids)
        self.examples = examples
        self.target_ids = list(map(str, target_ids))

    def __len__(self): return len(self.examples)

    def __getitem__(self, idx):
        b, t, y = self.examples[idx]
        return (torch.as_tensor(b, dtype=torch.float32),
                torch.as_tensor(t, dtype=torch.float32),
                torch.tensor(float(y), dtype=torch.float32),
                self.target_ids[idx])  # <- keep the id

In [28]:
interaction_df["combined_weight"] = (interaction_df["class_weight"]+interaction_df["target_weight"])/2
multipliers = []
for binder in interaction_df["binder"]:
    if binder == False:
        multipliers.append(0.5)
    else:
        multipliers.append(1)
interaction_df["combined_weight_account_pos"] = interaction_df["combined_weight"] * multipliers
interaction_df

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id_mod,target_binder_ID,class_weight,target_weight,combined_weight,combined_weight_account_pos
0,A,"[""B""]",False,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,VirB8,VirB8_1,0.000159,0.000631,0.000395,0.000198
1,A,"[""B""]",False,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1,0.000159,0.000029,0.000094,0.000047
2,A,"[""B""]",False,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2,0.000159,0.000029,0.000094,0.000047
3,A,"[""B""]",False,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_3,0.000159,0.000029,0.000094,0.000047
4,A,"[""B""]",False,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_4,0.000159,0.000029,0.000094,0.000047
...,...,...,...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_62,0.000159,0.000947,0.000553,0.000277
3528,A,"[""B""]",False,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_63,0.000159,0.000947,0.000553,0.000277
3529,A,"[""B""]",False,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_64,0.000159,0.000947,0.000553,0.000277
3530,A,"[""B""]",False,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_65,0.000159,0.000947,0.000553,0.000277


In [29]:
u = interaction_df[['binder', 'target_id_mod', 'combined_weight_account_pos']].drop_duplicates(subset=['target_id_mod', 'binder'])
print(u.sort_values('target_id_mod').to_string(index=False))
# or as a dict:
weight_map = u.set_index('target_id_mod')['combined_weight_account_pos'].to_dict()

 binder target_id_mod  combined_weight_account_pos
   True          EGFR                     0.000884
  False          EGFR                     0.000162
   True        EGFR_2                     0.000743
  False        EGFR_2                     0.000092
   True        EGFR_3                     0.008452
  False         FGFR2                     0.000047
   True         FGFR2                     0.000654
   True        IL10Ra                     0.002060
  False        IL10Ra                     0.000750
   True         IL2Ra                     0.001113
  False         IL2Ra                     0.000277
  False         IL7Ra                     0.000131
   True         IL7Ra                     0.000822
   True      InsulinR                     0.000906
  False      InsulinR                     0.000173
  False           LTK                     0.000513
   True           LTK                     0.001586
  False          Mdm2                     0.000203
   True          Mdm2          

In [30]:
train_targets = cv_splits[12][1]
g = torch.Generator().manual_seed(SEED)

train_weights_class = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "class_weight"].tolist()
train_weights_target = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "target_weight"].tolist()
train_weights_combined = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "combined_weight"].tolist()
train_weights_combined_boosted = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "combined_weight_account_pos"].tolist()

train_idx = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets)].index.tolist()
train_target_ids = interaction_df.loc[train_idx, "target_id_mod"].astype(str).tolist()
train_binders_ds = PairListDataset([ALL_btl[idx] for idx in train_idx], target_ids=train_target_ids)

train_sampler = WeightedRandomSampler(weights=train_weights_combined, num_samples=len(train_binders_ds), replacement=True, generator = g)
train_loader   = DataLoader(train_binders_ds,  batch_size=20, sampler=train_sampler)

for bi, batch in enumerate(train_loader):
    _, _, labels, ids = batch
    c = Counter(ids)
    print(f"Batch {bi}: {dict(c)}  | positives={int(labels.sum().item())} / {labels.numel()}")
    if bi == 10: break  # first 5 batches

Batch 0: {'IL2Ra': 1, 'InsulinR': 1, 'FGFR2': 7, 'sntx': 1, 'Mdm2': 4, 'EGFR': 1, 'EGFR_2': 2, 'IL7Ra': 1, 'LTK': 1, 'VirB8': 1}  | positives=8 / 20
Batch 1: {'FGFR2': 7, 'InsulinR': 2, 'TrkA': 3, 'EGFR_3': 1, 'Mdm2': 3, 'IL10Ra': 1, 'IL7Ra': 1, 'sntx_2': 1, 'EGFR': 1}  | positives=7 / 20
Batch 2: {'TrkA': 1, 'Mdm2': 3, 'FGFR2': 7, 'EGFR_2': 2, 'EGFR': 1, 'IL2Ra': 2, 'SARS_CoV2_RBD': 1, 'IL7Ra': 2, 'VirB8': 1}  | positives=6 / 20
Batch 3: {'EGFR_2': 3, 'InsulinR': 2, 'IL2Ra': 1, 'FGFR2': 7, 'Mdm2': 3, 'IL7Ra': 2, 'TrkA': 1, 'EGFR': 1}  | positives=5 / 20
Batch 4: {'EGFR_2': 3, 'SARS_CoV2_RBD': 2, 'IL10Ra': 2, 'EGFR_3': 1, 'FGFR2': 9, 'sntx': 1, 'IL7Ra': 1, 'EGFR': 1}  | positives=6 / 20
Batch 5: {'FGFR2': 8, 'IL2Ra': 3, 'IL10Ra': 1, 'SARS_CoV2_RBD': 1, 'Mdm2': 4, 'EGFR': 1, 'EGFR_2': 2}  | positives=3 / 20
Batch 6: {'SARS_CoV2_RBD': 1, 'LTK': 1, 'InsulinR': 3, 'sntx_2': 2, 'FGFR2': 6, 'VirB8': 4, 'IL10Ra': 1, 'Mdm2': 1, 'IL7Ra': 1}  | positives=11 / 20
Batch 7: {'EGFR_2': 1, 'FGFR2': 1

In [31]:
train_sampler = WeightedRandomSampler(weights=train_weights_combined_boosted, num_samples=len(train_binders_ds), replacement=True, generator = g)
train_loader   = DataLoader(train_binders_ds,  batch_size=20, sampler=train_sampler)

for bi, batch in enumerate(train_loader):
    _, _, labels, ids = batch
    c = Counter(ids)
    print(f"Batch {bi}: {dict(c)}  | positives={int(labels.sum().item())} / {labels.numel()}")
    if bi == 10: break  # first 5 batches

Batch 0: {'EGFR_2': 2, 'IL2Ra': 1, 'LTK': 1, 'IL10Ra': 1, 'Mdm2': 3, 'sntx': 3, 'FGFR2': 6, 'EGFR': 1, 'sntx_2': 1, 'SARS_CoV2_RBD': 1}  | positives=10 / 20
Batch 1: {'Mdm2': 3, 'FGFR2': 7, 'IL7Ra': 2, 'InsulinR': 1, 'VirB8': 2, 'SARS_CoV2_RBD': 2, 'EGFR_3': 1, 'sntx': 2}  | positives=15 / 20
Batch 2: {'IL2Ra': 2, 'FGFR2': 3, 'VirB8': 3, 'SARS_CoV2_RBD': 2, 'LTK': 1, 'IL10Ra': 1, 'sntx': 4, 'InsulinR': 2, 'EGFR_3': 1, 'EGFR': 1}  | positives=10 / 20
Batch 3: {'IL2Ra': 2, 'FGFR2': 6, 'sntx': 2, 'SARS_CoV2_RBD': 2, 'VirB8': 2, 'EGFR': 1, 'TrkA': 1, 'Mdm2': 1, 'IL7Ra': 1, 'InsulinR': 1, 'EGFR_3': 1}  | positives=13 / 20
Batch 4: {'FGFR2': 6, 'SARS_CoV2_RBD': 1, 'EGFR_2': 4, 'LTK': 2, 'IL2Ra': 1, 'Mdm2': 1, 'VirB8': 1, 'EGFR': 2, 'TrkA': 1, 'sntx': 1}  | positives=8 / 20
Batch 5: {'IL7Ra': 1, 'FGFR2': 10, 'Mdm2': 2, 'EGFR': 1, 'VirB8': 3, 'InsulinR': 2, 'TrkA': 1}  | positives=15 / 20
Batch 6: {'Mdm2': 2, 'FGFR2': 8, 'EGFR_3': 1, 'IL7Ra': 1, 'VirB8': 1, 'IL10Ra': 2, 'IL2Ra': 1, 'TrkA': 1, 

In [32]:
# val_targets = cv_splits[0][0]
# train_targets = cv_splits[0][1]

# # indexes of validation binders
# val_idx = interaction_df.loc[interaction_df.target_id_mod.isin(val_targets)].index.tolist()

# # weights of validation binders
# # val_weights = interaction_df.loc[interaction_df.target_id_mod.isin(val_targets), "class_weight"].tolist()

# # validation dataset : binder_emb, target_emb, label
# val_binders_ds = [ALL_btl[idx] for idx in val_idx]
# val_binders_ds = PairListDataset(val_binders_ds)

# # indexes of training binders
# train_idx = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets)].index.tolist()

# # weights of training binders
# train_weights_class = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "class_weight"].tolist()
# train_weights_target = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "target_weight"].tolist()
# train_weights_combined = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "combined_weight"].tolist()

# # training dataset : binder_emb, target_emb, label
# train_binders_ds = [ALL_btl[idx] for idx in train_idx]
# train_binders_ds = PairListDataset(train_binders_ds)

In [33]:
import gc, torch
# del obj  # any large temps you created in the cell
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# print_mem_consumption()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

torch.cuda.empty_cache()
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.065011712
Allocated memory:  0.0
Free memory:  0.008080896


### Training loop

In [34]:
def batch(iterable, n=20):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper_MetaAnal():

    def __init__(self, model, training_loader, validation_loader, test_dataset, 
                 optimizer, scheduler, EPOCHS, runID, device, test_indexes_for_auROC=None,
                 auROC_batch_size=18, model_save_steps=False, model_save_path=False, 
                 v=False, wandb_tracker=False, split_id=None):
        
        self.model = model 
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.EPOCHS = EPOCHS
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.split_id = split_id
        self.best_vloss = 1e09
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.device = device
        self.test_indexes_for_auROC = test_indexes_for_auROC
        self.auROC_batch_size = auROC_batch_size
        self.test_dataset = test_dataset

    def train_one_epoch(self):

        self.model.train()
        running_loss = 0.0

        for batch_data in self.training_loader:
            if batch_data[0].size(0) == 1:
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch_data, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

        if self.scheduler is not None:
            self.scheduler.step()

        return running_loss / max(1, len(self.training_loader))

    def calc_auroc_aupr_on_indexes(self, model, validation_dataset, batch_size=20, pad_value=-5000.0):

        model.eval()
        all_scores, all_labels = [], []
        batched_data = batch(validation_dataset, n=batch_size)

        with torch.no_grad():
            for one_batch in batched_data:
                items = [i for i in one_batch]
                binders = torch.stack([binder_emb for (binder_emb, _, _) in items]).to(self.device)
                targets = torch.stack([target_emb for (_, target_emb, _) in items]).to(self.device)
                labels = np.array([float(lbl) for *_, lbl in items], dtype=np.float32)

                logits = model.forward(binders, targets).detach().cpu().numpy()

                all_scores.extend(logits.tolist())
                all_labels.extend(labels.tolist())

        all_scores = np.array(all_scores, dtype=np.float64)
        all_labels = np.array(all_labels, dtype=np.int64)

        fpr, tpr, _ = metrics.roc_curve(all_labels, all_scores)
        auroc = metrics.roc_auc_score(all_labels, all_scores)
        aupr  = metrics.average_precision_score(all_labels, all_scores)

        return auroc, aupr, fpr, tpr
    
    def validate(self, dataloader, indexes_for_auc=False, auROC_dataset=False):
        self.model.eval()
        running_loss, n_loss = 0.0, 0
        all_scores, all_labels = [], []
    
        with torch.no_grad():
            for one_batch in dataloader:
                loss, logits, labels = self.model.validation_step(one_batch, self.device)
                running_loss += float(loss)
                n_loss += 1
                all_scores.append(logits.detach().float().cpu())
                all_labels.append(labels.detach().long().cpu())
    
        val_loss = running_loss / max(1, n_loss)
    
        if all_scores:
            scores = torch.cat(all_scores).numpy()
            labs   = torch.cat(all_labels).numpy()
            val_auroc = metrics.roc_auc_score(labs, scores)
            val_aupr  = metrics.average_precision_score(labs, scores)
        else:
            scores = np.array([], dtype=np.float32)   # ensure defined
            labs   = np.array([], dtype=np.int64)
            val_auroc = float("nan")
            val_aupr  = float("nan")
    
        return val_loss, val_auroc, val_aupr, scores, labs

    def train_model(self):

        if self.verbose:
            print(f"Training model {str(self.runID)}")
        
        # Pre-training snapshot
        val_loss, val_auroc, val_aupr, scores, labs = self.validate(
            dataloader=self.validation_loader,
            indexes_for_auc=self.test_indexes_for_auROC,
            auROC_dataset=self.test_dataset
        )

        if self.verbose:
            print(
                f'Before training - Val Loss {val_loss:.4f} | '
                f'Val AUROC {val_auroc if val_auroc==val_auroc else float("nan"):.4f} | '
                f'Val AUPR {val_aupr if val_aupr==val_aupr else float("nan"):.4f}'
            )
            
        if self.wandb_tracker:
            self.wandb_tracker.log(
                {"Val Loss": val_loss, "Val AUROC": val_auroc, "Val AUPR": val_aupr},
                step=0
            )
            
        # --- Epoch loop ---
        for epoch in range(1, self.EPOCHS + 1):
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            val_loss, val_auroc, val_aupr, scores, labs = self.validate(
                dataloader=self.validation_loader,
                indexes_for_auc=self.test_indexes_for_auROC,
                auROC_dataset=self.test_dataset
            )
    
            if self.verbose and (epoch % self.print_frequency_loss == 0):
                print(
                    f'EPOCH {epoch} - Train Loss {train_loss:.4f} | '
                    f'Val Loss {val_loss:.4f} | Val AUROC {val_auroc if val_auroc==val_auroc else float("nan"):.4f} | '
                    f'Val AUPR {val_aupr if val_aupr==val_aupr else float("nan"):.4f}'
                )

            if self.scheduler is not None and self.wandb_tracker:
                lr = float(self.optimizer.param_groups[0]["lr"])
                self.wandb_tracker.log({"learning_rate": lr}, step=epoch)

            if scores.size and labs.size:
                pos_mask = labs == 1
                neg_mask = labs == 0
                median_pos = float(np.median(scores[pos_mask])) if pos_mask.any() else float("nan")
                median_neg = float(np.median(scores[neg_mask])) if neg_mask.any() else float("nan")
                gap = median_pos - median_neg if np.isfinite(median_pos) and np.isfinite(median_neg) else float("nan")
            else:
                median_pos = median_neg = gap = float("nan")

            if self.wandb_tracker:
                log_items = {
                    "Train Loss": train_loss,
                    "Val Loss": val_loss,
                    "Val AUROC": val_auroc,
                    "Val AUPR": val_aupr,
                    "val_pos_median_logit": median_pos,
                    "val_neg_median_logit": median_neg,
                    "val_logit_gap": gap,
                }
                self.wandb_tracker.log(log_items, step=epoch)
    
        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [35]:
class PairListDataset(Dataset):
    def __init__(self, examples, weights=None):
        self.examples = examples
        self.weights = weights  # optional per-sample weights (list/array)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        b, t, y = self.examples[idx]
        b = torch.as_tensor(b, dtype=torch.float32)
        t = torch.as_tensor(t, dtype=torch.float32)
        y = torch.tensor(float(y), dtype=torch.float32)
        if self.weights is None:
            return b, t, y
        else:
            w = torch.tensor(float(self.weights[idx]), dtype=torch.float32)
            return b, t, y, w

In [36]:
batch_size = 20
learning_rate = 2e-5
EPOCHS = 5
g = torch.Generator().manual_seed(SEED)

# login once (env var preferred)
if use_wandb:
    import wandb
    wandb.login()

for i in range(len(cv_splits)):

    val_target_name = "_".join(cv_splits[i][0])
    
    # NEW model per split
    model = MiniCLIP_w_transformer_crossattn()
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
    
    # validation
    val_targets = cv_splits[i][0]
    val_idx = interaction_df.loc[interaction_df.target_id_mod.isin(val_targets)].index.tolist()
    val_binders = [ALL_btl[idx] for idx in val_idx]
    val_binders = PairListDataset(val_binders)

    # training
    train_targets = cv_splits[i][1]

    # train_weights_class = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "class_weight"].tolist()
    # train_weights_target = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "target_weight"].tolist()
    # train_weights_combined = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "combined_weight"].tolist()
    train_weights_combined_boost_positives = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets), "combined_weight_account_pos"].tolist()

    train_idx = interaction_df.loc[interaction_df.target_id_mod.isin(train_targets)].index.tolist()
    train_binders = [ALL_btl[idx] for idx in train_idx]
    train_binders = PairListDataset(train_binders)

    # loaders
    ### no weighting
    # train_loader   = DataLoader(train_binders, batch_size=20, shuffle=True)

    ### class weighting
    # train_sampler = WeightedRandomSampler(weights=train_weights_class, num_samples=train_binders.__len__(), replacement=True)
    # train_loader   = DataLoader(train_binders,   batch_size=20, sampler=train_sampler)

    ### target weighting
    # train_sampler = WeightedRandomSampler(weights=train_weights_target, num_samples=train_binders.__len__(), replacement=True)
    # train_loader   = DataLoader(train_binders,   batch_size=20, sampler=train_sampler)

    ### combined weighting
    train_sampler = WeightedRandomSampler(weights=train_weights_combined_boost_positives, num_samples=train_binders.__len__(), replacement=True, generator = g)
    train_loader   = DataLoader(train_binders,   batch_size=20, sampler=train_sampler)
    
    val_loader   = DataLoader(val_binders,   batch_size=20, shuffle=False, drop_last = False)

    # accelerator
    accelerator = Accelerator()
    device = accelerator.device
    model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(model, optimizer, train_loader, val_loader, scheduler)

    # wandb
    if use_wandb:
        run = wandb.init(
            project="MetaAnal_leave1OutCV",
            name=f"CV_split{i+1}_{val_target_name}_{runID}",
            group="cv_splits",
            config={"learning_rate": learning_rate, "batch_size": batch_size, "epochs": EPOCHS,
                    "architecture": "MiniCLIP_w_transformer_crossattn", "dataset": "Meta analysis"},
        )
        wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
    else:
        run = None

    # train
    training_wrapper = TrainWrapper_MetaAnal(
        model=model,
        training_loader=train_loader,
        validation_loader=val_loader,
        test_dataset=val_binders,   # ok if you truly want “full val”
        optimizer=optimizer,
        scheduler = scheduler,
        EPOCHS=EPOCHS,
        runID=runID,
        device=device,
        model_save_steps=model_save_steps,
        model_save_path=trained_model_dir,
        v=True,
        wandb_tracker=run,
        split_id=i+1
    )
    training_wrapper.train_model()

    # cleanup between splits
    if use_wandb:
        wandb.finish()
    del training_wrapper, model, optimizer, train_loader, val_loader, scheduler
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    try:
        accelerator.free_memory()
    except AttributeError:
        pass

Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 9.9593 | Val AUROC 0.4751 | Val AUPR 0.0824
EPOCH 1 - Train Loss 0.9805 | Val Loss 0.9979 | Val AUROC 0.5594 | Val AUPR 0.0995
EPOCH 2 - Train Loss 0.5084 | Val Loss 0.4123 | Val AUROC 0.5260 | Val AUPR 0.0906
EPOCH 3 - Train Loss 0.4628 | Val Loss 0.5752 | Val AUROC 0.5655 | Val AUPR 0.1010
EPOCH 4 - Train Loss 0.4123 | Val Loss 0.4040 | Val AUROC 0.4840 | Val AUPR 0.0831
EPOCH 5 - Train Loss 0.3987 | Val Loss 0.4304 | Val AUROC 0.5010 | Val AUPR 0.0860


0,1
Train Loss,█▂▂▁▁
Val AUPR,▁▇▄█▁▂
Val AUROC,▁█▅█▂▃
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,█▅█▁▃
val_neg_median_logit,█▂▄▁▂
val_pos_median_logit,█▂▅▁▂

0,1
Train Loss,0.39865
Val AUPR,0.08601
Val AUROC,0.50101
Val Loss,0.43039
learning_rate,0.0
val_logit_gap,-0.0221
val_neg_median_logit,-0.98652
val_pos_median_logit,-1.00862


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 3.8928 | Val AUROC 0.5330 | Val AUPR 0.6146
EPOCH 1 - Train Loss 0.7235 | Val Loss 0.7487 | Val AUROC 0.5836 | Val AUPR 0.6744
EPOCH 2 - Train Loss 0.5248 | Val Loss 0.6802 | Val AUROC 0.5472 | Val AUPR 0.6527
EPOCH 3 - Train Loss 0.4647 | Val Loss 0.7691 | Val AUROC 0.5175 | Val AUPR 0.6346
EPOCH 4 - Train Loss 0.4014 | Val Loss 0.7039 | Val AUROC 0.4949 | Val AUPR 0.6095
EPOCH 5 - Train Loss 0.3600 | Val Loss 0.6984 | Val AUROC 0.5317 | Val AUPR 0.6404


0,1
Train Loss,█▄▃▂▁
Val AUPR,▂█▆▄▁▄
Val AUROC,▄█▅▃▁▄
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,█▅▆▁▅
val_neg_median_logit,█▁█▂▂
val_pos_median_logit,█▁█▁▂

0,1
Train Loss,0.36003
Val AUPR,0.64044
Val AUROC,0.53171
Val Loss,0.69845
learning_rate,0.0
val_logit_gap,0.00912
val_neg_median_logit,0.46494
val_pos_median_logit,0.47406


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 9.1940 | Val AUROC 1.0000 | Val AUPR 1.0000
EPOCH 1 - Train Loss 0.7077 | Val Loss 0.6807 | Val AUROC 0.9889 | Val AUPR 0.9167
EPOCH 2 - Train Loss 0.5026 | Val Loss 0.3102 | Val AUROC 0.9667 | Val AUPR 0.6389
EPOCH 3 - Train Loss 0.4538 | Val Loss 0.4582 | Val AUROC 0.3333 | Val AUPR 0.0874
EPOCH 4 - Train Loss 0.3895 | Val Loss 0.4145 | Val AUROC 0.5111 | Val AUPR 0.1176
EPOCH 5 - Train Loss 0.3553 | Val Loss 0.3031 | Val AUROC 0.6778 | Val AUPR 0.1693


0,1
Train Loss,█▄▃▂▁
Val AUPR,█▇▅▁▁▂
Val AUROC,███▁▃▅
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▆█▁▂▄
val_neg_median_logit,█▄▅▄▁
val_pos_median_logit,█▆▃▃▁

0,1
Train Loss,0.35533
Val AUPR,0.16934
Val AUROC,0.67778
Val Loss,0.3031
learning_rate,0.0
val_logit_gap,0.12218
val_neg_median_logit,-2.33522
val_pos_median_logit,-2.21303


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 9.7368 | Val AUROC 0.4444 | Val AUPR 0.0852
EPOCH 1 - Train Loss 0.6761 | Val Loss 4.1571 | Val AUROC 0.3765 | Val AUPR 0.0759
EPOCH 2 - Train Loss 0.4687 | Val Loss 2.7565 | Val AUROC 0.3284 | Val AUPR 0.0706
EPOCH 3 - Train Loss 0.4384 | Val Loss 2.7837 | Val AUROC 0.3259 | Val AUPR 0.0704
EPOCH 4 - Train Loss 0.3786 | Val Loss 1.9100 | Val AUROC 0.3444 | Val AUPR 0.0724
EPOCH 5 - Train Loss 0.3229 | Val Loss 1.7231 | Val AUROC 0.3346 | Val AUPR 0.0713


0,1
Train Loss,█▄▃▂▁
Val AUPR,█▄▁▁▂▁
Val AUROC,█▄▁▁▂▂
Val Loss,█▃▂▂▁▁
learning_rate,█▆▄▂▁
val_logit_gap,█▇▁▃▃
val_neg_median_logit,█▅▄▂▁
val_pos_median_logit,█▅▃▂▁

0,1
Train Loss,0.32294
Val AUPR,0.07128
Val AUROC,0.33457
Val Loss,1.7231
learning_rate,0.0
val_logit_gap,-0.95816
val_neg_median_logit,0.73724
val_pos_median_logit,-0.22092


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 7.9645 | Val AUROC 0.8571 | Val AUPR 0.4869
EPOCH 1 - Train Loss 0.7152 | Val Loss 0.3003 | Val AUROC 0.9558 | Val AUPR 0.7878
EPOCH 2 - Train Loss 0.5453 | Val Loss 0.4484 | Val AUROC 0.7959 | Val AUPR 0.3428
EPOCH 3 - Train Loss 0.4643 | Val Loss 0.4808 | Val AUROC 0.7041 | Val AUPR 0.2690
EPOCH 4 - Train Loss 0.4189 | Val Loss 0.5275 | Val AUROC 0.7449 | Val AUPR 0.3023
EPOCH 5 - Train Loss 0.3765 | Val Loss 0.5610 | Val AUROC 0.7211 | Val AUPR 0.2697


0,1
Train Loss,█▄▃▂▁
Val AUPR,▄█▂▁▁▁
Val AUROC,▅█▄▁▂▁
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,█▅▁▆▃
val_neg_median_logit,█▃▆▄▁
val_pos_median_logit,█▃▃▄▁

0,1
Train Loss,0.37652
Val AUPR,0.2697
Val AUROC,0.72109
Val Loss,0.56099
learning_rate,0.0
val_logit_gap,0.89109
val_neg_median_logit,-3.42676
val_pos_median_logit,-2.53567


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 9.1998 | Val AUROC 0.3201 | Val AUPR 0.1611
EPOCH 1 - Train Loss 0.7053 | Val Loss 0.7713 | Val AUROC 0.3021 | Val AUPR 0.1680
EPOCH 2 - Train Loss 0.4930 | Val Loss 0.7667 | Val AUROC 0.3118 | Val AUPR 0.1635
EPOCH 3 - Train Loss 0.4355 | Val Loss 0.9893 | Val AUROC 0.3188 | Val AUPR 0.1695
EPOCH 4 - Train Loss 0.3838 | Val Loss 1.1601 | Val AUROC 0.3310 | Val AUPR 0.1687
EPOCH 5 - Train Loss 0.3271 | Val Loss 1.1457 | Val AUROC 0.3310 | Val AUPR 0.1707


0,1
Train Loss,█▄▃▂▁
Val AUPR,▁▆▃▇▇█
Val AUROC,▅▁▃▅██
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,██▄▁▂
val_neg_median_logit,▇█▃▄▁
val_pos_median_logit,▇█▃▂▁

0,1
Train Loss,0.32706
Val AUPR,0.17067
Val AUROC,0.33102
Val Loss,1.14572
learning_rate,0.0
val_logit_gap,-2.30029
val_neg_median_logit,-2.06
val_pos_median_logit,-4.3603


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 7.5294 | Val AUROC 0.3799 | Val AUPR 0.1421
EPOCH 1 - Train Loss 0.6821 | Val Loss 0.6457 | Val AUROC 0.4546 | Val AUPR 0.1736
EPOCH 2 - Train Loss 0.4927 | Val Loss 0.8434 | Val AUROC 0.4964 | Val AUPR 0.1994
EPOCH 3 - Train Loss 0.4354 | Val Loss 1.2821 | Val AUROC 0.4433 | Val AUPR 0.1582
EPOCH 4 - Train Loss 0.3793 | Val Loss 1.0494 | Val AUROC 0.4809 | Val AUPR 0.1816
EPOCH 5 - Train Loss 0.3466 | Val Loss 0.7119 | Val AUROC 0.4912 | Val AUPR 0.1826


0,1
Train Loss,█▄▃▂▁
Val AUPR,▁▅█▃▆▆
Val AUROC,▁▅█▅▇█
Val Loss,█▁▁▂▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▁█▅▅▆
val_neg_median_logit,▁▄█▆▁
val_pos_median_logit,▁▅█▆▂

0,1
Train Loss,0.34655
Val AUPR,0.1826
Val AUROC,0.49124
Val Loss,0.71186
learning_rate,0.0
val_logit_gap,-0.03559
val_neg_median_logit,-0.16405
val_pos_median_logit,-0.19963


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 11.1430 | Val AUROC 0.4435 | Val AUPR 0.0761
EPOCH 1 - Train Loss 0.7023 | Val Loss 0.4537 | Val AUROC 0.3212 | Val AUPR 0.0558
EPOCH 2 - Train Loss 0.5074 | Val Loss 0.7502 | Val AUROC 0.3389 | Val AUPR 0.0577
EPOCH 3 - Train Loss 0.4355 | Val Loss 0.8657 | Val AUROC 0.3595 | Val AUPR 0.0621
EPOCH 4 - Train Loss 0.3796 | Val Loss 0.8078 | Val AUROC 0.3539 | Val AUPR 0.0605
EPOCH 5 - Train Loss 0.3519 | Val Loss 0.8418 | Val AUROC 0.3847 | Val AUPR 0.0629


0,1
Train Loss,█▄▃▂▁
Val AUPR,█▁▂▃▃▃
Val AUROC,█▁▂▃▃▅
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,█▅▇▁▅
val_neg_median_logit,▂█▅▄▁
val_pos_median_logit,▃█▆▃▁

0,1
Train Loss,0.35192
Val AUPR,0.06294
Val AUROC,0.38469
Val Loss,0.84175
learning_rate,0.0
val_logit_gap,-0.38573
val_neg_median_logit,-1.43784
val_pos_median_logit,-1.82357


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 7.8833 | Val AUROC 1.0000 | Val AUPR 1.0000
EPOCH 1 - Train Loss 0.6764 | Val Loss 0.5927 | Val AUROC 0.4000 | Val AUPR 0.1099
EPOCH 2 - Train Loss 0.5024 | Val Loss 0.4847 | Val AUROC 0.9500 | Val AUPR 0.5833
EPOCH 3 - Train Loss 0.4444 | Val Loss 0.3751 | Val AUROC 1.0000 | Val AUPR 1.0000
EPOCH 4 - Train Loss 0.3839 | Val Loss 0.2600 | Val AUROC 1.0000 | Val AUPR 1.0000
EPOCH 5 - Train Loss 0.3447 | Val Loss 0.2376 | Val AUROC 1.0000 | Val AUPR 1.0000


0,1
Train Loss,█▄▃▂▁
Val AUPR,█▁▅███
Val AUROC,█▁▇███
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▁▂▅██
val_neg_median_logit,█▇▇▁▂
val_pos_median_logit,▁▃█▄▅

0,1
Train Loss,0.34469
Val AUPR,1.0
Val AUROC,1.0
Val Loss,0.23756
learning_rate,0.0
val_logit_gap,2.45877
val_neg_median_logit,-2.18918
val_pos_median_logit,0.26959


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 10.3282 | Val AUROC 0.6088 | Val AUPR 0.1140
EPOCH 1 - Train Loss 0.7487 | Val Loss 0.8408 | Val AUROC 0.5303 | Val AUPR 0.0868
EPOCH 2 - Train Loss 0.5263 | Val Loss 0.8890 | Val AUROC 0.5223 | Val AUPR 0.0717
EPOCH 3 - Train Loss 0.4606 | Val Loss 0.7391 | Val AUROC 0.5693 | Val AUPR 0.0893
EPOCH 4 - Train Loss 0.4201 | Val Loss 0.5707 | Val AUROC 0.5493 | Val AUPR 0.0723
EPOCH 5 - Train Loss 0.3748 | Val Loss 0.5307 | Val AUROC 0.5754 | Val AUPR 0.0766


0,1
Train Loss,█▄▃▂▁
Val AUPR,█▄▁▄▁▂
Val AUROC,█▂▁▅▃▅
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▄▁▄▄█
val_neg_median_logit,▇█▆▁▁
val_pos_median_logit,██▆▁▂

0,1
Train Loss,0.37481
Val AUPR,0.07656
Val AUROC,0.57539
Val Loss,0.53074
learning_rate,0.0
val_logit_gap,0.3918
val_neg_median_logit,-0.97267
val_pos_median_logit,-0.58087


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 11.2007 | Val AUROC 0.5383 | Val AUPR 0.1048
EPOCH 1 - Train Loss 0.7092 | Val Loss 6.0638 | Val AUROC 0.5420 | Val AUPR 0.1050
EPOCH 2 - Train Loss 0.4870 | Val Loss 4.7308 | Val AUROC 0.5235 | Val AUPR 0.0982
EPOCH 3 - Train Loss 0.4341 | Val Loss 3.9553 | Val AUROC 0.5160 | Val AUPR 0.0960
EPOCH 4 - Train Loss 0.3790 | Val Loss 3.4885 | Val AUROC 0.5333 | Val AUPR 0.1002
EPOCH 5 - Train Loss 0.3628 | Val Loss 3.1657 | Val AUROC 0.4741 | Val AUPR 0.0888


0,1
Train Loss,█▄▂▁▁
Val AUPR,██▅▄▆▁
Val AUROC,██▆▅▇▁
Val Loss,█▄▂▂▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▃▄▁█▅
val_neg_median_logit,█▅▃▂▁
val_pos_median_logit,█▅▃▂▁

0,1
Train Loss,0.36278
Val AUPR,0.08882
Val AUROC,0.47407
Val Loss,3.16566
learning_rate,0.0
val_logit_gap,0.09053
val_neg_median_logit,3.4025
val_pos_median_logit,3.49304


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 10.5325 | Val AUROC 0.9278 | Val AUPR 0.4323
EPOCH 1 - Train Loss 0.7370 | Val Loss 0.8923 | Val AUROC 0.7833 | Val AUPR 0.2227
EPOCH 2 - Train Loss 0.5163 | Val Loss 0.3665 | Val AUROC 0.8056 | Val AUPR 0.2255
EPOCH 3 - Train Loss 0.4466 | Val Loss 0.2652 | Val AUROC 0.7556 | Val AUPR 0.1872
EPOCH 4 - Train Loss 0.4083 | Val Loss 0.2685 | Val AUROC 0.7778 | Val AUPR 0.2083
EPOCH 5 - Train Loss 0.3595 | Val Loss 0.2801 | Val AUROC 0.7389 | Val AUPR 0.1800


0,1
Train Loss,█▄▃▂▁
Val AUPR,█▂▂▁▂▁
Val AUROC,█▃▃▂▂▁
Val Loss,█▁▁▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▁▅▃██
val_neg_median_logit,█▅▃▁▁
val_pos_median_logit,█▅▂▁▁

0,1
Train Loss,0.3595
Val AUPR,0.18001
Val AUROC,0.73889
Val Loss,0.28012
learning_rate,0.0
val_logit_gap,0.84735
val_neg_median_logit,-3.37515
val_pos_median_logit,-2.52779


Training model fc69aac7-4547-4dfc-a36a-dc6f33b1b453
Before training - Val Loss 9.6539 | Val AUROC 0.4106 | Val AUPR 0.1174
EPOCH 1 - Train Loss 0.7175 | Val Loss 0.6069 | Val AUROC 0.4267 | Val AUPR 0.1240
EPOCH 2 - Train Loss 0.5019 | Val Loss 1.2884 | Val AUROC 0.3996 | Val AUPR 0.1331
EPOCH 3 - Train Loss 0.4318 | Val Loss 1.0117 | Val AUROC 0.4739 | Val AUPR 0.1569
EPOCH 4 - Train Loss 0.3869 | Val Loss 0.8364 | Val AUROC 0.5231 | Val AUPR 0.1912
EPOCH 5 - Train Loss 0.3363 | Val Loss 0.8263 | Val AUROC 0.5512 | Val AUPR 0.2610


0,1
Train Loss,█▄▃▂▁
Val AUPR,▁▁▂▃▅█
Val AUROC,▂▂▁▄▇█
Val Loss,█▁▂▁▁▁
learning_rate,█▆▄▂▁
val_logit_gap,▁▂▅▅█
val_neg_median_logit,▁█▆▄▃
val_pos_median_logit,▁█▇▆▆

0,1
Train Loss,0.33635
Val AUPR,0.26104
Val AUROC,0.5512
Val Loss,0.82628
learning_rate,0.0
val_logit_gap,0.26613
val_neg_median_logit,0.17516
val_pos_median_logit,0.44128
