In [1]:
from transformers import AutoTokenizer
import torch
import numpy as np
import functools
from sklearn import metrics
from torch.utils.data import DataLoader
import sys

sys.path.append('/home/skrhakv/cryptic-nn/src')
import finetuning_utils
from finetuning_utils import FinetunedEsmModel

MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

PATH_TO_MODELS = '/home/skrhakv/cryptic-nn/final-data/trained-models'
PATH_TO_AUC_AUPRC_DATA = '/home/skrhakv/cryptic-nn/src/auc-auprc/data'

# get class weights
different class weights for the scPDB dataset.

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/scPDB_enhanced_binding_sites_translated.csv', tokenizer)
train_dataloader1 = DataLoader(train_dataset, batch_size=int(train_dataset.num_rows), collate_fn=partial_collate_fn)
for batch in train_dataloader1:
    labels = batch['labels']

import baseline_utils
class_labels = labels.cpu().numpy().reshape(-1)[labels.cpu().numpy().reshape(-1) >= 0]
weights = baseline_utils.compute_class_weights(class_labels)
weights

tensor([0.5875, 3.3580])

## base finetuned model:

In [2]:
import gc

sys.path.append('/home/skrhakv/cryptic-nn/src')
import baseline_utils

MODEL_PATH = f'/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/model-enhanced-scPDB.pt'
finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/ligysis_without_unobserved.csv', tokenizer) # it is called train because in other instance it was used for training but here we can use it for validation without problems

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=int(val_dataset.num_rows / 20), collate_fn=partial_collate_fn)

with torch.no_grad():
    logits_list = []
    labels_list = []

    for batch in val_dataloader:
        output = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        logits_list.append(cbs_logits.cpu().float().detach().numpy())
        labels_list.append(valid_flattened_labels.cpu().float().detach().numpy())

        del labels, cbs_logits, valid_flattened_labels, flattened_labels
        gc.collect()
        torch.cuda.empty_cache()

    cbs_logits = torch.tensor(np.concatenate(logits_list)).to(device)
    valid_flattened_labels = torch.tensor(np.concatenate(labels_list)).to(device)
    
    labels = valid_flattened_labels.cpu().float().numpy()
    predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
    best_threshold, previous_mcc = 0.0, -100
    for threshold in np.arange(0.1, 0.95, 0.05):
        rounded_predictions = (predictions > threshold).astype(int)
        acc = metrics.accuracy_score(labels, rounded_predictions)

        mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
        if mcc > previous_mcc:
            previous_mcc = mcc
            best_threshold = threshold
        f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')
        binary_f1 = metrics.f1_score(labels, rounded_predictions)

        print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f} | binary F1: {binary_f1:.4f}")
    predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

    # compute metrics on test dataset
    test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                            y_pred=predictions)

    fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    roc_auc = metrics.auc(fpr, tpr)

    mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')
    binary_f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, binary F1: {binary_f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

# np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
# np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)

	Threshold: 0.10 | Accuracy: 0.6603 | MCC: 0.2786 | F1: 0.7287 | binary F1: 0.3218
	Threshold: 0.15 | Accuracy: 0.7583 | MCC: 0.3243 | F1: 0.8031 | binary F1: 0.3742
	Threshold: 0.20 | Accuracy: 0.8054 | MCC: 0.3480 | F1: 0.8363 | binary F1: 0.4048
	Threshold: 0.25 | Accuracy: 0.8339 | MCC: 0.3659 | F1: 0.8557 | binary F1: 0.4267
	Threshold: 0.30 | Accuracy: 0.8528 | MCC: 0.3795 | F1: 0.8684 | binary F1: 0.4425
	Threshold: 0.35 | Accuracy: 0.8662 | MCC: 0.3900 | F1: 0.8772 | binary F1: 0.4538
	Threshold: 0.40 | Accuracy: 0.8764 | MCC: 0.3993 | F1: 0.8839 | binary F1: 0.4627
	Threshold: 0.45 | Accuracy: 0.8845 | MCC: 0.4076 | F1: 0.8891 | binary F1: 0.4697
	Threshold: 0.50 | Accuracy: 0.8913 | MCC: 0.4160 | F1: 0.8935 | binary F1: 0.4759
	Threshold: 0.55 | Accuracy: 0.8971 | MCC: 0.4232 | F1: 0.8971 | binary F1: 0.4803
	Threshold: 0.60 | Accuracy: 0.9020 | MCC: 0.4296 | F1: 0.9000 | binary F1: 0.4831
	Threshold: 0.65 | Accuracy: 0.9063 | MCC: 0.4352 | F1: 0.9025 | binary F1: 0.4841
	Thr

# LIGYSIS_NI
Run model trained on non-ehnanced scPDB dataset on the LIGYSIS_NI subset.

In [2]:
import gc

sys.path.append('/home/skrhakv/cryptic-nn/src')
import baseline_utils

MODEL_PATH = f'/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/old/model.pt'
finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptoshow-analysis/data/E-regular-binding-site-predictor/ligysis_NI_without_unobserved.csv', tokenizer) # it is called train because in other instance it was used for training but here we can use it for validation without problems

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=int(val_dataset.num_rows / 20), collate_fn=partial_collate_fn)

with torch.no_grad():
    logits_list = []
    labels_list = []

    for batch in val_dataloader:
        output = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        logits_list.append(cbs_logits.cpu().float().detach().numpy())
        labels_list.append(valid_flattened_labels.cpu().float().detach().numpy())

        del labels, cbs_logits, valid_flattened_labels, flattened_labels
        gc.collect()
        torch.cuda.empty_cache()

    cbs_logits = torch.tensor(np.concatenate(logits_list)).to(device)
    valid_flattened_labels = torch.tensor(np.concatenate(labels_list)).to(device)
    
    labels = valid_flattened_labels.cpu().float().numpy()
    predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
    best_threshold, previous_mcc = 0.0, -100
    for threshold in np.arange(0.1, 0.95, 0.05):
        rounded_predictions = (predictions > threshold).astype(int)
        acc = metrics.accuracy_score(labels, rounded_predictions)

        mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
        if mcc > previous_mcc:
            previous_mcc = mcc
            best_threshold = threshold
        f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')
        binary_f1 = metrics.f1_score(labels, rounded_predictions)

        print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f} | binary F1: {binary_f1:.4f}")
    predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

    # compute metrics on test dataset
    test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                            y_pred=predictions)

    fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    roc_auc = metrics.auc(fpr, tpr)

    mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')
    binary_f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

    precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
    auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, binary F1: {binary_f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

# np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
# np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)

	Threshold: 0.10 | Accuracy: 0.7007 | MCC: 0.3249 | F1: 0.7620 | binary F1: 0.3512
	Threshold: 0.15 | Accuracy: 0.7887 | MCC: 0.3814 | F1: 0.8274 | binary F1: 0.4150
	Threshold: 0.20 | Accuracy: 0.8303 | MCC: 0.4147 | F1: 0.8569 | binary F1: 0.4549
	Threshold: 0.25 | Accuracy: 0.8550 | MCC: 0.4391 | F1: 0.8743 | binary F1: 0.4835
	Threshold: 0.30 | Accuracy: 0.8713 | MCC: 0.4569 | F1: 0.8857 | binary F1: 0.5042
	Threshold: 0.35 | Accuracy: 0.8830 | MCC: 0.4706 | F1: 0.8938 | binary F1: 0.5197
	Threshold: 0.40 | Accuracy: 0.8916 | MCC: 0.4807 | F1: 0.8997 | binary F1: 0.5309
	Threshold: 0.45 | Accuracy: 0.8987 | MCC: 0.4903 | F1: 0.9046 | binary F1: 0.5407
	Threshold: 0.50 | Accuracy: 0.9044 | MCC: 0.4986 | F1: 0.9085 | binary F1: 0.5486
	Threshold: 0.55 | Accuracy: 0.9094 | MCC: 0.5055 | F1: 0.9118 | binary F1: 0.5546
	Threshold: 0.60 | Accuracy: 0.9135 | MCC: 0.5110 | F1: 0.9144 | binary F1: 0.5587
	Threshold: 0.65 | Accuracy: 0.9174 | MCC: 0.5160 | F1: 0.9168 | binary F1: 0.5615
	Thr