# Intro
Code copied from the `cryptic-finetuning` repository history to replicate the results that were calculated previously.

In [1]:
import torch
from transformers import AutoTokenizer, EsmModel
import numpy as np
import torch.nn as nn
import sklearn.metrics as metrics
import functools
import sys
sys.path.append('/home/skrhakv/cryptic-nn/src')
import finetuning_utils
import baseline_utils

torch.manual_seed(420)
OUTPUT_PATH = "/home/skrhakv/cryptic-nn/final-data/trained-models/multitask-finetuned-model-with-ligysis.pt"
ESM_MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'
MAX_LENGTH = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DECISION_THRESHOLD = 0.7
loaded_model = torch.load(OUTPUT_PATH, weights_only=False).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(ESM_MODEL_NAME)
partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)
val_dataset = finetuning_utils.process_sequence_dataset(
    '/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', 
    tokenizer,
    load_ids=True,
    )

In [2]:
SMOOTHING_MODEL_PATH = '/home/skrhakv/cryptoshow-analysis/data/C-optimize-smoother/smoother.pt'

SMOOTHING_DECISION_THRESHOLD = 0.4 # see src/C-optimize-smoother/classifier-for-cryptoshow.ipynb
DROPOUT = 0.5
LAYER_WIDTH = 2048
ESM2_DIM = 2560
INPUT_DIM  = ESM2_DIM * 2

class CryptoBenchClassifier(nn.Module):
    def __init__(self, dim=LAYER_WIDTH, dropout=DROPOUT):
        super().__init__()
        self.layer_1 = nn.Linear(in_features=INPUT_DIM, out_features=dim)
        self.dropout1 = nn.Dropout(dropout)

        self.layer_2 = nn.Linear(in_features=dim, out_features=dim)
        self.dropout2 = nn.Dropout(dropout)

        self.layer_3 = nn.Linear(in_features=dim, out_features=1)

        self.relu = nn.ReLU()

    def forward(self, x):
      # Intersperse the ReLU activation function between layers
       return self.layer_3(self.dropout2(self.relu(self.layer_2(self.dropout1(self.relu(self.layer_1(x)))))))


smoothing_model = torch.load(SMOOTHING_MODEL_PATH, weights_only=False).to(DEVICE)
smoothing_model.eval()


CryptoBenchClassifier(
  (layer_1): Linear(in_features=5120, out_features=2048, bias=True)
  (dropout1): Dropout(p=0.5, inplace=False)
  (layer_2): Linear(in_features=2048, out_features=2048, bias=True)
  (dropout2): Dropout(p=0.5, inplace=False)
  (layer_3): Linear(in_features=2048, out_features=1, bias=True)
  (relu): ReLU()
)

In [3]:
DATA_PATH = f'/home/skrhakv/cryptic-nn/data/cryptobench'
ESM_EMBEDDINGS_PATH = f'{DATA_PATH}/embeddings'
DISTANCE_MATRICES_PATH = f'{DATA_PATH}/distance-matrices'
POSITIVE_DISTANCE_THRESHOLD = 15

