In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()
import training_utils.partitioning_utils as pat_utils
from tqdm import trange

NVIDIA A100-PCIE-40GB


In [58]:
import requests
requests.get("https://api.wandb.ai/status").status_code

import wandb
wandb.login(key="f8a6d759fe657b095d56bddbdb4d586dfaebd468", relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /zhome/c9/0/203261/.netrc
[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
### Setting a seed to have the same initiation of weights

def set_seed(seed: int = 42):
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # CuDNN settings (for convolution etc.)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # (Optional) for some Python hashing randomness
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 0
set_seed(SEED)

In [4]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.5.1
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [54]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1152 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

In [61]:
# ## Training variables
runID = uuid.uuid4()

## Output path
trained_model_dir = f"/work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116_{runID}"

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  42.405855232
Reserved memory:  12.14251008
Allocated memory:  11.0
Free memory:  0.217145344


### Loading PPint dataframe

In [7]:
path_to_mmseqs_clustering = "/work3/s232958/data/PPint_DB/3_å_dataset5_singlefasta/clusterRes40"
all_seqs, clust, clust_keys = pat_utils.mmseqs_parser(path_to_mmseqs_clustering)

path_to_interaction_df = "/work3/s232958/data/PPint_DB/disordered_interfaces_no_cutoff_filtered_nonredundant80_3å_5.csv.gz"
disordered_interfaces_df = pd.read_csv(path_to_interaction_df,index_col=0).reset_index(drop=True)
disordered_interfaces_df["PDB_chain_name"] = (disordered_interfaces_df["PDB"] + "_" + disordered_interfaces_df["chainname"]).tolist()
disordered_interfaces_df["index_num"] = np.arange(len(disordered_interfaces_df))
disordered_interfaces_df["chain_name_index"] = [row["PDB_chain_name"] + "_" + str(row["index_num"]) for index, row in disordered_interfaces_df.iterrows()]
disordered_interfaces_df = disordered_interfaces_df.set_index("PDB_interface_name")
disordered_interfaces_df["interface_residues"] = disordered_interfaces_df["interface_residues"].apply(lambda x: ast.literal_eval(x))
# disordered_interfaces_df["inter_chain_hamming"] = [1 - (Ldistance(seq.split("-")[0], seq.split("-")[1]))/np.max([len(seq.split("-")[0]), len(seq.split("-")[1])]) for seq in disordered_interfaces_df["protien_interface_sequences"]]
disordered_interfaces_df["dimer"] = disordered_interfaces_df["inter_chain_hamming"] > 0.60
disordered_interfaces_df["clust_keys"] = [clust_keys.get(row["chain_name_index"]) for index, row in disordered_interfaces_df.iterrows()] 

pdb_interface_and_clust_keys = {index:disordered_interfaces_df.loc[index,"clust_keys"].values.tolist() for index in tqdm(disordered_interfaces_df.index.drop_duplicates(), total=len(disordered_interfaces_df)/2)}
new_clusters, new_clusters_clustkeys = pat_utils.recluster_mmseqs_keys_to_non_overlapping_groups(pdb_interface_and_clust_keys)

### Creating train and test datasets based on train and test-idexes
train_indexes, test_indexes = pat_utils.run_train_test_partition(interaction_df=disordered_interfaces_df,
                                                    clustering=new_clusters, # Clusters from Bidentate-graphs
                                                    train_ratio=0.8, 
                                                    test_ratio=0.2, 
                                                    v=True, 
                                                    seed=0)

disordered_interfaces_df["ID"] = [row["PDB"]+"_"+str(row["interface_index"])+"_"+row["chainname"] for __, row in disordered_interfaces_df.iterrows()]
disordered_interfaces_df["PDB_interface_name"] = disordered_interfaces_df.index

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 24725/24725.0 [00:36<00:00, 672.94it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 27834/27834 [00:00<00:00, 720834.90it/s]


0.8
0.2


In [8]:
print(f"Train indeces number: {len(train_indexes)}")
print(f"Test indeces number: {len(test_indexes)}")

Train indeces number: 19781
Test indeces number: 4944


In [9]:
# Creating new dataframe with pairs of proteins (PPints)
grouped = {}
for _, row in disordered_interfaces_df.iterrows():
    iface = row["PDB_interface_name"]
    seq = row["sequence"]
    rid = row["ID"]
    dimer = row["dimer"]
    
    if iface not in grouped:
        grouped[iface] = {
            "sequences": [],
            "IDs": [],
            "dimer": dimer,        # keep the dimer value for this interface
        }
    else:
        # Optional: sanity-check it's consistent per interface
        if grouped[iface]["dimer"] != dimer:
            print(f"Warning: multiple dimers for interface {iface}:",
                  grouped[iface]['dimer'], "vs", dimer)

    grouped[iface]["sequences"].append(seq)
    grouped[iface]["IDs"].append(rid)

records = []
for iface, vals in grouped.items():
    seqs = vals["sequences"]
    ids = vals["IDs"]
    if len(seqs) >= 2 and len(ids) >= 2:
        records.append({
            "interface_id": iface,
            "seq1": seqs[0],
            "seq2": seqs[1],
            "ID1": ids[0],
            "ID2": ids[1],
            "dimer": vals["dimer"],   # <- add dimer to final record
        })

PPint_interactions_NEW = pd.DataFrame(records)
PPint_interactions_NEW["seq_target_len"] = [len(row.seq1) for __, row in PPint_interactions_NEW.iterrows()]
PPint_interactions_NEW["seq_binder_len"] = [len(row.seq2) for __, row in PPint_interactions_NEW.iterrows()]
PPint_interactions_NEW["target_binder_id"] = PPint_interactions_NEW["ID1"] + "_" + PPint_interactions_NEW["ID2"]

PPint_interactions_NEW.head()

Unnamed: 0,interface_id,seq1,seq2,ID1,ID2,dimer,seq_target_len,seq_binder_len,target_binder_id
0,6NZA_0,MNTVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIH...,TVRSEKDSMGAIDVPADKLWGAQTQRSLEHFRISTEKMPTSLIHAL...,6NZA_0_A,6NZA_0_B,True,461,459,6NZA_0_A_6NZA_0_B
1,9JKA_1,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,VAAGATLALLSFLTPLAFLLLPPLLWREELEPCGTACEGLFISVAF...,9JKA_1_B,9JKA_1_C,True,362,362,9JKA_1_B_9JKA_1_C
2,8DQ6_1,PTLNLFTNIPVDAVTCSDILKDATKAVAKIIGKPESYVMILLNSGV...,PTLNLFTNIPVDAVTCSDILKDATKAVAKIIGKPESYVMILLNSGV...,8DQ6_1_B,8DQ6_1_C,True,109,97,8DQ6_1_B_8DQ6_1_C
3,2YMZ_0,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,ARMFEMFNLDWKSGGTMKIKGHISEDAESFAINLGCKSSDLALHFN...,2YMZ_0_A,2YMZ_0_B,True,130,130,2YMZ_0_A_2YMZ_0_B
4,6IDB_0,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,6IDB_0_A,6IDB_0_B,False,317,172,6IDB_0_A_6IDB_0_B


In [10]:
len(PPint_interactions_NEW)

24725

In [11]:
# sample random 10%
train_indexes_sample = random.sample(train_indexes, int(len(train_indexes) * 0.1))
test_indexes_sample = random.sample(test_indexes, int(len(test_indexes) * 0.1))

In [12]:
Df_train = PPint_interactions_NEW[PPint_interactions_NEW.interface_id.isin(train_indexes_sample)]
Df_test = PPint_interactions_NEW[PPint_interactions_NEW.interface_id.isin(test_indexes_sample)]
Df_train

Unnamed: 0,interface_id,seq1,seq2,ID1,ID2,dimer,seq_target_len,seq_binder_len,target_binder_id
24,6GRH_2,MINILPFEIISRNTKTLLITYISSVDITHEGMKKVLESLRSKQGII...,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,6GRH_2_2,6GRH_2_D,False,284,396,6GRH_2_2_6GRH_2_D
34,8R57_0,MIISKKNRNEICKYLFQEGVLYAKKDYNLAKHPQIDVPNLQVIKLM...,TQISKKKKFVSDGVFYAELNEMLTRELAEDGYSGVEVRVTPMRTEI...,8R57_0_K,8R57_0_F,False,83,211,8R57_0_K_8R57_0_F
47,7CUJ_2,SVLDIGLPMSALQRKMMHRLVQYFAFCIDHFCTGPSDSRIQEKIRL...,IELEYKRKPIPDYDFMKGLETTLQELYVEHQSKKRR,7CUJ_2_B,7CUJ_2_D,False,256,36,7CUJ_2_B_7CUJ_2_D
53,1UZR_0,RVSAINWNRLQDEKDAEVWDRLTGNFWLPEKVPVSNDIPSWGTLTA...,DRVSAINWNRLQDEKDAEVWDRLTGNFWLPEKVPVSNDIPSWGTLT...,1UZR_0_A,1UZR_0_C,True,282,283,1UZR_0_A_1UZR_0_C
55,2B5I_0,STKKTQLQLEHLLLDLQMILNGINNYKNPKLTRMLTFKFYMPKKAT...,SQFTCFYNSRAQISCVWSQTSCQVHAWPDRRRWQQTCELLPVSQAS...,2B5I_0_A,2B5I_0_B,False,120,196,2B5I_0_A_2B5I_0_B
...,...,...,...,...,...,...,...,...,...
24648,3X2Z_1,GMKVTFLGHAVVLIEGKKNIIIDPFISGNPVCPVKLEGLPKIDYIL...,GMKVTFLGHAVVLIEGKKNIIIDPFISGNPVCPVKLEGLPKIDYIL...,3X2Z_1_B,3X2Z_1_C,True,227,227,3X2Z_1_B_3X2Z_1_C
24668,2RCZ_0,GSKVTLVKSRKNEEYGLRLASHIFVKEISQDSLAARDGNIQEGDVV...,GSKVTLVKSRKNEEYGLRLASHIFVKEISQDSLAARDGNIQEGDVV...,2RCZ_0_A,2RCZ_0_B,True,79,81,2RCZ_0_A_2RCZ_0_B
24669,3HNP_3,TLTMGFIGFGKSANRYHLPYLKTRNNIKVKTIFVRQINEELAAPYE...,TLTMGFIGFGANRYHLPYLKTRNNIKVKTIFVRQINEELAAPYEER...,3HNP_3_D,3HNP_3_F,True,344,294,3HNP_3_D_3HNP_3_F
24696,6OVP_0,MSLKVDGFTSSIIFDVIRDGLNDPSQAKQKAESIKKANAIIVFNLK...,MSLKVDGFTSSIIFDVIRDGLNDPSQAKQKAESIKKANAIIVFNLK...,6OVP_0_A,6OVP_0_B,True,118,128,6OVP_0_A_6OVP_0_B


In [13]:
Df_test

Unnamed: 0,interface_id,seq1,seq2,ID1,ID2,dimer,seq_target_len,seq_binder_len,target_binder_id
7,4POB_0,DHATVTVTDDSFQEDVVSSNKPVLVDFWATWCGPCKMVAPVLEEIA...,ATVTVTDDSFQEDVVSSNKPVLVDFWATWCGPCKMVAPVLEEIAKD...,4POB_0_A,4POB_0_B,True,107,105,4POB_0_A_4POB_0_B
12,7T6C_0,YYPFVRKALFQLDPERAHEFTFQQLRRITGTPFEALVRQKVPAKPV...,YYPFVRKALFQLDPERAHEFTFQQLRRITGTPFEALVRQKVPAKPV...,7T6C_0_A,7T6C_0_B,True,335,335,7T6C_0_A_7T6C_0_B
28,1EGP_0,LKSFPEVVGKTVDQAREYFTLHYPQYNVYFLPEGSPVTL,YNRVRVFYNPGTNVVNHVPHVG,1EGP_0_A,1EGP_0_B,False,39,22,1EGP_0_A_1EGP_0_B
62,7YH3_0,KVENPLLISLYSHYVEQILSETNSIDDANQKLRDLGKELGQQIYLN...,KVENPLLISLYSHYVEQILSETNSIDDANQKLRDLGKELGQQIYLN...,7YH3_0_A,7YH3_0_C,True,150,155,7YH3_0_A_7YH3_0_C
70,4WMO_1,GYRSCNEIKSSDSRAPDGIYTLATEDGESYQTFCDTTNGGGWTLVA...,GYRSCNEIKSSDSRAPDGIYTLATEDGESYQTFCDTTNGGGWTLVA...,4WMO_1_D,4WMO_1_E,True,271,271,4WMO_1_D_4WMO_1_E
...,...,...,...,...,...,...,...,...,...
24455,2OYS_0,NKIFIYAGVRNHNSKTLEYTKRLSSIISSRNNVDISFRTPFNSELE...,NKIFIYAGVRNHNSKTLEYTKRLSSIISSRNNVDISFRTPFNSELE...,2OYS_0_A,2OYS_0_B,True,227,227,2OYS_0_A_2OYS_0_B
24491,6XRF_1,TLYRLHEADLEIPDAWQDQSINIFKLPASGPAREASFVISRDASQG...,MDAQAAARLGDEIAHGFGVAAMVAGAVAGALIGAAVVAAATGGLAA...,6XRF_1_B,6XRF_1_C,False,140,57,6XRF_1_B_6XRF_1_C
24507,5Z2L_0,GAFTGKTVLILGGSRGIGAAIVRRFVTDGANVRFTYAGSKDAAKRL...,GAFTGKTVLILGGSRGIGAAIVRRFVTDGANVRFTYAGSKDAAKRL...,5Z2L_0_A,5Z2L_0_B,True,239,243,5Z2L_0_A_5Z2L_0_B
24617,4LRS_0,APRVRITDSTLRDGSHAMAHQFTEEQVRATVHALDAAGVEVIEVSH...,GKAVAAIVGPGNIGTDLLIKLQRSEHIEVRYMVGVDPASEGLARAR...,4LRS_0_A,4LRS_0_B,False,337,294,4LRS_0_A_4LRS_0_B


#### ESMC encodings

In [20]:
# interaction_Dict = {}

# for _, row in Df_train.iterrows():
#     key_prot, seq_prot = row['ID1'], row['seq1']
#     key_pep, seq_pep = row['ID2'], row['seq2']
#     interaction_Dict[key_prot] = seq_prot
#     interaction_Dict[key_pep] = seq_pep

# for _, row in Df_test.iterrows():
#     key_prot, seq_prot = row['ID1'], row['seq1']
#     key_pep, seq_pep = row['ID2'], row['seq2']
#     interaction_Dict[key_prot] = seq_prot
#     interaction_Dict[key_pep] = seq_pep

# assert (len(list(interaction_Dict.items())) == (len(Df_test) + len(Df_train))*2)

In [34]:
# from pathlib import Path
# from esm.models.esmc import ESMC
# from esm.models.esmc import ESMC
# from esm.sdk.api import ESMProtein, LogitsConfig
# from esm.pretrained import get_esmc_model_tokenizers  

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# tokenizer = get_esmc_model_tokenizers()
# model = ESMC(
#     d_model=1152,
#     n_heads=18,
#     n_layers=36,
#     tokenizer=tokenizer,
# ).eval()

# weights_path = Path("/work3/s232958/models/esmc-600m-2024-12/data/weights/esmc_600m_2024_12_v0.pth")
# state_dict = torch.load(weights_path, map_location=device)

# model.load_state_dict(state_dict)
# client = model.to(device)  # or whatever variable you used
# client.eval()

# def calculate_ESM_pr_res_embeddings(sequence):
#     protein = ESMProtein(sequence=sequence)
#     protein_tensor = client.encode(protein)
#     logits_output = client.logits(
#     protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
#     )
#     return logits_output.embeddings.detach().cpu().numpy()

# def to_numpy(x):
#     try:
#         return x.detach().cpu().numpy()
#     except AttributeError:
#         return np.asarray(x)

  state_dict = torch.load(weights_path, map_location=device)


In [38]:
# path_to_output_embeddings = "/work3/s232958/data/PPint_DB/embeddings_esmC"

# for name, seq in tqdm(interaction_Dict.items(), total=len(interaction_Dict.items()), desc="Embedding PPint"):
#     emb = calculate_ESM_pr_res_embeddings(seq)
#     emb_np = to_numpy(emb)
#     out_path = os.path.join(path_to_output_embeddings, f"{name}.npy")
#     np.save(out_path, emb_np)
#     # print(f"Protein {name} embedded and saved to {out_path}")

Embedding PPint: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4944/4944 [04:17<00:00, 19.20it/s]


#### I will use `len(Df_train)` of datapoints for training, `len(Df_test)` for testing and `ALL` metaanalysis datapoints for validation.

#### CLIP_PPint_analysis_dataset

In [43]:
class CLIP_PPint_analysis_dataset(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1152,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen = self.dframe["seq_binder_len"].max()
        self.max_tlen = self.dframe["seq_target_len"].max()

        # paths
        self.encoding_bpath, self.encoding_tpath = paths

        # index & storage
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESMC embeddings"):
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = "_".join(parts[:3])
            bnd_id = "_".join(parts[3:])

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_tpath, f"{tgt_id}.npy"))[0]     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_bpath, f"{bnd_id}.npy"))[0]     # [Lb, D]

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return binder_emb, target_emb, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

bemb_path = "/work3/s232958/data/PPint_DB/embeddings_esmC"
temb_path = "/work3/s232958/data/PPint_DB/embeddings_esmC"

training_Dataset = CLIP_PPint_analysis_dataset(
    Df_train,
    paths=[bemb_path, temb_path],
    embedding_dim=1152
)

testing_Dataset = CLIP_PPint_analysis_dataset(
    Df_test,
    paths=[bemb_path, temb_path],
    embedding_dim=1152
)

#Loading ESMC embeddings: 100%|████████████████████████████████████████████████████████████████████████████████| 1978/1978 [00:13<00:00, 144.56it/s]
#Loading ESMC embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████| 494/494 [00:03<00:00, 137.72it/s]


In [44]:
### Getting indeces of non-dimers
indices_non_dimers_val = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers_val[:5]

### Getting accessions of non-dimers
accessions = [Df_test.loc[index].target_binder_id for index in indices_non_dimers_val]
emb_b, emb_t, labels = testing_Dataset._get_by_name(accessions[:5])
labels

tensor([1., 1., 1., 1., 1.])

### Loading Meta validation dataset

In [45]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal.csv")[["A_seq", "B_seq", "target_id_mod", "target_binder_ID", "binder"]].rename(columns = {
    "A_seq" : "seq_binder",
    "B_seq" : "seq_target",
    "target_binder_ID" : "binder_id",
    "target_id_mod" : "target_id",
    "binder" : "binder_label"
})
interaction_df["seq_target_len"] = [len(seq) for seq in interaction_df["seq_target"].tolist()]
interaction_df["seq_binder_len"] = [len(seq) for seq in interaction_df["seq_binder"].tolist()]

# Targets df
target_df = interaction_df[["target_id","seq_target"]].rename(columns={"seq_target":"sequence", "target_id" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["binder_id","seq_binder"]].rename(columns={"seq_binder":"sequence", "binder_id" : "ID"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")

# target_df

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id"], interaction_df["binder_id"]), start=1))
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=0).reset_index(drop=True)
interaction_df_shuffled

Unnamed: 0,seq_binder,seq_target,target_id,binder_id,binder_label,seq_target_len,seq_binder_len
0,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,True,101,62
1,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,False,621,58
2,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,False,101,65
3,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,False,101,64
4,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,False,165,65
...,...,...,...,...,...,...,...
3527,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,False,195,63
3528,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,False,101,65
3529,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,False,193,63
3530,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,False,101,65


#### ESMC embedding meta-analysis dataset

In [49]:
# meta_targets, meta_binders = {}, {}

# for _, row in interaction_df_shuffled.iterrows():
#     key_prot, seq_prot = row['target_id'], row['seq_target']
#     key_pep, seq_pep = row['binder_id'], row['seq_binder']
#     if key_prot not in meta_targets.keys():
#         meta_targets[key_prot] = seq_prot
#     else:
#         pass
#     meta_binders[key_pep] = seq_pep

# from pathlib import Path
# from esm.models.esmc import ESMC
# from esm.pretrained import get_esmc_model_tokenizers  

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# tokenizer = get_esmc_model_tokenizers()
# model = ESMC(
#     d_model=1152,
#     n_heads=18,
#     n_layers=36,
#     tokenizer=tokenizer,
# ).eval()

# weights_path = Path("/work3/s232958/models/esmc-600m-2024-12/data/weights/esmc_600m_2024_12_v0.pth")
# state_dict = torch.load(weights_path, map_location=device)

# model.load_state_dict(state_dict)
# client = model.to(device)  # or whatever variable you used
# client.eval()

# def calculate_ESM_pr_res_embeddings(sequence):
#     protein = ESMProtein(sequence=sequence)
#     protein_tensor = client.encode(protein)
#     logits_output = client.logits(
#     protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
#     )
#     return logits_output.embeddings.detach().cpu().numpy()

# def to_numpy(x):
#     try:
#         return x.detach().cpu().numpy()
#     except AttributeError:
#         return np.asarray(x)

# for name, seq in tqdm(meta_targets.items(), total=len(meta_targets.items()), desc="Embedding Meta targets"):
#     emb = calculate_ESM_pr_res_embeddings(seq)
#     emb_np = to_numpy(emb)
#     out_path = os.path.join("/work3/s232958/data/meta_analysis/targets_embeddings_esmC", f"{name}.npy")
#     np.save(out_path, emb_np)
#     # print(f"Protein {name} embedded and saved to {out_path}")

# for name, seq in tqdm(meta_binders.items(), total=len(meta_binders.items()), desc="Embedding Meta binders"):
#     emb = calculate_ESM_pr_res_embeddings(seq)
#     emb_np = to_numpy(emb)
#     out_path = os.path.join("/work3/s232958/data/meta_analysis/binders_embeddings_esmC", f"{name}.npy")
#     np.save(out_path, emb_np)
#     # print(f"Protein {name} embedded and saved to {out_path}")

  state_dict = torch.load(weights_path, map_location=device)
Embedding Meta targets: 100%|███████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 19.89it/s]
Embedding Meta binders: 100%|███████████████████████████████████████████████████████████████████████████████████| 3532/3532 [02:53<00:00, 20.41it/s]


#### Loading MetaData

In [51]:
class CLIP_PPint_MetaData(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1152,
        embedding_pad_value=-5000.0
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.max_blen = self.dframe["seq_binder_len"].max()
        self.max_tlen = self.dframe["seq_target_len"].max()

        # paths
        self.encoding_bpath, self.encoding_tpath = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder_label"]))
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_tpath, f"{tgt_id}.npy"))[0]     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_bpath, f"{bnd_id}.npy"))[0]     # [Lb, D]

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

bemb_path = "/work3/s232958/data/meta_analysis/binders_embeddings_esmC"
temb_path = "/work3/s232958/data/meta_analysis/targets_embeddings_esmC"

validation_Dataset = CLIP_PPint_MetaData(
    # interaction_df_shuffled[:len(Df_test)],
    interaction_df_shuffled,
    paths=[bemb_path, temb_path],
    embedding_dim=1152
)

#Loading ESM2 embeddings: 100%|████████████████████████████████████████████████████████████████████████████████| 3532/3532 [00:17<00:00, 197.37it/s]


In [52]:
accessions_Meta = list(interaction_df_shuffled.binder_id)
emb_b, emb_t, labels = validation_Dataset._get_by_name(accessions_Meta[:5])
labels

tensor([1, 0, 0, 0, 0])

### Train model from scratch with 10% of PPint dataset using old architecture (encodings only)

In [53]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=embedding_dimension, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        # self.cross_attn = nn.MultiheadAttention(
        #     embed_dim=self.embed_dimension,
        #     num_heads=8,
        #     dropout=0.1,
        #     batch_first=True
        # )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            # pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            # prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_trans  
            prot_emb = prot_emb + prot_trans

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding), dim=-1)
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding), dim=-1)
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        embedding_binder, embedding_target, labels = batch
        embedding_binder = embedding_binder.to(device)
        embedding_target = embedding_target.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_binder, embedding_target)
            logits = logits.float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [55]:
model = MiniCLIP_w_transformer_crossattn(embed_dimension=embedding_dimension, num_recycles=number_of_recycles).to("cuda")
model

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=1152, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1152, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1152, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

### Trianing loop

In [56]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 train_loader,
                 test_loader,
                 val_loader,
                 test_df,
                 test_dataset,
                 optimizer, 
                 epochs, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = train_loader
        self.testing_loader = test_loader
        self.validation_loader = val_loader
        self.test_dataset = test_dataset
        self.test_df = test_df
        self.auROC_batch_size = auROC_batch_size
        
        self.EPOCHS = epochs
        self.optimizer = optimizer
        self.device = device
        
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = 1_000_000
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.test_indexes_for_auROC = test_indexes_for_auROC

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            del loss, batch
            torch.cuda.empty_cache()
            
        return running_loss / len(self.training_loader)

    def calc_auroc_aupr_on_indexes(self, model, dataset, dataframe, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [dataframe.loc[index].target_binder_id for index in nondimer_indexes]  # <-- use dataframe
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                binder_emb, target_emb, labels = dataset._get_by_name(index_batch)
                binder_emb, target_emb = binder_emb.to(self.device), target_emb.to(self.device)

                # Make sure this matches your model's signature:
                logit_matrix = self.model.calculate_logit_matrix(binder_emb, target_emb)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr  = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores

    def validate(self):
        
        self.model.eval()
        
        running_loss_Meta = 0.0
        all_logits = []
        all_lbls = []
        used_batches_meta = 0

        # --- MetaDataset validation ---
        with torch.no_grad():
            for batch in tqdm(self.validation_loader, total=len(self.validation_loader)):
                if batch[0].size(0) == 1:
                    continue
                embedding_binder, embedding_target, labels = batch
                logits, loss = self.model.validation_step_MetaDataset(batch, self.device)
                
                running_loss_Meta += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_meta += 1
                
            if used_batches_meta > 0:
                val_loss_Meta = running_loss_Meta / used_batches_meta
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                meta_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                meta_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_Meta = (y_pred == y_true).mean()
            else:
                val_loss_Meta = float("nan")
                meta_auroc = float("nan")
                meta_aupr = float("nan")
                val_acc_Meta = float("nan")

        # --- PPint validation ---
        running_loss_ValPPint = 0.0
        running_accuracy_ValPPint = 0.0
        running_topk_accuracy_ValPPint = 0.0
        used_batches_ppint = 0

        with torch.no_grad():
            for batch in tqdm(self.testing_loader, total=len(self.testing_loader)):
                if batch[0].size(0) == 1:
                    continue
                loss, partner_accuracy, peptide_topk_accuracy = self.model.validation_step_PPint(batch, self.device)
                running_loss_ValPPint += loss.item()
                running_accuracy_ValPPint += partner_accuracy.item()
                running_topk_accuracy_ValPPint += peptide_topk_accuracy.item()
                used_batches_ppint += 1
                
            if used_batches_ppint > 0:
                val_loss_PPint = running_loss_ValPPint / used_batches_ppint
                val_accuracy_PPint = running_accuracy_ValPPint / used_batches_ppint
                val_topk_accuracy_PPint = running_topk_accuracy_ValPPint / used_batches_ppint
            else:
                val_loss_PPint = float("nan")
                val_accuracy_PPint = float("nan")
                val_topk_accuracy_PPint = float("nan")

        # --- AUROC on specific indexes (optional) ---
        if self.test_indexes_for_auROC is not None:
            non_dimer_auc, non_dimer_aupr, ___, ___ = self.calc_auroc_aupr_on_indexes(
                model=self.model, 
                dataset=self.test_dataset,
                dataframe=self.test_df,
                nondimer_indexes=self.test_indexes_for_auROC,
                batch_size=self.auROC_batch_size
            )
            
            return (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                    non_dimer_auc, non_dimer_aupr,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

        else:
            return (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

    def train_model(self):
        
        torch.cuda.empty_cache()
        
        if self.verbose:
            print(f"Training model {str(self.runID)}")

        # --- initial validation before training
        print("Initial validation before starting training")
        if self.test_indexes_for_auROC is not None:
            (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
             non_dimer_auc, non_dimer_aupr,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
        else:
            (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            non_dimer_auc, non_dimer_aupr = None, None
                
        if self.verbose: 
            print(f'Before training:')
            print(f'Meta Val-Loss {round(val_loss_Meta,4)}')
            print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
            print(f'Meta AUROC: {round(meta_auroc,4)}')
            print(f'Meta AUPR: {round(meta_aupr,4)}')
            print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
            print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
            if non_dimer_auc is not None:
                print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
        
        if self.wandb_tracker:
            metrics_to_log = {
                "PPint Test-Loss": val_loss_PPint,
                "Meta Val-loss": val_loss_Meta,
                "PPint Accuracy": val_accuracy_PPint,
                "Meta Accuracy": val_acc_Meta,
                "Meta Val-AUROC": meta_auroc,
                "Meta Val-AUPR": meta_aupr,
            }
            if non_dimer_auc is not None:
                metrics_to_log.update({
                    "PPint non-dimer AUROC": non_dimer_auc,
                    "PPint non-dimer AUPR": non_dimer_aupr,
                })
            self.wandb_tracker.log(metrics_to_log)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC is not None:
                (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                 non_dimer_auc, non_dimer_aupr,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            else:
                (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
                non_dimer_auc, non_dimer_aupr = None, None
            
            torch.cuda.empty_cache()
            
            # checkpoint save
            if self.model_save_steps and epoch % self.model_save_steps == 0:
                check_point_folder = os.path.join(self.trained_model_dir, f"{str(self.runID)}_checkpoint_{str(epoch)}")
                if self.verbose:
                    print("Saving model to:", check_point_folder)
                os.makedirs(check_point_folder, exist_ok=True)
                checkpoint_path = os.path.join(check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")
                torch.save({'epoch': epoch, 
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(), 
                            'val_loss_PPint': val_loss_PPint,
                            'val_loss_Meta': val_loss_Meta},
                           checkpoint_path)
            
            # console logging
            if self.verbose and epoch % self.print_frequency_loss == 0:
                print(f'EPOCH {epoch}:')
                print(f'Meta Val Loss {round(val_loss_Meta,4)}')
                print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
                print(f'Meta AUROC: {round(meta_auroc,4)}')
                print(f'Meta AUPR: {round(meta_aupr,4)}')
                print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
                print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
                if non_dimer_auc is not None:
                    print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                    print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "PPint Train-loss": train_loss,
                    "PPint Test-Loss": val_loss_PPint,
                    "Meta Val-loss": val_loss_Meta,
                    "PPint Accuracy": val_accuracy_PPint,
                    "Meta Accuracy": val_acc_Meta,
                    "Meta Val-AUROC": meta_auroc,
                    "Meta Val-AUPR": meta_aupr,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch.update({
                        "PPint non-dimer AUROC": non_dimer_auc,
                        "PPint non-dimer AUPR": non_dimer_aupr,
                    })
                self.wandb_tracker.log(metrics_to_log_epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [57]:
learning_rate = 2e-5
EPOCHS = 15
g = torch.Generator().manual_seed(SEED)
batch_size = 10
optimizer = AdamW(model.parameters(), lr=learning_rate)
accelerator = Accelerator()
device = accelerator.device

def collate_varlen(batch):
    b_emb = torch.stack([x[0] for x in batch], dim=0)
    t_emb = torch.stack([x[1] for x in batch], dim=0)
    lbls = torch.tensor([x[2].float() for x in batch])
    return b_emb, t_emb, lbls

train_dataloader = DataLoader(training_Dataset, batch_size=7, collate_fn=collate_varlen)
test_dataloader = DataLoader(testing_Dataset, batch_size=7, collate_fn=collate_varlen)
val_dataloader = DataLoader(validation_Dataset, batch_size=20, shuffle=False, drop_last = False, collate_fn=collate_varlen)

# accelerator
model, optimizer, train_dataloader, test_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, val_dataloader)

In [59]:
for i in val_dataloader:
    __, __, lbls = i
    print(lbls.to(device))
    break

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,
        0., 0.], device='cuda:0')


In [60]:
# wandb
if use_wandb:
    run = wandb.init(
        project="CLIP_retrain_w_10percent_of_PPint",
        name=f"PPint0.1_PPint_test_meta_val_{runID}",
        config={"learning_rate": learning_rate, 
                "batch_size": batch_size, 
                "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", 
                "dataset": 
                "PPint"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(
            model=model,
            train_loader=train_dataloader,
            test_loader=test_dataloader,
            val_loader=val_dataloader,
            test_df=Df_test,
            test_dataset=testing_Dataset,
            optimizer=optimizer,
            epochs=EPOCHS,
            runID=runID,
            device=device,
            test_indexes_for_auROC=indices_non_dimers_val,
            auROC_batch_size=10,
            model_save_steps=model_save_steps,
            model_save_path=trained_model_dir,
            v=True,
            wandb_tracker=wandb
)

training_wrapper.train_model() # start training

Training model 0e035ba9-3cbd-4aab-be0b-cc2e37723b27
Initial validation before starting training


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 177/177 [00:27<00:00,  6.39it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:12<00:00,  5.84it/s]
Calculating AUC: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.41it/s]


Before training:
Meta Val-Loss 9.1679
Meta Accuracy: 0.1107
Meta AUROC: 0.4931
Meta AUPR: 0.1102
PPint Test-Loss: 5.4009
PPint Accuracy: 0.8249
PPint non-dimer AUROC: 0.6985
PPint non-dimer AUPR: 0.4501


Epochs:   0%|                                                                                                                | 0/15 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:38,  1.77it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:01<02:32,  1.84it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:30,  1.86it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:29,  1.86it/s][A
Running through epoch:   2%|█▌                                                             

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_1


Epochs:   7%|██████▊                                                                                                | 1/15 [03:16<45:44, 196.07s/it]

EPOCH 1:
Meta Val Loss 0.4108
Meta Accuracy: 0.8825
Meta AUROC: 0.5085
Meta AUPR: 0.134
PPint Test-Loss: 0.2455
PPint Accuracy: 0.8717
PPint non-dimer AUROC: 0.7647
PPint non-dimer AUPR: 0.5088



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:05,  2.25it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:19,  2.01it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:24,  1.94it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:26,  1.90it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_2
EPOCH 2:
Meta Val Loss 0.3925
Meta Accuracy: 0.8808
Meta AUROC: 0.5232
Meta AUPR: 0.1474
PPint Test-Loss: 0.2201
PPint Accuracy: 0.8878
PPint non-dimer AUROC: 0.7849
PPint non-dimer AUPR: 0.5256



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:03,  2.29it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:18,  2.03it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.97it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.93it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_3



Epochs:  20%|████████████████████▌                                                                                  | 3/15 [09:46<39:05, 195.43s/it]

EPOCH 3:
Meta Val Loss 0.4264
Meta Accuracy: 0.8582
Meta AUROC: 0.5146
Meta AUPR: 0.1215
PPint Test-Loss: 0.2021
PPint Accuracy: 0.8858
PPint non-dimer AUROC: 0.8076
PPint non-dimer AUPR: 0.5404



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.30it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:18,  2.03it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.97it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.93it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_4
EPOCH 4:
Meta Val Loss 0.4528
Meta Accuracy: 0.8451
Meta AUROC: 0.519
Meta AUPR: 0.1184
PPint Test-Loss: 0.2131
PPint Accuracy: 0.8833
PPint non-dimer AUROC: 0.8032
PPint non-dimer AUPR: 0.5457



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.31it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:18,  2.03it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.96it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_5
EPOCH 5:
Meta Val Loss 0.6217
Meta Accuracy: 0.6852
Meta AUROC: 0.5243
Meta AUPR: 0.1162
PPint Test-Loss: 0.2145
PPint Accuracy: 0.8913
PPint non-dimer AUROC: 0.8049
PPint non-dimer AUPR: 0.5535



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.30it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:18,  2.03it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.96it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.93it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_6


Epochs:  40%|█████████████████████████████████████████▏                                                             | 6/15 [19:29<29:11, 194.58s/it]

EPOCH 6:
Meta Val Loss 0.4948
Meta Accuracy: 0.8191
Meta AUROC: 0.5278
Meta AUPR: 0.1242
PPint Test-Loss: 0.2292
PPint Accuracy: 0.8994
PPint non-dimer AUROC: 0.8189
PPint non-dimer AUPR: 0.5587



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:03,  2.29it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:23,  1.96it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:25,  1.93it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_7


Epochs:  47%|████████████████████████████████████████████████                                                       | 7/15 [22:44<25:58, 194.80s/it]

EPOCH 7:
Meta Val Loss 0.46
Meta Accuracy: 0.859
Meta AUROC: 0.5264
Meta AUPR: 0.1261
PPint Test-Loss: 0.2376
PPint Accuracy: 0.8954
PPint non-dimer AUROC: 0.8216
PPint non-dimer AUPR: 0.5498



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:04,  2.27it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:19,  2.02it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:23,  1.95it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:26,  1.90it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_8
EPOCH 8:
Meta Val Loss 0.4567
Meta Accuracy: 0.8695
Meta AUROC: 0.4826
Meta AUPR: 0.1035
PPint Test-Loss: 0.2475
PPint Accuracy: 0.9115
PPint non-dimer AUROC: 0.8458
PPint non-dimer AUPR: 0.6012



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.31it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:18,  2.04it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.96it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.93it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_9
EPOCH 9:
Meta Val Loss 0.5098
Meta Accuracy: 0.8638
Meta AUROC: 0.4959
Meta AUPR: 0.1051
PPint Test-Loss: 0.2573
PPint Accuracy: 0.9135
PPint non-dimer AUROC: 0.8255
PPint non-dimer AUPR: 0.5657



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:03,  2.29it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:18,  2.03it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.97it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.93it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_10
EPOCH 10:
Meta Val Loss 0.5009
Meta Accuracy: 0.8423
Meta AUROC: 0.5351
Meta AUPR: 0.1272
PPint Test-Loss: 0.2578
PPint Accuracy: 0.9115
PPint non-dimer AUROC: 0.8164
PPint non-dimer AUPR: 0.5395



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.30it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:17,  2.04it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.97it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.94it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:24,  1.92it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_11
EPOCH 11:
Meta Val Loss 0.5301
Meta Accuracy: 0.8774
Meta AUROC: 0.5383
Meta AUPR: 0.1253
PPint Test-Loss: 0.2623
PPint Accuracy: 0.91
PPint non-dimer AUROC: 0.8206
PPint non-dimer AUPR: 0.5427



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.30it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:17,  2.04it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.97it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:24,  1.94it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:24,  1.92it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_12
EPOCH 12:
Meta Val Loss 0.5456
Meta Accuracy: 0.8661
Meta AUROC: 0.4998
Meta AUPR: 0.1078
PPint Test-Loss: 0.2938
PPint Accuracy: 0.9034
PPint non-dimer AUROC: 0.8124
PPint non-dimer AUPR: 0.531



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:02,  2.30it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:17,  2.04it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:22,  1.97it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:26,  1.90it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:26,  1.89it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_13
EPOCH 13:
Meta Val Loss 0.5506
Meta Accuracy: 0.87
Meta AUROC: 0.4858
Meta AUPR: 0.104
PPint Test-Loss: 0.2644
PPint Accuracy: 0.8939
PPint non-dimer AUROC: 0.8233
PPint non-dimer AUPR: 0.5604



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:03,  2.28it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:20,  1.99it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:23,  1.95it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_14
EPOCH 14:
Meta Val Loss 0.546
Meta Accuracy: 0.861
Meta AUROC: 0.5029
Meta AUPR: 0.1114
PPint Test-Loss: 0.2727
PPint Accuracy: 0.8994
PPint non-dimer AUROC: 0.814
PPint non-dimer AUPR: 0.5199



Running through epoch:   0%|                                                                                                | 0/283 [00:00<?, ?it/s][A
Running through epoch:   0%|▎                                                                                       | 1/283 [00:00<02:03,  2.28it/s][A
Running through epoch:   1%|▌                                                                                       | 2/283 [00:00<02:19,  2.02it/s][A
Running through epoch:   1%|▉                                                                                       | 3/283 [00:01<02:23,  1.95it/s][A
Running through epoch:   1%|█▏                                                                                      | 4/283 [00:02<02:25,  1.92it/s][A
Running through epoch:   2%|█▌                                                                                      | 5/283 [00:02<02:25,  1.91it/s][A
Running through epoch:   2%|█▊                                                         

Saving model to: /work3/s232958/data/trained/PPint_retrain10%_0.4_Christian/251116/0e035ba9-3cbd-4aab-be0b-cc2e37723b27_checkpoint_15
EPOCH 15:
Meta Val Loss 0.5387
Meta Accuracy: 0.8423
Meta AUROC: 0.5258
Meta AUPR: 0.1244
PPint Test-Loss: 0.2467
PPint Accuracy: 0.9034
PPint non-dimer AUROC: 0.8272
PPint non-dimer AUPR: 0.5547


0,1
Meta Accuracy,▁████▆▇█████████
Meta Val-AUPR,▂▆█▄▃▃▄▅▁▁▅▄▂▁▂▄
Meta Val-AUROC,▂▄▆▅▆▆▇▇▁▃██▃▁▄▆
Meta Val-loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
PPint Accuracy,▁▅▆▆▆▆▇▇████▇▆▇▇
PPint Test-Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
PPint Train-loss,█▅▄▃▂▂▂▂▂▁▁▁▁▁▁
PPint non-dimer AUPR,▁▄▄▅▅▆▆▆█▆▅▅▅▆▄▆
PPint non-dimer AUROC,▁▄▅▆▆▆▇▇█▇▇▇▆▇▆▇

0,1
Meta Accuracy,0.8423
Meta Val-AUPR,0.12445
Meta Val-AUROC,0.52578
Meta Val-loss,0.5387
PPint Accuracy,0.90342
PPint Test-Loss,0.24673
PPint Train-loss,0.02655
PPint non-dimer AUPR,0.55466
PPint non-dimer AUROC,0.82718
