In [7]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 3 on the node
print(torch.cuda.get_device_name(0))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, TensorDataset
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator

import matplotlib.pyplot as plt
import seaborn as sns
import training_utils.dataset_utils as data_utils
import training_utils.partitioning_utils as pat_utils

import importlib
# import training_utils.train_utils as train_utils
# importlib.reload(train_utils)

Tesla V100-SXM2-32GB


In [8]:
### Setting a seed to have the same initiation of weights

def set_seed(seed: int = 42):
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # CuDNN settings (for convolution etc.)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # (Optional) for some Python hashing randomness
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 0
set_seed(SEED)

In [9]:
import requests
requests.get("https://api.wandb.ai/status").status_code

import wandb
wandb.login(key="f8a6d759fe657b095d56bddbdb4d586dfaebd468", relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /zhome/c9/0/203261/.netrc


True

In [10]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 3
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1152 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

In [11]:
# ## Training variables
runID = uuid.uuid4()

## Output path
trained_model_dir = f"/work3/s232958/data/trained/fine_tunning_on_meta/{runID}"

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


### Loading Metadata (will be used for fine-tuning)

In [12]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal.csv")[["A_seq", "B_seq", "target_id_mod", "target_binder_ID", "binder"]].rename(columns = {
    "A_seq" : "seq_binder",
    "B_seq" : "seq_target",
    "target_binder_ID" : "binder_id",
    "target_id_mod" : "target_id",
    "binder" : "binder_label"
})
interaction_df["seq_target_len"] = [len(seq) for seq in interaction_df["seq_target"].tolist()]
interaction_df["seq_binder_len"] = [len(seq) for seq in interaction_df["seq_binder"].tolist()]

# Targets df
target_df = interaction_df[["target_id","seq_target"]].rename(columns={"seq_target":"sequence", "target_id" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["binder_id","seq_binder"]].rename(columns={"seq_binder":"sequence", "binder_id" : "ID"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")

# target_df

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id"], interaction_df["binder_id"]), start=1))
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=0).reset_index(drop=True)

### Weights for binder/non-binders
N_bins = len(interaction_df_shuffled["binder_label"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df_shuffled["binder_label"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df_shuffled["class_weight"] = interaction_df_shuffled["binder_label"].apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])
# binder_nonbinder_weights_Dict = dict(zip(interaction_df["target_binder_ID"], interaction_df["class_weight"]))

### Weights for target
N_bins = len(interaction_df_shuffled["target_id"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df_shuffled["target_id"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df_shuffled["target_weight"] = interaction_df_shuffled["target_id"].apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])

### Combined weights
interaction_df_shuffled["combined_weight"] = (interaction_df_shuffled["class_weight"]+interaction_df_shuffled["target_weight"])/2
interaction_df_shuffled

Unnamed: 0,seq_binder,seq_target,target_id,binder_id,binder_label,seq_target_len,seq_binder_len,class_weight,target_weight,combined_weight
0,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,True,101,62,0.001279,0.000029,0.000654
1,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,False,621,58,0.000159,0.000207,0.000183
2,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,False,101,65,0.000159,0.000029,0.000094
3,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,False,101,64,0.000159,0.000029,0.000094
4,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,False,165,65,0.000159,0.000947,0.000553
...,...,...,...,...,...,...,...,...,...,...
3527,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,False,195,63,0.000159,0.000631,0.000395
3528,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,False,101,65,0.000159,0.000029,0.000094
3529,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,False,193,63,0.000159,0.000365,0.000262
3530,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,False,101,65,0.000159,0.000029,0.000094


In [13]:
### Combined weights (boost positives, reduce number of FGFR2 binder)
multipliers = []
for name in interaction_df_shuffled["target_id"]:
    if name == "FGFR2":
        multipliers.append(1)
    else:
        multipliers.append(5)
interaction_df_shuffled["target_weight_FGFR2_reduced"] = interaction_df_shuffled["target_weight"] * multipliers

### Combined weights
# interaction_df["combined_weight"] = (interaction_df["class_weight"]+interaction_df["target_weight"])/2

### Combined weights (boost positives)
multipliers = []
for binder in interaction_df_shuffled["binder_label"]:
    if binder == False:
        multipliers.append(1)
    else:
        multipliers.append(2)
interaction_df_shuffled["combined_weight_boost_pos"] = ((interaction_df_shuffled["class_weight"]+interaction_df_shuffled["target_weight_FGFR2_reduced"])/2) * multipliers
interaction_df_shuffled

Unnamed: 0,seq_binder,seq_target,target_id,binder_id,binder_label,seq_target_len,seq_binder_len,class_weight,target_weight,combined_weight,target_weight_FGFR2_reduced,combined_weight_boost_pos
0,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,True,101,62,0.001279,0.000029,0.000654,0.000029,0.001308
1,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,False,621,58,0.000159,0.000207,0.000183,0.001035,0.000597
2,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,False,101,65,0.000159,0.000029,0.000094,0.000029,0.000094
3,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,False,101,64,0.000159,0.000029,0.000094,0.000029,0.000094
4,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,False,165,65,0.000159,0.000947,0.000553,0.004735,0.002447
...,...,...,...,...,...,...,...,...,...,...,...,...
3527,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,False,195,63,0.000159,0.000631,0.000395,0.003157,0.001658
3528,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,False,101,65,0.000159,0.000029,0.000094,0.000029,0.000094
3529,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,False,193,63,0.000159,0.000365,0.000262,0.001827,0.000993
3530,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,False,101,65,0.000159,0.000029,0.000094,0.000029,0.000094


In [14]:
class CLIP_Meta_class(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1152,
        embedding_pad_value=-5000.0
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.max_blen = self.dframe["seq_binder_len"].max()+2
        self.max_tlen = self.dframe["seq_target_len"].max()+2

        # paths
        self.encoding_bpath, self.encoding_tpath = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESMC embeddings"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder_label"]))
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_tpath, f"{tgt_id}.npy"))[0]     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_bpath, f"{bnd_id}.npy"))[0]     # [Lb, D]

            # print(b_emb.shape[0], self.dframe.loc[accession].seq_binder_len)
            assert (b_emb.shape[0] == self.dframe.loc[accession].seq_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].seq_target_len+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

bemb_path = "/work3/s232958/data/meta_analysis/embeddings_esmC_binders"
temb_path = "/work3/s232958/data/meta_analysis/embeddings_esmC_targets"

finetuning_Dataset = CLIP_Meta_class(
    # interaction_df_shuffled[:len(Df_test)],
    interaction_df_shuffled,
    paths=[bemb_path, temb_path],
    embedding_dim=1152
)

#Loading ESMC embeddings: 100%|████████████████████████████████████████████████████| 3532/3532 [00:22<00:00, 156.82it/s]


In [15]:
accessions = list(interaction_df_shuffled.binder_id)
b, t, labels = finetuning_Dataset._get_by_name(accessions[16:30])
labels

tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])

### Loating Boltzgen data (validation set 1)

In [16]:
boltzgen_df = pd.read_csv("/work3/s232958/data/boltzgen/boltzgen_df_filtered.csv")
boltzgen_df

Unnamed: 0,binder_id,binder_seq,target_id,target_seq,binder_type,binder,boltz_iptm,af3_iptm,binder_id2,len_binder_seq,len_target_seq
0,pdgfrprot_16471,SHFVIGTAEAKSDSDEDIREALEKAANEAAEKAGLPPVKLTSVEIK...,pdgfr,LVVTPPGPELVLNVSSTFVLTCSGSAPVVWERMSQEPPQEMAKAQD...,prot,False,0.937140,0.88,pdgfr_1,89,289
1,insulinprot_34946,NPVVEEARKLLEKAKELLDEARKLLEEGDYEKAKELIEEAEKLLKE...,insulin,HLYPGEVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDF...,prot,False,0.170498,0.45,insulin_1,85,894
2,pdgfrprot_35947,ITEEQRKELIEKAAELVVKAIEEGKLASEVKKELKEFAKKLGVELT...,pdgfr,LVVTPPGPELVLNVSSTFVLTCSGSAPVVWERMSQEPPQEMAKAQD...,prot,True,0.945968,0.79,pdgfr_2,81,289
3,insulinnano_52317,EVQLVESGGGLVQPGGSLRLSCAASGFTFSNYAMGWFRQAPGKGRE...,insulin,HLYPGEVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDF...,nano,False,0.200709,0.14,insulin_2,132,894
4,1g13prot_19735,GKLSGKQLLELFKEKVKKLLEGKEELTREEVLEIVEKAVEETVKEA...,1g13,SSFSWDNCDEGKDPAVIRSLTLEPDPIIVPGNVTLSVMGSTSVPLS...,prot,False,0.851551,0.83,1g13_1,112,162
...,...,...,...,...,...,...,...,...,...,...,...
417,3qkgprot_47078,AVYTAVLTNTETGKEFTGTGKTPEEALRNAAEKFGREEGLGLEEVI...,3qkg,GPVPTPPDNIQVQENFNISRIYGKWYNLAIGSTSPWLKKIMDRMTV...,prot,False,0.600571,0.85,3qkg_26,86,193
418,3qkgprot_02705,ATEKVTVTCPLTGKEITVEIPVPPTVESLADAVVEIAKKCGLYATH...,3qkg,GPVPTPPDNIQVQENFNISRIYGKWYNLAIGSTSPWLKKIMDRMTV...,prot,True,0.871538,0.74,3qkg_27,84,193
419,3qkgprot_42882,APMTFKITLKNVETGVVEEVTVTAESAKAALEEALVKFNIDPFSIA...,3qkg,GPVPTPPDNIQVQENFNISRIYGKWYNLAIGSTSPWLKKIMDRMTV...,prot,False,0.894542,0.78,3qkg_28,92,193
420,3qkgprot_39630,AKRAIELAKAGRLEEAVEAVVEAAREKGLSDEEADLVRQGLVYAVE...,3qkg,GPVPTPPDNIQVQENFNISRIYGKWYNLAIGSTSPWLKKIMDRMTV...,prot,False,0.870742,0.88,3qkg_29,82,193


In [17]:
class CLIP_Boltzgen_class(Dataset):
    def __init__(
        self,
        dframe,
        path,
        embedding_dim=1152,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.encoding_path = path

        # lengths
        self.max_blen = self.dframe["len_binder_seq"].max()+2
        self.max_tlen = self.dframe["len_target_seq"].max()+2

        # index & storage
        self.dframe.set_index("binder_id2", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESMC embeddings"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = parts[0]
            bnd_id = accession
            # lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_path, f"{tgt_id}.npy"))[0]     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_path, f"{bnd_id}.npy"))[0]     # [Lb, D]

            # print(b_emb.shape[0], self.dframe.loc[accession].seq_binder_len)
            assert (b_emb.shape[0] == self.dframe.loc[accession].len_binder_seq+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].len_target_seq+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

validation_Boltzgen = CLIP_Boltzgen_class(
    boltzgen_df, 
    path = "/work3/s232958/data/boltzgen/embeddings_esmC", 
    embedding_dim=1152
)

#Loading ESMC embeddings: 100%|██████████████████████████████████████████████████████| 422/422 [00:02<00:00, 156.15it/s]


### Loading PPint (validation set 2)

In [18]:
Df_train = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
Df_test = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)

In [19]:
class CLIP_PPint_class(Dataset):
    def __init__(
        self,
        dframe,
        path,
        embedding_dim=1280,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen = self.dframe["seq_binder_len"].max()+2
        self.max_tlen = self.dframe["seq_target_len"].max()+2

        # paths
        self.encoding_path  = path

        # index & storage
        self.dframe.set_index("target_binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            parts = accession.split("_")
            tgt_id = parts[0]+"_"+parts[2]
            bnd_id = parts[3]+"_"+parts[5]

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_path, f"{tgt_id}.npy")).squeeze(0) # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_path, f"{bnd_id}.npy")).squeeze(0) # [Lb, D]

            # print(b_emb.shape[0], self.dframe.loc[accession].seq_binder_len)
            assert (b_emb.shape[0] == self.dframe.loc[accession].seq_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].seq_target_len+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return binder_emb, target_emb, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

emb_path = "/work3/s232958/data/PPint_DB/embeddings_esmC"

validation_PPint = CLIP_PPint_class(
    Df_test,
    path=emb_path,
    embedding_dim=1152
)

#Loading ESM2 embeddings and contacts: 100%|█████████████████████████████████████████| 494/494 [00:03<00:00, 134.60it/s]


In [20]:
### Getting indeces of non-dimers
indices_non_dimers_val = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers_val[:5]

### Getting accessions of non-dimers
accessions = [Df_test.loc[index].target_binder_id for index in indices_non_dimers_val]
emb_b, emb_t, labels = validation_PPint._get_by_name(accessions[10:17])

indices_non_dimers_val[:5]

[2, 10, 13, 17, 18]

### Loading Bindcraft (validation set 3)

In [21]:
bindcraft_df = pd.read_csv("./bindcraft_scores.csv")

In [22]:
class CLIP_bindcraft_dataset(Dataset):
    def __init__(
        self,
        dframe,
        path,
        embedding_dim=1152,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.encoding_path = path

        # lengths
        self.max_blen = self.dframe["len_seq_binder"].max()
        self.max_tlen = self.dframe["len_seq_target"].max()

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESMC embeddings"):
            tgt_id = self.dframe.loc[accession]["target_id"]
            bnd_id = accession
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder"]))

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_path, f"{tgt_id}.npy"))[0]     # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_path, f"{bnd_id}.npy"))[0]     # [Lb, D]

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

validation_bindcraft = CLIP_bindcraft_dataset(
    bindcraft_df, 
    path = "/work3/s232958/data/bindcraft/embeddings_esmC", 
    embedding_dim=1152
)

#Loading ESMC embeddings: 100%|██████████████████████████████████████████████████████| 150/150 [00:00<00:00, 183.31it/s]


### Loading pretrained model for finetuning

In [32]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=embedding_dimension, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding), dim=-1)
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding), dim=-1)
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            # k = 3
            # peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy#, peptide_topk_accuracy
    
    def validation_step_Boltzgen(self, batch, device):
        embedding_binder, embedding_target, labels = batch
        embedding_binder = embedding_binder.to(device)
        embedding_target = embedding_target.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_binder, embedding_target)
            logits = logits.float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def validation_step_Bindcraft(self, batch, device):
        embedding_binder, embedding_target, labels = batch
        embedding_binder = embedding_binder.to(device)
        embedding_target = embedding_target.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_binder, embedding_target)
            logits = logits.float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [33]:
ckpt_path = '../PPI_PLM/models/CLIP_no_structural_information/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_4/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_epoch_4.pth'
checkpoint = torch.load(ckpt_path, weights_only=False, map_location="cpu")
# print(list(checkpoint["model_state_dict"]))
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn()
model.load_state_dict(checkpoint['model_state_dict'])
torch.cuda.empty_cache()  # frees cached blocks (not live tensors)
device = torch.device("cuda:0")
model.to(device)
# model.train()

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=1152, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1152, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1152, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

### Fine-tunning on meta_analysis dataset, validation on boltzgen

In [34]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]


class TrainWrapper:

    def __init__(self, 
                 model, 
                 train_loader, # meta
                 boltzgen_loader, # boltzgen
                 PPint_loader, # PPInt
                 bindcraft_loader, # bindcraft
                 test_df, # PPInt_val_df
                 PPint_dataset,     # validation_PPint
                 optimizer, 
                 epochs, 
                 runID, 
                 device, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False,
                 test_indexes_for_auROC=None,
                 auROC_batch_size=10):
        
        self.model = model 
        self.training_loader = train_loader
        self.boltzgen_loader = boltzgen_loader   # Boltzgen
        self.PPint_loader = PPint_loader       # PPInt
        self.bindcraft_loader = bindcraft_loader
        self.PPint_dataset = PPint_dataset
        self.test_df = test_df
        
        self.EPOCHS = epochs
        self.optimizer = optimizer
        self.device = device
        
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = 1_000_000
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1

        # for AUROC on specific indexes
        self.test_indexes_for_auROC = test_indexes_for_auROC
        self.auROC_batch_size = auROC_batch_size

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0.0

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1:
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            del loss, batch
            # torch.cuda.empty_cache()
            
        return running_loss / len(self.training_loader)

    def calc_auroc_aupr_on_indexes(self, model, dataset, dataframe, nondimer_indexes, batch_size=10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [dataframe.loc[index].target_binder_id for index in nondimer_indexes]
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local,
                                    total=int(len(accessions) / batch_size),
                                    desc="Calculating AUC"):

                binder_emb, target_emb, labels = dataset._get_by_name(index_batch)
                binder_emb, target_emb = binder_emb.to(self.device), target_emb.to(self.device)

                logit_matrix = self.model.calculate_logit_matrix(binder_emb, target_emb)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1] * len(all_TP_scores) + [0] * len(all_FP_scores))
                
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr  = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores

    def validate(self):
        
        self.model.eval()
        
        # --- Boltzgen validation ---
        running_loss_boltz = 0.0
        all_logits = []
        all_lbls = []
        used_batches_boltz = 0

        with torch.no_grad():
            for batch in tqdm(self.boltzgen_loader, total=len(self.boltzgen_loader),
                              desc="Boltzgen validation"):
                if batch[0].size(0) == 1:
                    continue

                embedding_binder, embedding_target, labels = batch
                logits, loss = self.model.validation_step_Boltzgen(batch, self.device)
                
                running_loss_boltz += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_boltz += 1
                
            if used_batches_boltz > 0:
                val_loss_boltz = running_loss_boltz / used_batches_boltz
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                boltz_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                boltz_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_boltz = (y_pred == y_true).mean()
            else:
                val_loss_boltz = float("nan")
                boltz_auroc = float("nan")
                boltz_aupr = float("nan")
                val_acc_boltz = float("nan")

        # --- Bindcraft validation ---
        running_loss_bindcraft = 0.0
        all_logits = []
        all_lbls = []
        used_batches_bindcraft = 0

        with torch.no_grad():
            for batch in tqdm(self.bindcraft_loader, total=len(self.bindcraft_loader),
                              desc="Bindcraft validation"):
                if batch[0].size(0) == 1:
                    continue

                embedding_binder, embedding_target, labels = batch
                logits, loss = self.model.validation_step_Bindcraft(batch, self.device)
                
                running_loss_bindcraft += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_bindcraft += 1
                
            if used_batches_bindcraft > 0:
                val_loss_bindcraft = running_loss_bindcraft / used_batches_bindcraft
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                bindcraft_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                bindcraft_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_bindcraft = (y_pred == y_true).mean()
            else:
                val_loss_bindcraft = float("nan")
                bindcraft_auroc = float("nan")
                bindcraft_aupr = float("nan")
                val_acc_bindcraft = float("nan")

        # --- PPInt validation ---
        running_loss_PPint = 0.0
        running_accuracy_PPint = 0.0
        # running_topk_accuracy_PPint = 0.0
        used_batches_ppint = 0

        with torch.no_grad():
            for batch in tqdm(self.PPint_loader, total=len(self.PPint_loader),
                              desc="PPInt validation"):
                if batch[0].size(0) == 1:
                    continue
                # loss, partner_accuracy, peptide_topk_accuracy = self.model.validation_step_PPint(batch, self.device)
                loss, partner_accuracy = self.model.validation_step_PPint(batch, self.device)
                running_loss_PPint += loss.item()
                running_accuracy_PPint += partner_accuracy.item()
                # running_topk_accuracy_PPint += peptide_topk_accuracy.item()
                used_batches_ppint += 1
                
            if used_batches_ppint > 0:
                val_loss_PPint = running_loss_PPint / used_batches_ppint
                val_accuracy_PPint = running_accuracy_PPint / used_batches_ppint
                # val_topk_accuracy_PPint = running_topk_accuracy_PPint / used_batches_ppint
            else:
                val_loss_PPint = float("nan")
                val_accuracy_PPint = float("nan")
                # val_topk_accuracy_PPint = float("nan")

        # --- AUROC on specific indexes (optional) ---
        if self.test_indexes_for_auROC is not None:
            non_dimer_auc, non_dimer_aupr, ___, ___ = self.calc_auroc_aupr_on_indexes(
                model=self.model, 
                dataset=self.PPint_dataset,
                dataframe=self.test_df,
                nondimer_indexes=self.test_indexes_for_auROC,
                batch_size=self.auROC_batch_size
            )
        else:
            non_dimer_auc = float("nan")
            non_dimer_aupr = float("nan")
            
        # Always return the same structure
        return (val_loss_PPint, val_accuracy_PPint, #val_topk_accuracy_PPint,
                non_dimer_auc, non_dimer_aupr,
                val_loss_boltz, val_acc_boltz, boltz_auroc, boltz_aupr,
                val_loss_bindcraft, val_acc_bindcraft, bindcraft_auroc, bindcraft_aupr)

    def train_model(self):

        if self.verbose:
            print(f"Training model {str(self.runID)}")
        
        # Pre-training snapshot
        (val_loss_PPint, val_accuracy_PPint, #val_topk_accuracy_PPint,
        non_dimer_auc, non_dimer_aupr,
        val_loss_boltz, val_acc_boltz, boltz_auroc, boltz_aupr,
        val_loss_bindcraft, val_acc_bindcraft, bindcraft_auroc, bindcraft_aupr) = self.validate()

        if self.verbose:
            print("Before training")
            print(f"Boltzgen Val Loss      {val_loss_boltz:.4f}")
            print(f"Boltzgen Val AUROC     {boltz_auroc if boltz_auroc == boltz_auroc else float('nan'):.4f}")
            print(f"Boltzgen Val AUPR      {boltz_aupr  if boltz_aupr  == boltz_aupr  else float('nan'):.4f}")
            
            print(f"Bindcraft Val Loss      {val_loss_bindcraft:.4f}")
            print(f"Bindcraft Val AUROC     {bindcraft_auroc if bindcraft_auroc == bindcraft_auroc else float('nan'):.4f}")
            print(f"Bindcraft Val AUPR      {bindcraft_aupr  if bindcraft_aupr  == bindcraft_aupr  else float('nan'):.4f}")
            
            print(f"PPInt Val Loss       {val_loss_PPint:.4f}")
            print(f"PPInt Val Accuracy   {val_accuracy_PPint if val_accuracy_PPint == val_accuracy_PPint else float('nan'):.4f}")
            # print(f"PPInt   Val Top-k Acc  {val_topk_accuracy_PPint if val_topk_accuracy_PPint == val_topk_accuracy_PPint else float('nan'):.4f}")
            
            print(f"Non-dimer AUROC        {non_dimer_auc if non_dimer_auc == non_dimer_auc else float('nan'):.4f}")
            print(f"Non-dimer AUPR         {non_dimer_aupr if non_dimer_aupr == non_dimer_aupr else float('nan'):.4f}")

        if self.wandb_tracker:
            log_items = {
                "Boltzgen Val Loss": val_loss_boltz,
                "Boltzgen Val AUROC": boltz_auroc,
                "Boltzgen Val AUPR": boltz_aupr,

                "Bindcraft Val Loss": val_loss_bindcraft,
                "Bindcraft Val AUROC": bindcraft_auroc,
                "Bindcraft Val AUPR": bindcraft_aupr,
                
                "PPInt Val Loss": val_loss_PPint,
                "PPInt Val Accuracy": val_accuracy_PPint,
                # "PPInt Val Topk Accuracy": val_topk_accuracy_PPint,
                
                "Non-dimer AUROC": non_dimer_auc,
                "Non-dimer AUPR": non_dimer_aupr,
            }
            self.wandb_tracker.log(log_items)
            
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            (val_loss_PPint, val_accuracy_PPint, #val_topk_accuracy_PPint,
            non_dimer_auc, non_dimer_aupr,
            val_loss_boltz, val_acc_boltz, boltz_auroc, boltz_aupr,
            val_loss_bindcraft, val_acc_bindcraft, bindcraft_auroc, bindcraft_aupr) = self.validate()
            
            torch.cuda.empty_cache()
            
            # checkpoint save
            # if self.model_save_steps and epoch % self.model_save_steps == 0:
            #     check_point_folder = os.path.join(self.trained_model_dir,
            #                                       f"{str(self.runID)}_checkpoint_{str(epoch)}")
            #     if self.verbose:
            #         print("Saving model to:", check_point_folder)
            #     os.makedirs(check_point_folder, exist_ok=True)
            #     checkpoint_path = os.path.join(check_point_folder,
            #                                    f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")
            #     torch.save({
            #         'epoch': epoch, 
            #         'model_state_dict': self.model.state_dict(),
            #         'optimizer_state_dict': self.optimizer.state_dict(), 
            #         'val_loss_PPInt': val_loss_PPint,
            #         'val_loss_boltz': val_loss_boltz,
            #         'boltzgen_auroc': boltz_auroc, 
            #         'boltzgen_aupr': boltz_aupr,
            #         'PPInt_val_accuracy': val_accuracy_PPint,
            #         # 'PPInt_val_topk_accuracy': val_topk_accuracy_PPint,
            #         'non_dimer_auc': non_dimer_auc,
            #         'non_dimer_aupr': non_dimer_aupr,
            #     }, checkpoint_path)
            
            # console logging
            if self.verbose and epoch % self.print_frequency_loss == 0:
                print(f"[Epoch {epoch}]")
                print(f"Train Loss             {train_loss:.4f}")
                print(f"Boltzgen Val Loss      {val_loss_boltz:.4f}")
                print(f"Boltzgen Val AUROC     {boltz_auroc if boltz_auroc == boltz_auroc else float('nan'):.4f}")
                print(f"Boltzgen Val AUPR      {boltz_aupr  if boltz_aupr  == boltz_aupr  else float('nan'):.4f}")
                
                print(f"Bindcraft Val Loss      {val_loss_bindcraft:.4f}")
                print(f"Bindcraft Val AUROC     {bindcraft_auroc if bindcraft_auroc == bindcraft_auroc else float('nan'):.4f}")
                print(f"Bindcraft Val AUPR      {bindcraft_aupr  if bindcraft_aupr  == bindcraft_aupr  else float('nan'):.4f}")
                
                print(f"PPInt Val Loss       {val_loss_PPint:.4f}")
                print(f"PPInt Val Accuracy   {val_accuracy_PPint if val_accuracy_PPint == val_accuracy_PPint else float('nan'):.4f}")
                # print(f"PPInt   Val Top-k Acc  {val_topk_accuracy_PPint if val_topk_accuracy_PPint == val_topk_accuracy_PPint else float('nan'):.4f}")
                
                print(f"Non-dimer AUROC        {non_dimer_auc if non_dimer_auc == non_dimer_auc else float('nan'):.4f}")
                print(f"Non-dimer AUPR         {non_dimer_aupr if non_dimer_aupr == non_dimer_aupr else float('nan'):.4f}")
            
            # wandb logging
            if self.wandb_tracker:
                log_items = {
                    "Meta Train-loss": train_loss,
                    "Boltzgen Val Loss": val_loss_boltz,
                    "Boltzgen Val AUROC": boltz_auroc,
                    "Boltzgen Val AUPR": boltz_aupr,
    
                    "Bindcraft Val Loss": val_loss_bindcraft,
                    "Bindcraft Val AUROC": bindcraft_auroc,
                    "Bindcraft Val AUPR": bindcraft_aupr,
                    
                    "PPInt Val Loss": val_loss_PPint,
                    "PPInt Val Accuracy": val_accuracy_PPint,
                    # "PPInt Val Topk Accuracy": val_topk_accuracy_PPint,
                    
                    "Non-dimer AUROC": non_dimer_auc,
                    "Non-dimer AUPR": non_dimer_aupr,
                }
                self.wandb_tracker.log(log_items)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [35]:
# weights = torch.tensor(list(interaction_df_shuffled.combined_weight), dtype=torch.float) # class + target weighting
# weights = torch.tensor(list(interaction_df_shuffled.class_weight), dtype=torch.float) # class weighting
weights = torch.tensor(list(interaction_df_shuffled.target_weight), dtype=torch.float) # target weighting
# weights = torch.tensor(list(interaction_df_shuffled.combined_weight_boost_pos), dtype=torch.float) # extra boost of positives
# weights = torch.tensor(list(interaction_df_shuffled.target_weight_FGFR2_reduced), dtype=torch.float) # reducing influence of FGFR2

g = torch.Generator().manual_seed(SEED)

finetune_sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True, generator=g)

In [36]:
targets = [interaction_df_shuffled.loc[i].target_id for i in finetune_sampler]
target_Dict = {}
for target_name in targets:
    if target_name not in target_Dict.keys():
        target_Dict[target_name] = 1
    else:
        target_Dict[target_name] = target_Dict[target_name] + 1

target_Dict

{'EGFR_2': 197,
 'LTK': 228,
 'IL2Ra': 216,
 'InsulinR': 226,
 'EGFR_3': 243,
 'FGFR2': 209,
 'sntx_2': 216,
 'TrkA': 235,
 'sntx': 210,
 'Pdl1': 218,
 'IL7Ra': 199,
 'VirB8': 233,
 'EGFR': 245,
 'IL10Ra': 234,
 'SARS_CoV2_RBD': 214,
 'Mdm2': 209}

