In [1]:
import uuid, sys, os
import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats
from collections import Counter

from transformers import EsmModel, AutoTokenizer # huggingface
from peft import LoraConfig, get_peft_model
import esm

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.set_device(0)  # 0 == "first visible" -> actually GPU 2 on the node
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator
torch.cuda.empty_cache()
import training_utils.partitioning_utils as pat_utils
from tqdm import trange

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Tesla V100-SXM2-32GB


  warn(


In [2]:
import requests
requests.get("https://api.wandb.ai/status").status_code

import wandb
wandb.login(key="f8a6d759fe657b095d56bddbdb4d586dfaebd468", relogin=True)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /zhome/c9/0/203261/.netrc
[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
### Setting a seed to have the same initiation of weights

def set_seed(seed: int = 42):
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)
    
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # CuDNN settings (for convolution etc.)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # (Optional) for some Python hashing randomness
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 0
set_seed(SEED)

In [4]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.9.1+cu128
Using device: cuda
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [5]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 3
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1280 #| 960 | 1152
number_of_recycles = 2
padding_value = -5000

In [6]:
# ## Training variables
runID = uuid.uuid4()

## Output path
trained_model_dir = f"/work3/s232958/data/trained/original_architecture/{runID}"

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB
print_mem_consumption()

Total memory:  34.072559616
Reserved memory:  0.0
Allocated memory:  0.0
Free memory:  0.0


### Loading seq_encoder and proj_head

In [7]:
def to_numpy(x):
    """Safely convert torch tensor to numpy."""
    if torch.is_tensor(x):
        return x.detach().cpu().numpy()
    return np.asarray(x)

class ESM2EncoderLoRA(nn.Module):
    def __init__(self, padding_value=-5000.0):
        super().__init__()

        self.padding_value = padding_value

        self.model = EsmModel.from_pretrained(
            "facebook/esm2_t33_650M_UR50D",
            output_hidden_states=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

        # Freeze original weights
        for p in self.model.parameters():
            p.requires_grad = False

        # LoRA on top layers
        lora_cfg = LoraConfig(
            task_type="FEATURE_EXTRACTION",
            inference_mode=False,
            r=4,
            lora_alpha=1,
            lora_dropout=0.1,
            bias="none",
            # target_modules=["query", "key", "value", "dense"],
            target_modules=["query", "key", "value", "dense"],
            layers_to_transform=list(range(25, 33)),
        )

        self.model = get_peft_model(self.model, lora_cfg)

    @torch.no_grad()
    def get_attentions(self, sequences):
        inputs = self.tokenizer(
            sequences, return_tensors="pt", padding=True
        ).to(self.model.device)

        out = self.model(**inputs, output_attentions=True)
        return out.attentions   # list[num_layers] → [B, num_heads, L, L]

    def forward(self, sequences):
        inputs = self.tokenizer(
            sequences, return_tensors="pt", padding=True
        ).to(self.model.device)

        out = self.model(**inputs)
        reps = out.hidden_states[-1]                  # [B, Ltok, 1280]
        # reps = reps[:, 1:-1, :]                       # remove CLS/EOS

        # seq_lengths = [len(s) for s in sequences]
        # Lmax = max(seq_lengths)

        # B, D = reps.size(0), reps.size(-1)
        # padded = torch.full((B, Lmax, D), self.padding_value, device=reps.device)

        # for i, (r, real_len) in enumerate(zip(reps, seq_lengths)):
        #     padded[i, :real_len] = r[:real_len]

        return reps

### ESM-C boosting

In [None]:
seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/fd002c98-cf05-4f5a-8684-af9cc03cf6dc/seq_encoder_cos-sim0.2.pt"

seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
seq_encoder = ESM2EncoderLoRA()
seq_encoder.load_state_dict(seq_encoder_state_dict)
seq_encoder.to(device)
seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.1.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.2.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.3.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.15.pt"
seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/fd002c98-cf05-4f5a-8684-af9cc03cf6dc/seq_down_cos-sim0.2.pt"

seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
seq_down = nn.Linear(1152, 512)
seq_down.load_state_dict(seq_down_state_dict)
seq_down.to(device)
seq_down.eval()

### Loading trained for 0.1/0.2/0.3/0.4/0.5 cos-sim (combined Loss 2.0)

In [8]:
# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_encoder_cos-sim0.1.pt"
seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_encoder_cos-sim0.2.pt"
# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_encoder_cos-sim0.3.pt"
# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_encoder_cos-sim0.15.pt"
# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_encoder_cos-sim0.35.pt"

seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
seq_encoder = ESM2EncoderLoRA()
seq_encoder.load_state_dict(seq_encoder_state_dict)
seq_encoder.to(device)
seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.1.pt"
seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.2.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.3.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.15.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_combinedLoss02/2d2b1f23-4d4e-49b0-a768-7625349e0800/seq_down_cos-sim0.35.pt"

seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
seq_down = nn.Linear(1280, 512)
seq_down.load_state_dict(seq_down_state_dict)
seq_down.to(device)
seq_down.eval()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Linear(in_features=1280, out_features=512, bias=True)

### Loading trained for 0.3/0.4/0.5 cos-sim (Token-level Loss)

In [None]:
# # seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_NewLoss/5e338957-e280-40f1-a29f-b0c6b1551994/seq_encoder_cos-sim0.15.pt"
# # seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_NewLoss/5e338957-e280-40f1-a29f-b0c6b1551994/seq_encoder_cos-sim0.3.pt"
# # seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_NewLoss/5e338957-e280-40f1-a29f-b0c6b1551994/seq_encoder_cos-sim0.4.pt"
# # seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_NewLoss/5e338957-e280-40f1-a29f-b0c6b1551994/seq_encoder_cos-sim0.5.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# # seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_NewLoss/5e338957-e280-40f1-a29f-b0c6b1551994/seq_down_cos-sim0.15.pt"
# # seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_down_cos-sim0.3.pt"
# # seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_NewLoss/5e338957-e280-40f1-a29f-b0c6b1551994/seq_down_cos-sim0.4.pt"
# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_down_cos-sim0.5.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

### Loading trained for 0.15/0.2/0.3 cos-sim (Combined Loss)

In [None]:
device = "cuda"
### /work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee
### seq_encoder_1.pt cos-sim ~ 0.05
### [VAL]  avg cos: 0.0563, std: 0.0121

# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_encoder_1.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_down_1.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

In [None]:
### /work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee
### seq_encoder_3.pt cos-sim ~ 0.10
### [VAL]  avg cos: 0.1166, std: 0.0175

# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_encoder_3.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_down_3.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

In [None]:
# ### /work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee
# ### seq_encoder_5.pt cos-sim ~ 0.15
# ### [VAL]  avg cos: 0.1361, std: 0.0190 from 20_02

# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_encoder_5.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_down_5.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

In [None]:
### /work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee
### seq_encoder_25.pt cos-sim ~ 0.15
### [VAL]  avg cos: 0.1529, std: 0.0200

seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_encoder_25.pt"
seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
seq_encoder = ESM2EncoderLoRA()
seq_encoder.load_state_dict(seq_encoder_state_dict)
seq_encoder.to(device)
seq_encoder.eval()

seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_down_25.pt"
seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
seq_down = nn.Linear(1280, 512)
seq_down.load_state_dict(seq_down_state_dict)
seq_down.to(device)
seq_down.eval()

In [None]:
# ### /work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee
# ### seq_encoder_50.pt cos-sim ~ 0.17

# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_encoder_50.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/4334cb1c-12f0-4fa2-982c-da9f78779bee/seq_down_50.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

In [None]:
# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_encoder_cos-sim0.2.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_down_cos-sim0.2.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

In [None]:
# seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_encoder_cos-sim0.3.pt"
# seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
# seq_encoder = ESM2EncoderLoRA()
# seq_encoder.load_state_dict(seq_encoder_state_dict)
# seq_encoder.to(device)
# seq_encoder.eval()

# seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_down_cos-sim0.3.pt"
# seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
# seq_down = nn.Linear(1280, 512)
# seq_down.load_state_dict(seq_down_state_dict)
# seq_down.to(device)
# seq_down.eval()

In [None]:
seq_encoder_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_encoder_cos-sim0.35.pt"
seq_encoder_state_dict = torch.load(seq_encoder_checkpoint_path, map_location=device)
seq_encoder = ESM2EncoderLoRA()
seq_encoder.load_state_dict(seq_encoder_state_dict)
seq_encoder.to(device)
seq_encoder.eval()

seq_down_checkpoint_path = "/work3/s232958/data/trained/boostingESM2wESMIF/train_on_PPint_OldLoss/39c990f3-0db5-430a-8095-075f38a5e808/seq_down_cos-sim0.35.pt"
seq_down_state_dict = torch.load(seq_down_checkpoint_path, map_location=device)
seq_down = nn.Linear(1280, 512)
seq_down.load_state_dict(seq_down_state_dict)
seq_down.to(device)
seq_down.eval()

### Loading data

In [9]:
Df_train_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train.csv",index_col=0).reset_index(drop=True)
Df_train_small = Df_train_small[~Df_train_small.target_binder_id.str.startswith("6BJP")]
Df_train_small["interface_id"] = [row.ID1.split("_")[0]+"_"+row.ID1.split("_")[1] for __, row in Df_train_small.iterrows()]
Df_test_small = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test.csv",index_col=0).reset_index(drop=True)
Df_test_small["interface_id"] = [row.ID1.split("_")[0]+"_"+row.ID1.split("_")[1] for __, row in Df_test_small.iterrows()] 


Df_train = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_train_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_train = Df_train.merge(Df_train_small[["dimer", "interface_id"]], on = "interface_id", how="inner")
Df_test = pd.read_csv("/work3/s232958/data/PPint_DB/PPint_test_w_pbd_lens.csv",index_col=0).reset_index(drop=True)
Df_test = Df_test.merge(Df_test_small[["dimer", "interface_id"]], on = "interface_id", how="inner")
Df_train = Df_train[~Df_train.PDB.str.startswith("6BJP")]

Df_train

Unnamed: 0,interface_id,PDB,ID1,ID2,seq_target,seq_target_len,seq_pdb_target,pdb_target_len,target_chain,seq_binder,seq_binder_len,seq_pdb_binder,pdb_binder_len,binder_chain,pdb_path,dimer
0,6IDB_0,6IDB,6IDB_0_A,6IDB_0_B,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,DKICLGHHAVSNGTKVNTLTERGVEVVNATETVERTNIPRICSKGK...,317,A,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,GLFGAIAGFIENGWEGLIDGWYGFRHQNAQGEGTAADYKSTQSAID...,172,B,6idb.pdb.gz,False
1,2WZP_3,2WZP,2WZP_3_D,2WZP_3_G,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,VQLQESGGGLVQAGGSLRLSCTASRRTGSNWCMGWFRQLAGKEPEL...,122,D,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,TIKNFTFFSPNSTEFPVGSNNDGKLYMMLTGMDYRTIRRKDWSSPL...,266,G,2wzp.pdb.gz,False
2,1ZKP_0,1ZKP,1ZKP_0_A,1ZKP_0_C,LYFQSNAKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLA...,246,LYFQSNAMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGV...,251,A,AKTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQKYI...,240,AMKMTVVGFWGGFPEAGEATSGYLFEHDGFRLLVDCGSGVLAQLQK...,245,C,1zkp.pdb.gz,True
3,6GRH_3,6GRH,6GRH_3_C,6GRH_3_D,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,SKHELSLVEVTHYTDPEVLAIVKDFHVRGNFASLPEFAERTFVSAV...,266,C,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,MINVYSNLMSAWPATMAMSPKLNRNMPTFSQIWDYERITPASAAGE...,396,D,6grh.pdb.gz,False
4,8R57_1,8R57,8R57_1_M,8R57_1_f,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,DLMTALQLVMKKSSAHDGLVKGLREAAKAIEKHAAQICVLAEDCDQ...,118,M,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,PKKQKHKHKKVKLAVLQFYKVDDATGKVTRLRKECPNADCGAGTFM...,64,f,8r57.pdb.gz,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1972,4YO8_0,4YO8,4YO8_0_A,4YO8_0_B,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,HENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNVFHKG...,238,A,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,HHHHHENLYFQGVQKIGILGAMREEITPILELFGVDFEEIPLGGNV...,242,B,4yo8.pdb.gz,True
1973,3CKI_0,3CKI,3CKI_0_A,3CKI_0_B,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,DPMKNTCKLLVVADHRFYRYMGRGEESTTTNYLIELIDRVDDIYRN...,256,A,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,CTCSPSHPQDAFCNSDIVIRAKVVGKKLVKEGPFGTLVYTIKQMKM...,121,B,3cki.pdb.gz,False
1974,7MHY_1,7MHY,7MHY_1_M,7MHY_1_N,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,QVQLRQSGAELAKPGASVKMSCKASGYTFTNYWLHWIKQRPGQGLE...,118,M,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,DVLMTQTPLSLPVSLGDQVSISCRSSQSIVHNTYLEWYLQKPGQSP...,109,N,7mhy.pdb.gz,False
1975,7MHY_2,7MHY,7MHY_2_O,7MHY_2_P,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,IQLVQSGPELVKISCKASGYTFTNYGMNWVRQAPGKGLKWMGWINT...,100,O,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,VLMTQTPLSLPVSISCRSSQSIVHSNGNTYLEWYLQKPGQSPKLLI...,94,P,7mhy.pdb.gz,False


In [10]:
Df_test

Unnamed: 0,interface_id,PDB,ID1,ID2,seq_target,seq_target_len,seq_pdb_target,pdb_target_len,target_chain,seq_binder,seq_binder_len,seq_pdb_binder,pdb_binder_len,binder_chain,pdb_path,dimer
0,1NNW_0,1NNW,1NNW_0_A,1NNW_0_B,VYVAVLANIAGNLPALTAALSRIEEMREEGYEIEKYYILGNIVGLF...,251,VYVAVLANIAGNLPALTAALSRIEEMREEGYEIEKYYILGNIVGLF...,251,A,VYVAVLANIAGNLPALTAALSRIEEMREEGYEIEKYYILGNIVGLF...,251,MVYVAVLANIAGNLPALTAALSRIEEMREEGYEIEKYYILGNIVGL...,252,B,1nnw.pdb.gz,True
1,3UCN_0,3UCN,3UCN_0_A,3UCN_0_B,TADLSPLLEANRKWADECAAKDSTYFSKVAGSQAPEYLYIGCADSR...,222,TADLSPLLEANRKWADECAAKDSTYFSKVAGSQAPEYLYIGCADSR...,222,A,TADLSPLLEANRKWADECAAKDSTYFSKVAGSQAPEYLYIGCADSR...,222,TADLSPLLEANRKWADECAAKDSTYFSKVAGSQAPEYLYIGCADSR...,222,B,3ucn.pdb.gz,True
2,1POV_1,1POV,1POV_1_1,1POV_1_3,QHRSRSESSIESFFARGACVTIMTVDNPASTTNKDKLFAVWKITYK...,235,QHRSRSESSIESFFARGACVTIMTVDNPASTTNKDKLFAVWKITYK...,235,1,GLPVMNTPGSNQYLTADNFQSPCALPEFDVTPPIDIPGEVKNMMEL...,238,GLPVMNTPGSNQYLTADNFQSPCALPEFDVTPPIDIPGEVKNMMEL...,238,3,1pov.pdb.gz,False
3,3R6Y_2,3R6Y,3R6Y_2_C,3R6Y_2_D,VRIEKDFLGEKEIPKDAYYGVQTIRATENFPITGYRIHPELIKSLG...,383,VRIEKDFLGEKEIPKDAYYGVQTIRATENFPITGYRIHPELIKSLG...,383,C,VRIEKDFLGEKEIPKDAYYGVQTIRATENFPITGYRIHPELIKSLG...,390,VRIEKDFLGEKEIPKDAYYGVQTIRATENFPITGYRIHPELIKSLG...,390,D,3r6y.pdb.gz,True
4,5YHI_0,5YHI,5YHI_0_A,5YHI_0_B,PMRYPVDVYTGKIQVDGELMLTELGLEGDGPDRALCHYPREHYLYW...,202,PMRYPVDVYTGKIQVDGELMLTELGLEGDGPDRALCHYPREHYLYW...,202,A,PMRYPVDVYTGKIQVDGELMLTELGLEGDGPDRALCHYPREHYLYW...,201,PMRYPVDVYTGKIQVDGELMLTELGLEGDGPDRALCHYPREHYLYW...,201,B,5yhi.pdb.gz,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
489,3GXE_0,3GXE,3GXE_0_B,3GXE_0_F,DQCIVDDITYNVQDTFHKKHEEGHMLNCTCFGQGRGRWKCDPVDQC...,89,DQCIVDDITYNVQDTFHKKHEEGHMLNCTCFGQGRGRWKCDPVDQC...,89,B,GLGMKGHRGF,10,GLPGMKGHRGF,11,F,3gxe.pdb.gz,False
490,6LY5_21,6LY5,6LY5_21_d,6LY5_21_l,PSPIFGGSTGGWLRKAQVEEKYVITWDSPKEQIFEMPTGGAAIMRE...,132,PSPIFGGSTGGWLRKAQVEEKYVITWDSPKEQIFEMPTGGAAIMRE...,132,d,ANFIKPYNDDPFVGHLATPITSSAVTRSLLKNLPAYRFGLTPLLRG...,144,ANFIKPYNDDPFVGHLATPITSSAVTRSLLKNLPAYRFGLTPLLRG...,144,l,6ly5.pdb.gz,False
491,5MLK_0,5MLK,5MLK_0_A,5MLK_0_B,ARISKVLVANRGEIAVRVIRAARDAGLPSVAVYAEPDAESPHVRLA...,451,ARISKVLVANRGEIAVRVIRAARDAGLPSVAVYAEPDAESPHVRLA...,451,A,ARISKVLVANRGEIAVRVIRAARDAGLPSVAVYAEPDAESPHVRLA...,384,ARISKVLVANRGEIAVRVIRAARDAGLPSVAVYAEPDAESPHVRLA...,384,B,5mlk.pdb.gz,True
492,8BS4_0,8BS4,8BS4_0_A,8BS4_0_B,HPVLEKLKAAHSYNPKEFEWNLKSGRVFIIKSYSEDDIHRSIKYSI...,195,HPVLEKLKAAHSYNPKEFEWNLKSGRVFIIKSYSEDDIHRSIKYSI...,195,A,GHPVLEKLKAAHSYNPKEFEWNLKSGRVFIIKSYSEDDIHRSIKYS...,193,GHPVLEKLKAAHSYNPKEFEWNLKSGRVFIIKSYSEDDIHRSIKYS...,193,B,8bs4.pdb.gz,True


In [11]:
def compute_esm2_embeddings_PPint(encoder, proj_head, df, out_dir):
    
    os.makedirs(out_dir, exist_ok=True)
    encoder.eval()
    proj_head.eval()
    existing = set(f for f in os.listdir(out_dir) if f.endswith(".npy"))

    with torch.no_grad():
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Computing ESM2 embeddings (BOOSTED)"):
            pdb = row["PDB"]
            binder_chain = row["binder_chain"]
            target_chain = row["target_chain"]

            binder_id = f"{pdb}_{binder_chain}"
            target_id = f"{pdb}_{target_chain}"

            seq_binder = row["seq_pdb_binder"]
            seq_target = row["seq_pdb_target"]

            binder_fname = f"{binder_id}.npy"
            target_fname = f"{target_id}.npy"

            # ---- Binder ----
            if binder_fname not in existing:
                reps_b = encoder([seq_binder])     # [1, Ltok, Din]
                reps_b = proj_head(reps_b)         # [1, Ltok, Dproj] or [1, Ltok, Din] depending on head
                b_emb = reps_b[0].detach().cpu().numpy()  # [Ltok, D]

                assert b_emb.shape[0] == len(seq_binder) + 2, \
                    f"{binder_id}: {b_emb.shape[0]} vs {len(seq_binder)+2}"

                np.save(os.path.join(out_dir, binder_fname), b_emb)
                existing.add(binder_fname)

            # ---- Target ----
            if target_fname not in existing:
                reps_t = encoder([seq_target])
                reps_t = proj_head(reps_t)
                t_emb = reps_t[0].detach().cpu().numpy()

                assert t_emb.shape[0] == len(seq_target) + 2, \
                    f"{target_id}: {t_emb.shape[0]} vs {len(seq_target)+2}"

                np.save(os.path.join(out_dir, target_fname), t_emb)
                existing.add(target_fname)


out_dir = "/work3/s232958/data/PPint_DB/embeddings_esm2_boosted"

### Computing for Df_train
compute_esm2_embeddings_PPint(seq_encoder, seq_down, Df_train, out_dir)

### Computing for Df_test
compute_esm2_embeddings_PPint(seq_encoder, seq_down, Df_test, out_dir)

Computing ESM2 embeddings (BOOSTED): 100%|████████████████████████████████████████████████| 1977/1977 [03:44<00:00,  8.79it/s]
Computing ESM2 embeddings (BOOSTED): 100%|██████████████████████████████████████████████████| 494/494 [01:29<00:00,  5.51it/s]


In [12]:
class CLIP_PPint_class(Dataset):
    def __init__(
        self,
        dframe,
        path,
        embedding_dim=1280,
        embedding_pad_value=-5000.0,
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)

        # lengths
        self.max_blen = self.dframe["pdb_binder_len"].max()+2
        self.max_tlen = self.dframe["pdb_target_len"].max()+2

        # paths
        self.encoding_path  = path

        # index & storage
        self.dframe.set_index("interface_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings and contacts"):
            tgt_id = accession.split("_")[0]+"_"+str(self.dframe.loc[accession].target_chain)
            bnd_id = accession.split("_")[0]+"_"+str(self.dframe.loc[accession].binder_chain)

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_path, f"{tgt_id}.npy")) # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_path, f"{bnd_id}.npy")) # [Lt, D]

            # print(b_emb.shape[0], self.dframe.loc[accession].seq_binder_len)
            assert (b_emb.shape[0] == self.dframe.loc[accession].pdb_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].pdb_target_len+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        label = torch.tensor(1, dtype=torch.float32)  # single scalar labe
        return binder_emb, target_emb, label

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

