In [None]:
import uuid, sys, os

import pandas as pd
import numpy as np
from tqdm import tqdm
import ast
import math
import random

from sklearn import metrics
from scipy import stats

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
import pytorch_lightning as pl
from torch.optim import AdamW

torch.manual_seed(0)

from accelerate import Accelerator

import matplotlib.pyplot as plt
import seaborn as sns

from Levenshtein import distance as Ldistance

import training_utils.dataset_utils as data_utils
import training_utils.partitioning_utils as pat_utils

import importlib
import training_utils.train_utils as train_utils
importlib.reload(train_utils)

In [5]:
import wandb
wandb.login()

# n_estimators, max_depth = 20, 10
# wandb.init(
#     project="CLIP_PPint",
#     notes="commit message for the run",
#     config={
#         "n_estimators":n_estimators,
#         "max_depth" : max_depth
#     }
# )

True

In [6]:
os.chdir("/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts")
# print(os.getcwd())

print("PyTorch:", torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print("Current location:", os.getcwd())

PyTorch: 2.5.1
Using device: cpu
Current location: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts


In [None]:
# Model parameters
memory_verbose = False
use_wandb = True # Used to track loss in real-time without printing
model_save_steps = 1
train_frac = 1.0
test_frac = 1.0

embedding_dimension = 1152 # 1280 | 960 | 1152
number_of_recycles = 2
padding_value = -5000

batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

In [None]:
## Model Class
### MiniClip 
def gaussian_kernel(x, sigma):
    return np.exp(-x**2 / (2 * sigma**2))

def transform_vector(vector, sigma):

    interacting_indices = np.where(vector == 1)[0]   # positions where vector == 1
    transformed_vector = np.zeros_like(vector, dtype=float)
    
    for i in range(len(vector)):
        if vector[i] == 0:
            distances = np.abs(interacting_indices - i)   # distance to all "1"s
            min_distance = np.min(distances)              # closest "1"
            transformed_vector[i] = gaussian_kernel(min_distance, sigma)
        else:
            transformed_vector[i] = 1.0
    return transformed_vector

def safe_shuffle(n, device):
    shuffled = torch.randperm(n, device=device)
    while torch.any(shuffled == torch.arange(n, device=device)):
        shuffled = torch.randperm(n, device=device)
    return shuffled

def create_key_padding_mask(embeddings, padding_value=-5000, offset=10):
    return (embeddings < (padding_value + offset)).all(dim=-1)

def create_mean_of_non_masked(embeddings, padding_mask):
    # Use masked select and mean to compute the mean of non-masked elements
    # embeddings should be of shape (batch_size, seq_len, features)
    seq_embeddings = []
    for i in range(embeddings.shape[0]): # looping over all batch elements
        non_masked_embeddings = embeddings[i][~padding_mask[i]] # shape [num_real_tokens, features]
        if len(non_masked_embeddings) == 0:
            print("You are masking all positions when creating sequence representation")
            sys.exit(1)
        mean_embedding = non_masked_embeddings.mean(dim=0) # sequence is represented by the single vecotr [1152] [features]
        seq_embeddings.append(mean_embedding)
    return torch.stack(seq_embeddings)

class MiniCLIP_w_transformer_crossattn(pl.LightningModule):

    def __init__(self, padding_value = -5000, embed_dimension=1152, num_recycles=1):

        super().__init__()
        self.num_recycles = num_recycles # how many times you iteratively refine embeddings with self- and cross-attention (ALPHA-Fold-style recycling).
        self.padding_value = padding_value
        self.embed_dimension = embed_dimension

        self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))  # ~CLIP init

        self.transformerencoder =  nn.TransformerEncoderLayer(
            d_model=self.embed_dimension,
            nhead=8,
            dropout=0.1,
            batch_first=True,
            dim_feedforward=self.embed_dimension
            )
 
        self.norm = nn.LayerNorm(self.embed_dimension)  # For residual additions

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dimension,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        self.prot_embedder = nn.Sequential(
            nn.Linear(self.embed_dimension, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
        )
        
    def forward(self, pep_input, prot_input, label=None, pep_int_mask=None, prot_int_mask=None, int_prob=None, mem_save=True): # , pep_tokens, prot_tokens

        pep_mask = create_key_padding_mask(embeddings=pep_input, padding_value=self.padding_value)
        prot_mask = create_key_padding_mask(embeddings=prot_input, padding_value=self.padding_value)
 
        # Initialize residual states
        pep_emb = pep_input.clone()
        prot_emb = prot_input.clone()

 
        for _ in range(self.num_recycles):

            # Transformer encoding with residual
            pep_trans = self.transformerencoder(self.norm(pep_emb), src_key_padding_mask=pep_mask)
            prot_trans = self.transformerencoder(self.norm(prot_emb), src_key_padding_mask=prot_mask)

            # Cross-attention with residual
            pep_cross, _ = self.cross_attn(query=self.norm(pep_trans), key=self.norm(prot_trans), value=self.norm(prot_trans), key_padding_mask=prot_mask)
            prot_cross, _ = self.cross_attn(query=self.norm(prot_trans), key=self.norm(pep_trans), value=self.norm(pep_trans), key_padding_mask=pep_mask)
            
            # Additive update with residual connection
            pep_emb = pep_emb + pep_cross  
            prot_emb = prot_emb + prot_cross

        pep_seq_coding = create_mean_of_non_masked(pep_emb, pep_mask)
        prot_seq_coding = create_mean_of_non_masked(prot_emb, prot_mask)
        
        # Use self-attention outputs for embeddings
        pep_seq_coding = F.normalize(self.prot_embedder(pep_seq_coding))
        prot_seq_coding = F.normalize(self.prot_embedder(prot_seq_coding))
 
        if mem_save:
            torch.cuda.empty_cache()
        
        scale = torch.exp(self.logit_scale).clamp(max=100.0)
        logits = scale * (pep_seq_coding * prot_seq_coding).sum(dim=-1) # Dot-Product for comparison
        
        return logits
    

    def training_step(self, batch, device):
        # Passing the sequences to the models
 
        embedding_pep = batch[0]
        embedding_prot = batch[1]
        binder_label = batch[2]
 
        embedding_pep = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        binder_label = binder_label.to(device)

        logits = self.forward(embedding_pep, embedding_prot)
        loss = F.binary_cross_entropy_with_logits(logits, binder_label)
 
        torch.cuda.empty_cache()
        return loss
    
    def validation_step(self, batch, device):
        # Predict on random batches of training batch size
        embedding_pep, embedding_prot, binder_label = batch
 
        # Move to device
        embedding_pep  = embedding_pep.to(device)
        embedding_prot = embedding_prot.to(device)
        # Ensure float dtype and shape matches logits later
        binder_label   = binder_label.to(device).float()

        with torch.no_grad():
            # Forward -> logits per pair [B]
            logits = self.forward(embedding_pep, embedding_prot)   # shape [B]

            # BCE loss expects float labels with same shape
            binder_label = binder_label.view_as(logits)
            loss = F.binary_cross_entropy_with_logits(logits, binder_label)

            # Simple metrics for binary classification
            probs = torch.sigmoid(logits)                          # [B]
            preds = (probs >= 0.5).float()                         # [B]
            acc   = (preds == binder_label).float().mean()         # scalar

            # Optionally return detached CPU scalars
            return loss.detach(), acc.detach()

In [9]:
## Output path
trained_model_dir = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts"

## Embeddings paths
binders_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings"
targets_embeddings = "/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings"

# ## Training variables
runID = uuid.uuid4()

def print_mem_consumption():
    # 1. Total memory available on the GPU (device 0)
    t = torch.cuda.get_device_properties(0).total_memory
    # 2. How much memory PyTorch has *reserved* from CUDA
    r = torch.cuda.memory_reserved(0)
    # 3. How much of that reserved memory is actually *used* by tensors
    a = torch.cuda.memory_allocated(0)
    # 4. Reserved but not currently allocated (so “free inside PyTorch’s pool”)
    f = r - a

    print("Total memory: ", t/1e9)      # total VRAM in GB
    print("Reserved memory: ", r/1e9)   # PyTorch’s reserved pool in GB
    print("Allocated memory: ", a//1e9) # actually in use (integer division)
    print("Free memory: ", f/1e9)       # slack in the reserved pool in GB

#### Loading data frame

In [10]:
### Loading the dataset
interaction_df = pd.read_csv("../data/meta_analysis/interaction_df_metaanal.csv", index_col = 0).drop(columns = ["binder_id", "target_id"]).rename(columns={
    "A_seq" : "binder_seq",
    "B_seq" : "target_seq"
})

all_targets = interaction_df.target_id_mod.unique()
binder_nonbinder = interaction_df.binder.value_counts()
target_binder_nonbinder_Dict = dict(interaction_df.groupby("target_id_mod")["binder"].value_counts())
sorted_items = sorted(target_binder_nonbinder_Dict.items(), key=lambda kv: kv[1], reverse=True)

# %%
# Annotating each observation with a weight corresponding to whether it is considered a binder or not
N_bins = len(interaction_df["binder"].value_counts())
pr_class_uniform_weight = 1 / N_bins
pr_class_weight_informed_with_size_of_bins = pr_class_uniform_weight  / interaction_df["binder"].value_counts()
pr_class_weight_informed_with_size_of_bins = pr_class_weight_informed_with_size_of_bins.to_dict()
interaction_df["observation_weight"] = interaction_df.binder.apply(lambda x: pr_class_weight_informed_with_size_of_bins[x])
weights_Dict = dict(zip(interaction_df["target_binder_ID"], interaction_df["observation_weight"]))
interaction_df

Unnamed: 0,binder_chain,target_chains,binder,binder_seq,target_seq,target_id_mod,target_binder_ID,observation_weight
0,A,"[""B""]",False,LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK,ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...,VirB8,VirB8_1,0.000159
1,A,"[""B""]",False,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_1,0.000159
2,A,"[""B""]",False,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_2,0.000159
3,A,"[""B""]",False,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_3,0.000159
4,A,"[""B""]",False,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...,FGFR2,FGFR2_4,0.000159
...,...,...,...,...,...,...,...,...
3527,A,"[""B""]",False,DLRKYAAELVDRLAEKYNLDSDQYNALVRLASELVWQGKSKEEIEK...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_62,0.000159
3528,A,"[""B""]",False,SKEEIKKEAEELIEELKKKGYNLPLRILEFALKEIEETNSEKYYEQ...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_63,0.000159
3529,A,"[""B""]",False,SPEYKKFLELIKEAEAARKAGDLDKAKELLEKALELAKKMKAKSLI...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_64,0.000159
3530,A,"[""B""]",False,DPLLAYKLLKLSQKALEKAYAEDRERAEELLEEAEAALRSLGDEAG...,ELCDDDPPEIPHATFKAMAYKEGTMLNCECKRGFRRIKSGSLYMLC...,IL2Ra,IL2Ra_65,0.000159


#### Setting aside some targets for validation

In [11]:
targets = 0
targets_all = 0
for i in target_binder_nonbinder_Dict.keys():
    targets_all += target_binder_nonbinder_Dict[i]
    if i[0] in ["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"]:
        targets += target_binder_nonbinder_Dict[i]
print(targets_all-targets, targets)
print(targets / (targets_all-targets))

print('Targets for validation : ["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"]')

3029 503
0.16606140640475403
Targets for validation : ["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"]


#### Creating separate targets/ binder dataframes

In [12]:
# Targets df
target_df = interaction_df[["target_id_mod","target_seq"]].rename(columns={"target_seq":"sequence", "target_id_mod" : "ID"})
target_df["seq_len"] = target_df["sequence"].apply(len)
target_df = target_df.drop_duplicates(subset=["ID","sequence"])
target_df = target_df.set_index("ID")

# Binders df
binder_df = interaction_df[["target_binder_ID","binder_seq", "binder"]].rename(columns={"binder_seq":"sequence", "target_binder_ID" : "ID", "binder" : "label"})
binder_df["seq_len"] = binder_df["sequence"].apply(len)
binder_df = binder_df.set_index("ID")
binder_df["observation_weight"] = binder_df.index.map(weights_Dict)

# Interaction Dict
interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))

# Printing dataframes for review
print(target_df.head(3), "\n")
print(binder_df.head(3))

                                                sequence  seq_len
ID                                                               
VirB8  ANPYISVANIMLQNYVKQREKYNYDTLKEQFTFIKNASTSIVYMQF...      138
FGFR2  RSPHRPILQAGLPANASTVVGGDVEFVCKVYSDAQPHIQWIKHVPY...      101
IL7Ra  DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKC...      193 

                                                  sequence  label  seq_len  \
ID                                                                           
VirB8_1           LDFIVFAGPEKAIKFYKEMAKRNLEVKIWIDGDWAVVQVK  False       40   
FGFR2_1  SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...  False       62   
FGFR2_2  DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...  False       61   

         observation_weight  
ID                           
VirB8_1            0.000159  
FGFR2_1            0.000159  
FGFR2_2            0.000159  


#### Creating separate targets/ binder dataframes (for validation/ training)

In [13]:
target_df_val = target_df.loc[["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"]]
target_df_train = target_df.loc[target_df.index.difference(["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"])]
# print(len(target_df_val), len(target_df_train))

idx = binder_df.index.astype(str)
mask = pd.Series(idx).str.startswith(tuple(["IL7Ra", "sntx_2", "sntx", "VirB8","Mdm2", "IL10Ra", "IL2Ra"])).to_numpy()

binder_df_val   = binder_df[mask]    # IDs whose index starts with any of the names
binder_df_train = binder_df[~mask]   # everything else
print(f"Number of instances for training : {len(binder_df_train)}, \nNumber of instances for validation : {len(binder_df_val)}")

Number of instances for training : 3029, 
Number of instances for validation : 503


In [14]:
binder_df = binder_df.reset_index()
target_df = target_df.reset_index()

binder_df_val = binder_df_val.reset_index()
binder_df_train = binder_df_train.reset_index()
target_df_val = target_df_val.reset_index()
target_df_train = target_df_train.reset_index()

binder_df_train

Unnamed: 0,ID,sequence,label,seq_len,observation_weight
0,FGFR2_1,SEQDETMHRIVRSVIQHAYKHNDEMAEYFAQNAAEIYKEQNKSEEA...,False,62,0.000159
1,FGFR2_2,DYKQLKKHATKLLELAKKDPSSKRDLLRTAASYANKVLFEDSDPRA...,False,61,0.000159
2,FGFR2_3,DEKEELERRANRVAFLAIQIQNEEYHRILAELYVQFMKAAENNDTE...,False,64,0.000159
3,FGFR2_4,PDNKEKLMSIAVQLILRINEAARSEEQWRYANRAAFAAVEASSGSD...,False,64,0.000159
4,FGFR2_5,DDKEHLTKVAREAAKELNDPRAEEAVKIWEHNIDRFSHAAQLAQSV...,False,63,0.000159
...,...,...,...,...,...
3024,EGFR_2_298,CKLVSATVTVDSSTGQAQVVAKNECLGVQTFTAATAAEALAKMQAAIAA,False,49,0.000159
3025,EGFR_2_299,EQEKADVINEYREKKAFAFF,False,20,0.000159
3026,EGFR_2_300,SEETKAKAEELKTKALEAKYKAAELLAKGDELYKEAPKSKEAADKA...,False,71,0.000159
3027,EGFR_2_301,SEETKAKAEELQTKALEAKYKAAELLAKGDELYKEAPKSKEAADKA...,False,71,0.000159


In [11]:
print(len(binder_df_train[binder_df_train.ID.str.startswith("IL7Ra")]))
print(len(binder_df_val[binder_df_val.ID.str.startswith("IL7Ra")]))

0
171


- `CLIP_meta_analysis_dataset` for defining binders/ target validation/ training datasets seperately
- `PairDataset` for pairing targets with binders for training and validation seperately

In [15]:
class CLIP_meta_analysis_dataset(Dataset):

    def __init__(self, sequence_df, esm_encoding_paths, embedding_dim=1152, padding_value=-5000):

        super(CLIP_meta_analysis_dataset, self).__init__()

        self.sigma = 1
        self.sequence_df = sequence_df
        self.max_length = sequence_df["seq_len"].max()
        # self.sequence_df["index_num"] = np.arange(len(self.sequence_df))
        # print(self.sequence_df)
        self.esm_encoding_paths = esm_encoding_paths
        num_samples = len(self.sequence_df)
        self.x = torch.full((num_samples, self.max_length, embedding_dim), padding_value, dtype=torch.float32)

        # inside CLIP_meta_analysis_dataset.__init__
        self.accessions = self.sequence_df["ID"].astype(str).tolist()
        self.name_to_row = {name: i for i, name in enumerate(self.accessions)}
        
        # Load embeddings into the pre-allocated tensor
        all_acc_list = self.sequence_df.ID.astype(str).tolist()
        all_acc_loaded_list = []

        iterator = tqdm(all_acc_list, position=0, leave=True, total=num_samples, desc="# Reading in ESM-embeddings from folder")
        for i, accession in enumerate(iterator):
            npy_path = os.path.join(esm_encoding_paths, f"{accession}.npy")
            try:
                embd = np.load(npy_path)[0]
                length_to_pad = self.max_length - len(embd)
                if length_to_pad > 0:
                    zero_padding = np.ones((length_to_pad, embd.shape[1])) * padding_value
                    padded_array = np.concatenate((embd, zero_padding), axis=0)
                else:
                    padded_array = embd[:self.max_length] 
                self.x[i] = torch.tensor(padded_array, dtype=torch.float32)
                all_acc_loaded_list.append(accession)
            except FileNotFoundError as e:
                raise FileNotFoundError(f"Embedding file {accession}.npy not found.")
            
        missing = sorted(set(all_acc_list) - set(all_acc_loaded_list))
        if missing:
            raise FileNotFoundError(
                f"Missing {len(missing)} embedding files in '{esm_encoding_paths}'. "
                f"Examples: {missing}")
        
                
    def __len__(self):
        return int(self.x.shape[0])

    def __getitem__(self, idx):
        return self.x[idx]
    
    # add a helper:
    def get_by_name(self, name: str):
        return self.x[self.name_to_row[name]]
    
class PairDataset(Dataset):
    def __init__(self, targets_dataset, binders_dataset, binder_df):
        self.targets = targets_dataset    # expects string-key access
        self.binders = binders_dataset    # expects string-key access
        self.binder_df = binder_df

    def __len__(self):
        # if you won't use DataLoader's default sampling, this can be any valid length;
        # keeping it tied to binders is reasonable if it implements __len__.
        return len(self.binders)

    def __getitem__(self, idx):
        # idx is integer position from DataLoader
        row = self.binder_df.iloc[idx]
        bname = row["ID"]                 # binder name string
        parts = bname.split("_")

        if parts[0] == "SARS":
            tname = "SARS_CoV2_RBD"
        elif len(parts) == 3:
            tname = f"{parts[0]}_{parts[1]}"
        elif len(parts) == 2:
            tname = parts[0]
        else:
            tname = parts[0]

        binder_emb = self.binders.get_by_name(bname)          # tensor [Lb, D]
        target_emb = self.targets.get_by_name(tname)    # tensor [Lt, D]
        label = float(row["label"])              # 0.0 or 1.0
        label = torch.tensor(label, dtype=torch.float32)

        return binder_emb, target_emb, label
    
    def _get_observation_weights(self):
        return self.binder_df["observation_weight"].tolist()
    
targets_dataset_val = CLIP_meta_analysis_dataset(target_df_val, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings", embedding_dim=1152)
binders_dataset_val = CLIP_meta_analysis_dataset(binder_df_val, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings", embedding_dim=1152)

targets_dataset_train = CLIP_meta_analysis_dataset(target_df_train, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/targets_embeddings", embedding_dim=1152)
binders_dataset_train = CLIP_meta_analysis_dataset(binder_df_train, esm_encoding_paths="/zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/data/meta_analysis/binders_embeddings", embedding_dim=1152)

training_dataset = PairDataset(targets_dataset_train, binders_dataset_train, binder_df_train)
validation_dataset = PairDataset(targets_dataset_val, binders_dataset_val, binder_df_val)

# Reading in ESM-embeddings from folder: 100%|██████████| 7/7 [00:00<00:00, 498.29it/s]


# Reading in ESM-embeddings from folder: 100%|██████████| 503/503 [00:00<00:00, 855.61it/s]
# Reading in ESM-embeddings from folder: 100%|██████████| 9/9 [00:00<00:00, 208.46it/s]
# Reading in ESM-embeddings from folder: 100%|██████████| 3029/3029 [00:03<00:00, 778.01it/s]


`validation_dataset` / `training_dataset` store embedding of binder, embedding of target and label (binder - 1, non-binder - 0)

In [16]:
binder_emb, target_emb, label = validation_dataset[45]

print(f"Label assigned (binder - 1, non-binder - 0) : {label}")
print(f"Shape of the target embedding : {target_emb.shape}")
print(f"Shape of the target embedding : {binder_emb.shape}")

print(f"Size of the validation dataset : {len(validation_dataset)}")
print(f"Size of training dataset : {len(training_dataset)}")

Label assigned (binder - 1, non-binder - 0) : 1.0
Shape of the target embedding : torch.Size([207, 1152])
Shape of the target embedding : torch.Size([200, 1152])
Size of the validation dataset : 503
Size of training dataset : 3029


### Loading pretrained model for finetuning

In [17]:
path = '../PPI_PLM/models/CLIP_no_structural_information/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_4/a1d0549b-3f90-4ce2-b795-7bca2276cb07_checkpoint_epoch_4.pth'
checkpoint = torch.load(path, weights_only=False, map_location=torch.device('cpu'))
# print(list(checkpoint["model_state_dict"]))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MiniCLIP_w_transformer_crossattn()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
# model.train()

MiniCLIP_w_transformer_crossattn(
  (transformerencoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
    )
    (linear1): Linear(in_features=1152, out_features=1152, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1152, out_features=1152, bias=True)
    (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
  (cross_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
  )
  (prot_embedder): Sequential(
    (0): Linear(in_features=1152, out_features=640, bias=True)
    (1): ReLU()
    (2): Linear(in_features=640, out_features=32

### Loading training and vlaidation datasets (DataLoaders)

In [18]:
train_loader = DataLoader(training_dataset, batch_size=20)
val_loader = DataLoader(validation_dataset, batch_size=20)

# interaction_Dict = dict(enumerate(zip(interaction_df["target_id_mod"], interaction_df["target_binder_ID"]), start=1))

# ### Creating the DataLoaders
# train_weights = training_dataset._get_observation_weights()
# val_weights = validation_dataset._get_observation_weights()

batch = next(iter(val_loader))
print(f"Shape of the binders embeddings : {batch[0].shape}")
print(f"Shape of the targets embeddings : {batch[1].shape}")
print(f"Labels (0 - non-binder, 1 - binder) : {batch[2]}")

with torch.no_grad():
    loss = model.training_step(batch,device)
    print(loss)
    loss = model.validation_step(batch,device)
    print(loss)

Shape of the binders embeddings : torch.Size([20, 200, 1152])
Shape of the targets embeddings : torch.Size([20, 207, 1152])
Labels (0 - non-binder, 1 - binder) : tensor([0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])
tensor(2.8326)
(tensor(2.8346), tensor(0.1000))


In [19]:
# import gc, torch
# del obj  # any large temps you created in the cell
# gc.collect()
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()
#     torch.cuda.ipc_collect()

# print_mem_consumption()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

torch.cuda.empty_cache()
# print_mem_consumption()

In [26]:
__, __, label = validation_dataset[4]
label

items = [validation_dataset[i] for i in range(0, 17)]
binders = [binder_emb for (binder_emb, _, _) in items]
targets = [target_emb for (_, target_emb, _) in items]
labels  = torch.tensor([float(label) for (_, _, label) in items], dtype=torch.float32).detach().cpu().numpy()

with torch.no_grad():
    model.eval()
    scores = model.forward(binders, targets).detach().cpu().numpy()

pos_scores, neg_scores = [], []
if np.any(y == 1.0):
    pos_scores.extend(scores[y == 1.0].tolist())
if np.any(y == 0.0):
    neg_scores.extend(scores[y == 0.0].tolist())

TypeError: '<' not supported between instances of 'list' and 'int'

#### Training loop

In [None]:
def batch(iterable, n=18):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class TrainWrapper_MetaAnal():

    def __init__(self, model, training_loader, validation_loader, test_dataset, 
                 optimizer, EPOCHS, runID, device, test_indexes_for_auROC=None,
                 auROC_batch_size=18, model_save_steps=False, model_save_path=False, 
                 v=False, wandb_tracker=False):
        
        self.model = model 
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.EPOCHS = EPOCHS
        self.wandb_tracker = wandb_tracker
        self.model_save_steps = model_save_steps
        self.verbose = v
        self.best_vloss = float('inf')
        self.optimizer = optimizer
        self.runID = runID
        self.trained_model_dir = model_save_path
        self.print_frequency_loss = 1
        self.device = device
        self.test_indexes_for_auROC = test_indexes_for_auROC
        self.auROC_batch_size = auROC_batch_size
        self.test_dataset = test_dataset

    def train_one_epoch(self):

        self.model.train()
        running_loss = 0.0

        for batch_data in tqdm(self.training_loader, total=len(self.training_loader), desc="Running through epoch"):

            if batch_data[0].size(0) == 1:
                continue
            
            self.optimizer.zero_grad()
            loss = self.model.training_step(batch_data, self.device)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()

        return running_loss / max(1, len(self.training_loader))

    def calc_auroc_aupr_on_indexes(self, model, dataset, obs_indexes, batch_size=18, pad_value=-5000.0):

        model.eval()
        pos_scores, neg_scores = [], []
        batched_data = batch(obs_indexes, n=batch_size)

        with torch.no_grad():
            for index_batch in tqdm(batched_data, total= len(obs_indexes)/batch_size, desc="Calculating AUC"):

                # Gather items: dataset[i] -> (binder_emb [L,D], target_emb [L,D], label)
                items = [dataset[i] for i in index_batch]
                binders = [binder_emb for (binder_emb, _, _) in items]
                targets = [target_emb for (_, target_emb, _) in items]
                labels  = torch.tensor([float(label) for (_, _, label) in items], dtype=torch.float32, device=self.device).detach().cpu().numpy()

                # Forward: logits per pair [B]
                logits = model.forward(binders, targets).detach().cpu().numpy()

                all_scores.extend(logits.tolist())
                all_labels.extend(labels.tolist())

        all_scores = np.array(all_scores, dtype=np.float64)
        all_labels = np.array(all_labels, dtype=np.int64)

        fpr, tpr, _ = metrics.roc_curve(all_labels, all_scores)
        auroc = metrics.roc_auc_score(all_labels, all_scores)
        aupr  = metrics.average_precision_score(all_labels, all_scores)

        return auroc, aupr, fpr, tpr

    
    def validate(self, dataloader, indexes_for_auc=False, auROC_dataset=False):

        self.model.eval()
        running_loss, running_acc = 0.0, 0.0

        with torch.no_grad():
            for batch_data in tqdm(dataloader, total=len(dataloader), desc="Validation"):

                if batch_data[0].size(0) == 1:
                    continue

                loss, acc = self.model.validation_step(batch_data, self.device)  # your updated val step
                running_loss += loss.item()
                running_acc  += acc.item()

        val_loss = running_loss / max(1, len(dataloader))
        val_acc  = running_acc  / max(1, len(dataloader))

        return val_loss, val_acc

    def train_model(self):

        if self.verbose:
            print(f"Training model {str(self.runID)}")
        
        # Pre-training snapshot
        val_loss, val_accuracy  = self.validate(dataloader=self.validation_loader, indexes_for_auc=self.test_indexes_for_auROC, auROC_dataset=self.test_dataset)

        if self.verbose:
            print(f'Before training - Validation Loss {round(val_loss,4)} | Validation Accuracy: {round(val_accuracy,4)}')

        if self.test_indexes_for_auROC is not None:
            auroc, aupr, _, _ = self.calc_auroc_aupr_on_indexes(self.model, self.test_dataset, self.test_indexes_for_auROC, batch_size=self.auROC_batch_size)

        if self.wandb_tracker:
            log_items = {
                    "Training Loss": train_loss,
                    "Validation Loss": val_loss,
                    "Validation Accuracy": val_accuracy
                }
        if auroc is not None:
            log_items["Val AUROC"] = auroc
            log_items["Val AUPR"]  = aupr
        self.wandb_tracker.log(log_items)
            
        for epoch in tqdm(range(1, self.EPOCHS + 1), total=self.EPOCHS, desc="Epochs"):

            train_loss = self.train_one_epoch()
            val_loss, val_accuracy = self.validate(dataloader=self.validation_loader, indexes_for_auc=self.test_indexes_for_auROC, auROC_dataset=self.test_dataset)

            if self.test_indexes_for_auROC is not None:
                auroc, aupr, _, _ = self.calc_auroc_aupr_on_indexes(self.model, self.test_dataset, self.test_indexes_for_auROC, batch_size=self.auROC_batch_size)

            if self.model_save_steps and (epoch % self.model_save_steps == 0):

                check_point_folder = os.path.join(self.trained_model_dir, f"{str(self.runID)}_checkpoint_{str(epoch)}")
                
                if self.verbose:
                    print("Saving model to:", check_point_folder)

                os.makedirs(check_point_folder, exist_ok=True)
                checkpoint_path = os.path.join(check_point_folder, f"{str(self.runID)}_checkpoint_epoch_{str(epoch)}.pth")

                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_loss': val_loss
                    }, 
                    checkpoint_path)

            if self.verbose and (epoch % self.print_frequency_loss == 0):
                print(f'EPOCH {epoch} - Training Loss {round(train_loss,4)} | Validation Loss {round(val_loss,4)} | Validation Accuracy {round(val_accuracy,4)}')

            if self.wandb_tracker:
                log_items = {
                    "Training Loss": train_loss,
                    "Validation Loss": val_loss,
                    "Validation Accuracy": val_accuracy,
                    }
                
            if self.test_indexes_for_auROC is not None:
                log_items["Val AUROC"] = auroc
                log_items["Val AUPR"]  = aupr

        if self.wandb_tracker:
            self.wandb_tracker.finish()

In [None]:
batch_size = 20
learning_rate = 2e-5
EPOCHS = 15

if use_wandb:
    import wandb
    wandb.login(key = "f8a6d759fe657b095d56bddbdb4d586dfaebd468")
    wandb.init(
        project="CLIP_PPint",
        name=str(runID),
        config={
            "learning_rate": learning_rate,
            "architecture": "MiniCLIP_w_transformer_crossattn", # model.__class__.__name__
            "batch_size": batch_size,
            "dataset": "Meta analysis dataset",
            "training_procedure": "new_binary_cross",
        },
    )
else:
    print("WandB Tracking not used")
    wandb = None

# --- Optimizer ---
optimizer = AdamW(model.parameters(), lr=learning_rate)

# --- Accelerator ---
accelerator = Accelerator()  # device picked automatically (CPU if no GPU)
device = accelerator.device

# Prepare objects for distributed/mixed precision/device
model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader)

# If using wandb, watch the UNWRAPPED model (Accelerate wraps the module)
if wandb is not None:
    wandb.watch(accelerator.unwrap_model(model), log="all", log_freq=100)

# --- Train wrapper ---

training_wrapper = TrainWrapper_MetaAnal(
                    model=model,
                    training_loader=train_loader,
                    validation_loader=val_loader,
                    test_dataset=validation_dataset,
                    optimizer=optimizer,
                    EPOCHS=EPOCHS,
                    runID=runID,
                    device=device,                      # pass accelerator.device
                    test_indexes_for_auROC = binder_df_val.index.to_list(),
                    model_save_steps=model_save_steps,
                    model_save_path=trained_model_dir,
                    v=True,
                    wandb_tracker=wandb                 # pass None when not using wandb
)

training_wrapper.train_model()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /zhome/c9/0/203261/.netrc
[34m[1mwandb[0m: Currently logged in as: [33ms232958[0m ([33ms232958-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training model 9897baa2-00d5-439e-9518-929c71219cc2


Validation: 100%|█████████████████████████████████████████████████████████████████████████████| 26/26 [00:01<00:00, 15.80it/s]


Before training - Val loss 1.9654 | Acc: 0.2731


Epochs:   0%|                                                                                          | 0/15 [00:00<?, ?it/s]
Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:54,  2.80it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:43,  3.47it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:39,  3.76it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:01<00:37,  3.94it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:36,  4.03it/s][A
Running through epoch:   4%|██▌                                                              

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_1


Epochs:   7%|█████▍                                                                            | 1/15 [00:37<08:49, 37.84s/it]

EPOCH 1 - Train loss 0.3214 | Val loss 0.7833 | Val acc 0.65



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:53,  2.85it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:42,  3.55it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:38,  3.86it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:01<00:36,  4.03it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:35,  4.09it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:35,  4.16it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_2


Epochs:  13%|██████████▉                                                                       | 2/15 [01:16<08:17, 38.28s/it]

EPOCH 2 - Train loss 0.2862 | Val loss 0.7222 | Val acc 0.6462



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:48,  3.13it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:39,  3.75it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:37,  3.99it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:01<00:36,  4.10it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:32,  4.47it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.38it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_3


Epochs:  20%|████████████████▍                                                                 | 3/15 [01:54<07:39, 38.28s/it]

EPOCH 3 - Train loss 0.2748 | Val loss 0.7203 | Val acc 0.6577



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:38,  3.87it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:36,  4.11it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.22it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.28it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.32it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.39it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_4


Epochs:  27%|█████████████████████▊                                                            | 4/15 [02:32<07:00, 38.23s/it]

EPOCH 4 - Train loss 0.2625 | Val loss 0.6912 | Val acc 0.6731



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:40,  3.74it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:38,  3.89it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:36,  4.05it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:35,  4.17it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.22it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:34,  4.25it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_5


Epochs:  33%|███████████████████████████▎                                                      | 5/15 [03:11<06:23, 38.31s/it]

EPOCH 5 - Train loss 0.2484 | Val loss 0.6769 | Val acc 0.7077



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:38,  3.89it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:36,  4.10it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.17it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.23it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.24it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:34,  4.26it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_6


Epochs:  40%|████████████████████████████████▊                                                 | 6/15 [03:49<05:43, 38.14s/it]

EPOCH 6 - Train loss 0.2331 | Val loss 0.6863 | Val acc 0.7385



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:39,  3.82it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:36,  4.13it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.25it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.31it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:33,  4.34it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.36it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_7


Epochs:  47%|██████████████████████████████████████▎                                           | 7/15 [04:26<05:03, 37.91s/it]

EPOCH 7 - Train loss 0.2124 | Val loss 0.7291 | Val acc 0.7135



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:37,  3.98it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:35,  4.24it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:34,  4.29it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.31it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:33,  4.34it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.36it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_8


Epochs:  53%|███████████████████████████████████████████▋                                      | 8/15 [05:04<04:24, 37.85s/it]

EPOCH 8 - Train loss 0.1862 | Val loss 0.7773 | Val acc 0.7019



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:37,  4.00it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:35,  4.20it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.25it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.29it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.32it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.34it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_9


Epochs:  60%|█████████████████████████████████████████████████▏                                | 9/15 [05:41<03:46, 37.72s/it]

EPOCH 9 - Train loss 0.1527 | Val loss 0.8217 | Val acc 0.7058



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:38,  3.95it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:35,  4.17it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.25it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.29it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.32it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.31it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_10


Epochs:  67%|██████████████████████████████████████████████████████                           | 10/15 [06:19<03:08, 37.80s/it]

EPOCH 10 - Train loss 0.1102 | Val loss 0.8192 | Val acc 0.7173



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:37,  3.99it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:35,  4.23it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:34,  4.28it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.28it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.32it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.34it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_11


Epochs:  73%|███████████████████████████████████████████████████████████▍                     | 11/15 [06:57<02:31, 37.87s/it]

EPOCH 11 - Train loss 0.0782 | Val loss 0.811 | Val acc 0.7692



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:38,  3.92it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:36,  4.17it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.24it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.28it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:33,  4.32it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.33it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_12


Epochs:  80%|████████████████████████████████████████████████████████████████▊                | 12/15 [07:36<01:54, 38.00s/it]

EPOCH 12 - Train loss 0.0972 | Val loss 0.8944 | Val acc 0.7385



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:39,  3.81it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:36,  4.07it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.17it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.26it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:34,  4.24it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:34,  4.23it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_13


Epochs:  87%|██████████████████████████████████████████████████████████████████████▏          | 13/15 [08:21<01:20, 40.41s/it]

EPOCH 13 - Train loss 0.0638 | Val loss 0.8844 | Val acc 0.7635



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:38,  3.89it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:35,  4.21it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:34,  4.28it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:34,  4.34it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:33,  4.35it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:33,  4.39it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_14


Epochs:  93%|███████████████████████████████████████████████████████████████████████████▌     | 14/15 [09:17<00:44, 44.95s/it]

EPOCH 14 - Train loss 0.0487 | Val loss 0.9408 | Val acc 0.7731



Running through epoch:   0%|                                                                          | 0/152 [00:00<?, ?it/s][A
Running through epoch:   1%|▍                                                                 | 1/152 [00:00<00:38,  3.94it/s][A
Running through epoch:   1%|▊                                                                 | 2/152 [00:00<00:36,  4.08it/s][A
Running through epoch:   2%|█▎                                                                | 3/152 [00:00<00:35,  4.15it/s][A
Running through epoch:   3%|█▋                                                                | 4/152 [00:00<00:35,  4.17it/s][A
Running through epoch:   3%|██▏                                                               | 5/152 [00:01<00:35,  4.16it/s][A
Running through epoch:   4%|██▌                                                               | 6/152 [00:01<00:34,  4.18it/s][A
Running through epoch:   5%|███                                                          

Saving model to: /zhome/c9/0/203261/DBL046_PP_osaul/DBL046_PP_osaul/tmp/ona_drafts/9897baa2-00d5-439e-9518-929c71219cc2_checkpoint_15


Epochs: 100%|█████████████████████████████████████████████████████████████████████████████████| 15/15 [10:33<00:00, 42.23s/it]

EPOCH 15 - Train loss 0.0291 | Val loss 0.9404 | Val acc 0.7558





0,1
Train-loss,█▇▇▇▆▆▅▅▄▃▂▃▂▁▁
Val-acc,▁▆▆▆▇▇█▇▇▇▇█████
Val-loss,█▂▁▁▁▁▁▁▂▂▂▂▂▂▂▂

0,1
Train-loss,0.0291
Val-acc,0.75577
Val-loss,0.94041
