In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random
import matplotlib.pyplot as plt

from sklearn import metrics
from scipy import stats
from collections import Counter
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay, classification_report, roc_curve

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()
import training_utils.partitioning_utils as pat_utils
from tqdm import trange

Tesla V100-SXM2-32GB


  warn(
  _torch_pytree._register_pytree_node(


In [2]:
import requests
requests.get("https://api.wandb.ai/status").status_code

import wandb
wandb.login(key="f8a6d759fe657b095d56bddbdb4d586dfaebd468", relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /zhome/c9/0/203261/.netrc
[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
### Setting a seed to have the same initiation of weights

def set_seed(seed: int = 42):
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # CuDNN settings (for convolution etc.)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # (Optional) for some Python hashing randomness
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 0
set_seed(SEED)

In [4]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.9.1+cu128
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [5]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 3
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 512 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

In [6]:
# ## Training variables
runID = uuid.uuid4()

## Output path
trained_model_dir = f"/work3/s232958/data/trained/with_structure/{runID}"

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


## ESM2 + ESM-IF

In [7]:
Df_train = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_test = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test_w_pbd_lens.csv",index_col=0).reset_index(drop=True)

Df_train["target_chain"] = [str(row.ID1[:5]+row.ID1[-1]) for __, row in Df_train.iterrows()]
Df_train["binder_chain"] = [str(row.ID2[:5]+row.ID2[-1]) for __, row in Df_train.iterrows()]

Df_test["target_chain"] = [str(row.ID1[:5]+row.ID1[-1]) for __, row in Df_test.iterrows()]
Df_test["binder_chain"] = [str(row.ID2[:5]+row.ID2[-1]) for __, row in Df_test.iterrows()]

Df_train["target_binder_id"] = [str(row.ID1)+"_"+str(row.ID2) for __, row in Df_train.iterrows()]
Df_test["target_binder_id"] = [str(row.ID1)+"_"+str(row.ID2) for __, row in Df_test.iterrows()]

Df_train_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
Df_test_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)

Df_train = pd.merge(Df_train, Df_train_small[["target_binder_id", "dimer"]], on="target_binder_id", how="inner")
Df_test = pd.merge(Df_test, Df_test_small[["target_binder_id", "dimer"]], on="target_binder_id", how="inner")

Df_train

Unnamed: 0,interface_id,PDB,ID1,ID2,seq_target,seq_target_len,seq_pdb_target,pdb_target_len,target_chain,seq_binder,seq_binder_len,seq_pdb_binder,pdb_binder_len,binder_chain,pdb_path,target_binder_id,dimer
0,6IDB_0,6IDB,6IDB_0_A,6IDB_0_B,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,6IDB_A,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,6IDB_B,6idb.pdb.gz,6IDB_0_A_6IDB_0_B,False
1,2WZP_3,2WZP,2WZP_3_D,2WZP_3_G,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,2WZP_D,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,2WZP_G,2wzp.pdb.gz,2WZP_3_D_2WZP_3_G,False
2,1ZKP_0,1ZKP,1ZKP_0_A,1ZKP_0_C,LYFQSNAKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLA...,246,LYFQSNAMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGV...,251,1ZKP_A,AKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQKYI...,240,AMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQK...,245,1ZKP_C,1zkp.pdb.gz,1ZKP_0_A_1ZKP_0_C,True
3,6GRH_3,6GRH,6GRH_3_C,6GRH_3_D,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,6GRH_C,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,6GRH_D,6grh.pdb.gz,6GRH_3_C_6GRH_3_D,False
4,8R57_1,8R57,8R57_1_M,8R57_1_f,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,8R57_M,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,8R57_f,8r57.pdb.gz,8R57_1_M_8R57_1_f,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1972,4YO8_0,4YO8,4YO8_0_A,4YO8_0_B,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,4YO8_A,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,4YO8_B,4yo8.pdb.gz,4YO8_0_A_4YO8_0_B,True
1973,3CKI_0,3CKI,3CKI_0_A,3CKI_0_B,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,3CKI_A,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,3CKI_B,3cki.pdb.gz,3CKI_0_A_3CKI_0_B,False
1974,7MHY_1,7MHY,7MHY_1_M,7MHY_1_N,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,7MHY_M,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,7MHY_N,7mhy.pdb.gz,7MHY_1_M_7MHY_1_N,False
1975,7MHY_2,7MHY,7MHY_2_O,7MHY_2_P,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,7MHY_O,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,7MHY_P,7mhy.pdb.gz,7MHY_2_O_7MHY_2_P,False


In [8]:
class CLIP_PPint_w_esmIF(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim_struct=512,
        embedding_dim_seq=1280,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim_seq = embedding_dim_seq
        self.embedding_dim_struct = embedding_dim_struct
        self.emb_pad = embedding_pad_value

        # lengths
        self.max_blen_seq = self.dframe["seq_binder_len"].max()+2
        self.max_tlen_seq = self.dframe["seq_target_len"].max()+2
        self.max_blen_struct = self.dframe["pdb_binder_len"].max()+2
        self.max_tlen_struct = self.dframe["pdb_target_len"].max()+2

        # paths
        self.seq_encodings_path, self.struct_encodings_path = paths

        # index & storage
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = parts[0]+"_"+parts[2]
            bnd_id = parts[-3]+"_"+parts[-1]

            ### --- SEQ embeddings (pad to fixed lengths) --- ###
            # laod embeddings
            t_emb_seq = np.load(os.path.join(self.seq_encodings_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb_seq = np.load(os.path.join(self.seq_encodings_path, f"{bnd_id}.npy"))     # [Lb, D]
            t_emb_struct = np.load(os.path.join(self.struct_encodings_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb_struct = np.load(os.path.join(self.struct_encodings_path, f"{bnd_id}.npy"))     # [Lb, D]

            assert (b_emb_seq.shape[0] == self.dframe.loc[accession].pdb_binder_len+2 == b_emb_struct.shape[0])
            assert (t_emb_seq.shape[0] == self.dframe.loc[accession].pdb_target_len+2 == t_emb_struct.shape[0])

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb_seq.shape[1] != self.embedding_dim_seq or b_emb_seq.shape[1] != self.embedding_dim_seq:
                raise ValueError("Embedding dim mismatch with 'embedding_dim_seq'.")
            if t_emb_struct.shape[1] != self.embedding_dim_struct or b_emb_struct.shape[1] != self.embedding_dim_struct:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")
                
            # add -5000 to all the padded target rows
                ### SEQ_embeddings ###
            if t_emb_seq.shape[0] < self.max_tlen_seq:
                t_emb_seq = np.concatenate([t_emb_seq, np.full((self.max_tlen_seq - t_emb_seq.shape[0], t_emb_seq.shape[1]), self.emb_pad, dtype=t_emb_seq.dtype)], axis=0)
            else:
                t_emb_seq = t_emb_seq[: self.max_tlen_seq] # no padding was used
            if b_emb_seq.shape[0] < self.max_blen_seq:
                b_emb_seq = np.concatenate([b_emb_seq, np.full((self.max_blen_seq - b_emb_seq.shape[0], b_emb_seq.shape[1]), self.emb_pad, dtype=b_emb_seq.dtype)], axis=0)
            else:
                b_emb_seq = b_emb_seq[: self.max_blen_seq] # no padding was used

                ### STRUCT_embeddings ###
            if t_emb_struct.shape[0] < self.max_tlen_struct:
                t_emb_struct = np.concatenate([t_emb_struct, np.full((self.max_tlen_struct - t_emb_struct.shape[0], t_emb_struct.shape[1]), self.emb_pad, dtype=t_emb_struct.dtype)], axis=0)
            else:
                t_emb_struct = t_emb_struct[: self.max_tlen_struct] # no padding was used
            if b_emb_struct.shape[0] < self.max_blen_struct:
                b_emb_struct = np.concatenate([b_emb_struct, np.full((self.max_blen_struct - b_emb_struct.shape[0], b_emb_struct.shape[1]), self.emb_pad, dtype=b_emb_struct.dtype)], axis=0)
            else:
                b_emb_struct = b_emb_struct[: self.max_blen_struct] # no padding was used

            self.samples.append((b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct = self.samples[idx]
        b_emb_seq, t_emb_seq = torch.from_numpy(b_emb_seq).float(), torch.from_numpy(t_emb_seq).float()
        b_emb_struct, t_emb_struct = torch.from_numpy(b_emb_struct).float(), torch.from_numpy(t_emb_struct).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_emb_seq_list, t_emb_seq_list, b_emb_struct_list, t_emb_struct_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b_emb_seq  = torch.stack([torch.as_tensor(x) for x in b_emb_seq_list],  dim=0)  # [B, ...]
        t_emb_seq  = torch.stack([torch.as_tensor(x) for x in t_emb_seq_list],  dim=0)  # [B, ...]
        
        b_emb_struct  = torch.stack([torch.as_tensor(x) for x in b_emb_struct_list],  dim=0)  # [B, ...]
        t_emb_struct  = torch.stack([torch.as_tensor(x) for x in t_emb_struct_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, labels

emb_seq_path = "/work3/s232958/data/PPint_DB/embeddings_esm2"
emb_struct_path = "/work3/s232958/data/PPint_DB/esmif_embeddings_noncanonical"

training_Dataset = CLIP_PPint_w_esmIF(
    Df_train,
    paths=[emb_seq_path, emb_struct_path],
    embedding_dim_seq=1280,
    embedding_dim_struct=512
)

testing_Dataset = CLIP_PPint_w_esmIF(
    Df_test,
    paths=[emb_seq_path, emb_struct_path],
    embedding_dim_seq=1280,
    embedding_dim_struct=512
)

#Loading ESM2 embeddings and contacts: 100%|████████████████████████████████████████| 1977/1977 [00:34<00:00, 56.76it/s]
#Loading ESM2 embeddings and contacts: 100%|██████████████████████████████████████████| 494/494 [00:10<00:00, 48.53it/s]


In [9]:
### Getting indeces of non-dimers
indices_non_dimers_val = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers_val[:5]

### Getting accessions of non-dimers
accessions = [Df_test.loc[index].target_binder_id for index in indices_non_dimers_val]
b, t, bct, tct, labels = testing_Dataset._get_by_name(accessions[:5])
labels

tensor([1., 1., 1., 1., 1.])

In [10]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal_w_pbd_lens.csv").drop(columns = ["binder_id", "target_id"]).rename(columns = {
    "target_id_mod" : "target_id",
    "target_binder_ID" : "binder_id",
})

# Interaction Dict
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=0).reset_index(drop=True)
interaction_df_shuffled

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id,binder_id,seq_len_binder,seq_len_target,pdb_len_binder,pdb_len_target
0,A,"[""B""]",True,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,62,101,62,101
1,A,"[""B""]",False,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,58,621,58,621
2,A,"[""B""]",False,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,65,101,65,101
3,A,"[""B""]",False,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,64,101,64,101
4,A,"[""B""]",False,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,65,165,65,165
...,...,...,...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,63,195,63,195
3528,A,"[""B""]",False,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,65,101,65,101
3529,A,"[""B""]",False,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,63,193,63,193
3530,A,"[""B""]",False,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,65,101,65,101


In [11]:
class CLIP_metadata_w_esmIF(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        seq_embedding_dim=1280,
        struct_embedding_dim=512,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim_seq = int(seq_embedding_dim)
        self.embedding_dim_struct = int(struct_embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen_seq = self.dframe["seq_len_binder"].max()+2
        self.max_tlen_seq = self.dframe["seq_len_target"].max()+2
        self.max_blen_struct = self.dframe["pdb_len_binder"].max()+2
        self.max_tlen_struct = self.dframe["pdb_len_target"].max()+2

        # paths
        self.seq_bembed, self.seq_tembed, self.struct_bembed, self.struct_tembed = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- SEQ embeddings (pad to fixed lengths) --- ###
            # laod embeddings
            b_emb_seq = np.load(os.path.join(self.seq_bembed, f"{bnd_id}.npy"))     # [Lb, D]
            t_emb_seq = np.load(os.path.join(self.seq_tembed, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb_struct = np.load(os.path.join(self.struct_bembed, f"{bnd_id}.npy"))     # [Lb, D]
            t_emb_struct = np.load(os.path.join(self.struct_tembed, f"{tgt_id}.npy"))     # [Lt, D]

            assert (b_emb_seq.shape[0] == self.dframe.loc[accession].pdb_len_binder+2 == b_emb_struct.shape[0])
            assert (t_emb_seq.shape[0] == self.dframe.loc[accession].pdb_len_target+2 == t_emb_struct.shape[0])

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb_seq.shape[1] != self.embedding_dim_seq or b_emb_seq.shape[1] != self.embedding_dim_seq:
                raise ValueError("Embedding dim mismatch with 'embedding_dim_seq'.")
            if t_emb_struct.shape[1] != self.embedding_dim_struct or b_emb_struct.shape[1] != self.embedding_dim_struct:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")
                
            # add -5000 to all the padded target rows
                ### SEQ_embeddings ###
            if t_emb_seq.shape[0] < self.max_tlen_seq:
                t_emb_seq = np.concatenate([t_emb_seq, np.full((self.max_tlen_seq - t_emb_seq.shape[0], t_emb_seq.shape[1]), self.emb_pad, dtype=t_emb_seq.dtype)], axis=0)
            else:
                t_emb_seq = t_emb_seq[: self.max_tlen_seq] # no padding was used
            if b_emb_seq.shape[0] < self.max_blen_seq:
                b_emb_seq = np.concatenate([b_emb_seq, np.full((self.max_blen_seq - b_emb_seq.shape[0], b_emb_seq.shape[1]), self.emb_pad, dtype=b_emb_seq.dtype)], axis=0)
            else:
                b_emb_seq = b_emb_seq[: self.max_blen_seq] # no padding was used

                ### STRUCT_embeddings ###
            if t_emb_struct.shape[0] < self.max_tlen_struct:
                t_emb_struct = np.concatenate([t_emb_struct, np.full((self.max_tlen_struct - t_emb_struct.shape[0], t_emb_struct.shape[1]), self.emb_pad, dtype=t_emb_struct.dtype)], axis=0)
            else:
                t_emb_struct = t_emb_struct[: self.max_tlen_struct] # no padding was used
            if b_emb_struct.shape[0] < self.max_blen_struct:
                b_emb_struct = np.concatenate([b_emb_struct, np.full((self.max_blen_struct - b_emb_struct.shape[0], b_emb_struct.shape[1]), self.emb_pad, dtype=b_emb_struct.dtype)], axis=0)
            else:
                b_emb_struct = b_emb_struct[: self.max_blen_struct] # no padding was used

            self.samples.append((b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, lbl))
        
    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, lbls = self.samples[idx]
        b_emb_seq, t_emb_seq = torch.from_numpy(b_emb_seq).float(), torch.from_numpy(t_emb_seq).float()
        b_emb_struct, t_emb_struct = torch.from_numpy(b_emb_struct).float(), torch.from_numpy(t_emb_struct).float()
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_emb_seq_list, t_emb_seq_list, b_emb_struct_list, t_emb_struct_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b_emb_seq  = torch.stack([torch.as_tensor(x) for x in b_emb_seq_list],  dim=0)  # [B, ...]
        t_emb_seq  = torch.stack([torch.as_tensor(x) for x in t_emb_seq_list],  dim=0)  # [B, ...]
        
        b_emb_struct  = torch.stack([torch.as_tensor(x) for x in b_emb_struct_list],  dim=0)  # [B, ...]
        t_emb_struct  = torch.stack([torch.as_tensor(x) for x in t_emb_struct_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b_emb_seq, t_emb_seq, b_emb_struct, t_emb_struct, labels

esm2_path_binders = "/work3/s232958/data/meta_analysis/embeddings_esm2_binders"
esm2_path_targets = "/work3/s232958/data/meta_analysis/embeddings_esm2_targets"

## Contact maps paths
esmIF_path_binders = "/work3/s232958/data/meta_analysis/esmif_embeddings_binders"
esmIF_path_targets = "/work3/s232958/data/meta_analysis/esmif_embeddings_targets"

validation_Dataset = CLIP_metadata_w_esmIF(
    interaction_df_shuffled,
    paths=[esm2_path_binders, esm2_path_targets, esmIF_path_binders, esmIF_path_targets],
)

#Loading ESM2 embeddings and contacts: 100%|████████████████████████████████████████| 3532/3532 [00:38<00:00, 92.38it/s]


In [12]:
accessions_Meta = list(interaction_df_shuffled.binder_id)
emb_b_Seq, emb_t_Seq, emb_b_Struct, emb_t_Struct, labels = validation_Dataset._get_by_name(accessions_Meta[:5])
labels

tensor([1, 0, 0, 0, 0])

In [25]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, seq_embed_dimension=1280, struct_embed_dimension=512, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.seq_embed_dimension = seq_embed_dimension
        self.struct_embed_dimension = struct_embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        self.struct_alpha = nn.Parameter(torch.tensor(0.0)) # Sigmoid(0) = 0.5

        # --- SEQUENCE embeddings --- #
        
        self.norm_seq = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        self.seq_encoder =  nn.TransformerEncoderLayer(
            d_model=self.seq_embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.seq_embed_dimension
            )

        self.seq_cross_attn = nn.MultiheadAttention(
            embed_dim=self.seq_embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.seq_proj = nn.Sequential(
            nn.Linear(self.seq_embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )

        # --- STRUCTURE embeddings --- #

        self.norm_struct = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        self.initial_stuct_proj = nn.Linear(self.struct_embed_dimension, self.seq_embed_dimension)

        self.struct_encoder =  nn.TransformerEncoderLayer(
            d_model=self.seq_embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.seq_embed_dimension
            )

        self.struct_to_seq_attn = nn.MultiheadAttention(
            embed_dim=self.seq_embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
    def forward(self, pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True):
        
        # Key padding masks (True = pad -> to be ignored by attention)
        pep_seq_mask = create_key_padding_mask(embeddings = pep_seq_emb, padding_value = self.padding_value).to(device)   # [B, Lp]
        prot_seq_mask = create_key_padding_mask(embeddings = prot_seq_emb, padding_value = self.padding_value).to(device)    # [B, Lt]
        
        pep_struct_mask = create_key_padding_mask(embeddings = pep_struct_emb, padding_value = self.padding_value).to(device)     # [B, Lp_cm]
        prot_struct_mask = create_key_padding_mask(embeddings = prot_struct_emb, padding_value = self.padding_value).to(device)     # [B, Lt_cm]
    
        # Residual states
        pep_seq_emb = pep_seq_emb.to(device)
        prot_seq_emb = prot_seq_emb.to(device)
        
        # Project 512-dim structural embeddings to 1280-dim sequence space
        pep_struct_emb_proj = self.initial_stuct_proj(pep_struct_emb.to(device))
        prot_struct_emb_proj = self.initial_stuct_proj(prot_struct_emb.to(device))
    
        for _ in range(self.num_recycles):
            
            # --- Self-attention encoders (sequence streams) ---
            pep_trans_seq = self.seq_encoder(self.norm_seq(pep_seq_emb), src_key_padding_mask=pep_seq_mask)   # [B, Lp, E]
            prot_trans_seq = self.seq_encoder(self.norm_seq(prot_seq_emb), src_key_padding_mask=prot_seq_mask)  # [B, Lt, E]
    
            # --- Self-attention encoders (structure streams) ---
            pep_trans_str = self.struct_encoder(self.norm_struct(pep_struct_emb_proj), src_key_padding_mask=pep_struct_mask)   # [B, Lp_cm, E]
            prot_trans_str = self.struct_encoder(self.norm_struct(prot_struct_emb_proj), src_key_padding_mask=prot_struct_mask)  # [B, Lt_cm, E]

            # --- Cross-attend to structures ---
            pep_struct_upd, _ = self.struct_to_seq_attn(query=self.norm_seq(pep_trans_seq), key=self.norm_struct(pep_trans_str), value=self.norm_struct(pep_trans_str), key_padding_mask=pep_struct_mask)
            prot_struct_upd, _ = self.struct_to_seq_attn(query=self.norm_seq(prot_trans_seq), key=self.norm_struct(prot_trans_str), value=self.norm_struct(prot_trans_str), key_padding_mask=prot_struct_mask)

            current_alpha = torch.sigmoid(self.struct_alpha)
            pep_trans_emb  = pep_trans_seq  + current_alpha * pep_struct_upd    # [B, Lp, E]
            prot_trans_emb = prot_trans_seq + current_alpha * prot_struct_upd    # [B, Lt, E]
    
            # --- Cross-attend binder vs target ---
            pep_cross,  _  = self.seq_cross_attn(query=self.norm_seq(pep_trans_seq), key=self.norm_seq(prot_trans_seq), value=self.norm_seq(prot_trans_seq), key_padding_mask=prot_seq_mask)
            prot_cross, _  = self.seq_cross_attn(query=self.norm_seq(prot_trans_seq), key=self.norm_seq(pep_trans_seq), value=self.norm_seq(pep_trans_seq), key_padding_mask=pep_seq_mask)
    
            # --- Residual updates ---
            pep_seq_emb = pep_seq_emb + pep_cross
            prot_seq_emb = prot_seq_emb + prot_cross
    
        # Pool (mean over non-masked positions)
        pep_seq_coding   = create_mean_of_non_masked(pep_seq_emb, pep_seq_mask)
        prot_seq_coding  = create_mean_of_non_masked(prot_seq_emb, prot_seq_mask)

        # Projections + L2-normalize
        pep_full   = F.normalize(self.seq_proj(pep_seq_coding),   dim=-1)
        prot_full  = F.normalize(self.seq_proj(prot_seq_coding),  dim=-1)
    
        if mem_save:
            torch.cuda.empty_cache()
    
        scale  = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_full * prot_full).sum(dim=-1)  # [B]
        
        return logits

    def training_step(self, batch, device):
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch

        # loss of predicting partner using peptide
        positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device)) # F.binary_cross_entropy_with_logits does sigmoid transfromation inside, excepts data, labels
        
        # Negative indexes
        rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1) # upper triangle
        
        # loss of predicting peptide using partner
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        # loss = (positive_loss + negative_loss) / 2
        # return loss
        
        base_loss = (positive_loss + negative_loss) / 2

        lambda_l1 = 1e-3 
        l1_penalty = lambda_l1 * torch.abs(torch.sigmoid(self.struct_alpha))
        
        total_loss = base_loss + l1_penalty
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return total_loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb = pep_seq_emb.to(device), prot_seq_emb.to(device), pep_struct_emb.to(device), prot_struct_emb.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1)
            
            negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2

            logit_matrix = torch.zeros((pep_seq_emb.size(0), pep_seq_emb.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(pep_seq_emb.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(pep_seq_emb.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            return loss, peptide_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb, labels = batch
        pep_seq_emb, prot_seq_emb = pep_seq_emb.to(device), prot_seq_emb.to(device) 
        pep_struct_emb, prot_struct_emb = pep_struct_emb.to(device), prot_struct_emb.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb).float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self, pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb):
        
        rows, cols = torch.triu_indices(pep_seq_emb.size(0), pep_seq_emb.size(0), offset=1)
        positive_logits = self.forward(pep_seq_emb, prot_seq_emb, pep_struct_emb, prot_struct_emb)
        negative_logits = self.forward(pep_seq_emb[rows,:,:], prot_seq_emb[cols,:,:], pep_struct_emb[rows,:,:], prot_struct_emb[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((pep_seq_emb.size(0),pep_seq_emb.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(pep_seq_emb.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [26]:
model = MiniCLIP_w_transformer_crossattn(
    seq_embed_dimension=1280,
    struct_embed_dimension=512,
    num_recycles=number_of_recycles
).to("cuda")

model

MiniCLIP_w_transformer_crossattn(
  (norm_seq): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (seq_encoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
    )
    (linear1): Linear(in_features=1280, out_features=1280, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1280, out_features=1280, bias=True)
    (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (seq_cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
  )
  (seq_proj): Sequential(
    (0): Linear(in_features=1280, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=320, b

In [27]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 train_loader,
                 test_loader,
                 val_loader,
                 test_df,
                 test_dataset,
                 optimizer, 
                 epochs, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = train_loader
        self.testing_loader = test_loader
        self.validation_loader = val_loader
        self.test_dataset = test_dataset
        self.test_df = test_df
        self.auROC_batch_size = auROC_batch_size
        
        self.EPOCHS = epochs
        self.optimizer = optimizer
        self.device = device
        
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = 1_000_000
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.test_indexes_for_auROC = test_indexes_for_auROC

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            del loss, batch
            torch.cuda.empty_cache()
            
        return running_loss / len(self.training_loader)

    def calc_auroc_aupr_on_indexes(self, model, dataset, dataframe, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [dataframe.loc[index].target_binder_id for index in nondimer_indexes]  # <-- use dataframe
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = dataset._get_by_name(index_batch)
                embedding_pep, embedding_prot = embedding_pep.to(self.device), embedding_prot.to(self.device)

                # Make sure this matches your model's signature:
                logit_matrix = self.model.calculate_logit_matrix(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr  = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores

    def validate(self):
        
        self.model.eval()
        
        running_loss_Meta = 0.0
        all_logits = []
        all_lbls = []
        used_batches_meta = 0

        # --- MetaDataset validation ---
        with torch.no_grad():
            for batch in tqdm(self.validation_loader, total=len(self.validation_loader)):
                if batch[0].size(0) == 1:
                    continue
                __, __, __, __, labels = batch
                logits, loss = self.model.validation_step_MetaDataset(batch, self.device)
                
                running_loss_Meta += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_meta += 1
                
            if used_batches_meta > 0:
                val_loss_Meta = running_loss_Meta / used_batches_meta
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                meta_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                meta_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_Meta = (y_pred == y_true).mean()
            else:
                val_loss_Meta = float("nan")
                meta_auroc = float("nan")
                meta_aupr = float("nan")
                val_acc_Meta = float("nan")

        # --- PPint validation ---
        running_loss_ValPPint = 0.0
        running_accuracy_ValPPint = 0.0
        used_batches_ppint = 0

        with torch.no_grad():
            for batch in tqdm(self.testing_loader, total=len(self.testing_loader)):
                if batch[0].size(0) == 1:
                    continue
                loss, partner_accuracy = self.model.validation_step_PPint(batch, self.device)
                running_loss_ValPPint += loss.item()
                running_accuracy_ValPPint += partner_accuracy.item()
                used_batches_ppint += 1
                
            if used_batches_ppint > 0:
                val_loss_PPint = running_loss_ValPPint / used_batches_ppint
                val_accuracy_PPint = running_accuracy_ValPPint / used_batches_ppint
            else:
                val_loss_PPint = float("nan")
                val_accuracy_PPint = float("nan")

        # --- AUROC on specific indexes (optional) ---
        if self.test_indexes_for_auROC is not None:
            non_dimer_auc, non_dimer_aupr, ___, ___ = self.calc_auroc_aupr_on_indexes(
                model=self.model, 
                dataset=self.test_dataset,
                dataframe=self.test_df,
                nondimer_indexes=self.test_indexes_for_auROC,
                batch_size=self.auROC_batch_size
            )
            
            return (val_loss_PPint, val_accuracy_PPint,
                    non_dimer_auc, non_dimer_aupr,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

        else:
            return (val_loss_PPint, val_accuracy_PPint,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

    def train_model(self):
        
        torch.cuda.empty_cache()
        
        if self.verbose:
            print(f"Training model {str(self.runID)}")

        # --- initial validation before training
        print("Initial validation before starting training")
        if self.test_indexes_for_auROC is not None:
            (val_loss_PPint, val_accuracy_PPint,
             non_dimer_auc, non_dimer_aupr,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
        else:
            (val_loss_PPint, val_accuracy_PPint,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            non_dimer_auc, non_dimer_aupr = None, None
                
        if self.verbose: 
            print(f'Before training:')
            print(f'Meta Val-Loss {round(val_loss_Meta,4)}')
            print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
            print(f'Meta AUROC: {round(meta_auroc,4)}')
            print(f'Meta AUPR: {round(meta_aupr,4)}')
            print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
            print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
            if non_dimer_auc is not None:
                print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
        
        if self.wandb_tracker:
            metrics_to_log = {
                "PPint Test-Loss": val_loss_PPint,
                "Meta Val-loss": val_loss_Meta,
                "PPint Accuracy": val_accuracy_PPint,
                "Meta Accuracy": val_acc_Meta,
                "Meta Val-AUROC": meta_auroc,
                "Meta Val-AUPR": meta_aupr,
            }
            if non_dimer_auc is not None:
                metrics_to_log.update({
                    "PPint non-dimer AUROC": non_dimer_auc,
                    "PPint non-dimer AUPR": non_dimer_aupr,
                })
            self.wandb_tracker.log(metrics_to_log)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC is not None:
                (val_loss_PPint, val_accuracy_PPint,
                 non_dimer_auc, non_dimer_aupr,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            else:
                (val_loss_PPint, val_accuracy_PPint,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
                non_dimer_auc, non_dimer_aupr = None, None
            
            torch.cuda.empty_cache()
            
            # checkpoint save
            if self.model_save_steps and epoch % self.model_save_steps == 0:
                check_point_folder = os.path.join(self.trained_model_dir, f"{str(self.runID)}_checkpoint_{str(epoch)}")
                if self.verbose:
                    print("Saving model to:", check_point_folder)
                os.makedirs(check_point_folder, exist_ok=True)
                checkpoint_path = os.path.join(check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")
                torch.save({'epoch': epoch, 
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(), 
                            'val_loss_PPint': val_loss_PPint,
                            'val_loss_Meta': val_loss_Meta},
                           checkpoint_path)
            
            # console logging
            if self.verbose and epoch % self.print_frequency_loss == 0:
                print(f'EPOCH {epoch}:')
                print(f'Meta Val Loss {round(val_loss_Meta,4)}')
                print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
                print(f'Meta AUROC: {round(meta_auroc,4)}')
                print(f'Meta AUPR: {round(meta_aupr,4)}')
                print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
                print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
                if non_dimer_auc is not None:
                    print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                    print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "PPint Train-loss": train_loss,
                    "PPint Test-Loss": val_loss_PPint,
                    "Meta Val-loss": val_loss_Meta,
                    "PPint Accuracy": val_accuracy_PPint,
                    "Meta Accuracy": val_acc_Meta,
                    "Meta Val-AUROC": meta_auroc,
                    "Meta Val-AUPR": meta_aupr,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch.update({
                        "PPint non-dimer AUROC": non_dimer_auc,
                        "PPint non-dimer AUPR": non_dimer_aupr,
                    })
                self.wandb_tracker.log(metrics_to_log_epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [28]:
train_dataloader = DataLoader(training_Dataset, batch_size=10, shuffle=True, drop_last=True)
test_dataloader = DataLoader(testing_Dataset, batch_size=10, shuffle=False)
val_dataloader = DataLoader(validation_Dataset, batch_size=10, shuffle=False, drop_last = False)

for i in val_dataloader:
    __, __, __, __, lbls = i
    print(lbls.to(device))
    break

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')


In [29]:
runID = uuid.uuid4()
learning_rate = 2e-5
EPOCHS = 12
batch_size = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# optimizer = AdamW(model.parameters(), lr=learning_rate)
# accelerator = Accelerator()
# model, optimizer, train_dataloader, test_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, val_dataloader)
# device = accelerator.device

scaler_params = ['struct_alpha', 'pe_scale']
params_scalers = [p for n, p in model.named_parameters() if any(s in n for s in scaler_params)]
params_others = [p for n, p in model.named_parameters() if not any(s in n for s in scaler_params)]

optimizer_grouped_parameters = [
    {'params': params_scalers, 'lr': 1e-3}, 
    {'params': params_others, 'lr': 2e-5}
]

optimizer = AdamW(optimizer_grouped_parameters)

# 2. Accelerator setup
accelerator = Accelerator()

# Prepare everything (Accelerator handles moving to device)
model, optimizer, train_dataloader, test_dataloader, val_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, test_dataloader, val_dataloader
)

device = accelerator.device

In [30]:
# wandb
if use_wandb:
    run = wandb.init(
        project="CLIP_retrain_w_PPint0.1",
        name=f"ESM2&SMIF_scaler(sigmoid(0.5))_increaseLR_L1Loss_{runID}",
        config={"learning_rate": learning_rate, 
                "batch_size": batch_size, 
                "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", 
                "dataset": 
                "PPint"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(
            model=model,
            train_loader=train_dataloader,
            test_loader=test_dataloader,
            val_loader=val_dataloader,
            test_df=Df_test,
            test_dataset=testing_Dataset,
            optimizer=optimizer,
            epochs=EPOCHS,
            runID=runID,
            device=device,
            test_indexes_for_auROC=indices_non_dimers_val,
            auROC_batch_size=10,
            model_save_steps=model_save_steps,
            model_save_path=trained_model_dir,
            v=True,
            wandb_tracker=wandb
)

training_wrapper.train_model() # start training

Training model 7b9cf6d2-8b2d-4644-ab69-4cbaf2caec25
Initial validation before starting training


100%|█████████████████████████████████████████████████████████████████████████████████| 354/354 [01:03<00:00,  5.60it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 50/50 [00:39<00:00,  1.25it/s]
Calculating AUC: 13it [00:10,  1.25it/s]                                                                                


Before training:
Meta Val-Loss 10.444
Meta Accuracy: 0.1107
Meta AUROC: 0.5097
Meta AUPR: 0.1239
PPint Test-Loss: 6.231
PPint Accuracy: 0.832
PPint non-dimer AUROC: 0.6506
PPint non-dimer AUPR: 0.3726


Epochs:   0%|                                                                                    | 0/12 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:04,  1.86s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:01,  1.85s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:00,  1.86s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:57,  1.85s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:55,  1.85s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:54,  1.86s/it][A
Running thr

EPOCH 1:
Meta Val Loss 0.4153
Meta Accuracy: 0.8867
Meta AUROC: 0.4813
Meta AUPR: 0.1042
PPint Test-Loss: 0.2394
PPint Accuracy: 0.868
PPint non-dimer AUROC: 0.7265
PPint non-dimer AUPR: 0.4475



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:10,  1.89s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:10,  1.90s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:06,  1.89s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<06:05,  1.90s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<06:01,  1.88s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:58,  1.88s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:55,  1.87s/it][A
Running

EPOCH 2:
Meta Val Loss 0.4226
Meta Accuracy: 0.8831
Meta AUROC: 0.4849
Meta AUPR: 0.1135
PPint Test-Loss: 0.2442
PPint Accuracy: 0.866
PPint non-dimer AUROC: 0.769
PPint non-dimer AUPR: 0.4672



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:07,  1.87s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:04,  1.87s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:01,  1.86s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:59,  1.86s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:57,  1.86s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:55,  1.86s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:53,  1.86s/it][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/ab0444eb-1438-4513-a60f-6fef0340bb95/7b9cf6d2-8b2d-4644-ab69-4cbaf2caec25_checkpoint_3


Epochs:  25%|██████████████████▎                                                      | 3/12 [24:01<1:12:05, 480.66s/it]

EPOCH 3:
Meta Val Loss 0.4071
Meta Accuracy: 0.8884
Meta AUROC: 0.4825
Meta AUPR: 0.1062
PPint Test-Loss: 0.2371
PPint Accuracy: 0.856
PPint non-dimer AUROC: 0.8109
PPint non-dimer AUPR: 0.5071



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:02,  1.85s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:02,  1.86s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:00,  1.86s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<06:00,  1.87s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:58,  1.87s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:55,  1.86s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:53,  1.86s/it][A
Running

EPOCH 4:
Meta Val Loss 0.4531
Meta Accuracy: 0.8672
Meta AUROC: 0.4751
Meta AUPR: 0.1033
PPint Test-Loss: 0.2217
PPint Accuracy: 0.882
PPint non-dimer AUROC: 0.7912
PPint non-dimer AUPR: 0.5286



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:02<06:43,  2.06s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:18,  1.94s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:08,  1.90s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<06:02,  1.88s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:58,  1.87s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:55,  1.86s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:53,  1.86s/it][A
Running

EPOCH 5:
Meta Val Loss 0.5867
Meta Accuracy: 0.8879
Meta AUROC: 0.5132
Meta AUPR: 0.1281
PPint Test-Loss: 0.2109
PPint Accuracy: 0.88
PPint non-dimer AUROC: 0.8218
PPint non-dimer AUPR: 0.5309



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:03,  1.86s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:02,  1.86s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:00,  1.86s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:58,  1.86s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:55,  1.85s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:54,  1.86s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:53,  1.86s/it][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/ab0444eb-1438-4513-a60f-6fef0340bb95/7b9cf6d2-8b2d-4644-ab69-4cbaf2caec25_checkpoint_6


Epochs:  50%|█████████████████████████████████████▌                                     | 6/12 [48:04<48:04, 480.72s/it]

EPOCH 6:
Meta Val Loss 0.4929
Meta Accuracy: 0.889
Meta AUROC: 0.5385
Meta AUPR: 0.1458
PPint Test-Loss: 0.2681
PPint Accuracy: 0.884
PPint non-dimer AUROC: 0.8179
PPint non-dimer AUPR: 0.5257



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:06,  1.87s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:06,  1.88s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:00,  1.86s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:58,  1.86s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:56,  1.86s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:54,  1.85s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:51,  1.85s/it][A
Running

EPOCH 7:
Meta Val Loss 0.4563
Meta Accuracy: 0.8573
Meta AUROC: 0.5037
Meta AUPR: 0.1139
PPint Test-Loss: 0.2053
PPint Accuracy: 0.884
PPint non-dimer AUROC: 0.7933
PPint non-dimer AUPR: 0.4871



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:00,  1.84s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<05:59,  1.85s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<05:59,  1.85s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:57,  1.85s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:56,  1.85s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:53,  1.85s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:12<05:51,  1.85s/it][A
Running

EPOCH 8:
Meta Val Loss 0.463
Meta Accuracy: 0.8887
Meta AUROC: 0.5032
Meta AUPR: 0.1109
PPint Test-Loss: 0.241
PPint Accuracy: 0.884
PPint non-dimer AUROC: 0.8227
PPint non-dimer AUPR: 0.5412



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:05,  1.86s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:06,  1.88s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:03,  1.87s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<06:00,  1.87s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:57,  1.86s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:54,  1.85s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:51,  1.85s/it][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/ab0444eb-1438-4513-a60f-6fef0340bb95/7b9cf6d2-8b2d-4644-ab69-4cbaf2caec25_checkpoint_9


Epochs:  75%|██████████████████████████████████████████████████████▊                  | 9/12 [1:12:03<24:00, 480.02s/it]

EPOCH 9:
Meta Val Loss 0.447
Meta Accuracy: 0.8865
Meta AUROC: 0.503
Meta AUPR: 0.111
PPint Test-Loss: 0.2321
PPint Accuracy: 0.896
PPint non-dimer AUROC: 0.8165
PPint non-dimer AUPR: 0.5471



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<05:58,  1.83s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:00,  1.85s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<05:58,  1.85s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:55,  1.84s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:54,  1.85s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:52,  1.85s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:12<05:50,  1.84s/it][A
Running

EPOCH 10:
Meta Val Loss 0.4731
Meta Accuracy: 0.848
Meta AUROC: 0.5227
Meta AUPR: 0.1146
PPint Test-Loss: 0.2186
PPint Accuracy: 0.896
PPint non-dimer AUROC: 0.8152
PPint non-dimer AUPR: 0.5249



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<05:57,  1.83s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<05:59,  1.84s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<05:58,  1.85s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:56,  1.85s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:54,  1.85s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:53,  1.85s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:12<05:51,  1.85s/it][A
Running

EPOCH 11:
Meta Val Loss 0.4652
Meta Accuracy: 0.8876
Meta AUROC: 0.5592
Meta AUPR: 0.1469
PPint Test-Loss: 0.2793
PPint Accuracy: 0.884
PPint non-dimer AUROC: 0.8228
PPint non-dimer AUPR: 0.5424



Running through epoch:   0%|                                                                    | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                           | 1/197 [00:01<06:06,  1.87s/it][A
Running through epoch:   1%|▌                                                           | 2/197 [00:03<06:04,  1.87s/it][A
Running through epoch:   2%|▉                                                           | 3/197 [00:05<06:00,  1.86s/it][A
Running through epoch:   2%|█▏                                                          | 4/197 [00:07<05:59,  1.86s/it][A
Running through epoch:   3%|█▌                                                          | 5/197 [00:09<05:56,  1.86s/it][A
Running through epoch:   3%|█▊                                                          | 6/197 [00:11<05:54,  1.86s/it][A
Running through epoch:   4%|██▏                                                         | 7/197 [00:13<05:53,  1.86s/it][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/ab0444eb-1438-4513-a60f-6fef0340bb95/7b9cf6d2-8b2d-4644-ab69-4cbaf2caec25_checkpoint_12


Epochs: 100%|████████████████████████████████████████████████████████████████████████| 12/12 [1:36:00<00:00, 480.06s/it]

EPOCH 12:
Meta Val Loss 0.5471
Meta Accuracy: 0.8867
Meta AUROC: 0.5355
Meta AUPR: 0.1239
PPint Test-Loss: 0.3168
PPint Accuracy: 0.892
PPint non-dimer AUROC: 0.8241
PPint non-dimer AUPR: 0.537





0,1
Meta Accuracy,▁████████████
Meta Val-AUPR,▄▁▃▁▁▅█▃▂▂▃█▄
Meta Val-AUROC,▄▂▂▂▁▄▆▃▃▃▅█▆
Meta Val-loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Accuracy,▁▅▅▄▆▆▇▇▇██▇█
PPint Test-Loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Train-loss,█▄▃▃▂▂▂▂▁▁▁▁
PPint non-dimer AUPR,▁▄▅▆▇▇▇▆██▇██
PPint non-dimer AUROC,▁▄▆▇▇██▇█████

0,1
Meta Accuracy,0.88675
Meta Val-AUPR,0.12394
Meta Val-AUROC,0.53548
Meta Val-loss,0.54705
PPint Accuracy,0.892
PPint Test-Loss,0.31678
PPint Train-loss,0.05423
PPint non-dimer AUPR,0.537
PPint non-dimer AUROC,0.82407
