In [None]:
from transformers import AutoTokenizer
import torch
import numpy as np
import functools
from sklearn import metrics
from torch.utils.data import DataLoader
import sys

sys.path.append('/home/skrhakv/cryptic-nn/src')
import baseline_utils
import finetuning_utils
from finetuning_utils import FinetunedEsmModel, MultitaskFinetunedEsmModel, MultitaskFinetunedEsmModelWithCnn

MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

PATH_TO_MODELS = '/home/skrhakv/cryptic-nn/final-data/trained-models'
PATH_TO_AUC_AUPRC_DATA = '/home/skrhakv/cryptic-nn/src/auc-auprc/data'

## base finetuned model:

In [None]:
MODEL = 'base-finetuned-model'
MODEL_PATH = f'{PATH_TO_MODELS}/{MODEL}.pt'
finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt', tokenizer)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=partial_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=partial_collate_fn)

with torch.no_grad():
    for batch in val_dataloader:

        output = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]
        
        labels = valid_flattened_labels.cpu().float().numpy()
        predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
        best_threshold, previous_mcc = 0.0, -100
        for threshold in np.arange(0.1, 0.95, 0.05):
            rounded_predictions = (predictions > threshold).astype(int)
            acc = metrics.accuracy_score(labels, rounded_predictions)

            mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
            if mcc > previous_mcc:
                previous_mcc = mcc
                best_threshold = threshold
            f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')

            print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}")
        predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)

        fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)


	Threshold: 0.10 | Accuracy: 0.3888 | MCC: 0.1554 | F1: 0.5012
	Threshold: 0.15 | Accuracy: 0.5491 | MCC: 0.2133 | F1: 0.6592
	Threshold: 0.20 | Accuracy: 0.6447 | MCC: 0.2536 | F1: 0.7386
	Threshold: 0.25 | Accuracy: 0.7153 | MCC: 0.2895 | F1: 0.7917
	Threshold: 0.30 | Accuracy: 0.7663 | MCC: 0.3197 | F1: 0.8278
	Threshold: 0.35 | Accuracy: 0.8065 | MCC: 0.3452 | F1: 0.8551
	Threshold: 0.40 | Accuracy: 0.8400 | MCC: 0.3720 | F1: 0.8774
	Threshold: 0.45 | Accuracy: 0.8663 | MCC: 0.3937 | F1: 0.8945
	Threshold: 0.50 | Accuracy: 0.8886 | MCC: 0.4170 | F1: 0.9090
	Threshold: 0.55 | Accuracy: 0.9050 | MCC: 0.4320 | F1: 0.9196
	Threshold: 0.60 | Accuracy: 0.9180 | MCC: 0.4423 | F1: 0.9277
	Threshold: 0.65 | Accuracy: 0.9282 | MCC: 0.4477 | F1: 0.9339
	Threshold: 0.70 | Accuracy: 0.9361 | MCC: 0.4435 | F1: 0.9382
	Threshold: 0.75 | Accuracy: 0.9425 | MCC: 0.4387 | F1: 0.9411
	Threshold: 0.80 | Accuracy: 0.9465 | MCC: 0.4127 | F1: 0.9411
	Threshold: 0.85 | Accuracy: 0.9481 | MCC: 0.3669 | F1:

## multitask model:

In [None]:
MODEL = 'multitask-finetuned-model'
MODEL_PATH = f'{PATH_TO_MODELS}/{MODEL}.pt'

finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

plDDT_path = '/home/skrhakv/cryptic-nn/data/ligysis/plDDT'
plDDT_scaler = finetuning_utils.train_scaler('/home/skrhakv/cryptic-nn/data/ligysis/train.txt', plDDT_path=plDDT_path, uniprot_ids=True)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=partial_collate_fn)

with torch.no_grad():
    for batch in val_dataloader:
        output1, _, _ = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output1.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        labels = valid_flattened_labels.cpu().float().numpy()
        predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
        best_threshold, previous_mcc = 0.0, -100
        for threshold in np.arange(0.1, 0.95, 0.05):
            rounded_predictions = (predictions > threshold).astype(int)
            acc = metrics.accuracy_score(labels, rounded_predictions)

            mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
            if mcc > previous_mcc:
                previous_mcc = mcc
                best_threshold = threshold
            f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')

            print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}")
        predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)
        
        fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)


	Threshold: 0.10 | Accuracy: 0.5195 | MCC: 0.2030 | F1: 0.6326
	Threshold: 0.15 | Accuracy: 0.6391 | MCC: 0.2519 | F1: 0.7342
	Threshold: 0.20 | Accuracy: 0.7085 | MCC: 0.2870 | F1: 0.7868
	Threshold: 0.25 | Accuracy: 0.7585 | MCC: 0.3164 | F1: 0.8224
	Threshold: 0.30 | Accuracy: 0.7968 | MCC: 0.3397 | F1: 0.8486
	Threshold: 0.35 | Accuracy: 0.8298 | MCC: 0.3666 | F1: 0.8708
	Threshold: 0.40 | Accuracy: 0.8559 | MCC: 0.3889 | F1: 0.8879
	Threshold: 0.45 | Accuracy: 0.8769 | MCC: 0.4081 | F1: 0.9015
	Threshold: 0.50 | Accuracy: 0.8931 | MCC: 0.4261 | F1: 0.9122
	Threshold: 0.55 | Accuracy: 0.9059 | MCC: 0.4333 | F1: 0.9202
	Threshold: 0.60 | Accuracy: 0.9173 | MCC: 0.4405 | F1: 0.9273
	Threshold: 0.65 | Accuracy: 0.9266 | MCC: 0.4459 | F1: 0.9329
	Threshold: 0.70 | Accuracy: 0.9339 | MCC: 0.4431 | F1: 0.9369
	Threshold: 0.75 | Accuracy: 0.9400 | MCC: 0.4315 | F1: 0.9395
	Threshold: 0.80 | Accuracy: 0.9448 | MCC: 0.4186 | F1: 0.9409
	Threshold: 0.85 | Accuracy: 0.9475 | MCC: 0.3823 | F1:

## multitask model with additional data:

In [4]:
MODEL = 'multitask-finetuned-model-with-ligysis'
MODEL_PATH = f'{PATH_TO_MODELS}/{MODEL}.pt'

finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

plDDT_path = '/home/skrhakv/cryptic-nn/data/ligysis/plDDT'
plDDT_scaler = finetuning_utils.train_scaler('/home/skrhakv/cryptic-nn/data/ligysis/train.txt', plDDT_path=plDDT_path, uniprot_ids=True)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=partial_collate_fn)

with torch.no_grad():
    for batch in val_dataloader:
        output1, _, _ = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output1.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        labels = valid_flattened_labels.cpu().float().numpy()
        predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
        best_threshold, previous_mcc = 0.0, -100
        for threshold in np.arange(0.1, 0.95, 0.05):
            rounded_predictions = (predictions > threshold).astype(int)
            acc = metrics.accuracy_score(labels, rounded_predictions)

            mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
            if mcc > previous_mcc:
                previous_mcc = mcc
                best_threshold = threshold
            f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')

            print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}")
        predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)

        fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")

np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)


	Threshold: 0.10 | Accuracy: 0.5258 | MCC: 0.2029 | F1: 0.6384
	Threshold: 0.15 | Accuracy: 0.6454 | MCC: 0.2551 | F1: 0.7391
	Threshold: 0.20 | Accuracy: 0.7116 | MCC: 0.2909 | F1: 0.7890
	Threshold: 0.25 | Accuracy: 0.7584 | MCC: 0.3200 | F1: 0.8223
	Threshold: 0.30 | Accuracy: 0.7938 | MCC: 0.3426 | F1: 0.8467
	Threshold: 0.35 | Accuracy: 0.8242 | MCC: 0.3672 | F1: 0.8671
	Threshold: 0.40 | Accuracy: 0.8500 | MCC: 0.3916 | F1: 0.8842
	Threshold: 0.45 | Accuracy: 0.8716 | MCC: 0.4138 | F1: 0.8984
	Threshold: 0.50 | Accuracy: 0.8895 | MCC: 0.4345 | F1: 0.9102
	Threshold: 0.55 | Accuracy: 0.9043 | MCC: 0.4499 | F1: 0.9198
	Threshold: 0.60 | Accuracy: 0.9166 | MCC: 0.4600 | F1: 0.9276
	Threshold: 0.65 | Accuracy: 0.9268 | MCC: 0.4681 | F1: 0.9340
	Threshold: 0.70 | Accuracy: 0.9350 | MCC: 0.4720 | F1: 0.9389
	Threshold: 0.75 | Accuracy: 0.9413 | MCC: 0.4634 | F1: 0.9419
	Threshold: 0.80 | Accuracy: 0.9461 | MCC: 0.4425 | F1: 0.9429
	Threshold: 0.85 | Accuracy: 0.9489 | MCC: 0.4090 | F1:

