In [2]:
from transformers import AutoTokenizer
import torch
import numpy as np
import functools
from sklearn import metrics
from torch.utils.data import DataLoader
import sys

sys.path.append('/home/skrhakv/cryptic-nn/src')
import baseline_utils
import finetuning_utils
from finetuning_utils import FinetunedEsmModel, MultitaskFinetunedEsmModel, MultitaskFinetunedEsmModelWithCnn

MODEL_NAME = 'facebook/esm2_t33_650M_UR50D'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device


'cuda'

In [None]:
MODEL = 'base-finetuned-model'
MODEL_PATH = f'/home/skrhakv/nn-for-kamila/model.pt'
finetuned_model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/train.txt', tokenizer)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)

val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=partial_collate_fn)

with torch.no_grad():
    for batch in val_dataloader:

        output = finetuned_model(batch)

        labels = batch['labels'].to(device)

        flattened_labels = labels.flatten()

        cbs_logits = output.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]
        
        labels = valid_flattened_labels.cpu().float().numpy()
        predictions = torch.sigmoid(cbs_logits).cpu().float().detach().numpy()
        best_threshold, previous_mcc = 0.0, -100
        for threshold in np.arange(0.1, 0.95, 0.05):
            rounded_predictions = (predictions > threshold).astype(int)
            acc = metrics.accuracy_score(labels, rounded_predictions)

            mcc = metrics.matthews_corrcoef(labels, rounded_predictions)
            if mcc > previous_mcc:
                previous_mcc = mcc
                best_threshold = threshold
            f1 = metrics.f1_score(labels, rounded_predictions, average='weighted')

            print(f"\tThreshold: {threshold:.2f} | Accuracy: {acc:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}")
        predictions = (torch.sigmoid(cbs_logits)>best_threshold).float() # torch.round(torch.sigmoid(cbs_logits))

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)

        fpr, tpr, thresholds1 = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds2 = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

print(f"Best threshold: {best_threshold:.2f}:")
print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, sum: {sum(predictions.to(dtype=torch.int))}")


	Threshold: 0.10 | Accuracy: 0.1779 | MCC: 0.0838 | F1: 0.2223
	Threshold: 0.15 | Accuracy: 0.4048 | MCC: 0.1591 | F1: 0.5188
	Threshold: 0.20 | Accuracy: 0.5852 | MCC: 0.2238 | F1: 0.6904
	Threshold: 0.25 | Accuracy: 0.6934 | MCC: 0.2739 | F1: 0.7757
	Threshold: 0.30 | Accuracy: 0.7587 | MCC: 0.3117 | F1: 0.8225
	Threshold: 0.35 | Accuracy: 0.8034 | MCC: 0.3435 | F1: 0.8531
	Threshold: 0.40 | Accuracy: 0.8367 | MCC: 0.3710 | F1: 0.8753
	Threshold: 0.45 | Accuracy: 0.8612 | MCC: 0.3935 | F1: 0.8914
	Threshold: 0.50 | Accuracy: 0.8796 | MCC: 0.4107 | F1: 0.9033
	Threshold: 0.55 | Accuracy: 0.8955 | MCC: 0.4298 | F1: 0.9137
	Threshold: 0.60 | Accuracy: 0.9086 | MCC: 0.4462 | F1: 0.9223
	Threshold: 0.65 | Accuracy: 0.9191 | MCC: 0.4587 | F1: 0.9291
	Threshold: 0.70 | Accuracy: 0.9266 | MCC: 0.4664 | F1: 0.9338
	Threshold: 0.75 | Accuracy: 0.9334 | MCC: 0.4681 | F1: 0.9378
	Threshold: 0.80 | Accuracy: 0.9386 | MCC: 0.4643 | F1: 0.9405
	Threshold: 0.85 | Accuracy: 0.9430 | MCC: 0.4483 | F1: