In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()
import training_utils.partitioning_utils as pat_utils
from tqdm import trange

Tesla V100-SXM2-32GB


  warn(
  _torch_pytree._register_pytree_node(


In [2]:
import requests
requests.get("https://api.wandb.ai/status").status_code

import wandb
wandb.login(key="f8a6d759fe657b095d56bddbdb4d586dfaebd468", relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /zhome/c9/0/203261/.netrc
[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
### Setting a seed to have the same initiation of weights

def set_seed(seed: int = 42):
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # CuDNN settings (for convolution etc.)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # (Optional) for some Python hashing randomness
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 0
set_seed(SEED)

In [4]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.9.1+cu128
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [5]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 2
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 512 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

In [6]:
# ## Training variables
runID = uuid.uuid4()

## Output path
trained_model_dir = f"/work3/s232958/data/trained/original_architecture/{runID}"

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


### Loading PPint dataframe

In [7]:
Df_train = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_test = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test_w_pbd_lens.csv",index_col=0).reset_index(drop=True)

Df_train["target_chain"] = [str(row.ID1[:5]+row.ID1[-1]) for __, row in Df_train.iterrows()]
Df_train["binder_chain"] = [str(row.ID2[:5]+row.ID2[-1]) for __, row in Df_train.iterrows()]

Df_test["target_chain"] = [str(row.ID1[:5]+row.ID1[-1]) for __, row in Df_test.iterrows()]
Df_test["binder_chain"] = [str(row.ID2[:5]+row.ID2[-1]) for __, row in Df_test.iterrows()]

Df_train["target_binder_id"] = [str(row.ID1)+"_"+str(row.ID2) for __, row in Df_train.iterrows()]
Df_test["target_binder_id"] = [str(row.ID1)+"_"+str(row.ID2) for __, row in Df_test.iterrows()]

Df_train_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
Df_test_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)

Df_train = pd.merge(Df_train, Df_train_small[["target_binder_id", "dimer"]], on="target_binder_id", how="inner")
Df_test = pd.merge(Df_test, Df_test_small[["target_binder_id", "dimer"]], on="target_binder_id", how="inner")

Df_train

Unnamed: 0,interface_id,PDB,ID1,ID2,seq_target,seq_target_len,seq_pdb_target,pdb_target_len,target_chain,seq_binder,seq_binder_len,seq_pdb_binder,pdb_binder_len,binder_chain,pdb_path,target_binder_id,dimer
0,6IDB_0,6IDB,6IDB_0_A,6IDB_0_B,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,6IDB_A,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,6IDB_B,6idb.pdb.gz,6IDB_0_A_6IDB_0_B,False
1,2WZP_3,2WZP,2WZP_3_D,2WZP_3_G,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,2WZP_D,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,2WZP_G,2wzp.pdb.gz,2WZP_3_D_2WZP_3_G,False
2,1ZKP_0,1ZKP,1ZKP_0_A,1ZKP_0_C,LYFQSNAKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLA...,246,LYFQSNAMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGV...,251,1ZKP_A,AKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQKYI...,240,AMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQK...,245,1ZKP_C,1zkp.pdb.gz,1ZKP_0_A_1ZKP_0_C,True
3,6GRH_3,6GRH,6GRH_3_C,6GRH_3_D,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,6GRH_C,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,6GRH_D,6grh.pdb.gz,6GRH_3_C_6GRH_3_D,False
4,8R57_1,8R57,8R57_1_M,8R57_1_f,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,8R57_M,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,8R57_f,8r57.pdb.gz,8R57_1_M_8R57_1_f,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1972,4YO8_0,4YO8,4YO8_0_A,4YO8_0_B,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,4YO8_A,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,4YO8_B,4yo8.pdb.gz,4YO8_0_A_4YO8_0_B,True
1973,3CKI_0,3CKI,3CKI_0_A,3CKI_0_B,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,3CKI_A,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,3CKI_B,3cki.pdb.gz,3CKI_0_A_3CKI_0_B,False
1974,7MHY_1,7MHY,7MHY_1_M,7MHY_1_N,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,7MHY_M,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,7MHY_N,7mhy.pdb.gz,7MHY_1_M_7MHY_1_N,False
1975,7MHY_2,7MHY,7MHY_2_O,7MHY_2_P,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,7MHY_O,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,7MHY_P,7mhy.pdb.gz,7MHY_2_O_7MHY_2_P,False


In [8]:
class CLIP_PPint_class(Dataset):
    def __init__(
        self,
        dframe,
        path,
        embedding_dim=512,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen = self.dframe["pdb_binder_len"].max()+2
        self.max_tlen = self.dframe["pdb_target_len"].max()+2

        # paths
        self.encoding_path  = path

        # index & storage
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = parts[0]+"_"+parts[2]
            bnd_id = parts[3]+"_"+parts[5]

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_path, f"{bnd_id}.npy"))     # [Lb, D]

            assert (b_emb.shape[0] == self.dframe.loc[accession].pdb_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].pdb_target_len+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return binder_emb, target_emb, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

emb_path = "/work3/s232958/data/PPint_DB/esmif_embeddings_noncanonical"

training_Dataset = CLIP_PPint_class(
    Df_train,
    path=emb_path,
    embedding_dim=512
)

testing_Dataset = CLIP_PPint_class(
    Df_test,
    path=emb_path,
    embedding_dim=512
)

#Loading ESM2 embeddings and contacts: 100%|███████████████████████████████████████| 1977/1977 [00:10<00:00, 186.23it/s]
#Loading ESM2 embeddings and contacts: 100%|█████████████████████████████████████████| 494/494 [00:02<00:00, 190.16it/s]


In [9]:
### Getting indeces of non-dimers
indices_non_dimers_val = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers_val[:5]

### Getting accessions of non-dimers
accessions = [Df_test.loc[index].target_binder_id for index in indices_non_dimers_val]
emb_b, emb_t, labels = testing_Dataset._get_by_name(accessions[:5])
labels

tensor([1., 1., 1., 1., 1.])

### Loading Meta validation dataset

In [10]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal_w_pbd_lens.csv").drop(columns = ["binder_id", "target_id"]).rename(columns = {
    "target_id_mod" : "target_id",
    "target_binder_ID" : "binder_id",
})