In [37]:
targets = [interaction_df_shuffled.loc[i].binder_label for i in finetune_sampler]
class_Dict = {}
for class_ in targets:
    if class_ not in class_Dict.keys():
        class_Dict[class_] = 1
    else:
        class_Dict[class_] = class_Dict[class_] + 1

class_Dict

{False: 2641, True: 891}

In [38]:
learning_rate = 2e-5
EPOCHS = 12
g = torch.Generator().manual_seed(SEED)
batch_size = 10
optimizer = AdamW(model.parameters(), lr=learning_rate)
accelerator = Accelerator()
device = accelerator.device

# finetune_dataloader = DataLoader(finetuning_Dataset, batch_size=10, shuffle=True, drop_last = True) # no resampling
finetune_dataloader = DataLoader(finetuning_Dataset, batch_size=7, sampler=finetune_sampler, shuffle=False, drop_last = False) #class or target
boltzgen_dataloader = DataLoader(validation_Boltzgen, batch_size=10, shuffle=False, drop_last = False)
bindcraft_dataloader = DataLoader(validation_bindcraft, batch_size=10, shuffle=False, drop_last = False)
PPint_dataloader = DataLoader(validation_PPint, batch_size=10, shuffle=False, drop_last = False)

# accelerator
model, optimizer, finetune_dataloader, boltzgen_dataloader, bindcraft_dataloader, PPint_dataloader = accelerator.prepare(model, optimizer, finetune_dataloader, boltzgen_dataloader, bindcraft_dataloader, PPint_dataloader)

In [39]:
for i in finetune_dataloader:
    __, __, lbls = i
    print(lbls.to(device))
    break

tensor([1, 0, 1, 0, 0, 0, 0], device='cuda:0')


In [40]:
# wandb
if use_wandb:
    run = wandb.init(
        project="Finetuning_on_Meta02",
        name=f"target_weighting_{runID}",
        config={
            "learning_rate": learning_rate,
            "batch_size": batch_size,
            "epochs": EPOCHS,
            "architecture": "MiniCLIP_w_transformer_crossattn",
            "dataset": "PPint",
        },
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(
    model=model,
    train_loader=finetune_dataloader,
    boltzgen_loader = boltzgen_dataloader,
    bindcraft_loader = bindcraft_dataloader,
    PPint_loader=PPint_dataloader,
    test_df=Df_test,
    PPint_dataset = validation_PPint,
    optimizer=optimizer,
    epochs=EPOCHS,
    runID=runID,
    device=device,
    test_indexes_for_auROC=indices_non_dimers_val,
    auROC_batch_size=10,
    model_save_steps=model_save_steps,
    model_save_path=trained_model_dir,
    v=True,
    wandb_tracker=wandb if use_wandb else None,
)

training_wrapper.train_model()

Training model e9af7af9-505b-446f-aca8-4e9d82ded006


Boltzgen validation: 100%|██████████████████████████████████████████████████████████████| 43/43 [00:04<00:00,  9.03it/s]
Bindcraft validation: 100%|█████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 17.84it/s]
PPInt validation: 100%|█████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.74it/s]
Calculating AUC: 13it [00:04,  3.12it/s]                                                                                


Before training
Boltzgen Val Loss      0.8226
Boltzgen Val AUROC     0.6001
Boltzgen Val AUPR      0.3194
Bindcraft Val Loss      1.1844
Bindcraft Val AUROC     0.5094
Bindcraft Val AUPR      0.4988
PPInt Val Loss       0.3069
PPInt Val Accuracy   0.8880
Non-dimer AUROC        0.8934
Non-dimer AUPR         0.6688


Epochs:   0%|                                                                                    | 0/12 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:28,  1.88it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:34,  1.83it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:36,  1.81it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:36,  1.81it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:35,  1.81it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:36,  1.80it/s][A
Running thr

[Epoch 1]
Train Loss             0.4855
Boltzgen Val Loss      0.8539
Boltzgen Val AUROC     0.4694
Boltzgen Val AUPR      0.2206
Bindcraft Val Loss      1.0090
Bindcraft Val AUROC     0.6964
Bindcraft Val AUPR      0.6535
PPInt Val Loss       0.3962
PPInt Val Accuracy   0.8620
Non-dimer AUROC        0.8484
Non-dimer AUPR         0.5802



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:31,  1.85it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:41,  1.79it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:39,  1.80it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:38,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:37,  1.80it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:37,  1.80it/s][A
Running

[Epoch 2]
Train Loss             0.3446
Boltzgen Val Loss      0.8795
Boltzgen Val AUROC     0.4880
Boltzgen Val AUPR      0.2404
Bindcraft Val Loss      1.0891
Bindcraft Val AUROC     0.7172
Bindcraft Val AUPR      0.6761
PPInt Val Loss       0.3822
PPInt Val Accuracy   0.8660
Non-dimer AUROC        0.8110
Non-dimer AUPR         0.5326



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:30,  1.86it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:38,  1.81it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:37,  1.81it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:36,  1.81it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:36,  1.80it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:35,  1.81it/s][A
Running

[Epoch 3]
Train Loss             0.2981
Boltzgen Val Loss      0.8858
Boltzgen Val AUROC     0.4956
Boltzgen Val AUPR      0.2538
Bindcraft Val Loss      1.3882
Bindcraft Val AUROC     0.7420
Bindcraft Val AUPR      0.7044
PPInt Val Loss       0.4486
PPInt Val Accuracy   0.8580
Non-dimer AUROC        0.7965
Non-dimer AUPR         0.5042



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:29,  1.87it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:35,  1.82it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:37,  1.81it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:36,  1.80it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:36,  1.80it/s][A
Running

[Epoch 4]
Train Loss             0.2572
Boltzgen Val Loss      0.8595
Boltzgen Val AUROC     0.5178
Boltzgen Val AUPR      0.2852
Bindcraft Val Loss      1.3648
Bindcraft Val AUROC     0.7041
Bindcraft Val AUPR      0.6621
PPInt Val Loss       0.5338
PPInt Val Accuracy   0.8560
Non-dimer AUROC        0.7822
Non-dimer AUPR         0.4848



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:31,  1.86it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:37,  1.81it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:38,  1.80it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:37,  1.81it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:37,  1.80it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:36,  1.80it/s][A
Running

[Epoch 5]
Train Loss             0.2439
Boltzgen Val Loss      0.8985
Boltzgen Val AUROC     0.5013
Boltzgen Val AUPR      0.2464
Bindcraft Val Loss      1.3849
Bindcraft Val AUROC     0.7022
Bindcraft Val AUPR      0.6484
PPInt Val Loss       0.4271
PPInt Val Accuracy   0.8500
Non-dimer AUROC        0.7962
Non-dimer AUPR         0.5137



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:32,  1.85it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:36,  1.82it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:37,  1.81it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:36,  1.81it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:36,  1.80it/s][A
Running

[Epoch 6]
Train Loss             0.2257
Boltzgen Val Loss      0.9476
Boltzgen Val AUROC     0.5186
Boltzgen Val AUPR      0.2817
Bindcraft Val Loss      1.5948
Bindcraft Val AUROC     0.7285
Bindcraft Val AUPR      0.7006
PPInt Val Loss       0.4213
PPInt Val Accuracy   0.8420
Non-dimer AUROC        0.7756
Non-dimer AUPR         0.5024



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:38,  1.81it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:41,  1.78it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:39,  1.79it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:40,  1.79it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:39,  1.79it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:39,  1.79it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:37,  1.80it/s][A
Running

[Epoch 7]
Train Loss             0.2099
Boltzgen Val Loss      1.0819
Boltzgen Val AUROC     0.4638
Boltzgen Val AUPR      0.2406
Bindcraft Val Loss      1.3477
Bindcraft Val AUROC     0.7300
Bindcraft Val AUPR      0.7045
PPInt Val Loss       0.4900
PPInt Val Accuracy   0.8360
Non-dimer AUROC        0.7926
Non-dimer AUPR         0.5117



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:31,  1.86it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:35,  1.83it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:37,  1.81it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:36,  1.81it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:36,  1.81it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:36,  1.80it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:35,  1.80it/s][A
Running

[Epoch 8]
Train Loss             0.2159
Boltzgen Val Loss      1.0013
Boltzgen Val AUROC     0.5025
Boltzgen Val AUPR      0.2600
Bindcraft Val Loss      1.7025
Bindcraft Val AUROC     0.6887
Bindcraft Val AUPR      0.6457
PPInt Val Loss       0.4110
PPInt Val Accuracy   0.8540
Non-dimer AUROC        0.7761
Non-dimer AUPR         0.5013



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:30,  1.86it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:37,  1.81it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:40,  1.79it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:39,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:38,  1.79it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:38,  1.79it/s][A
Running

[Epoch 9]
Train Loss             0.1912
Boltzgen Val Loss      0.8915
Boltzgen Val AUROC     0.5261
Boltzgen Val AUPR      0.3098
Bindcraft Val Loss      1.7189
Bindcraft Val AUROC     0.7358
Bindcraft Val AUPR      0.6869
PPInt Val Loss       0.5328
PPInt Val Accuracy   0.8640
Non-dimer AUROC        0.7819
Non-dimer AUPR         0.4922



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:26,  1.89it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:33,  1.84it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:35,  1.82it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:38,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:37,  1.80it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:36,  1.81it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:35,  1.81it/s][A
Running

[Epoch 10]
Train Loss             0.1908
Boltzgen Val Loss      0.8882
Boltzgen Val AUROC     0.5157
Boltzgen Val AUPR      0.3002
Bindcraft Val Loss      1.6591
Bindcraft Val AUROC     0.7176
Bindcraft Val AUPR      0.6941
PPInt Val Loss       0.4905
PPInt Val Accuracy   0.8500
Non-dimer AUROC        0.7742
Non-dimer AUPR         0.4897



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:32,  1.85it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:36,  1.82it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:35,  1.82it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:35,  1.82it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:35,  1.81it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:34,  1.81it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:35,  1.80it/s][A
Running

[Epoch 11]
Train Loss             0.1831
Boltzgen Val Loss      1.0504
Boltzgen Val AUROC     0.4992
Boltzgen Val AUPR      0.2591
Bindcraft Val Loss      1.7423
Bindcraft Val AUROC     0.7393
Bindcraft Val AUPR      0.7120
PPInt Val Loss       0.5031
PPInt Val Accuracy   0.8440
Non-dimer AUROC        0.7762
Non-dimer AUPR         0.4957



Running through epoch:   0%|                                                                    | 0/505 [00:00<?, ?it/s][A
Running through epoch:   0%|                                                            | 1/505 [00:00<04:31,  1.85it/s][A
Running through epoch:   0%|▏                                                           | 2/505 [00:01<04:35,  1.83it/s][A
Running through epoch:   1%|▎                                                           | 3/505 [00:01<04:34,  1.83it/s][A
Running through epoch:   1%|▍                                                           | 4/505 [00:02<04:35,  1.82it/s][A
Running through epoch:   1%|▌                                                           | 5/505 [00:02<04:35,  1.82it/s][A
Running through epoch:   1%|▋                                                           | 6/505 [00:03<04:34,  1.82it/s][A
Running through epoch:   1%|▊                                                           | 7/505 [00:03<04:34,  1.81it/s][A
Running

[Epoch 12]
Train Loss             0.1824
Boltzgen Val Loss      0.9493
Boltzgen Val AUROC     0.5375
Boltzgen Val AUPR      0.2773
Bindcraft Val Loss      1.4292
Bindcraft Val AUROC     0.7265
Bindcraft Val AUPR      0.6913
PPInt Val Loss       0.4530
PPInt Val Accuracy   0.8480
Non-dimer AUROC        0.7764
Non-dimer AUPR         0.4943





0,1
Bindcraft Val AUPR,▁▆▇█▆▆██▆▇▇█▇
Bindcraft Val AUROC,▁▇▇█▇▇██▆█▇██
Bindcraft Val Loss,▃▁▂▅▄▅▇▄██▇█▅
Boltzgen Val AUPR,█▁▂▃▆▃▅▂▄▇▇▄▅
Boltzgen Val AUROC,█▁▂▃▄▃▄▁▃▄▄▃▅
Boltzgen Val Loss,▁▂▃▃▂▃▄█▆▃▃▇▄
Meta Train-loss,█▅▄▃▂▂▂▂▁▁▁▁
Non-dimer AUPR,█▅▃▂▁▂▂▂▂▁▁▁▁
Non-dimer AUROC,█▅▃▂▁▂▁▂▁▁▁▁▁
PPInt Val Accuracy,█▄▅▄▄▃▂▁▃▅▃▂▃

0,1
Bindcraft Val AUPR,0.69133
Bindcraft Val AUROC,0.72651
Bindcraft Val Loss,1.42925
Boltzgen Val AUPR,0.27733
Boltzgen Val AUROC,0.53745
Boltzgen Val Loss,0.94932
Meta Train-loss,0.1824
Non-dimer AUPR,0.49426
Non-dimer AUROC,0.7764
PPInt Val Accuracy,0.848