with torch.no_grad():
    all_test_logits = []
    all_test_pred = []
    all_y_test = []
    this_test_losses = []

    for batch in val_dataset:
        
        protein_id = batch['ids'][0]
        del batch['ids']
        batch = finetuning_utils.collate_fn([batch], tokenizer=tokenizer)
        output1, _, _ = loaded_model(batch)

        labels = batch['labels'].to(DEVICE)
        flattened_labels = labels.flatten()
        y_test = flattened_labels[flattened_labels != -100]
        logits = output1.flatten()[flattened_labels != -100]

        test_pred = (torch.sigmoid(logits)>DECISION_THRESHOLD).float()
        test_pred_copy = test_pred.clone().detach().cpu().numpy()

        # let's use the smoothing model here:
        # first, get the distance matrix
        distance_matrix = np.load(f'{DISTANCE_MATRICES_PATH}/{protein_id}.npy')
        
        assert distance_matrix.shape[0] == distance_matrix.shape[1]
        assert distance_matrix.shape[0] == test_pred_copy.shape[0]
        
        X_test = np.load(f'{DATA_PATH}/embeddings/{protein_id}.npy')
        assert X_test.shape[0] == distance_matrix.shape[0]

        # loop over the residues that are not binding and are potential candidates for smoothing
        for residue_idx in torch.where(test_pred == 0.0)[0]:
            # get the embedding of the residue
            current_residue_embedding = X_test[residue_idx]

            # get the close binding residues
            close_residues_indices = np.where(distance_matrix[residue_idx] < POSITIVE_DISTANCE_THRESHOLD)[0]
            close_binding_residues_indices = np.intersect1d(close_residues_indices, torch.where(test_pred == 1.0)[0].cpu().numpy())

            # create embedding 
            if len(close_binding_residues_indices) == 0:
                # no close binding residues - skip this residue
                continue
            elif len(close_binding_residues_indices) == 1:
                surrounding_embedding = X_test[close_binding_residues_indices].reshape(-1)
            else:
                # get the mean of the close binding residues
                surrounding_embedding = np.mean(X_test[close_binding_residues_indices], axis=0).reshape(-1)

            concatenated_embedding = torch.tensor(np.concatenate((current_residue_embedding, surrounding_embedding), axis=0), dtype=torch.float32).to(DEVICE)
            
            # get the prediction
            test_logits = smoothing_model(concatenated_embedding).squeeze()
            result = (torch.sigmoid(test_logits)>SMOOTHING_DECISION_THRESHOLD).float()
            if result == 1:
                # set the residue as binding
                test_pred_copy[residue_idx] = 1
        
        all_test_logits.append(logits.cpu().detach().numpy())
        all_y_test.append(y_test.cpu().detach().numpy())
        all_test_pred.append(test_pred_copy)
        assert len(y_test) == len(test_pred_copy)        
    test_logits = torch.tensor(np.concatenate(all_test_logits, axis=0), dtype=torch.float32).to(DEVICE)
    test_pred = torch.tensor(np.concatenate(all_test_pred, axis=0), dtype=torch.float32).to(DEVICE)
    y_test = torch.tensor(np.concatenate(all_y_test, axis=0), dtype=torch.float32).to(DEVICE)
    
    # compute metrics on test dataset
    test_acc = baseline_utils.accuracy_fn(y_true=y_test,
                            y_pred=test_pred)

    mcc = metrics.matthews_corrcoef(y_test.cpu().numpy(), test_pred.cpu().numpy())

    f1 = metrics.f1_score(y_test.cpu().numpy(), test_pred.cpu().numpy(), average='weighted')

print(f"Accuracy: {test_acc:.2f}% |  MCC: {mcc:.4f}, F1: {f1:.4f}")


Accuracy: 92.36% |  MCC: 0.4604, F1: 0.9318


## Compare it with model without smoothing


In [4]:
from torch.utils.data import DataLoader

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)
val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer)
val_dataloader = DataLoader(val_dataset, batch_size=val_dataset.num_rows, collate_fn=partial_collate_fn)

with torch.no_grad():
    for batch in val_dataloader:
        output1, _, _ = loaded_model(batch)

        labels = batch['labels'].to(DEVICE)

        flattened_labels = labels.flatten()

        cbs_logits = output1.flatten()[flattened_labels != -100]
        valid_flattened_labels = labels.flatten()[flattened_labels != -100]

        predictions = (torch.sigmoid(cbs_logits)>DECISION_THRESHOLD).float()

        # compute metrics on test dataset
        test_acc = baseline_utils.accuracy_fn(y_true=valid_flattened_labels,
                                y_pred=predictions)
        fpr, tpr, thresholds = metrics.roc_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        roc_auc = metrics.auc(fpr, tpr)

        mcc = metrics.matthews_corrcoef(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy())

        f1 = metrics.f1_score(valid_flattened_labels.cpu().float().numpy(), predictions.cpu().float().numpy(), average='weighted')

        precision, recall, thresholds = metrics.precision_recall_curve(valid_flattened_labels.cpu().float().numpy(), torch.sigmoid(cbs_logits).cpu().float().numpy())
        auprc = metrics.auc(recall, precision)

print(f"Accuracy: {test_acc:.2f}% | AUC: {roc_auc:.4f}, MCC: {mcc:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}")


Accuracy: 93.51% | AUC: 0.8935, MCC: 0.4720, F1: 0.9389, AUPRC: 0.4752