# Interaction Dict
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=0).reset_index(drop=True)
interaction_df_shuffled

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id,binder_id,seq_len_binder,seq_len_target,pdb_len_binder,pdb_len_target
0,A,"[""B""]",True,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,62,101,62,101
1,A,"[""B""]",False,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,58,621,58,621
2,A,"[""B""]",False,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,65,101,65,101
3,A,"[""B""]",False,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,64,101,64,101
4,A,"[""B""]",False,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,65,165,65,165
...,...,...,...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,63,195,63,195
3528,A,"[""B""]",False,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,65,101,65,101
3529,A,"[""B""]",False,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,63,193,63,193
3530,A,"[""B""]",False,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,65,101,65,101


In [11]:
class CLIP_Meta_class(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=512,
        embedding_pad_value=-5000.0
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.max_blen = self.dframe["pdb_len_binder"].max()+2
        self.max_tlen = self.dframe["pdb_len_target"].max()+2

        # paths
        self.encoding_bpath, self.encoding_tpath = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))
            parts = accession.split("_")
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_tpath, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_bpath, f"{bnd_id}.npy"))     # [Lb, D]

            assert (b_emb.shape[0] == self.dframe.loc[accession].pdb_len_binder+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].pdb_len_target+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

bemb_path = "/work3/s232958/data/meta_analysis/esmif_embeddings_binders"
temb_path = "/work3/s232958/data/meta_analysis/esmif_embeddings_targets"

validation_Dataset = CLIP_Meta_class(
    # interaction_df_shuffled[:len(Df_test)],
    interaction_df_shuffled,
    paths=[bemb_path, temb_path],
    embedding_dim=512
)

#Loading ESM2 embeddings: 100%|████████████████████████████████████████████████████| 3532/3532 [00:17<00:00, 206.93it/s]


In [12]:
accessions_Meta = list(interaction_df_shuffled.binder_id)
emb_b, emb_t, labels = validation_Dataset._get_by_name(accessions_Meta[:5])
labels

tensor([1, 0, 0, 0, 0])

### Train model from scratch with 10% of PPint dataset using old architecture (encodings only)

In [13]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=embedding_dimension, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = 512

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension*2
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding), dim=-1)
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding), dim=-1)
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        embedding_binder, embedding_target, labels = batch
        embedding_binder = embedding_binder.to(device)
        embedding_target = embedding_target.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_binder, embedding_target)
            logits = logits.float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [14]:
model = MiniCLIP_w_transformer_crossattn(embed_dimension=embedding_dimension, num_recycles=number_of_recycles).to("cuda")
model

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (linear1): Linear(in_features=512, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1024, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=Tr

### Trianing loop

In [15]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 train_loader,
                 test_loader,
                 val_loader,
                 test_df,
                 test_dataset,
                 optimizer, 
                 epochs, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = train_loader
        self.testing_loader = test_loader
        self.validation_loader = val_loader
        self.test_dataset = test_dataset
        self.test_df = test_df
        self.auROC_batch_size = auROC_batch_size
        
        self.EPOCHS = epochs
        self.optimizer = optimizer
        self.device = device
        
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = 1_000_000
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.test_indexes_for_auROC = test_indexes_for_auROC

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            del loss, batch
            torch.cuda.empty_cache()
            
        return running_loss / len(self.training_loader)

    def calc_auroc_aupr_on_indexes(self, model, dataset, dataframe, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [dataframe.loc[index].target_binder_id for index in nondimer_indexes]  # <-- use dataframe
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                binder_emb, target_emb, labels = dataset._get_by_name(index_batch)
                binder_emb, target_emb = binder_emb.to(self.device), target_emb.to(self.device)

                # Make sure this matches your model's signature:
                logit_matrix = self.model.calculate_logit_matrix(binder_emb, target_emb)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr  = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores

    def validate(self):
        
        self.model.eval()
        
        running_loss_Meta = 0.0
        all_logits = []
        all_lbls = []
        used_batches_meta = 0

        # --- MetaDataset validation ---
        with torch.no_grad():
            for batch in tqdm(self.validation_loader, total=len(self.validation_loader)):
                if batch[0].size(0) == 1:
                    continue
                embedding_binder, embedding_target, labels = batch
                logits, loss = self.model.validation_step_MetaDataset(batch, self.device)
                
                running_loss_Meta += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_meta += 1
                
            if used_batches_meta > 0:
                val_loss_Meta = running_loss_Meta / used_batches_meta
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                meta_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                meta_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_Meta = (y_pred == y_true).mean()
            else:
                val_loss_Meta = float("nan")
                meta_auroc = float("nan")
                meta_aupr = float("nan")
                val_acc_Meta = float("nan")

        # --- PPint validation ---
        running_loss_ValPPint = 0.0
        running_accuracy_ValPPint = 0.0
        running_topk_accuracy_ValPPint = 0.0
        used_batches_ppint = 0

        with torch.no_grad():
            for batch in tqdm(self.testing_loader, total=len(self.testing_loader)):
                if batch[0].size(0) == 1:
                    continue
                loss, partner_accuracy, peptide_topk_accuracy = self.model.validation_step_PPint(batch, self.device)
                running_loss_ValPPint += loss.item()
                running_accuracy_ValPPint += partner_accuracy.item()
                running_topk_accuracy_ValPPint += peptide_topk_accuracy.item()
                used_batches_ppint += 1
                
            if used_batches_ppint > 0:
                val_loss_PPint = running_loss_ValPPint / used_batches_ppint
                val_accuracy_PPint = running_accuracy_ValPPint / used_batches_ppint
                val_topk_accuracy_PPint = running_topk_accuracy_ValPPint / used_batches_ppint
            else:
                val_loss_PPint = float("nan")
                val_accuracy_PPint = float("nan")
                val_topk_accuracy_PPint = float("nan")

        # --- AUROC on specific indexes (optional) ---
        if self.test_indexes_for_auROC is not None:
            non_dimer_auc, non_dimer_aupr, ___, ___ = self.calc_auroc_aupr_on_indexes(
                model=self.model, 
                dataset=self.test_dataset,
                dataframe=self.test_df,
                nondimer_indexes=self.test_indexes_for_auROC,
                batch_size=self.auROC_batch_size
            )
            
            return (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                    non_dimer_auc, non_dimer_aupr,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

        else:
            return (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

    def train_model(self):
        
        torch.cuda.empty_cache()
        
        if self.verbose:
            print(f"Training model {str(self.runID)}")

        # --- initial validation before training
        print("Initial validation before starting training")
        if self.test_indexes_for_auROC is not None:
            (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
             non_dimer_auc, non_dimer_aupr,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
        else:
            (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            non_dimer_auc, non_dimer_aupr = None, None
                
        if self.verbose: 
            print(f'Before training:')
            print(f'Meta Val-Loss {round(val_loss_Meta,4)}')
            print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
            print(f'Meta AUROC: {round(meta_auroc,4)}')
            print(f'Meta AUPR: {round(meta_aupr,4)}')
            print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
            print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
            if non_dimer_auc is not None:
                print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
        
        if self.wandb_tracker:
            metrics_to_log = {
                "PPint Test-Loss": val_loss_PPint,
                "Meta Val-loss": val_loss_Meta,
                "PPint Accuracy": val_accuracy_PPint,
                "Meta Accuracy": val_acc_Meta,
                "Meta Val-AUROC": meta_auroc,
                "Meta Val-AUPR": meta_aupr,
            }
            if non_dimer_auc is not None:
                metrics_to_log.update({
                    "PPint non-dimer AUROC": non_dimer_auc,
                    "PPint non-dimer AUPR": non_dimer_aupr,
                })
            self.wandb_tracker.log(metrics_to_log, step=0)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC is not None:
                (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                 non_dimer_auc, non_dimer_aupr,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            else:
                (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
                non_dimer_auc, non_dimer_aupr = None, None
            
            torch.cuda.empty_cache()
            
            # checkpoint save
            # if self.model_save_steps and epoch % self.model_save_steps == 0:
            #     check_point_folder = os.path.join(self.trained_model_dir, f"{str(self.runID)}_checkpoint_{str(epoch)}")
            #     if self.verbose:
            #         print("Saving model to:", check_point_folder)
            #     os.makedirs(check_point_folder, exist_ok=True)
            #     checkpoint_path = os.path.join(check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")
            #     torch.save({'epoch': epoch, 
            #                 'model_state_dict': self.model.state_dict(),
            #                 'optimizer_state_dict': self.optimizer.state_dict(), 
            #                 'val_loss_PPint': val_loss_PPint,
            #                 'val_loss_Meta': val_loss_Meta},
            #                checkpoint_path)
            
            # console logging
            if self.verbose and epoch % self.print_frequency_loss == 0:
                print(f'EPOCH {epoch}:')
                print(f'Meta Val Loss {round(val_loss_Meta,4)}')
                print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
                print(f'Meta AUROC: {round(meta_auroc,4)}')
                print(f'Meta AUPR: {round(meta_aupr,4)}')
                print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
                print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
                if non_dimer_auc is not None:
                    print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                    print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "PPint Train-loss": train_loss,
                    "PPint Test-Loss": val_loss_PPint,
                    "Meta Val-loss": val_loss_Meta,
                    "PPint Accuracy": val_accuracy_PPint,
                    "Meta Accuracy": val_acc_Meta,
                    "Meta Val-AUROC": meta_auroc,
                    "Meta Val-AUPR": meta_aupr,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch.update({
                        "PPint non-dimer AUROC": non_dimer_auc,
                        "PPint non-dimer AUPR": non_dimer_aupr,
                    })
                self.wandb_tracker.log(metrics_to_log_epoch, step=epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [16]:
learning_rate = 2e-5
EPOCHS = 12
batch_size = 10
optimizer = AdamW(model.parameters(), lr=learning_rate)
accelerator = Accelerator()
device = accelerator.device

train_dataloader = DataLoader(training_Dataset, batch_size=7, shuffle=True, drop_last = True)
test_dataloader = DataLoader(testing_Dataset, batch_size=7, shuffle=False, drop_last = False)
val_dataloader = DataLoader(validation_Dataset, batch_size=15, shuffle=False, drop_last = False)

# accelerator
model, optimizer, train_dataloader, test_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, val_dataloader)

In [17]:
for i in val_dataloader:
    __, __, lbls = i
    print(lbls.to(device))
    break

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0], device='cuda:0')


In [18]:
# wandb
if use_wandb:
    run = wandb.init(
        project="CSSP_combined_Loss2.0",
        name=f"ESMIF",
        config={"learning_rate": learning_rate, 
                "batch_size": batch_size, 
                "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", 
                "dataset": 
                "PPint"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(
            model=model,
            train_loader=train_dataloader,
            test_loader=test_dataloader,
            val_loader=val_dataloader,
            test_df=Df_test,
            test_dataset=testing_Dataset,
            optimizer=optimizer,
            epochs=EPOCHS,
            runID=runID,
            device=device,
            test_indexes_for_auROC=indices_non_dimers_val,
            auROC_batch_size=10,
            model_save_steps=model_save_steps,
            model_save_path=trained_model_dir,
            v=True,
            wandb_tracker=wandb
)

training_wrapper.train_model() # start training

Training model 7e89e0b1-5650-42ef-8374-eca8f9c5fb7b
Initial validation before starting training


100%|█████████████████████████████████████████████████████████████████████████████████| 236/236 [00:15<00:00, 15.58it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 71/71 [00:06<00:00, 11.25it/s]
Calculating AUC: 13it [00:01,  6.95it/s]                                                                                


Before training:
Meta Val-Loss 8.4436
Meta Accuracy: 0.1107
Meta AUROC: 0.5306
Meta AUPR: 0.2002
PPint Test-Loss: 5.4842
PPint Accuracy: 0.8712
PPint non-dimer AUROC: 0.7049
PPint non-dimer AUPR: 0.4731


Epochs:   0%|                                                                                    | 0/12 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:52,  2.51it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:29,  3.14it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:21,  3.42it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:17,  3.58it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:15,  3.69it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:13,  3.76it/s][A
Running thr

EPOCH 1:
Meta Val Loss 0.4194
Meta Accuracy: 0.8737
Meta AUROC: 0.5227
Meta AUPR: 0.1669
PPint Test-Loss: 0.1802
PPint Accuracy: 0.9175
PPint non-dimer AUROC: 0.8459
PPint non-dimer AUPR: 0.5744



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:16,  3.65it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:15,  3.71it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:15,  3.71it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:13,  3.80it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:12,  3.80it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:11,  3.83it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:12,  3.81it/s][A
Running

EPOCH 2:
Meta Val Loss 0.5098
Meta Accuracy: 0.7777
Meta AUROC: 0.5698
Meta AUPR: 0.2084
PPint Test-Loss: 0.1921
PPint Accuracy: 0.9175
PPint non-dimer AUROC: 0.8774
PPint non-dimer AUPR: 0.6096



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:15,  3.73it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:12,  3.84it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:13,  3.81it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:11,  3.87it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:12,  3.84it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:12,  3.80it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:11,  3.85it/s][A
Running

EPOCH 3:
Meta Val Loss 0.398
Meta Accuracy: 0.8754
Meta AUROC: 0.5726
Meta AUPR: 0.1887
PPint Test-Loss: 0.1639
PPint Accuracy: 0.9276
PPint non-dimer AUROC: 0.8869
PPint non-dimer AUPR: 0.6197



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:15,  3.70it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:15,  3.71it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:12,  3.83it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:12,  3.86it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:12,  3.83it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:10,  3.90it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:10,  3.88it/s][A
Running

EPOCH 4:
Meta Val Loss 0.4155
Meta Accuracy: 0.8579
Meta AUROC: 0.5597
Meta AUPR: 0.1846
PPint Test-Loss: 0.1443
PPint Accuracy: 0.9276
PPint non-dimer AUROC: 0.8811
PPint non-dimer AUPR: 0.6138



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:22,  3.40it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:15,  3.71it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:14,  3.74it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:15,  3.71it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:13,  3.75it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:13,  3.78it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:12,  3.80it/s][A
Running

EPOCH 5:
Meta Val Loss 0.4276
Meta Accuracy: 0.8567
Meta AUROC: 0.5517
Meta AUPR: 0.1571
PPint Test-Loss: 0.1484
PPint Accuracy: 0.9376
PPint non-dimer AUROC: 0.8852
PPint non-dimer AUPR: 0.6362



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:16,  3.69it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:13,  3.80it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:14,  3.72it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:12,  3.83it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:13,  3.77it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:14,  3.73it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:12,  3.80it/s][A
Running

EPOCH 6:
Meta Val Loss 0.5651
Meta Accuracy: 0.7316
Meta AUROC: 0.5587
Meta AUPR: 0.1519
PPint Test-Loss: 0.1394
PPint Accuracy: 0.9336
PPint non-dimer AUROC: 0.8995
PPint non-dimer AUPR: 0.6467



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:11,  3.91it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:10,  3.96it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:10,  3.95it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:10,  3.92it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:10,  3.92it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:10,  3.92it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:10,  3.91it/s][A
Running

EPOCH 7:
Meta Val Loss 0.4309
Meta Accuracy: 0.8794
Meta AUROC: 0.5247
Meta AUPR: 0.1399
PPint Test-Loss: 0.1543
PPint Accuracy: 0.9356
PPint non-dimer AUROC: 0.8948
PPint non-dimer AUPR: 0.6583



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:16,  3.69it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:12,  3.84it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:13,  3.79it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:13,  3.80it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:13,  3.77it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:13,  3.75it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:12,  3.80it/s][A
Running

EPOCH 8:
Meta Val Loss 0.4315
Meta Accuracy: 0.8613
Meta AUROC: 0.5664
Meta AUPR: 0.1762
PPint Test-Loss: 0.1569
PPint Accuracy: 0.9437
PPint non-dimer AUROC: 0.9077
PPint non-dimer AUPR: 0.67



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:13,  3.84it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:14,  3.75it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:12,  3.84it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:12,  3.82it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:12,  3.82it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:11,  3.84it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:11,  3.87it/s][A
Running

EPOCH 9:
Meta Val Loss 0.4367
Meta Accuracy: 0.8508
Meta AUROC: 0.5468
Meta AUPR: 0.1536
PPint Test-Loss: 0.161
PPint Accuracy: 0.9416
PPint non-dimer AUROC: 0.8992
PPint non-dimer AUPR: 0.6484



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:13,  3.80it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:12,  3.87it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:12,  3.83it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:11,  3.89it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:10,  3.91it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:09,  3.95it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:10,  3.91it/s][A
Running

EPOCH 10:
Meta Val Loss 0.4346
Meta Accuracy: 0.8836
Meta AUROC: 0.5613
Meta AUPR: 0.1615
PPint Test-Loss: 0.1693
PPint Accuracy: 0.9376
PPint non-dimer AUROC: 0.9002
PPint non-dimer AUPR: 0.6596



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:17,  3.65it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:14,  3.74it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:13,  3.80it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:12,  3.82it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:13,  3.78it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:12,  3.82it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:12,  3.82it/s][A
Running

EPOCH 11:
Meta Val Loss 0.4414
Meta Accuracy: 0.88
Meta AUROC: 0.5351
Meta AUPR: 0.1668
PPint Test-Loss: 0.2002
PPint Accuracy: 0.9396
PPint non-dimer AUROC: 0.9104
PPint non-dimer AUPR: 0.67



Running through epoch:   0%|                                                                    | 0/282 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/282 [00:00<01:17,  3.61it/s][A
Running through epoch:   1%|▍                                                           | 2/282 [00:00<01:15,  3.69it/s][A
Running through epoch:   1%|▋                                                           | 3/282 [00:00<01:13,  3.80it/s][A
Running through epoch:   1%|▊                                                           | 4/282 [00:01<01:12,  3.84it/s][A
Running through epoch:   2%|█                                                           | 5/282 [00:01<01:12,  3.81it/s][A
Running through epoch:   2%|█▎                                                          | 6/282 [00:01<01:11,  3.87it/s][A
Running through epoch:   2%|█▍                                                          | 7/282 [00:01<01:10,  3.88it/s][A
Running

EPOCH 12:
Meta Val Loss 0.4432
Meta Accuracy: 0.8717
Meta AUROC: 0.523
Meta AUPR: 0.1329
PPint Test-Loss: 0.1814
PPint Accuracy: 0.9437
PPint non-dimer AUROC: 0.9039
PPint non-dimer AUPR: 0.6601





0,1
Meta Accuracy,▁█▇███▇██████
Meta Val-AUPR,▇▄█▆▆▃▃▂▅▃▄▄▁
Meta Val-AUROC,▂▁██▆▅▆▁▇▄▆▃▁
Meta Val-loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Accuracy,▁▅▅▆▆▇▇▇██▇██
PPint Test-Loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Train-loss,█▃▃▂▂▂▁▁▁▁▁▁
PPint non-dimer AUPR,▁▅▆▆▆▇▇██▇███
PPint non-dimer AUROC,▁▆▇▇▇▇█▇█████

0,1
Meta Accuracy,0.87174
Meta Val-AUPR,0.13291
Meta Val-AUROC,0.52298
Meta Val-loss,0.44315
PPint Accuracy,0.94366
PPint Test-Loss,0.18136
PPint Train-loss,0.04119
PPint non-dimer AUPR,0.66007
PPint non-dimer AUROC,0.90394
