# external libraries
We are using libraries `baseline_utils` and `finetuning_utils` from this repository: https://github.com/skrhakv/cryptic-nn

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import csv
from torch import nn
from sklearn.utils import class_weight
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
import sys 

sys.path.append('/home/skrhakv/cryptic-nn/src')
import baseline_utils
import finetuning_utils

torch.manual_seed(0)

DATASET = 'cryptobench'
DATA_PATH = f'/home/skrhakv/cryptic-nn/data/{DATASET}'
ESM_EMBEDDINGS_PATH = f'{DATA_PATH}/embeddings'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Finetuning ESM2 model
Let's finetune the whole ESM2 model.

In [2]:
from transformers import AutoTokenizer
import torch
import torch.nn as nn
import functools
from sklearn import metrics
import gc
import bitsandbytes as bnb
from torch.utils.data import DataLoader
import warnings

warnings.filterwarnings('ignore')
torch.manual_seed(42)
MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'


In [3]:
finetuned_model = finetuning_utils.FinetunedEsmModel(MODEL_NAME).half().to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/scPDB_enhanced_binding_sites_translated.csv', tokenizer)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/ligysis_without_unobserved.csv', tokenizer) 

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=partial_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=int(val_dataset.num_rows / 20), collate_fn=partial_collate_fn)

optimizer = bnb.optim.AdamW8bit(finetuned_model.parameters(), lr=0.0001, eps=1e-4) 

EPOCHS = 3

# precomputed class weights
class_weights = torch.tensor([0.5590, 4.7378], device=device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

for name, param in finetuned_model.named_parameters():
     if name.startswith('llm'): 
        param.requires_grad = False

test_losses = []
train_losses = []

for epoch in range(EPOCHS):
    if epoch > 1:
        for name, param in finetuned_model.named_parameters():
            param.requires_grad = True

    finetuned_model.eval()

    # VALIDATION LOOP
    with torch.no_grad():
        logits_list = []
        labels_list = []

        for batch in val_dataloader:
            output = finetuned_model(batch)

            labels = batch['labels'].to(device)
    
            flattened_labels = labels.flatten()

            cbs_logits = output.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]

            logits_list.append(cbs_logits.cpu().float().detach().numpy())
            labels_list.append(valid_flattened_labels.cpu().float().detach().numpy())

            del labels, cbs_logits, valid_flattened_labels, flattened_labels
            gc.collect()
            torch.cuda.empty_cache()

        # TODO: is it going to fail on memory or not when using LYGISYS?
        cbs_logits = torch.tensor(np.concatenate(logits_list)).to(device)
        valid_flattened_labels = torch.tensor(np.concatenate(labels_list)).to(device)

        predictions = torch.round(torch.sigmoid(cbs_logits)) # (probabilities>0.95).float() # torch.round(torch.sigmoid(valid_flattened_cbs_logits))

        cbs_test_loss =  loss_fn(cbs_logits, valid_flattened_labels)

        test_loss = cbs_test_loss

        test_losses.append(test_loss.cpu().float().detach().numpy())

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)
        fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

    
    finetuned_model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:

        output = finetuned_model(batch)
        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        loss =  loss_fn(cbs_logits, valid_flattened_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.cpu().float().detach().numpy())
        
        del labels, output, cbs_logits, valid_flattened_labels, flattened_labels
        gc.collect()
        torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 0 | Loss: 0.66142, Accuracy: 25.90% | Test loss: 0.97030, AUC: 0.4947, MCC: -0.0189, F1: 0.3136, AUPRC: 0.1027, sum: 706830
Epoch: 1 | Loss: 2.89185, Accuracy: 77.26% | Test loss: 0.73298, AUC: 0.7945, MCC: 0.3046, F1: 0.8120, AUPRC: 0.3893, sum: 230600
Epoch: 2 | Loss: 0.23628, Accuracy: 75.54% | Test loss: 0.74321, AUC: 0.7909, MCC: 0.2933, F1: 0.7997, AUPRC: 0.3862, sum: 248394


In [4]:
OUTPUT_PATH = '/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/model-enhanced-scPDB.pt'
torch.save(finetuned_model, OUTPUT_PATH)

# Try one extra epoch:

(this didn't work!)

In [None]:
MODEL_PATH = OUTPUT_PATH # f'/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/model.pt'
finetuned_model = torch.load(OUTPUT_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt', tokenizer)
# val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)

train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/Near-Hit-Scoring/data/input/scPDB_90_SI.csv', tokenizer)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/ligysis/train.txt', tokenizer) # it is called train because in other instance it was used for training but here we can use it for validation without problems

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=partial_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=int(val_dataset.num_rows / 20), collate_fn=partial_collate_fn)

optimizer = bnb.optim.AdamW8bit(finetuned_model.parameters(), lr=0.0001, eps=1e-4) 

EPOCHS = 1

# precomputed class weights
class_weights = torch.tensor([0.5590, 4.7378], device=device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])

test_losses = []
train_losses = []

for epoch in range(EPOCHS):

    finetuned_model.eval()

    # VALIDATION LOOP
    with torch.no_grad():
        logits_list = []
        labels_list = []

        for batch in val_dataloader:
            output = finetuned_model(batch)

            labels = batch['labels'].to(device)
    
            flattened_labels = labels.flatten()

            cbs_logits = output.flatten()[flattened_labels != -100]
            valid_flattened_labels = labels.flatten()[flattened_labels != -100]

            logits_list.append(cbs_logits.cpu().float().detach().numpy())
            labels_list.append(valid_flattened_labels.cpu().float().detach().numpy())

            del labels, cbs_logits, valid_flattened_labels, flattened_labels
            gc.collect()
            torch.cuda.empty_cache()

        # TODO: is it going to fail on memory or not when using LYGISYS?
        cbs_logits = torch.tensor(np.concatenate(logits_list)).to(device)
        valid_flattened_labels = torch.tensor(np.concatenate(labels_list)).to(device)

        predictions = torch.round(torch.sigmoid(cbs_logits)) # (probabilities>0.95).float() # torch.round(torch.sigmoid(valid_flattened_cbs_logits))

        cbs_test_loss =  loss_fn(cbs_logits, valid_flattened_labels)

        test_loss = cbs_test_loss

        test_losses.append(test_loss.cpu().float().detach().numpy())

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)
        fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

    
    finetuned_model.train()

    batch_losses = []

    # TRAIN
    for batch in train_dataloader:

        output = finetuned_model(batch)
        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        loss =  loss_fn(cbs_logits, valid_flattened_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.cpu().float().detach().numpy())
        
        del labels, output, cbs_logits, valid_flattened_labels, flattened_labels
        gc.collect()
        torch.cuda.empty_cache()

    train_losses.append(sum(batch_losses) / len(batch_losses))
    print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

    finetuned_model.eval()

# VALIDATION LOOP
with torch.no_grad():
    logits_list = []
    labels_list = []

    for batch in val_dataloader:
        output = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        logits_list.append(cbs_logits.cpu().float().detach().numpy())
        labels_list.append(valid_flattened_labels.cpu().float().detach().numpy())

        del labels, cbs_logits, valid_flattened_labels, flattened_labels
        gc.collect()
        torch.cuda.empty_cache()

    # TODO: is it going to fail on memory or not when using LYGISYS?
    cbs_logits = torch.tensor(np.concatenate(logits_list)).to(device)
    valid_flattened_labels = torch.tensor(np.concatenate(labels_list)).to(device)

    predictions = torch.round(torch.sigmoid(cbs_logits)) # (probabilities>0.95).float() # torch.round(torch.sigmoid(valid_flattened_cbs_logits))

    cbs_test_loss =  loss_fn(cbs_logits, valid_flattened_labels)

    test_loss = cbs_test_loss

    test_losses.append(test_loss.cpu().float().detach().numpy())

    # compute metrics on test dataset
    test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                            y_pred=predictions)
    fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    roc_auc = metrics.auc(fpr, tpr)

    mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

    precision, recall, thresholds = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    auprc = metrics.auc(recall, precision)
print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {test_acc:.2f}% | Test loss: {test_loss:.5f}, AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")


Epoch: 0 | Loss: 0.07614, Accuracy: 92.78% | Test loss: 0.49911, AUC: 0.8205, MCC: 0.4211, F1: 0.9300, AUPRC: 0.4373, sum: 87069
Epoch: 0 | Loss: 0.07614, Accuracy: 91.79% | Test loss: 0.53121, AUC: 0.8252, MCC: 0.4135, F1: 0.9241, AUPRC: 0.4542, sum: 106653
