In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random
import matplotlib.pyplot as plt
from tqdm import trange
from pathlib import Path
from PIL import Image

from sklearn import metrics
from scipy import stats
from collections import Counter

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW
import training_utils.partitioning_utils as pat_utils
# from rotary_embedding_torch import RotaryEmbedding
# from some_utils.RoPE_for_ViT import RoPEAttention
# from some_utils.ALiBi2D import ALiBi2DTransformerLayer

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()

Tesla V100-SXM2-32GB


  warn(
  _torch_pytree._register_pytree_node(


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
### Setting a seed to have the same initiation of weights

def set_seed(seed: int = 0):
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # CuDNN settings (for convolution etc.)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # (Optional) for some Python hashing randomness
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 0
set_seed(SEED)

In [3]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.9.1+cu128
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [4]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 3
train_frac = 1.0
test_frac = 1.0

seq_embed_dimension = 1280 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

In [5]:
# ## Training variables
runID = uuid.uuid4()

## Output path
trained_model_dir = f"/work3/s232958/data/trained/with_structure/{runID}"

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


## Loading PPint data

In [6]:
Df_train_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
Df_train_small = Df_train_small[~Df_train_small.target_binder_id.str.startswith("6BJP")]
Df_train_small["interface_id"] = [row.ID1.split("_")[0]+"_"+row.ID1.split("_")[1] for __, row in Df_train_small.iterrows()]
Df_test_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)
Df_test_small["interface_id"] = [row.ID1.split("_")[0]+"_"+row.ID1.split("_")[1] for __, row in Df_test_small.iterrows()] 


Df_train = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_train = Df_train.merge(Df_train_small[["dimer", "interface_id"]], on = "interface_id", how="inner")
Df_test = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_test = Df_test.merge(Df_test_small[["dimer", "interface_id"]], on = "interface_id", how="inner")
Df_train = Df_train[~Df_train.PDB.str.startswith("6BJP")]

Df_train

Unnamed: 0,interface_id,PDB,ID1,ID2,seq_target,seq_target_len,seq_pdb_target,pdb_target_len,target_chain,seq_binder,seq_binder_len,seq_pdb_binder,pdb_binder_len,binder_chain,pdb_path,dimer
0,6IDB_0,6IDB,6IDB_0_A,6IDB_0_B,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,A,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,B,6idb.pdb.gz,False
1,2WZP_3,2WZP,2WZP_3_D,2WZP_3_G,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,D,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,G,2wzp.pdb.gz,False
2,1ZKP_0,1ZKP,1ZKP_0_A,1ZKP_0_C,LYFQSNAKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLA...,246,LYFQSNAMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGV...,251,A,AKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQKYI...,240,AMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQK...,245,C,1zkp.pdb.gz,True
3,6GRH_3,6GRH,6GRH_3_C,6GRH_3_D,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,C,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,D,6grh.pdb.gz,False
4,8R57_1,8R57,8R57_1_M,8R57_1_f,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,M,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,f,8r57.pdb.gz,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1972,4YO8_0,4YO8,4YO8_0_A,4YO8_0_B,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,A,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,B,4yo8.pdb.gz,True
1973,3CKI_0,3CKI,3CKI_0_A,3CKI_0_B,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,A,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,B,3cki.pdb.gz,False
1974,7MHY_1,7MHY,7MHY_1_M,7MHY_1_N,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,M,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,N,7mhy.pdb.gz,False
1975,7MHY_2,7MHY,7MHY_2_O,7MHY_2_P,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,O,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,P,7mhy.pdb.gz,False


### PPint dataloader (loading raw contact maps)

In [7]:
class CLIP_PPint_class_w_contacts(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1280,
        embedding_pad_value=-5000.0,
        structure_pad_value=0.0,   # now used for raw cmaps
        fixed_max_blen=None,
        fixed_max_tlen=None,
    ):
        super().__init__()

        self.dframe = dframe.copy()
        self.embedding_dim = embedding_dim
        self.emb_pad = float(embedding_pad_value)
        self.struct_pad = float(structure_pad_value)

        # paths
        self.emb_path, self.cont_path = paths

        # lengths
        self.max_blen = self.dframe["pdb_binder_len"].max()+2
        self.max_tlen = self.dframe["pdb_target_len"].max()+2

        # index & storage
        self.dframe.set_index("interface_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            tgt_id = accession.split("_")[0]+"_"+str(self.dframe.loc[accession].target_chain)
            bnd_id = accession.split("_")[0]+"_"+str(self.dframe.loc[accession].binder_chain)

            # laod embeddings
            t_emb = np.load(os.path.join(self.emb_path, f"{tgt_id}.npy"))     # [Lt, D]
            b_emb = np.load(os.path.join(self.emb_path, f"{bnd_id}.npy"))     # [Lb, D]

            # print(b_emb.shape[0], self.dframe.loc[accession].seq_binder_len)
            assert (b_emb.shape[0] == self.dframe.loc[accession].pdb_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].pdb_target_len+2)

            # sanity check
            if t_emb.shape[1] != embedding_dim or b_emb.shape[1] != embedding_dim:
                raise ValueError("Embedding dim mismatch.")

            # pad to global max
            t_emb = self._pad_embedding(t_emb, self.max_tlen)
            b_emb = self._pad_embedding(b_emb, self.max_blen)

            # ---------------- RAW CONTACT MAPS ----------------
            # they are variable-size [L, L]
            t_cmap = torch.from_numpy(
                np.load(os.path.join(self.cont_path, f"{tgt_id}.npy"))
            ).float()

            b_cmap = torch.from_numpy(
                np.load(os.path.join(self.cont_path, f"{bnd_id}.npy"))
            ).float()

            # DO NOT unfold or flatten → keep raw
            # DO NOT pad here → return list of unpadded maps

            self.samples.append((b_emb, t_emb, b_cmap, t_cmap))

    # ----------------------------------------------------
    # PAD EMBEDDINGS
    # ----------------------------------------------------
    def _pad_embedding(self, arr, max_len):
        L, D = arr.shape
        if L < max_len:
            pad = np.full((max_len - L, D), self.emb_pad, dtype=arr.dtype)
            arr = np.concatenate([arr, pad], axis=0)
        else:
            arr = arr[:max_len]
        return arr

    # ----------------------------------------------------
    # DATASET API
    # ----------------------------------------------------
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_emb, t_emb, b_cmap, t_cmap = self.samples[idx]

        b_emb = torch.from_numpy(b_emb).float()
        t_emb = torch.from_numpy(t_emb).float()

        # labels optional → default 1 for PPI-pairs?
        label = torch.tensor(1.0)

        return b_emb, t_emb, b_cmap, t_cmap, label

    # ----------------------------------------------------
    # GET MULTIPLE BY NAME
    # ----------------------------------------------------
    def _get_by_name(self, name):
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])

        out = [self.__getitem__(self.name_to_row[n]) for n in name]

        b_list, t_list, bmap_list, tmap_list, lbl_list = zip(*out)

        b      = torch.stack(b_list)
        t      = torch.stack(t_list)
        labels = torch.stack(lbl_list)

        # return raw cmaps unchanged
        return b, t, list(bmap_list), list(tmap_list), labels

emb_path = "/work3/s232958/data/PPint_DB/embeddings_esm2"
cont_path = "/work3/s232958/data/PPint_DB/contacts_esm2"

# global_max_blen = int(PPint_interaactions_df["seq_binder_len"].max())
# global_max_tlen = int(PPint_interaactions_df["seq_target_len"].max())

training_Dataset = CLIP_PPint_class_w_contacts(
    Df_train,
    paths=[emb_path, cont_path],
    embedding_dim=1280
)

testing_Dataset = CLIP_PPint_class_w_contacts(
    Df_test,
    paths=[emb_path, cont_path],
    embedding_dim=1280
)