## multitask model with additional data and CNN extractor

In [5]:
MODEL = 'multitask-finetuned-model-with-CNN-with-ligysis'
MODEL_PATH = f'{PATH_TO_MODELS}/{MODEL}.pt'

finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

plDDT_path = '/home/skrhakv/cryptic-nn/data/ligysis/plDDT'
plDDT_scaler = finetuning_utils.train_scaler('/home/skrhakv/cryptic-nn/data/ligysis/train.txt', plDDT_path=plDDT_path, uniprot_ids=True)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer, plDDT_path='/home/skrhakv/cryptic-nn/data/cryptobench/plDDT', plDDT_scaler=plDDT_scaler)

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=partial_collate_fn)

with torch.no_grad():
    for batch in val_dataloader:
        output1, _, _ = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output1.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        labels = valid_flattened_labels.cpu().float().numpy()
        predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
        best_threshold, previous_mcc = 0.0, -100
        for threshold in np.arange(0.1, 0.95, 0.05):
            rounded_predictions = (predictions > threshold).astype(int)
            acc = metrics.accuracy_score(labels, rounded_predictions)

            mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
            if mcc > previous_mcc:
                previous_mcc = mcc
                best_threshold = threshold
            f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')

            print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}")
        
        
        predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)
        fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")
np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-rocauc.npz', fpr, tpr, thresholds1)
np.savez(f'{PATH_TO_AUC_AUPRC_DATA}/{MODEL}-auprc.npz', precision, recall, thresholds2)


	Threshold: 0.10 | Accuracy: 0.4119 | MCC: 0.1597 | F1: 0.5266
	Threshold: 0.15 | Accuracy: 0.5316 | MCC: 0.2026 | F1: 0.6438
	Threshold: 0.20 | Accuracy: 0.6034 | MCC: 0.2337 | F1: 0.7055
	Threshold: 0.25 | Accuracy: 0.6546 | MCC: 0.2576 | F1: 0.7463
	Threshold: 0.30 | Accuracy: 0.6927 | MCC: 0.2778 | F1: 0.7752
	Threshold: 0.35 | Accuracy: 0.7258 | MCC: 0.2975 | F1: 0.7993
	Threshold: 0.40 | Accuracy: 0.7544 | MCC: 0.3146 | F1: 0.8196
	Threshold: 0.45 | Accuracy: 0.7793 | MCC: 0.3315 | F1: 0.8368
	Threshold: 0.50 | Accuracy: 0.8038 | MCC: 0.3504 | F1: 0.8534
	Threshold: 0.55 | Accuracy: 0.8251 | MCC: 0.3633 | F1: 0.8676
	Threshold: 0.60 | Accuracy: 0.8446 | MCC: 0.3781 | F1: 0.8804
	Threshold: 0.65 | Accuracy: 0.8637 | MCC: 0.3986 | F1: 0.8930
	Threshold: 0.70 | Accuracy: 0.8807 | MCC: 0.4133 | F1: 0.9041
	Threshold: 0.75 | Accuracy: 0.8977 | MCC: 0.4344 | F1: 0.9152
	Threshold: 0.80 | Accuracy: 0.9130 | MCC: 0.4473 | F1: 0.9250
	Threshold: 0.85 | Accuracy: 0.9269 | MCC: 0.4544 | F1: