In [2]:
from transformers import AutoTokenizer
import torch
import numpy as np
import functools
from sklearn import metrics
from torch.utils.data import DataLoader
import sys

sys.path.append('/home/skrhakv/cryptic-nn/src')
import finetuning_utils
from finetuning_utils import FinetunedEsmModel

MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

PATH_TO_MODELS = '/home/skrhakv/cryptic-nn/final-data/trained-models'
PATH_TO_AUC_AUPRC_DATA = '/home/skrhakv/cryptic-nn/src/auc-auprc/data'

# get class weights
different class weights for the scPDB dataset.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/Near-Hit-Scoring/data/input/scPDB_90_SI.csv', tokenizer)
train_dataloader1 = DataLoader(train_dataset, batch_size=int(train_dataset.num_rows), collate_fn=partial_collate_fn)
for batch in train_dataloader1:
    labels = batch['labels']

import baseline_utils
class_labels = labels.cpu().numpy().reshape(-1)[labels.cpu().numpy().reshape(-1) >= 0]
weights = baseline_utils.compute_class_weights(class_labels)
weights

## base finetuned model:

In [6]:
import gc

sys.path.append('/home/skrhakv/cryptic-nn/src')
import baseline_utils

MODEL_PATH = f'/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/model-enhanced-scPDB.pt'
finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/ligysis_without_unobserved.csv', tokenizer) # it is called train because in other instance it was used for training but here we can use it for validation without problems

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=int(val_dataset.num_rows / 20), collate_fn=partial_collate_fn)

with torch.no_grad():
    logits_list = []
    labels_list = []

    for batch in val_dataloader:
        output = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        logits_list.append(cbs_logits.cpu().float().detach().numpy())
        labels_list.append(valid_flattened_labels.cpu().float().detach().numpy())

        del labels, cbs_logits, valid_flattened_labels, flattened_labels
        gc.collect()
        torch.cuda.empty_cache()

    cbs_logits = torch.tensor(np.concatenate(logits_list)).to(device)
    valid_flattened_labels = torch.tensor(np.concatenate(labels_list)).to(device)
    
    labels = valid_flattened_labels.cpu().float().numpy()
    predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
    best_threshold, previous_mcc = 0.0, -100
    for threshold in np.arange(0.1, 0.95, 0.05):
        rounded_predictions = (predictions > threshold).astype(int)
        acc = metrics.accuracy_score(labels, rounded_predictions)

        mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
        if mcc > previous_mcc:
            previous_mcc = mcc
            best_threshold = threshold
        f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')
        binary_f1 = metrics.f1_score(labels, rounded_predictions)

        print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f} | binary F1: {binary_f1:.4f}")
    predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

    # compute metrics on test dataset
    test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                            y_pred=predictions)

    fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    roc_auc = metrics.auc(fpr, tpr)

    mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')
    binary_f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, binary F1: {binary_f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

# np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
# np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)

	Threshold: 0.10 | Accuracy: 0.7174 | MCC: 0.3245 | F1: 0.7729 | binary F1: 0.3631
	Threshold: 0.15 | Accuracy: 0.7826 | MCC: 0.3608 | F1: 0.8209 | binary F1: 0.4064
	Threshold: 0.20 | Accuracy: 0.8201 | MCC: 0.3848 | F1: 0.8473 | binary F1: 0.4363
	Threshold: 0.25 | Accuracy: 0.8446 | MCC: 0.4036 | F1: 0.8643 | binary F1: 0.4591
	Threshold: 0.30 | Accuracy: 0.8618 | MCC: 0.4193 | F1: 0.8761 | binary F1: 0.4769
	Threshold: 0.35 | Accuracy: 0.8749 | MCC: 0.4324 | F1: 0.8850 | binary F1: 0.4910
	Threshold: 0.40 | Accuracy: 0.8850 | MCC: 0.4429 | F1: 0.8917 | binary F1: 0.5014
	Threshold: 0.45 | Accuracy: 0.8934 | MCC: 0.4521 | F1: 0.8972 | binary F1: 0.5095
	Threshold: 0.50 | Accuracy: 0.9004 | MCC: 0.4600 | F1: 0.9016 | binary F1: 0.5152
	Threshold: 0.55 | Accuracy: 0.9062 | MCC: 0.4664 | F1: 0.9051 | binary F1: 0.5181
	Threshold: 0.60 | Accuracy: 0.9109 | MCC: 0.4703 | F1: 0.9076 | binary F1: 0.5171
	Threshold: 0.65 | Accuracy: 0.9149 | MCC: 0.4735 | F1: 0.9094 | binary F1: 0.5132
	Thr