emb_path = "/work3/s232958/data/PPint_DB/embeddings_esm2_boosted"

training_Dataset = CLIP_PPint_class(
    Df_train,
    path=emb_path,
    embedding_dim=512
)

testing_Dataset = CLIP_PPint_class(
    Df_test,
    path=emb_path,
    embedding_dim=512
)

#Loading ESM2 embeddings and contacts: 100%|█████████████████████████████████████████████| 1977/1977 [00:07<00:00, 278.45it/s]
#Loading ESM2 embeddings and contacts: 100%|███████████████████████████████████████████████| 494/494 [00:02<00:00, 217.47it/s]


In [13]:
### Getting indeces of non-dimers
indices_non_dimers_val = Df_test[~Df_test["dimer"]].index.tolist()
indices_non_dimers_val[:5]

### Getting accessions of non-dimers
accessions = [Df_test.loc[index].interface_id for index in indices_non_dimers_val]
emb_b, emb_t, labels = testing_Dataset._get_by_name(accessions[:5])
labels

tensor([1., 1., 1., 1., 1.])

### Loading Meta validation dataset

In [14]:
interaction_df = pd.read_csv("/work3/s232958/data/meta_analysis/interaction_df_metaanal.csv")[["A_seq", "B_seq", "target_id_mod", "target_binder_ID", "binder"]].rename(columns = {
    "A_seq" : "seq_binder",
    "B_seq" : "seq_target",
    "target_binder_ID" : "binder_id",
    "target_id_mod" : "target_id",
    "binder" : "binder_label"
})
interaction_df["seq_target_len"] = [len(seq) for seq in interaction_df["seq_target"].tolist()]
interaction_df["seq_binder_len"] = [len(seq) for seq in interaction_df["seq_binder"].tolist()]

# Targets df
target_df = interaction_df[["target_id","seq_target"]].rename(columns={"seq_target":"sequence", "target_id" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["binder_id","seq_binder"]].rename(columns={"seq_binder":"sequence", "binder_id" : "ID"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")

# target_df

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id"], interaction_df["binder_id"]), start=1))
interaction_df_shuffled = interaction_df.sample(frac=1, random_state=0).reset_index(drop=True)
interaction_df_shuffled

Unnamed: 0,seq_binder,seq_target,target_id,binder_id,binder_label,seq_target_len,seq_binder_len
0,DIVEEAHKLLSRAMSEAMENDDPDKLRRANELYFKLEEALKNNDPK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_124,True,101,62
1,SEELVEKVVEEILNSDLSNDQKILETHDRLMELHDQGKISKEEYYK...,LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYV...,EGFR_2,EGFR_2_149,False,621,58
2,TINRVFHLHIQGDTEEARKAHEELVEEVRRWAEELAKRLNLTVRVT...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_339,False,101,65
3,DDLRKVERIASELAFFAAEQNDTKVAFTALELIHQLIRAIFHNDEE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1234,False,101,64
4,DEEVEELEELLEKAEDPRERAKLLRELAKLIRRDPRLRELATEVVA...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_48,False,165,65
...,...,...,...,...,...,...,...
3527,SEDELRELVKEIRKVAEKQGDKELRTLWIEAYDLLASLWYGAADEL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...,SARS_CoV2_RBD,SARS_CoV2_RBD_25,False,195,63
3528,TEEEILKMLVELTAHMAGVPDVKVEIHNGTLRVTVNGDTREARSVL...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2027,False,101,65
3529,VEELKEARKLVEEVLRKKGDQIAEIWKDILEELEQRYQEGKLDPEE...,DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...,IL7Ra,IL7Ra_90,False,193,63
3530,DAEEEIREIVEKLNDPLLREILRLLELAKEKGDPRLEAELYLAFEK...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1605,False,101,65


In [15]:
def compute_esm2_embeddings_Meta(encoder, proj_head, df, tout_dir, bout_dir):
    
    os.makedirs(tout_dir, exist_ok=True)
    os.makedirs(bout_dir, exist_ok=True)
    t_existing = set(f for f in os.listdir(tout_dir) if f.endswith(".npy"))
    b_existing = set(f for f in os.listdir(bout_dir) if f.endswith(".npy"))

    encoder.eval()
    proj_head.eval()

    with torch.no_grad():
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Computing ESM2 embeddings (BOOSTED)"):

            binder_id = row["binder_id"]
            target_id = row["target_id"]

            seq_binder = row["seq_binder"]
            seq_target = row["seq_target"]

            binder_fname = f"{binder_id}.npy"
            target_fname = f"{target_id}.npy"

            # ---- Binder ----
            if binder_fname not in b_existing:
                reps_b = encoder([seq_binder])     # [1, Ltok, Din]
                reps_b = proj_head(reps_b)         # [1, Ltok, Dproj] or [1, Ltok, Din] depending on head
                b_emb = reps_b[0].detach().cpu().numpy()  # [Ltok, D]

                assert b_emb.shape[0] == len(seq_binder) + 2, \
                    f"{binder_id}: {b_emb.shape[0]} vs {len(seq_binder)+2}"

                np.save(os.path.join(bout_dir, binder_fname), b_emb)
                b_existing.add(binder_fname)

            # ---- Target ----
            if target_fname not in t_existing:
                reps_t = encoder([seq_target])
                reps_t = proj_head(reps_t)
                t_emb = reps_t[0].detach().cpu().numpy()

                assert t_emb.shape[0] == len(seq_target) + 2, \
                    f"{target_id}: {t_emb.shape[0]} vs {len(seq_target)+2}"

                np.save(os.path.join(tout_dir, target_fname), t_emb)
                t_existing.add(target_fname)

tout_dir = "/work3/s232958/data/meta_analysis/embeddings_esm2_boosted_targets"
bout_dir = "/work3/s232958/data/meta_analysis/embeddings_esm2_boosted_binders"

compute_esm2_embeddings_Meta(seq_encoder, seq_down, interaction_df_shuffled, tout_dir, bout_dir)

Computing ESM2 embeddings (BOOSTED): 100%|████████████████████████████████████████████████| 3532/3532 [03:05<00:00, 19.00it/s]


In [16]:
class CLIP_PPint_metaanal(Dataset):
    def __init__(
        self,
        dframe,
        paths,
        embedding_dim=1280,
        embedding_pad_value=-5000.0
    ):
        super().__init__()
        self.dframe = dframe.copy()
        self.embedding_dim = int(embedding_dim)
        self.emb_pad = float(embedding_pad_value)
        self.max_blen = self.dframe["seq_binder_len"].max()+2
        self.max_tlen = self.dframe["seq_target_len"].max()+2

        # paths
        self.encoding_bpath, self.encoding_tpath = paths

        # index & storage
        self.dframe.set_index("binder_id", inplace=True)
        self.accessions = self.dframe.index.astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        self.samples = []

        for accession in tqdm(self.accessions, total=len(self.accessions), desc="#Loading ESM2 embeddings"):
            lbl = torch.tensor(int(self.dframe.loc[accession, "binder_label"]))
            parts = accession.split("_") # e.g. accession 7S8T_5_F_7S8T_5_G
            tgt_id = "_".join(parts[:-1])
            bnd_id = accession

            ### --- embeddings (pad to fixed lengths) --- ###
            
            # laod embeddings
            t_emb = np.load(os.path.join(self.encoding_tpath, f"{tgt_id}.npy")) # [Lt, D]
            b_emb = np.load(os.path.join(self.encoding_bpath, f"{bnd_id}.npy")) # [Lb, D]
            
            assert (b_emb.shape[0] == self.dframe.loc[accession].seq_binder_len+2)
            assert (t_emb.shape[0] == self.dframe.loc[accession].seq_target_len+2)

            # quich check whether embedding dimmension is as it suppose to be
            if t_emb.shape[1] != self.embedding_dim or b_emb.shape[1] != self.embedding_dim:
                raise ValueError("Embedding dim mismatch with 'embedding_dim'.")

            # add -5000 to all the padded target rows
            if t_emb.shape[0] < self.max_tlen:
                t_emb = np.concatenate([t_emb, np.full((self.max_tlen - t_emb.shape[0], t_emb.shape[1]), self.emb_pad, dtype=t_emb.dtype)], axis=0)
            else:
                t_emb = t_emb[: self.max_tlen] # no padding was used

            # add -5000 to all the padded binder rows
            if b_emb.shape[0] < self.max_blen:
                b_emb = np.concatenate([b_emb, np.full((self.max_blen - b_emb.shape[0], b_emb.shape[1]), self.emb_pad, dtype=b_emb.dtype)], axis=0)
            else:
                b_emb = b_emb[: self.max_blen] # no padding was used

            self.samples.append((b_emb, t_emb, lbl))

    # ---- Dataset API ----
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        b_arr, t_arr, lbls = self.samples[idx]
        binder_emb, target_emb = torch.from_numpy(b_arr).float(), torch.from_numpy(t_arr).float()
        return binder_emb, target_emb, lbls

    def _get_by_name(self, name):
        # Single item -> return exactly what __getitem__ returns
        if isinstance(name, str):
            return self.__getitem__(self.name_to_row[name])
        
        # Multiple items -> fetch all
        out = [self.__getitem__(self.name_to_row[n]) for n in list(name)]
        b_list, t_list, lbl_list = zip(*out)
    
        # Stack embeddings
        b  = torch.stack([torch.as_tensor(x) for x in b_list],  dim=0)  # [B, ...]
        t  = torch.stack([torch.as_tensor(x) for x in t_list],  dim=0)  # [B, ...]
    
        # Stack labels
        labels = torch.stack(lbl_list)  # [B]
    
        return b, t, labels

temb_path = "/work3/s232958/data/meta_analysis/embeddings_esm2_boosted_targets"
bemb_path = "/work3/s232958/data/meta_analysis/embeddings_esm2_boosted_binders"

validation_Dataset = CLIP_PPint_metaanal(
    # interaction_df_shuffled[:len(Df_test)],
    interaction_df_shuffled,
    paths=[bemb_path, temb_path],
    embedding_dim=512
)

#Loading ESM2 embeddings: 100%|██████████████████████████████████████████████████████████| 3532/3532 [00:14<00:00, 251.27it/s]


In [17]:
accessions_Meta = list(interaction_df_shuffled.binder_id)
emb_b, emb_t, labels = validation_Dataset._get_by_name(accessions_Meta[:5])
labels

tensor([1, 0, 0, 0, 0])

In [18]:
embedding_dimension = 512

def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=embedding_dimension, num_recycles=2):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()
 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding), dim=-1)
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding), dim=-1)
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1)
        
        return logits

    def training_step(self, batch, device):
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        positive_logits = self.forward(embedding_pep, embedding_prot)
        
        # Negative indexes
        rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)         
        
        negative_logits = self(embedding_pep[rows,:,:], 
                          embedding_prot[cols,:,:], 
                          int_prob=0.0)

        # loss of predicting partner using peptide
        positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
 
        # loss of predicting peptide using partner
        negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))
        
        loss = (positive_loss + negative_loss) / 2
 
        # del partner_prediction_loss, peptide_prediction_loss, embedding_pep, embedding_prot
        torch.cuda.empty_cache()
        return loss

    def validation_step_PPint(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, labels = batch
        embedding_pep, embedding_prot = embedding_pep.to(device), embedding_prot.to(device)
        
        with torch.no_grad():

            positive_logits = self(embedding_pep, embedding_prot)
            
            # loss of predicting partner using peptide
            positive_loss = F.binary_cross_entropy_with_logits(positive_logits, torch.ones_like(positive_logits).to(device))
            
            # Negaive indexes
            rows, cols = torch.triu_indices(embedding_prot.size(0), embedding_prot.size(0), offset=1)
            
            negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
    
            negative_loss =  F.binary_cross_entropy_with_logits(negative_logits, torch.zeros_like(negative_logits).to(device))

            loss = (positive_loss + negative_loss) / 2
           
            logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
            logit_matrix[rows, cols] = negative_logits
            logit_matrix[cols, rows] = negative_logits
            
            # Fill diagonal with positive scores
            diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
            logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()

            labels = torch.arange(embedding_prot.size(0)).to(self.device)
            peptide_predictions = logit_matrix.argmax(dim=0)
            peptide_ranks = logit_matrix.argsort(dim=0).diag() + 1
            peptide_mrr = (peptide_ranks).float().pow(-1).mean()
            
            # partner_accuracy = partner_predictions.eq(labels).float().mean()
            peptide_accuracy = peptide_predictions.eq(labels).float().mean()
    
            k = 3
            peptide_topk_accuracy = torch.any((logit_matrix.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logit_matrix.shape[0]
    
            del logit_matrix,positive_logits,negative_logits,embedding_pep,embedding_prot

            return loss, peptide_accuracy, peptide_topk_accuracy
    
    def validation_step_MetaDataset(self, batch, device):
        embedding_binder, embedding_target, labels = batch
        embedding_binder = embedding_binder.to(device)
        embedding_target = embedding_target.to(device)
        labels = labels.to(device).float()
    
        with torch.no_grad():
            logits = self.forward(embedding_binder, embedding_target)
            logits = logits.float()
            loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1))
            return logits, loss

    def calculate_logit_matrix(self,embedding_pep,embedding_prot):
        rows, cols = torch.triu_indices(embedding_pep.size(0), embedding_pep.size(0), offset=1)
        
        positive_logits = self(embedding_pep, embedding_prot)
        negative_logits = self(embedding_pep[rows,:,:], embedding_prot[cols,:,:], int_prob=0.0)
        
        logit_matrix = torch.zeros((embedding_pep.size(0),embedding_pep.size(0)),device=self.device)
        logit_matrix[rows, cols] = negative_logits
        logit_matrix[cols, rows] = negative_logits
        
        diag_indices = torch.arange(embedding_pep.size(0), device=self.device)
        logit_matrix[diag_indices, diag_indices] = positive_logits.squeeze()
        
        return logit_matrix

In [19]:
model = MiniCLIP_w_transformer_crossattn(embed_dimension=embedding_dimension, num_recycles=number_of_recycles).to("cuda")
model

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (linear1): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=512, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=320, bias=True

### Trianing loop

In [20]:
def batch(iterable, n=1):
    """Takes any indexable iterable (e.g., a list of observation IDs) and yields contiguous slices of length n."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper():

    def __init__(self, 
                 model, 
                 train_loader,
                 test_loader,
                 val_loader,
                 test_df,
                 test_dataset,
                 optimizer, 
                 epochs, 
                 runID, 
                 device, 
                 test_indexes_for_auROC = None,
                 auROC_batch_size=10, 
                 model_save_steps=False, 
                 model_save_path=False, 
                 v=False, 
                 wandb_tracker=False):
        
        self.model = model 
        self.training_loader = train_loader
        self.testing_loader = test_loader
        self.validation_loader = val_loader
        self.test_dataset = test_dataset
        self.test_df = test_df
        self.auROC_batch_size = auROC_batch_size
        
        self.EPOCHS = epochs
        self.optimizer = optimizer
        self.device = device
        
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = 1_000_000
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.test_indexes_for_auROC = test_indexes_for_auROC

    def train_one_epoch(self):

        self.model.train() 
        running_loss = 0

        for batch in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):
            
            if batch[0].size(0) == 1: 
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

            del loss, batch
            torch.cuda.empty_cache()
            
        return running_loss / len(self.training_loader)

    def calc_auroc_aupr_on_indexes(self, model, dataset, dataframe, nondimer_indexes, batch_size = 10):

        self.model.eval()
        all_TP_scores, all_FP_scores = [], []
        accessions = [dataframe.loc[index].interface_id for index in nondimer_indexes]  # <-- use dataframe
        batches_local = batch(accessions, n=batch_size)
        
        with torch.no_grad():
            for index_batch in tqdm(batches_local, total=int(len(accessions)/batch_size), desc="Calculating AUC"):

                binder_emb, target_emb, labels = dataset._get_by_name(index_batch)
                binder_emb, target_emb = binder_emb.to(self.device), target_emb.to(self.device)

                # Make sure this matches your model's signature:
                logit_matrix = self.model.calculate_logit_matrix(binder_emb, target_emb)
                
                TP_scores = logit_matrix.diag().detach().cpu().tolist()
                all_TP_scores += TP_scores
                
                # Get FP scores from upper triangle (excluding diagonal)
                n = logit_matrix.size(0)
                rows, cols = torch.triu_indices(n, n, offset=1)
                FP_scores = logit_matrix[rows, cols].detach().cpu().tolist()
                all_FP_scores += FP_scores
            
        all_score_predictions = np.array(all_TP_scores + all_FP_scores)
        all_labels = np.array([1]*len(all_TP_scores) + [0]*len(all_FP_scores))
                
        fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_score_predictions)
        auroc = metrics.roc_auc_score(all_labels, all_score_predictions)
        aupr  = metrics.average_precision_score(all_labels, all_score_predictions)
        
        return auroc, aupr, all_TP_scores, all_FP_scores

    def validate(self):
        
        self.model.eval()
        
        running_loss_Meta = 0.0
        all_logits = []
        all_lbls = []
        used_batches_meta = 0

        # --- MetaDataset validation ---
        with torch.no_grad():
            for batch in tqdm(self.validation_loader, total=len(self.validation_loader)):
                if batch[0].size(0) == 1:
                    continue
                embedding_binder, embedding_target, labels = batch
                logits, loss = self.model.validation_step_MetaDataset(batch, self.device)
                
                running_loss_Meta += loss.item()
                all_logits.append(logits.detach().view(-1).cpu())
                all_lbls.append(labels.detach().view(-1).cpu())
                used_batches_meta += 1
                
            if used_batches_meta > 0:
                val_loss_Meta = running_loss_Meta / used_batches_meta
                all_logits = torch.cat(all_logits).numpy()
                all_lbls   = torch.cat(all_lbls).numpy()
            
                fpr, tpr, thresholds = metrics.roc_curve(all_lbls, all_logits)
                meta_auroc = metrics.roc_auc_score(all_lbls, all_logits)
                meta_aupr  = metrics.average_precision_score(all_lbls, all_logits)

                y_pred = (all_logits >= 0).astype(int)
                y_true = all_lbls.astype(int)
                val_acc_Meta = (y_pred == y_true).mean()
            else:
                val_loss_Meta = float("nan")
                meta_auroc = float("nan")
                meta_aupr = float("nan")
                val_acc_Meta = float("nan")

        # --- PPint validation ---
        running_loss_ValPPint = 0.0
        running_accuracy_ValPPint = 0.0
        running_topk_accuracy_ValPPint = 0.0
        used_batches_ppint = 0

        with torch.no_grad():
            for batch in tqdm(self.testing_loader, total=len(self.testing_loader)):
                if batch[0].size(0) == 1:
                    continue
                loss, partner_accuracy, peptide_topk_accuracy = self.model.validation_step_PPint(batch, self.device)
                running_loss_ValPPint += loss.item()
                running_accuracy_ValPPint += partner_accuracy.item()
                running_topk_accuracy_ValPPint += peptide_topk_accuracy.item()
                used_batches_ppint += 1
                
            if used_batches_ppint > 0:
                val_loss_PPint = running_loss_ValPPint / used_batches_ppint
                val_accuracy_PPint = running_accuracy_ValPPint / used_batches_ppint
                val_topk_accuracy_PPint = running_topk_accuracy_ValPPint / used_batches_ppint
            else:
                val_loss_PPint = float("nan")
                val_accuracy_PPint = float("nan")
                val_topk_accuracy_PPint = float("nan")

        # --- AUROC on specific indexes (optional) ---
        if self.test_indexes_for_auROC is not None:
            non_dimer_auc, non_dimer_aupr, ___, ___ = self.calc_auroc_aupr_on_indexes(
                model=self.model, 
                dataset=self.test_dataset,
                dataframe=self.test_df,
                nondimer_indexes=self.test_indexes_for_auROC,
                batch_size=self.auROC_batch_size
            )
            
            return (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                    non_dimer_auc, non_dimer_aupr,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

        else:
            return (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                    val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr)

    def train_model(self):
        
        torch.cuda.empty_cache()
        
        if self.verbose:
            print(f"Training model {str(self.runID)}")

        # --- initial validation before training
        print("Initial validation before starting training")
        if self.test_indexes_for_auROC is not None:
            (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
             non_dimer_auc, non_dimer_aupr,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
        else:
            (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
             val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            non_dimer_auc, non_dimer_aupr = None, None
                
        if self.verbose: 
            print(f'Before training:')
            print(f'Meta Val-Loss {round(val_loss_Meta,4)}')
            print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
            print(f'Meta AUROC: {round(meta_auroc,4)}')
            print(f'Meta AUPR: {round(meta_aupr,4)}')
            print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
            print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
            if non_dimer_auc is not None:
                print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
        
        if self.wandb_tracker:
            metrics_to_log = {
                "PPint Test-Loss": val_loss_PPint,
                "Meta Val-loss": val_loss_Meta,
                "PPint Accuracy": val_accuracy_PPint,
                "Meta Accuracy": val_acc_Meta,
                "Meta Val-AUROC": meta_auroc,
                "Meta Val-AUPR": meta_aupr,
            }
            if non_dimer_auc is not None:
                metrics_to_log.update({
                    "PPint non-dimer AUROC": non_dimer_auc,
                    "PPint non-dimer AUPR": non_dimer_aupr,
                })
            self.wandb_tracker.log(metrics_to_log, step=0)
        
        # --- training loop
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):
            
            torch.cuda.empty_cache()
            
            train_loss = self.train_one_epoch()
            
            # validation after epoch
            if self.test_indexes_for_auROC is not None:
                (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                 non_dimer_auc, non_dimer_aupr,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
            else:
                (val_loss_PPint, val_accuracy_PPint, val_topk_accuracy_PPint,
                 val_loss_Meta, val_acc_Meta, meta_auroc, meta_aupr) = self.validate()
                non_dimer_auc, non_dimer_aupr = None, None
            
            torch.cuda.empty_cache()
            
            # checkpoint save
            # if self.model_save_steps and epoch % self.model_save_steps == 0:
            #     check_point_folder = os.path.join(self.trained_model_dir, f"{str(self.runID)}_checkpoint_{str(epoch)}")
            #     if self.verbose:
            #         print("Saving model to:", check_point_folder)
            #     os.makedirs(check_point_folder, exist_ok=True)
            #     checkpoint_path = os.path.join(check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")
            #     torch.save({'epoch': epoch, 
            #                 'model_state_dict': self.model.state_dict(),
            #                 'optimizer_state_dict': self.optimizer.state_dict(), 
            #                 'val_loss_PPint': val_loss_PPint,
            #                 'val_loss_Meta': val_loss_Meta},
            #                checkpoint_path)
            
            # console logging
            if self.verbose and epoch % self.print_frequency_loss == 0:
                print(f'EPOCH {epoch}:')
                print(f'Meta Val Loss {round(val_loss_Meta,4)}')
                print(f'Meta Accuracy: {round(val_acc_Meta,4)}')
                print(f'Meta AUROC: {round(meta_auroc,4)}')
                print(f'Meta AUPR: {round(meta_aupr,4)}')
                print(f'PPint Test-Loss: {round(val_loss_PPint,4)}')
                print(f'PPint Accuracy: {round(val_accuracy_PPint,4)}')
                if non_dimer_auc is not None:
                    print(f'PPint non-dimer AUROC: {round(non_dimer_auc,4)}')
                    print(f'PPint non-dimer AUPR: {round(non_dimer_aupr,4)}')
            
            # wandb logging
            if self.wandb_tracker:
                metrics_to_log_epoch = {
                    "PPint Train-loss": train_loss,
                    "PPint Test-Loss": val_loss_PPint,
                    "Meta Val-loss": val_loss_Meta,
                    "PPint Accuracy": val_accuracy_PPint,
                    "Meta Accuracy": val_acc_Meta,
                    "Meta Val-AUROC": meta_auroc,
                    "Meta Val-AUPR": meta_aupr,
                }
                if non_dimer_auc is not None:
                    metrics_to_log_epoch.update({
                        "PPint non-dimer AUROC": non_dimer_auc,
                        "PPint non-dimer AUPR": non_dimer_aupr,
                    })
                self.wandb_tracker.log(metrics_to_log_epoch, step=epoch)

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [21]:
learning_rate = 2e-5
EPOCHS = 12
g = torch.Generator().manual_seed(SEED)
batch_size = 10
optimizer = AdamW(model.parameters(), lr=learning_rate)
accelerator = Accelerator()
device = accelerator.device

train_dataloader = DataLoader(training_Dataset, batch_size=10, shuffle=True, drop_last = True)
test_dataloader = DataLoader(testing_Dataset, batch_size=10, shuffle=False)
val_dataloader = DataLoader(validation_Dataset, batch_size=20, shuffle=False)

# accelerator
model, optimizer, train_dataloader, test_dataloader, val_dataloader = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, val_dataloader)

In [22]:
for i in val_dataloader:
    __, __, lbls = i
    print(lbls.to(device))
    break

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')


In [23]:
# wandb
if use_wandb:
    run = wandb.init(
        project="CSSP_all_Losses",
        name=f"Boosted_combined2.0_CS0.35",
        config={"learning_rate": learning_rate, 
                "batch_size": batch_size, 
                "epochs": EPOCHS,
                "architecture": "MiniCLIP_w_transformer_crossattn", 
                "dataset": 
                "PPint"},
    )
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)
else:
    run = None

# train
training_wrapper = TrainWrapper(
            model=model,
            train_loader=train_dataloader,
            test_loader=test_dataloader,
            val_loader=val_dataloader,
            test_df=Df_test,
            test_dataset=testing_Dataset,
            optimizer=optimizer,
            epochs=EPOCHS,
            runID=runID,
            device=device,
            test_indexes_for_auROC=indices_non_dimers_val,
            auROC_batch_size=10,
            model_save_steps=model_save_steps,
            model_save_path=trained_model_dir,
            v=True,
            wandb_tracker=wandb
)

training_wrapper.train_model() # start training

Training model d8c23469-56b7-4f64-8790-11febf5467d3
Initial validation before starting training


100%|███████████████████████████████████████████████████████████████████████████████████████| 177/177 [00:34<00:00,  5.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:24<00:00,  2.05it/s]
Calculating AUC: 13it [00:05,  2.23it/s]                                                                                      


Before training:
Meta Val-Loss 5.4796
Meta Accuracy: 0.1107
Meta AUROC: 0.481
Meta AUPR: 0.0978
PPint Test-Loss: 3.4744
PPint Accuracy: 0.85
PPint non-dimer AUROC: 0.6957
PPint non-dimer AUPR: 0.4657


Epochs:   0%|                                                                                          | 0/12 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:01<03:37,  1.11s/it][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:02<04:02,  1.24s/it][A
Running through epoch:   2%|█                                                                 | 3/197 [00:03<03:50,  1.19s/it][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:04<04:01,  1.25s/it][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:06<03:51,  1.21s/it][A
Running through epoch:   3%|██                                                               

EPOCH 1:
Meta Val Loss 0.457
Meta Accuracy: 0.8216
Meta AUROC: 0.4757
Meta AUPR: 0.0979
PPint Test-Loss: 0.2036
PPint Accuracy: 0.874
PPint non-dimer AUROC: 0.8073
PPint non-dimer AUPR: 0.5351



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:23,  2.36it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.30it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:25,  2.26it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:24,  2.28it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.29it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.29it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 2:
Meta Val Loss 0.4562
Meta Accuracy: 0.8338
Meta AUROC: 0.4664
Meta AUPR: 0.1011
PPint Test-Loss: 0.1869
PPint Accuracy: 0.892
PPint non-dimer AUROC: 0.8396
PPint non-dimer AUPR: 0.5606



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:23,  2.35it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.30it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:24,  2.30it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:24,  2.28it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:24,  2.27it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:25,  2.25it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 3:
Meta Val Loss 0.4254
Meta Accuracy: 0.8604
Meta AUROC: 0.4828
Meta AUPR: 0.1038
PPint Test-Loss: 0.1659
PPint Accuracy: 0.894
PPint non-dimer AUROC: 0.8424
PPint non-dimer AUPR: 0.5808



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:22,  2.36it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:23,  2.34it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:23,  2.32it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:28,  2.17it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:26,  2.21it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 4:
Meta Val Loss 0.3875
Meta Accuracy: 0.8882
Meta AUROC: 0.5253
Meta AUPR: 0.121
PPint Test-Loss: 0.164
PPint Accuracy: 0.904
PPint non-dimer AUROC: 0.8616
PPint non-dimer AUPR: 0.6042



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:22,  2.38it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:23,  2.33it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:23,  2.32it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.28it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 5:
Meta Val Loss 0.4003
Meta Accuracy: 0.8783
Meta AUROC: 0.5138
Meta AUPR: 0.125
PPint Test-Loss: 0.1667
PPint Accuracy: 0.902
PPint non-dimer AUROC: 0.8677
PPint non-dimer AUPR: 0.6121



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:24,  2.32it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.30it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:24,  2.28it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:24,  2.29it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 6:
Meta Val Loss 0.4953
Meta Accuracy: 0.7902
Meta AUROC: 0.515
Meta AUPR: 0.124
PPint Test-Loss: 0.1801
PPint Accuracy: 0.9
PPint non-dimer AUROC: 0.8576
PPint non-dimer AUPR: 0.5912



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:25,  2.29it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.30it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:24,  2.30it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:24,  2.29it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 7:
Meta Val Loss 0.7151
Meta Accuracy: 0.6082
Meta AUROC: 0.5105
Meta AUPR: 0.1201
PPint Test-Loss: 0.1578
PPint Accuracy: 0.904
PPint non-dimer AUROC: 0.866
PPint non-dimer AUPR: 0.6027



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:23,  2.34it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.29it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:24,  2.30it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:24,  2.29it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.29it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.29it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 8:
Meta Val Loss 0.4198
Meta Accuracy: 0.8749
Meta AUROC: 0.5071
Meta AUPR: 0.1206
PPint Test-Loss: 0.2035
PPint Accuracy: 0.891
PPint non-dimer AUROC: 0.8579
PPint non-dimer AUPR: 0.602



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:22,  2.37it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.32it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.29it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 9:
Meta Val Loss 0.4178
Meta Accuracy: 0.8621
Meta AUROC: 0.5507
Meta AUPR: 0.1324
PPint Test-Loss: 0.1989
PPint Accuracy: 0.9
PPint non-dimer AUROC: 0.8639
PPint non-dimer AUPR: 0.6044



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:24,  2.32it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.30it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:24,  2.30it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:23,  2.30it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:23,  2.28it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 10:
Meta Val Loss 0.4644
Meta Accuracy: 0.8112
Meta AUROC: 0.5384
Meta AUPR: 0.1264
PPint Test-Loss: 0.198
PPint Accuracy: 0.9
PPint non-dimer AUROC: 0.86
PPint non-dimer AUPR: 0.6001



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:23,  2.35it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:23,  2.33it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:23,  2.32it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.31it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:22,  2.31it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 11:
Meta Val Loss 0.4122
Meta Accuracy: 0.8856
Meta AUROC: 0.5824
Meta AUPR: 0.1428
PPint Test-Loss: 0.2085
PPint Accuracy: 0.906
PPint non-dimer AUROC: 0.8626
PPint non-dimer AUPR: 0.5926



Running through epoch:   0%|                                                                          | 0/197 [00:00<?, ?it/s][A
Running through epoch:   1%|▎                                                                 | 1/197 [00:00<01:23,  2.36it/s][A
Running through epoch:   1%|▋                                                                 | 2/197 [00:00<01:24,  2.31it/s][A
Running through epoch:   2%|█                                                                 | 3/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   2%|█▎                                                                | 4/197 [00:01<01:23,  2.31it/s][A
Running through epoch:   3%|█▋                                                                | 5/197 [00:02<01:23,  2.30it/s][A
Running through epoch:   3%|██                                                                | 6/197 [00:02<01:22,  2.31it/s][A
Running through epoch:   4%|██▎                                                          

EPOCH 12:
Meta Val Loss 0.4373
Meta Accuracy: 0.8573
Meta AUROC: 0.556
Meta AUPR: 0.1296
PPint Test-Loss: 0.2038
PPint Accuracy: 0.902
PPint non-dimer AUROC: 0.8717
PPint non-dimer AUPR: 0.6172





0,1
Meta Accuracy,▁▇████▇▅██▇██
Meta Val-AUPR,▁▁▂▂▅▅▅▄▅▆▅█▆
Meta Val-AUROC,▂▂▁▂▅▄▄▄▃▆▅█▆
Meta Val-loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Accuracy,▁▄▆▆█▇▇█▆▇▇█▇
PPint Test-Loss,█▁▁▁▁▁▁▁▁▁▁▁▁
PPint Train-loss,█▄▃▂▂▂▂▁▁▁▁▁
PPint non-dimer AUPR,▁▄▅▆▇█▇▇▇▇▇▇█
PPint non-dimer AUROC,▁▅▇▇██▇█▇████

0,1
Meta Accuracy,0.8573
Meta Val-AUPR,0.12963
Meta Val-AUROC,0.55601
Meta Val-loss,0.43734
PPint Accuracy,0.902
PPint Test-Loss,0.20379
PPint Train-loss,0.05903
PPint non-dimer AUPR,0.61721
PPint non-dimer AUROC,0.87175