#Loading ESM2 embeddings and contacts: 100%|████████████████████████████████████████| 1977/1977 [00:20<00:00, 94.17it/s]
#Loading ESM2 embeddings and contacts: 100%|██████████████████████████████████████████| 494/494 [00:05<00:00, 94.33it/s]


In [8]:
### Getting indeces of non-dimers
indices_non_dimers_val = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers_val[:5]

### Getting accessions of non-dimers
accessions = [Df_test.loc[index].interface_id for index in indices_non_dimers_val]
b, t, bct, tct, labels = testing_Dataset._get_by_name(accessions[:5])
labels

tensor([1., 1., 1., 1., 1.])

### Loading Meta-analysis dataset for validation

In [9]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal.csv")[["A_seq", "B_seq", "target_id_mod", "target_binder_ID", "binder"]].rename(columns = {
    "A_seq" : "seq_binder",
    "B_seq" : "seq_target",
    "target_binder_ID" : "binder_id",
    "target_id_mod" : "target_id",
    "binder" : "binder_label"
})
interaction_df["seq_target_len"] = [len(seq) for seq in interaction_df["seq_target"].tolist()]
interaction_df["seq_binder_len"] = [len(seq) for seq in interaction_df["seq_binder"].tolist()]

# Targets df
target_df = interaction_df[["target_id","seq_target"]].rename(columns={"seq_target":"sequence", "target_id" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["binder_id","seq_binder"]].rename(columns={"seq_binder":"sequence", "binder_id" : "ID"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")

# target_df

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id"], interaction_df["binder_id"]), start=1))
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=SEED).reset_index(drop=True)
interaction_df_shuffled

Unnamed: 0,seq_binder,seq_target,target_id,binder_id,binder_label,seq_target_len,seq_binder_len
0,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,True,101,62
1,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,False,621,58
2,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,False,101,65
3,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,False,101,64
4,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,False,165,65
...,...,...,...,...,...,...,...
3527,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,False,195,63
3528,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,False,101,65
3529,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,False,193,63
3530,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,False,101,65


In [10]:
class CLIP_Meta_class(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1280,
        embedding_pad_value=-5000.0,
        structure_pad_value=0.0,
        fixed_max_blen=None,
        fixed_max_tlen=None,
    ):
        super().__init__()

        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.struct_pad = float(structure_pad_value)

        # paths
        self.bemb_path, self.temb_path, self.bcont_path, self.tcont_path = paths

        # lengths
        self.max_blen = self.dframe["seq_binder_len"].max()+2
        self.max_tlen = self.dframe["seq_target_len"].max()+2

        # index & bookkeeping
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}

        self.samples = []

        # ------------------------------------------------------------
        # LOAD ALL SAMPLES
        # ------------------------------------------------------------
        for accession in tqdm(self.accessions, desc="#Loading ESM2 embeddings + raw cmaps"):
            label = torch.tensor(int(self.dframe.loc[accession, "binder_label"]))

            parts = accession.split("_")
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            # laod embeddings
            t_emb = np.load(os.path.join(self.temb_path, f"{tgt_id}.npy")) # [Lt, D]
            b_emb = np.load(os.path.join(self.bemb_path, f"{bnd_id}.npy")) # [Lb, D]
            
            assert (b_emb.shape[0] == self.dframe.loc[accession].seq_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].seq_target_len+2)

            # validate dims
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch in dataset.")

            # loading maps
            t_cmap = torch.from_numpy(np.load(os.path.join(self.tcont_path, f"{tgt_id}.npy"))).float()   # [Lt, Lt]
            b_cmap = torch.from_numpy(np.load(os.path.join(self.bcont_path, f"{bnd_id}.npy"))).float()   # [Lb, Lb]

            assert (b_emb.shape[0] == b_cmap.shape[0]+2)
            assert (t_emb.shape[0] == t_cmap.shape[0]+2)

            # pad embeddings to global length
            t_emb = self._pad_embedding(t_emb, self.max_tlen)
            b_emb = self._pad_embedding(b_emb, self.max_blen)

            # store raw contact maps (no unfold, no patches)
            self.samples.append((b_emb, t_emb, b_cmap, t_cmap, label))

    # ------------------------------------------------------------
    # PAD ESM2 EMBEDDINGS
    # ------------------------------------------------------------
    def _pad_embedding(self, arr, max_len):
        L, D = arr.shape
        if L < max_len:
            pad = np.full((max_len - L, D), self.emb_pad, dtype=arr.dtype)
            arr = np.concatenate([arr, pad], axis=0)
        else:
            arr = arr[:max_len]
        return arr

    # ------------------------------------------------------------
    # DATASET API
    # ------------------------------------------------------------
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_emb, t_emb, b_cmap, t_cmap, label = self.samples[idx]

        b_emb = torch.from_numpy(b_emb).float()
        t_emb = torch.from_numpy(t_emb).float()

        return b_emb, t_emb, b_cmap, t_cmap, label.float()

    # ------------------------------------------------------------
    # Fetch multiple items by name
    # ------------------------------------------------------------
    def _get_by_name(self, names):
        if isinstance(names, str):
            return self.__getitem__(self.name_to_row[names])

        items = [self.__getitem__(self.name_to_row[n]) for n in names]

        b_list, t_list, bmap_list, tmap_list, lbl_list = zip(*items)

        b = torch.stack(b_list)
        t = torch.stack(t_list)
        labels = torch.stack(lbl_list)

        # cmaps remain variable-size lists
        return b, t, list(bmap_list), list(tmap_list), labels

bemb_path = "/work3/s232958/data/meta_analysis/embeddings_esm2_binders"
temb_path = "/work3/s232958/data/meta_analysis/embeddings_esm2_targets"

## Contact maps paths
bcont_path = "/work3/s232958/data/meta_analysis/binders_contacts"
tcont_path = "/work3/s232958/data/meta_analysis/targets_contacts"

validation_Dataset = CLIP_Meta_class(
    interaction_df_shuffled,
    paths=[bemb_path, temb_path, bcont_path, tcont_path],
    embedding_dim=1280
)

#Loading ESM2 embeddings + raw cmaps: 100%|████████████████████████████████████████| 3532/3532 [00:31<00:00, 111.58it/s]


In [11]:
accs = []
for i, row in interaction_df_shuffled[:100].iterrows():
    if i % 11 == 0:
        accs.append(row.binder_id)
print(accs)
__, __, __, __, lbls = validation_Dataset._get_by_name(accs)
lbls

['FGFR2_124', 'Pdl1_49', 'Mdm2_41', 'FGFR2_1194', 'FGFR2_798', 'FGFR2_1731', 'FGFR2_25', 'InsulinR_74', 'IL7Ra_80', 'FGFR2_967']


tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.])

### Loading Boltzgen

## CLIP-model

In [13]:
def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    """
    Purpose: return vector indicating which rows are not padded (don't have values = -5000)
    """
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1280] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

def pad_contact_maps(contact_list, pad_value=0.0):
    """
    contact_list: list of [L, L] or [1, L, L] tensors
    Returns: [B, 1, Lmax, Lmax]
    """
    # Ensure shape is [1, L, L]
    normalized = []
    for cm in contact_list:
        if cm.dim() == 2:
            cm = cm.unsqueeze(0)    # → [1, L, L]
        normalized.append(cm)

    # Find largest L
    Lmax = max(cm.shape[-1] for cm in normalized)

    padded_list = []
    for cm in normalized:
        L = cm.shape[-1]
        pad_amt = Lmax - L
        # Pad to the right and bottom: (left, right, top, bottom)
        padded = F.pad(cm, (0, pad_amt, 0, pad_amt), value=pad_value)
        padded_list.append(padded)

    # Stack → [B, 1, Lmax, Lmax]
    return torch.stack(padded_list, dim=0)

#### with RoPE

Standard Transformer Encoder Layer: <br>
$x_1 = x + Dropout(MultiHeadAttention(LayerNorm(x)))$ <br>
$x_2 = x_1 + Dropout(FeedForwardNetwork(LayerNorm(x_1)))$

How I wioll try to pass contacts maps: <>
$patches → Linear → transformer\_with\_RoPE(Q/K) → CroscAttn → output$

In [14]:
# import some_utils.RoPE_for_ViT as rope

# class RoPEStructEncoderLayer(nn.Module):
#     def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, rope_theta=10.0, rope_mixed=True):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(d_model)
#         self.norm2 = nn.LayerNorm(d_model)

#         # RoPE-based attention (ViT style)
#         self.self_attn = rope.RoPEAttention(
#             dim=d_model,
#             num_heads=nhead,
#             qkv_bias=True,
#             attn_drop=dropout,
#             proj_drop=dropout,
#             rope_theta=rope_theta,
#             rope_mixed=rope_mixed,
#         )

#         # FFN
#         self.linear1 = nn.Linear(d_model, dim_feedforward)
#         self.dropout = nn.Dropout(dropout)
#         self.linear2 = nn.Linear(dim_feedforward, d_model)

#         self.dropout1 = nn.Dropout(dropout)
#         self.dropout2 = nn.Dropout(dropout)
#         self.activation = F.gelu

#     def forward(self, x, src_key_padding_mask=None):
#         # x: [B, N, d_model] (N = maybe 1 + H*W if CLS, or H*W if no CLS)

#         # pre-norm + attention
#         x_norm = self.norm1(x)

#         # If you want to respect src_key_padding_mask, you need to zero out padded tokens
#         # or build an attn_mask for RoPEAttention. The provided RoPEAttention doesn't use masks, 
#         # so a simple approach is:
#         if src_key_padding_mask is not None:
#             # set padded tokens to zero before attention
#             x_norm = x_norm.masked_fill(src_key_padding_mask.unsqueeze(-1), 0.0)

#         attn_out = self.self_attn(x_norm)  # RoPE applied inside

#         x1 = x + self.dropout1(attn_out)

#         # FFN
#         x2 = self.norm2(x1)
#         z = self.linear2(self.dropout(self.activation(self.linear1(x2))))
#         x = x1 + self.dropout2(z)

#         return x

In [15]:
# head_dim = self.seq_embed_dimension // 8
# rope_module = RotaryEmbedding(head_dim)

# self.seq_encoder = RoPEEncoderLayer(
#     d_model=self.seq_embed_dimension,
#     nhead=8,
#     rope_module=rope_module,
#     dim_feedforward=self.seq_embed_dimension,
#     dropout=0.1,
#     batch_first=True,
# )

#### with ALiBi2D

In [16]:
# x = validation_Dataset[1][3]
# x_mask = None
# print(x.shape)

# token_project = nn.Linear(
#     x.shape[1],
#     1152,
#     bias=True)

# x_proj = token_project(x)
# print(x_proj.shape)

# norm_struct = nn.LayerNorm(1152)

# x_proj_norm = norm_struct(x_proj)
# print(x_proj_norm.shape)
# batch = x_proj_norm.unsqueeze(0)
# print(batch.shape)

# ALiBi2Dlayer = ALiBi2DTransformerLayer(
#     d_model=1152,
#     nhead=8,
#     dropout=0.1
# )
# with torch.no_grad():
#     out = ALiBi2Dlayer(batch, x_mask)

# print("Input shape :", batch.shape)
# print("Output shape:", out.shape)

# assert out.shape == batch.shape
# print("✓ ALiBi2D transformer layer runs correctly!")


# def build_alibi_slopes(num_heads: int):
#     def get_slopes(n):
#         start = 2 ** (-2 ** -(math.log2(n) - 3))
#         ratio = start
#         return [start * (ratio ** i) for i in range(n)]

#     if math.log2(num_heads).is_integer():
#         return torch.tensor(get_slopes(num_heads))

#     closest_pow2 = 2 ** math.floor(math.log2(num_heads))
#     return torch.tensor(
#         get_slopes(closest_pow2)
#         + get_slopes(2 * closest_pow2)[0::2][: n_heads - closest_pow2]
#     )

# slopes_x = nn.Parameter(build_alibi_slopes(num_heads=8), requires_grad=True)
# slopes_y = nn.Parameter(build_alibi_slopes(num_heads=8), requires_grad=True)

# def _build_bias(real_N):
#     """
#     Build ALiBi-2D bias for a square grid of real_N tokens.
#     """

#     side = int(math.sqrt(real_N))          # H=W
#     idx   = torch.arange(real_N)

#     x_coord = idx % side
#     y_coord = idx // side

#     dx = (x_coord[:, None] - x_coord[None, :]).abs()
#     dy = (y_coord[:, None] - y_coord[None, :]).abs()

#     # [heads, real_N, real_N]
#     return (
#         slopes_x[:, None, None] * dx[None, :, :]
#         + slopes_y[:, None, None] * dy[None, :, :]
#     )

# _build_bias(4).shape

In [17]:
# class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

#     def __init__(
#         self,
#         padding_value=-5000,
#         seq_embed_dimension = seq_embed_dimension,
#         struct_embed_dimension = struct_embed_dimension,
#         num_recycles=2
#     ):
#         super().__init__()
#         self.num_recycles = num_recycles
#         self.padding_value = padding_value
#         self.seq_embed_dimension = seq_embed_dimension
#         self.struct_embed_dimension = struct_embed_dimension

#         # CLIP scaling
#         self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))
#         # self.struct_alpha = nn.Parameter(torch.tensor(0.0))

#         ### ---------------- SEQUENCE ENCODER ---------------- ###
#         self.seq_encoder = nn.TransformerEncoderLayer(
#             d_model=self.seq_embed_dimension, 
#             nhead=8,
#             dropout=0.1, 
#             batch_first=True, 
#             dim_feedforward=self.seq_embed_dimension)

#         self.norm_seq = nn.LayerNorm(self.seq_embed_dimension)

#         # cross-attention (sequence-side)
#         self.seq_cross_attn = nn.MultiheadAttention(
#             embed_dim=self.seq_embed_dimension, 
#             num_heads=8,
#             dropout=0.1, 
#             batch_first=True)

#         self.projection_head = nn.Sequential(
#             nn.Linear(self.seq_embed_dimension, 640),
#             nn.ReLU(),
#             nn.Linear(640, 320)
#         )

#         ### ---------------- STRUCTURE ENCODER ---------------- ###
#         self.token_project = nn.Linear(
#             self.struct_embed_dimension,
#             self.seq_embed_dimension,
#             bias=True)

#         self.norm_struct = nn.LayerNorm(self.seq_embed_dimension)

#         # ALiBi-2D transformer
#         self.struct_encoder = ALiBi2DTransformerLayer(
#             d_model=self.seq_embed_dimension,
#             nhead=8,
#             dropout=0.1
#         )

#         # structure → sequence
#         self.struct_to_seq_attn = nn.MultiheadAttention(
#             embed_dim=self.seq_embed_dimension, 
#             num_heads=8,
#             dropout=0.1, 
#             batch_first=True)


#     def forward(self, pep_emb, prot_emb, pep_contacts_list, prot_contacts_list,
#                 label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True):

#         device = pep_emb.device

#         # variable lengths
#         pep_max_len = max(x.size(0) for x in pep_contacts_list)
#         prot_max_len = max(x.size(0) for x in prot_contacts_list)

#         # 1. Project flat patches → [L_i, E]
#         pep_cm_proj  = [self.token_project(x.to(device)) for x in pep_contacts_list]
#         prot_cm_proj = [self.token_project(x.to(device)) for x in prot_contacts_list]

#         # Real number of patches
#         pep_real_N = [cm.size(0) for cm in pep_cm_proj]
#         prot_real_N = [cm.size(0) for cm in prot_cm_proj]
#         pep_hw = [(int(math.sqrt(n)), int(math.sqrt(n))) for n in pep_real_N]
#         prot_hw = [(int(math.sqrt(n)), int(math.sqrt(n))) for n in prot_real_N]

#         # 2. Pad along token dimension
#         pep_cm_pos = torch.stack([pad_tokens_to_L(cm, pep_max_len) for cm in pep_cm_proj], dim=0)
#         prot_cm_pos = torch.stack([pad_tokens_to_L(cm, prot_max_len) for cm in prot_cm_proj], dim=0)

#         # Key padding masks
#         pep_mask_emb  = create_key_padding_mask(pep_emb, padding_value=self.padding_value).to(device)
#         prot_mask_emb = create_key_padding_mask(prot_emb, padding_value=self.padding_value).to(device)

#         pep_mask_cm  = create_key_padding_mask(pep_cm_pos, padding_value=self.padding_value).to(device)
#         prot_mask_cm = create_key_padding_mask(prot_cm_pos, padding_value=self.padding_value).to(device)

#         # clones (safe)
#         b_emb = binder_emb
#         t_emb = target_emb
    
#         b_ct = binder_ct
#         t_ct = target_ct

#         for _ in range(self.num_recycles):

#             # Sequence self-attention
#             pep_trans_emb = self.seq_encoder(self.norm_seq(pep_emb),
#                                              src_key_padding_mask=pep_mask_emb)
#             prot_trans_emb = self.seq_encoder(self.norm_seq(prot_emb),
#                                               src_key_padding_mask=prot_mask_emb)

#             # Structure self-attention (ALiBi-2D)
#             pep_trans_cm = self.struct_encoder(self.norm_struct(pep_contacts), grid_hw=pep_hw, mask=pep_mask_cm)
#             prot_trans_cm = self.struct_encoder(self.norm_struct(prot_contacts), grid_hw=prot_hw, mask=prot_mask_cm)

#             # Cross-attendance: structure → sequence
#             pep_struct_upd, _ = self.struct_to_seq_attn(
#                 query=self.norm_seq(pep_trans_emb),
#                 key=self.norm_struct(pep_trans_cm),
#                 value=self.norm_struct(pep_trans_cm),
#                 key_padding_mask=pep_mask_cm)

#             prot_struct_upd, _ = self.struct_to_seq_attn(
#                 query=self.norm_seq(prot_trans_emb),
#                 key=self.norm_struct(prot_trans_cm),
#                 value=self.norm_struct(prot_trans_cm),
#                 key_padding_mask=prot_mask_cm)

#             pep_trans_emb  = pep_trans_emb  + pep_struct_upd # * self.struct_alpha.tanh()
#             prot_trans_emb = prot_trans_emb + prot_struct_upd # * self.struct_alpha.tanh() 

#             # Cross interactions
#             pep_cross, _ = self.seq_cross_attn(
#                 query=self.norm_seq(pep_trans_emb),
#                 key=self.norm_seq(prot_trans_emb),
#                 value=self.norm_seq(prot_trans_emb),
#                 key_padding_mask=prot_mask_emb)

#             prot_cross, _ = self.seq_cross_attn(
#                 query=self.norm_seq(prot_trans_emb),
#                 key=self.norm_seq(pep_trans_emb),
#                 value=self.norm_seq(pep_trans_emb),
#                 key_padding_mask=pep_mask_emb)

#             pep_emb = pep_emb + pep_cross
#             prot_emb = prot_emb + prot_cross

#         # Pool over non-padded positions
#         pep_seq_coding  = create_mean_of_non_masked(pep_emb, pep_mask_emb)
#         prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask_emb)

#         # Head + L2 norm
#         pep_full  = F.normalize(self.projection_head(pep_seq_coding), dim=-1)
#         prot_full = F.normalize(self.projection_head(prot_seq_coding), dim=-1)

#         logits = torch.exp(self.logit_scale).clamp(max=100.0) * (pep_full * prot_full).sum(dim=-1)

#         return logits

#     def training_step(self, batch, device):
#         embedding_pep, embedding_prot, contacts_pep, contacts_target, labels = batch
#         # embedding_pep, embedding_prot, contacts_pep, contacts_prot = embedding_pep.to(device), embedding_prot.to(device), contacts_pep.to(device), contacts_prot.to(device)

#         # loss of predicting partner using peptide
#         positive_logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
#         positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device)) # F.binary_cross_entropy_with_logits does sigmoid transfromation inside, excepts data, labels
        
#         # Negative indexes
#         rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1) # upper triangle
        
#         pep_cm_list  = [contacts_pep[i] for i in rows.tolist()]  # list of [Li, 256]
#         prot_cm_list = [contacts_prot[j] for j in cols.tolist()]  # list of [Lj, 256]

#         # loss of predicting peptide using partner
#         # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
#         negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], pep_cm_list, prot_cm_list, int_prob=0.0)
#         negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
#         loss = (positive_loss + negative_loss) / 2
 
#         # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
#         torch.cuda.empty_cache()
#         return loss

#     def validation_step_PPint(self, batch, device):
#         # Predict on random batches of training batch size
#         embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = batch
#         embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
#         # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        
#         with torch.no_grad():

#             positive_logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
            
#             # loss of predicting partner using peptide
#             positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
#             # Negaive indexes
#             rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)

#             pep_cm_list  = [contacts_pep[i] for i in rows.tolist()]  # list of [Li, 256]
#             prot_cm_list = [contacts_prot[j] for j in cols.tolist()]  # list of [Lj, 256]
#             negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], pep_cm_list, prot_cm_list, int_prob=0.0)
#             negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

#             loss = (positive_loss + negative_loss) / 2

#             logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
#             logit_matrix[rows, cols] = negative_logits
#             logit_matrix[cols, rows] = negative_logits
            
#             # Fill diagonal with positive scores
#             diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
#             logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

#             labels = torch.arange(embedding_prot.size(0)).to(self.device)
#             peptide_predictions = logit_matrix.argmax(dim=0)
#             peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
#             peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
#             # partner_accuracy = partner_predictions.eq(labels).float().mean()
#             peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
#             k = 3
#             peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
#             del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

#             return loss, peptide_accuracy, peptide_topk_accuracy
    
#     def validation_step_MetaDataset(self, batch, device):
#         embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = batch
#         embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
#         # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
#         labels = labels.to(device).float()
    
#         with torch.no_grad():
#             logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot).float()
#             loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
#             return logits, loss

#     def calculate_logit_matrix(self, embedding_pep, embedding_prot, contacts_pep, contacts_prot):
        
#         rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
#         pep_cm_list  = [contacts_pep[i] for i in rows.tolist()]  # list of [Li, 256]
#         prot_cm_list = [contacts_prot[j] for j in cols.tolist()]  # list of [Lj, 256]
        
#         positive_logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
#         # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contactcontacts_prots_target[cols,:,:], int_prob=0.0)
#         negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], pep_cm_list, prot_cm_list, int_prob=0.0)
        
#         logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
#         logit_matrix[rows, cols] = negative_logits
#         logit_matrix[cols, rows] = negative_logits
        
#         diag_indices = torch.arange(v.size(0), device=self.device)
#         logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
#         return logit_matrix

#### CNN

##### CNN with random image

In [18]:
# from pathlib import Path
# from PIL import Image
# import numpy as np
# import torch
# import torch.nn as nn
# import matplotlib.pyplot as plt
# import math

# # Load image as PIL
# img_path = Path("/zhome/c9/0/203261/dog_image.png")
# pil_img = Image.open(img_path).convert("RGB")  # ensure 3 channels

# # Resize using PIL (width, height)
# pil_resized = pil_img.resize((600, 600))

# # Convert to NumPy array (H, W, C)
# im = np.asarray(pil_resized)

# # Convert to torch tensor and to N, C, H, W, then to float
# im_torch = torch.from_numpy(im)          # (H, W, C), dtype uint8
# im_torch = im_torch.permute(2, 0, 1)     # -> (C, H, W)
# im_torch = im_torch.unsqueeze(0)         # -> (1, C, H, W)
# im_torch = im_torch.float() / 255.0      # -> float32 in [0, 1]

# # Plot original and resized
# plt.figure(figsize=(6, 2))
# plt.subplot(1, 2, 1)
# plt.title("Original")
# plt.imshow(pil_img)
# plt.axis("off")

# plt.subplot(1, 2, 2)
# plt.title("Resized 600x600")
# plt.imshow(pil_resized)
# plt.axis("off")

# plt.tight_layout()
# plt.show()


# class ContactCNN(nn.Module):
#     def __init__(self, c_out=32):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Conv2d(3, 16, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(16, c_out, kernel_size=3, padding=1),
#             nn.ReLU(),
#         )

#     def forward(self, x):
#         return self.net(x)


# cnn = ContactCNN(c_out=32)

# # Forward pass
# F = cnn(im_torch)                # (1, 32, 600, 600)
# F_squeeze = F.squeeze(0)         # (32, 600, 600)

# C, H, W = F_squeeze.shape

# cols = 8
# rows = math.ceil(C / cols)
# fig, axes = plt.subplots(rows, cols, figsize=(3.0*cols, 3.0*rows), constrained_layout=True)
# axes = axes.ravel()

# # Shared color scale across all maps
# vmin = float(F_squeeze.min().item())
# vmax = float(F_squeeze.max().item())

# im0 = None
# for i in range(rows * cols):
#     ax = axes[i]
#     if i < C:
#         M = F_squeeze[i].detach().cpu().numpy()  # (H, W)
#         im0 = ax.imshow(
#             M,
#             origin="lower",
#             interpolation="nearest",
#             cmap="binary",
#             vmin=vmin,
#             vmax=vmax
#         )
#         ax.set_title(f"Filter {i+1}", fontsize=10)
#         ax.set_xlabel("X")
#         ax.set_ylabel("Y")
#         ax.tick_params(labelsize=8)
#     else:
#         ax.axis("off")

# if im0 is not None:
#     fig.colorbar(im0, ax=axes.tolist(), fraction=0.02, pad=0.02)

# fig.suptitle("Conv Feature Maps (32 filters)", fontsize=14)
# plt.show()

In [19]:
# cont_np = np.load("/work3/s232958/data/PPint_DB/binders_contacts/6M9S_0_B.npy")  # [N, N], float or int
# fig = plt.figure(figsize=(6, 4))
# plt.title('Contact map')
# im = plt.imshow(cont_np, cmap='gray_r', aspect='auto')
# fig.colorbar(im, ax=plt.gca(), label='Value')

In [20]:
# # original contact map as a batch size (1, L_real, L_real)
# C_orig = torch.tensor(cont_np).unsqueeze(0)
# L_real, L_pad = cont_np.shape[0], 600 # 600- let's say padding to the longest sequence
# batch_size = 1

# # padding original input
# C_padded = torch.zeros(batch_size, L_pad, L_pad)  
# C_padded[:, :L_real, :L_real] = C_orig

# #cretaing mask for padded vlaues
# mask = torch.zeros(batch_size, L_pad, dtype=torch.bool)  # (1, L_pad)
# mask[:, :L_real] = True
# # mask

# #adding channel
# C_in = C_padded.unsqueeze(1)  # (1, 1, L_pad, L_pad)

##### CNN for contact maps to [B, d_model]

In [30]:
class SimpleCNN(nn.Module):
    def __init__(self, in_channels=1, channels=32, d_model=320):
        super().__init__()
        c = channels

        def block(cin, cout, stride=1):
            return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=3, padding=1, stride=stride, bias=False),
                nn.BatchNorm2d(cout),
                nn.ReLU(inplace=True),
            )

        self.convolution = nn.Sequential(
            block(in_channels, c,   stride=2),   # L -> L/2
            block(c,          c*2,  stride=2),   # L/2 -> L/4
            block(c*2,        c*4,  stride=2),   # L/4 -> L/8
            block(c*4,        c*8,  stride=2),   # L/8 -> L/16
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.project = nn.Linear(c*8, d_model)

    def forward(self, x):
        h = self.convolution(x)              # [B, c*8, ~L/16, ~L/16]
        h = self.pool(h).flatten(1)          # [B, c*8]
        return self.project(h)               # [B, d_model]


# class ContactCNN_Dilated(nn.Module):
#     """
#     DeepCov-like dilated CNN
#     Best for capturing long-range structure.
#     """
#     def __init__(self, in_channels=1, base_channels=64, d_model=320):
#         super().__init__()
#         C = base_channels

#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_channels, C, kernel_size=3, padding=1, dilation=1),
#             nn.ReLU())
        
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(C, C, kernel_size=3, padding=2, dilation=2),
#             nn.ReLU())
        
#         self.conv3 = nn.Sequential(
#             nn.Conv2d(C, C, kernel_size=3, padding=4, dilation=4),
#             nn.ReLU())
        
#         self.conv4 = nn.Sequential(
#             nn.Conv2d(C, C, kernel_size=3, padding=8, dilation=8),
#             nn.ReLU())
        
#         self.conv5 = nn.Sequential(
#             nn.Conv2d(C, C, kernel_size=3, padding=16, dilation=16),
#             nn.ReLU())

#         self.pool = nn.AdaptiveAvgPool2d((4, 4))
#         self.project = nn.Linear(C * 4 * 4, d_model)

#     def forward(self, x):
#         h = self.conv1(x)
#         h = self.conv2(h)
#         h = self.conv3(h)
#         h = self.conv4(h)
#         h = self.conv5(h)
#         h = self.pool(h)
#         h = h.flatten(start_dim=1)
#         return self.project(h)

class Fusion(nn.Module):
    def __init__(self, d_model=320):
        super().__init__()

        self.alpha = nn.Parameter(torch.tensor(0.0))
        self.norm = nn.LayerNorm(d_model)

    def forward(self, seq_embed, conv_struct):
        return self.norm(seq_embed + conv_struct * torch.tanh(self.alpha))

In [31]:
class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(
        self,
        padding_value=-5000,
        seq_embed_dimension = seq_embed_dimension,
        num_recycles=2
    ):
        super().__init__()
        self.num_recycles = num_recycles # 2
        self.padding_value = padding_value # -5000
        self.seq_embed_dimension = seq_embed_dimension # 1280

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init
        self.struct_alpha = nn.Parameter(torch.tensor(0.0))

        self.fusion = Fusion(d_model=320)

        ### SEQUENCE ###
        # self attention
        self.seq_encoder = nn.TransformerEncoderLayer(d_model=self.seq_embed_dimension, nhead=8, dropout=0.1, batch_first=True, dim_feedforward=self.seq_embed_dimension)
        self.norm_seq = nn.LayerNorm(self.seq_embed_dimension)  # For residual additions
        
        # cross attention
        self.seq_cross_attn = nn.MultiheadAttention(embed_dim=self.seq_embed_dimension, num_heads=8, dropout=0.1, batch_first=True)
        self.projection_head= nn.Sequential(nn.Linear(self.seq_embed_dimension, 640), nn.ReLU(), nn.Linear(640, 320))

        ### STRUCTURE ###        
        # Linear projection of flat patches
        self.contact_encoder = SimpleCNN(in_channels=1, channels=32, d_model=320)
                
    def forward(self, pep_emb, prot_emb, pep_contacts_list, prot_contacts_list, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True):
        
        pep_cmaps = pad_contact_maps(pep_contacts_list)  # [B, 1, Lmax, Lmax]
        prot_cmaps = pad_contact_maps(prot_contacts_list)  # [B, 1, Lmax, Lmax]
        device = pep_emb.device
        
        # Key padding masks (True = pad -> to be ignored by attention)
        pep_mask_emb = create_key_padding_mask(embeddings = pep_emb, padding_value = self.padding_value).to(device)   # [B, Lp]
        prot_mask_emb = create_key_padding_mask(embeddings = prot_emb, padding_value = self.padding_value).to(device)    # [B, Lt]
    
        # Residual states
        pep_emb, prot_emb = pep_emb.to(device).clone(), prot_emb.to(device).clone()
        pep_cmaps, prot_cmaps = pep_cmaps.to(device).clone(), prot_cmaps.to(device).clone()

        for _ in range(self.num_recycles):

            # Self-attention on sequences
            pep_trans = self.seq_encoder(self.norm_seq(pep_emb), src_key_padding_mask=pep_mask_emb)
            prot_trans = self.seq_encoder(self.norm_seq(prot_emb), src_key_padding_mask=prot_mask_emb)

            # Cross attention
            pep_cross, _ = self.seq_cross_attn(self.norm_seq(pep_trans), self.norm_seq(prot_trans), self.norm_seq(prot_trans), key_padding_mask=prot_mask_emb)
            prot_cross, _ = self.seq_cross_attn(self.norm_seq(prot_trans), self.norm_seq(pep_trans), self.norm_seq(pep_trans), key_padding_mask=pep_mask_emb)

            pep_emb = pep_emb + pep_cross
            prot_emb = prot_emb + prot_cross

        # Pool over true tokens
        pep_seq_vec = create_mean_of_non_masked(pep_emb, pep_mask_emb)
        prot_seq_vec = create_mean_of_non_masked(prot_emb, prot_mask_emb)

        pep_full   = self.projection_head(pep_seq_vec) # [B, 320]
        prot_full  = self.projection_head(prot_seq_vec)# [B, 320]

        # --- ContactsCNN (structure streams) ---
        pep_cnn = self.contact_encoder(pep_cmaps)   # [B, 320]
        prot_cnn = self.contact_encoder(prot_cmaps)  # [B, 320]

        pep_full = F.normalize(self.fusion(pep_full, pep_cnn), dim=-1)
        prot_full = F.normalize(self.fusion(prot_full, prot_cnn), dim=-1)
    
        if mem_save:
            torch.cuda.empty_cache()
    
        scale  = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_full * prot_full).sum(dim=-1)  # [B]
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = batch
        # embedding_pep, embedding_prot, contacts_pep, contacts_prot = embedding_pep.to(device), embedding_prot.to(device), contacts_pep.to(device), contacts_prot.to(device)

        # loss of predicting partner using peptide
        positive_logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device)) # F.binary_cross_entropy_with_logits does sigmoid transfromation inside, excepts data, labels
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1) # upper triangle
        
        pep_cm_list  = [contacts_pep[i] for i in rows.tolist()]  # list of [Li, 256]
        prot_cm_list = [contacts_prot[j] for j in cols.tolist()]  # list of [Lj, 256]

        # loss of predicting peptide using partner
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], pep_cm_list, prot_cm_list, int_prob=0.0)
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            pep_cm_list  = [contacts_pep[i] for i in rows.tolist()]  # list of [Li, 256]
            prot_cm_list = [contacts_prot[j] for j in cols.tolist()]  # list of [Lj, 256]
            
            negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], pep_cm_list, prot_cm_list, int_prob=0.0)
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2

            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        # contacts_pep, contacts_prot = contacts_pep.to(device), contacts_prot.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot).float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self, embedding_pep, embedding_prot, contacts_pep, contacts_prot):
        
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        pep_cm_list  = [contacts_pep[i] for i in rows.tolist()]  # list of [Li, 256]
        prot_cm_list = [contacts_prot[j] for j in cols.tolist()]  # list of [Lj, 256]
        
        positive_logits = self.forward(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
        # negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], contacts_pep[rows,:,:], contacts_prot[cols,:,:], int_prob=0.0)
        negative_logits = self.forward(embedding_pep[rows,:,:], embedding_prot[cols,:,:], pep_cm_list, prot_cm_list, int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

### Train model from scratch with 10% of PPint dataset

In [32]:
model = MiniCLIP_w_transformer_crossattn(
    seq_embed_dimension=seq_embed_dimension,
    num_recycles=number_of_recycles
).to("cuda")

model

MiniCLIP_w_transformer_crossattn(
  (fusion): Fusion(
    (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  )
  (seq_encoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
    )
    (linear1): Linear(in_features=1280, out_features=1280, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1280, out_features=1280, bias=True)
    (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm_seq): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (seq_cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
  )
  (projection_head): Sequential(
    (0): Linear(in_features=1280, 

In [33]:
np.arange(10)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

### Trianing loop

In [34]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 train_loader,
                 test_loader,
                 val_loader,
                 test_df,
                 test_dataset,
                 optimizer, 
                 epochs, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = train_loader
        self.testing_loader = test_loader
        self.validation_loader = val_loader
        self.test_dataset = test_dataset
        self.test_df = test_df
        self.auROC_batch_size = auROC_batch_size
        
        self.EPOCHS = epochs
        self.optimizer = optimizer
        self.device = device
        
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = 1_000_000
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.test_indexes_for_auROC = test_indexes_for_auROC

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            del loss, batch
            torch.cuda.empty_cache()
            
        return running_loss / len(self.training_loader)

    def calc_auroc_aupr_on_indexes(self, model, dataset, dataframe, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [dataframe.loc[index].target_binder_id for index in nondimer_indexes]  # <-- use dataframe
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                embedding_pep, embedding_prot, contacts_pep, contacts_prot, labels = dataset._get_by_name(index_batch)
                embedding_pep, embedding_prot = embedding_pep.to(self.device), embedding_prot.to(self.device)

                # Make sure this matches your model's signature:
                logit_matrix = self.model.calculate_logit_matrix(embedding_pep, embedding_prot, contacts_pep, contacts_prot)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr  = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores

    def validate(self):
        
        self.model.eval()
        
        running_loss_Meta = 0.0
        all_logits = []
        all_lbls = []
        used_batches_meta = 0

        # --- MetaDataset validation ---
        with torch.no_grad():
            for batch in tqdm(self.validation_loader, total=len(self.validation_loader)):
                if batch[0].size(0) == 1:
                    continue
                __, __, __, __, labels = batch
                logits, loss = self.model.validation_step_MetaDataset(batch, self.device)
                
                running_loss_Meta += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_meta += 1
                
            if used_batches_meta > 0:
                val_loss_Meta = running_loss_Meta / used_batches_meta
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                meta_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                meta_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_Meta = (y_pred == y_true).mean()
            else:
                val_loss_Meta = float("nan")
                meta_auroc = float("nan")
                meta_aupr = float("nan")
                val_acc_Meta = float("nan")

        # --- PPint validation ---
        running_loss_ValPPint = 0.0
        running_accuracy_ValPPint = 0.0
        used_batches_ppint = 0

        with torch.no_grad():
            for batch in tqdm(self.testing_loader, total=len(self.testing_loader)):
                if batch[0].size(0) == 1:
                    continue
                loss, peptide_accuracy = self.model.validation_step_PPint(batch, self.device)
                running_loss_ValPPint += loss.item()
                running_accuracy_ValPPint += peptide_accuracy.item()
                used_batches_ppint += 1
                
            if used_batches_ppint > 0:
                val_loss_PPint = running_loss_ValPPint / used_batches_ppint
                val_accuracy_PPint = running_accuracy_ValPPint / used_batches_ppint
            else:
                val_loss_PPint = float("nan")
                val_accuracy_PPint = float("nan")

        # --- AUROC on specific indexes (optional) ---
        if self.test_indexes_for_auROC is not None:
            non_dimer_auc, non_dimer_aupr, ___, ___ = self.calc_auroc_aupr_on_indexes(
                model=self.model, 
                dataset=self.test_dataset,
                dataframe=self.test_df,
                nondimer_indexes=self.test_indexes_for_auROC,
                batch_size=self.auROC_batch_size
            )
            
            return (val_loss_PPint, val_accuracy_PPint,
                    non_dimer_auc, non_dimer_aupr,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

        else:
            return (val_loss_PPint, val_accuracy_PPint,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

    def train_model(self):
        
        torch.cuda.empty_cache()
        
        if self.verbose:
            print(f"Training model {str(self.runID)}")

        # --- initial validation before training
        print("Initial validation before starting training")
        if self.test_indexes_for_auROC is not None:
            (val_loss_PPint, val_accuracy_PPint,
             non_dimer_auc, non_dimer_aupr,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
        else:
            (val_loss_PPint, val_accuracy_PPint,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            non_dimer_auc, non_dimer_aupr = None, None
                
        if self.verbose: 
            print(f'Before training:')
            print(f'Meta Val-Loss {round(val_loss_Meta,4)}')
            print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
            print(f'Meta AUROC: {round(meta_auroc,4)}')
            print(f'Meta AUPR: {round(meta_aupr,4)}')
            print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
            print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
            if non_dimer_auc is not None:
                print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
        
        if self.wandb_tracker:
            metrics_to_log = {
                "PPint Test-Loss": val_loss_PPint,
                "Meta Val-loss": val_loss_Meta,
                "PPint Accuracy": val_accuracy_PPint,
                "Meta Accuracy": val_acc_Meta,
                "Meta Val-AUROC": meta_auroc,
                "Meta Val-AUPR": meta_aupr,
            }
            if non_dimer_auc is not None:
                metrics_to_log.update({
                    "PPint non-dimer AUROC": non_dimer_auc,
                    "PPint non-dimer AUPR": non_dimer_aupr,
                })
            self.wandb_tracker.log(metrics_to_log)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC is not None:
                (val_loss_PPint, val_accuracy_PPint,
                 non_dimer_auc, non_dimer_aupr,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            else:
                (val_loss_PPint, val_accuracy_PPint,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
                non_dimer_auc, non_dimer_aupr = None, None
            
            torch.cuda.empty_cache()
            
            # checkpoint save
            if self.model_save_steps and epoch % self.model_save_steps == 0:
                check_point_folder = os.path.join(self.trained_model_dir, f"{str(self.runID)}_checkpoint_{str(epoch)}")
                if self.verbose:
                    print("Saving model to:", check_point_folder)
                os.makedirs(check_point_folder, exist_ok=True)
                checkpoint_path = os.path.join(check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")
                torch.save({'epoch': epoch, 
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(), 
                            'val_loss_PPint': val_loss_PPint,
                            'val_loss_Meta': val_loss_Meta},
                           checkpoint_path)
            
            # console logging
            if self.verbose and epoch % self.print_frequency_loss == 0:
                print(f'EPOCH {epoch}:')
                print(f'Meta Val Loss {round(val_loss_Meta,4)}')
                print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
                print(f'Meta AUROC: {round(meta_auroc,4)}')
                print(f'Meta AUPR: {round(meta_aupr,4)}')
                print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
                print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
                if non_dimer_auc is not None:
                    print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                    print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "PPint Train-loss": train_loss,
                    "PPint Test-Loss": val_loss_PPint,
                    "Meta Val-loss": val_loss_Meta,
                    "PPint Accuracy": val_accuracy_PPint,
                    "Meta Accuracy": val_acc_Meta,
                    "Meta Val-AUROC": meta_auroc,
                    "Meta Val-AUPR": meta_aupr,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch.update({
                        "PPint non-dimer AUROC": non_dimer_auc,
                        "PPint non-dimer AUPR": non_dimer_aupr,
                    })
                self.wandb_tracker.log(metrics_to_log_epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [35]:
learning_rate = 2e-5
EPOCHS = 12
batch_size = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

optimizer = AdamW(model.parameters(), lr=learning_rate)
accelerator = Accelerator()
device = accelerator.device

In [36]:
def collate_varlen(batch):
    b_emb = torch.stack([x[0] for x in batch], dim=0)  # fixed length -> stack
    t_emb = torch.stack([x[1] for x in batch], dim=0)
    b_ctok = [x[2] for x in batch]                     # var-len -> list
    t_ctok = [x[3] for x in batch]
    # lbls = torch.tensor([float(x[4]) for x in batch])
    lbls = torch.tensor([x[4].float() for x in batch])
    return b_emb, t_emb, b_ctok, t_ctok, lbls

train_dataloader = DataLoader(training_Dataset, batch_size=5, shuffle=True, drop_last=True, collate_fn=collate_varlen)
test_dataloader = DataLoader(testing_Dataset, batch_size=5, shuffle=False, collate_fn=collate_varlen)
val_dataloader = DataLoader(validation_Dataset, batch_size=5, shuffle=False, drop_last = False, collate_fn=collate_varlen)

# accelerator
model, optimizer, train_dataloader, test_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, val_dataloader)

In [37]:
# wandb
if use_wandb:
    run = wandb.init(
        project="CLIP_retrain_w_PPint0.1",
        name=f"Retrain_PPint0.1_ESM2_w_struct_CNN_{runID}",
        config={"learning_rate": learning_rate, 
                "batch_size": batch_size, 
                "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", 
                "dataset": 
                "PPint"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(
            model=model,
            train_loader=train_dataloader,
            test_loader=test_dataloader,
            val_loader=val_dataloader,
            test_df=Df_test,
            test_dataset=testing_Dataset,
            optimizer=optimizer,
            epochs=EPOCHS,
            runID=runID,
            device=device,
            test_indexes_for_auROC=indices_non_dimers_val,
            auROC_batch_size=10,
            model_save_steps=model_save_steps,
            model_save_path=trained_model_dir,
            v=True,
            wandb_tracker=wandb
)

training_wrapper.train_model() # start training

Training model 1b3d5c1d-4063-440a-bd96-ca9b118f3d47
Initial validation before starting training


100%|█████████████████████████████████████████████████████████████████████████████████| 707/707 [00:55<00:00, 12.65it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 99/99 [00:19<00:00,  5.10it/s]
Calculating AUC: 13it [00:07,  1.85it/s]                                                                                


Before training:
Meta Val-Loss 10.6886
Meta Accuracy: 0.1107
Meta AUROC: 0.4844
Meta AUPR: 0.1191
PPint Test-Loss: 6.4585
PPint Accuracy: 0.8545
PPint non-dimer AUROC: 0.6477
PPint non-dimer AUPR: 0.3743


Epochs:   0%|                                                                                    | 0/12 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<04:49,  1.36it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:01<03:45,  1.75it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:56,  1.66it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:58,  1.64it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:47,  1.71it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:39,  1.77it/s][A
Running thr

EPOCH 1:
Meta Val Loss 0.4224
Meta Accuracy: 0.8777
Meta AUROC: 0.4948
Meta AUPR: 0.1152
PPint Test-Loss: 0.2367
PPint Accuracy: 0.9192
PPint non-dimer AUROC: 0.7664
PPint non-dimer AUPR: 0.485



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<03:15,  2.02it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<03:16,  2.00it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:01,  2.16it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:21,  1.94it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:46,  1.72it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:40,  1.77it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:28,  1.86it/s][A
Running

EPOCH 2:
Meta Val Loss 0.5143
Meta Accuracy: 0.7998
Meta AUROC: 0.4856
Meta AUPR: 0.1051
PPint Test-Loss: 0.2363
PPint Accuracy: 0.9071
PPint non-dimer AUROC: 0.7391
PPint non-dimer AUPR: 0.4088



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<02:26,  2.69it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<03:20,  1.96it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:40,  1.78it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:41,  1.77it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:30,  1.86it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:16,  1.98it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:25,  1.89it/s][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/1b3d5c1d-4063-440a-bd96-ca9b118f3d47/1b3d5c1d-4063-440a-bd96-ca9b118f3d47_checkpoint_3


Epochs:  25%|██████████████████▊                                                        | 3/12 [14:18<42:48, 285.36s/it]

EPOCH 3:
Meta Val Loss 0.4217
Meta Accuracy: 0.8754
Meta AUROC: 0.4851
Meta AUPR: 0.1051
PPint Test-Loss: 0.2317
PPint Accuracy: 0.9152
PPint non-dimer AUROC: 0.8081
PPint non-dimer AUPR: 0.5012



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<02:05,  3.13it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<02:29,  2.62it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:05,  2.11it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:01<03:24,  1.91it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:31,  1.84it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:02<03:20,  1.94it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:11,  2.02it/s][A
Running

EPOCH 4:
Meta Val Loss 0.4163
Meta Accuracy: 0.8633
Meta AUROC: 0.5035
Meta AUPR: 0.1177
PPint Test-Loss: 0.2005
PPint Accuracy: 0.9273
PPint non-dimer AUROC: 0.81
PPint non-dimer AUPR: 0.5159



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<03:40,  1.79it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:01<03:40,  1.78it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:25,  1.91it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:09,  2.06it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:25,  1.90it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:37,  1.79it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:37,  1.79it/s][A
Running

EPOCH 5:
Meta Val Loss 0.4146
Meta Accuracy: 0.8865
Meta AUROC: 0.5215
Meta AUPR: 0.1191
PPint Test-Loss: 0.1954
PPint Accuracy: 0.9172
PPint non-dimer AUROC: 0.8281
PPint non-dimer AUPR: 0.5151



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<03:03,  2.15it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:01<03:38,  1.80it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:37,  1.80it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:30,  1.85it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:23,  1.91it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:15,  1.99it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:25,  1.89it/s][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/1b3d5c1d-4063-440a-bd96-ca9b118f3d47/1b3d5c1d-4063-440a-bd96-ca9b118f3d47_checkpoint_6


Epochs:  50%|█████████████████████████████████████▌                                     | 6/12 [28:30<28:27, 284.55s/it]

EPOCH 6:
Meta Val Loss 0.4198
Meta Accuracy: 0.872
Meta AUROC: 0.4948
Meta AUPR: 0.1082
PPint Test-Loss: 0.2094
PPint Accuracy: 0.9273
PPint non-dimer AUROC: 0.8053
PPint non-dimer AUPR: 0.4959



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<02:25,  2.71it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<03:15,  2.01it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:41,  1.77it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:45,  1.73it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:32,  1.83it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:16,  1.98it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:28,  1.86it/s][A
Running

EPOCH 7:
Meta Val Loss 0.3969
Meta Accuracy: 0.8783
Meta AUROC: 0.5471
Meta AUPR: 0.1314
PPint Test-Loss: 0.2017
PPint Accuracy: 0.9293
PPint non-dimer AUROC: 0.8289
PPint non-dimer AUPR: 0.548



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<02:21,  2.79it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<03:10,  2.07it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:29,  1.87it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:36,  1.80it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:26,  1.89it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:14,  2.00it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:22,  1.92it/s][A
Running

EPOCH 8:
Meta Val Loss 0.5383
Meta Accuracy: 0.7424
Meta AUROC: 0.5896
Meta AUPR: 0.1415
PPint Test-Loss: 0.2603
PPint Accuracy: 0.9313
PPint non-dimer AUROC: 0.8202
PPint non-dimer AUPR: 0.5463



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<03:25,  1.92it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:01<03:48,  1.72it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:51,  1.69it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:32,  1.84it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:21,  1.94it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:29,  1.86it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:41,  1.75it/s][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/1b3d5c1d-4063-440a-bd96-ca9b118f3d47/1b3d5c1d-4063-440a-bd96-ca9b118f3d47_checkpoint_9


Epochs:  75%|████████████████████████████████████████████████████████▎                  | 9/12 [42:38<14:09, 283.12s/it]

EPOCH 9:
Meta Val Loss 0.4515
Meta Accuracy: 0.8834
Meta AUROC: 0.5585
Meta AUPR: 0.1288
PPint Test-Loss: 0.2751
PPint Accuracy: 0.9212
PPint non-dimer AUROC: 0.8007
PPint non-dimer AUPR: 0.5002



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<03:59,  1.65it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:01<03:43,  1.76it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<03:30,  1.86it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:02<03:19,  1.96it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:02<03:12,  2.03it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:03<03:23,  1.91it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:03<03:26,  1.88it/s][A
Running

EPOCH 10:
Meta Val Loss 0.5386
Meta Accuracy: 0.8893
Meta AUROC: 0.5095
Meta AUPR: 0.1127
PPint Test-Loss: 0.2908
PPint Accuracy: 0.9293
PPint non-dimer AUROC: 0.8287
PPint non-dimer AUPR: 0.5054



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<02:01,  3.25it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<02:09,  3.03it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<02:12,  2.96it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:01<02:15,  2.89it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:01<02:16,  2.85it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:02<02:17,  2.83it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:02<02:17,  2.83it/s][A
Running

EPOCH 11:
Meta Val Loss 0.4595
Meta Accuracy: 0.8851
Meta AUROC: 0.5767
Meta AUPR: 0.1288
PPint Test-Loss: 0.2644
PPint Accuracy: 0.9313
PPint non-dimer AUROC: 0.8129
PPint non-dimer AUPR: 0.5062



Running through epoch:   0%|                                                                    | 0/395 [00:00<?, ?it/s][A
Running through epoch:   0%|▏                                                           | 1/395 [00:00<02:02,  3.22it/s][A
Running through epoch:   1%|▎                                                           | 2/395 [00:00<02:12,  2.97it/s][A
Running through epoch:   1%|▍                                                           | 3/395 [00:01<02:16,  2.87it/s][A
Running through epoch:   1%|▌                                                           | 4/395 [00:01<02:16,  2.87it/s][A
Running through epoch:   1%|▊                                                           | 5/395 [00:01<02:16,  2.86it/s][A
Running through epoch:   2%|▉                                                           | 6/395 [00:02<02:16,  2.86it/s][A
Running through epoch:   2%|█                                                           | 7/395 [00:02<02:14,  2.88it/s][A
Running

Saving model to: /work3/s232958/data/trained/with_structure/1b3d5c1d-4063-440a-bd96-ca9b118f3d47/1b3d5c1d-4063-440a-bd96-ca9b118f3d47_checkpoint_12


Epochs: 100%|██████████████████████████████████████████████████████████████████████████| 12/12 [52:48<00:00, 264.03s/it]

EPOCH 12:
Meta Val Loss 0.4759
Meta Accuracy: 0.8797
Meta AUROC: 0.4963
Meta AUPR: 0.1083
PPint Test-Loss: 0.2422
PPint Accuracy: 0.9273
PPint non-dimer AUROC: 0.8212
PPint non-dimer AUPR: 0.5106





0,1
Meta Accuracy,▁█▇█████▇████
Meta Val-AUPR,▄▃▁▁▃▄▂▆█▆▂▆▂
Meta Val-AUROC,▁▂▁▁▂▃▂▅█▆▃▇▂
Meta Val-loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Accuracy,▁▇▆▇█▇███▇███
PPint Test-Loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Train-loss,█▄▃▃▂▂▂▂▂▁▁▁
PPint non-dimer AUPR,▁▅▂▆▇▇▆██▆▆▆▆
PPint non-dimer AUROC,▁▆▅▇▇█▇██▇█▇█

0,1
Meta Accuracy,0.87967
Meta Val-AUPR,0.10828
Meta Val-AUROC,0.49632
Meta Val-loss,0.47594
PPint Accuracy,0.92727
PPint Test-Loss,0.24224
PPint Train-loss,0.0729
PPint non-dimer AUPR,0.51056
PPint non-dimer AUROC,0.82117
