In [1]:
import torch
from transformers import AutoTokenizer, EsmModel
import torch.nn as nn
import numpy as np

# some constants
MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LENGTH = 1024
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
OUTPUT_SIZE = 1
DROPOUT = 0.25

# UPDATE THIS!
MODEL_PATH = '/home/skrhakv/cryptic-nn/src/650M-version/finetuned-model-650M.pt'

# define the model - if we do not define the model then the loading of the model will fail
class FinetuneESM(nn.Module):
    def __init__(self, esm_model: str) -> None:
        super().__init__()
        self.llm = EsmModel.from_pretrained(esm_model)
        self.dropout = nn.Dropout(DROPOUT)
        self.classifier = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)
        self.plDDT_regressor = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)
        self.distance_regressor = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)

    def forward(self, batch: dict[str, np.ndarray]) -> torch.Tensor:
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        token_embeddings = self.llm(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state
        
        return self.classifier(token_embeddings), self.plDDT_regressor(token_embeddings), self.distance_regressor(token_embeddings)

# load the model
model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval()

with open('/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/3b6rB.txt', 'r') as f:
    sequence = f.readline()


In [2]:
import sys
sys.path.append('/home/skrhakv/cryptoshow-analysis/src/B-evaluate-cryptoshow')
import eval_utils

DECISION_THRESHOLD = 0.8
DROPOUT = 0.3
LAYER_WIDTH = 256
ESM2_DIM  = 1280 * 2

class CryptoBenchClassifier(nn.Module):
    def __init__(self, input_dim=ESM2_DIM):
        super().__init__()
        self.layer_1 = nn.Linear(in_features=input_dim, out_features=LAYER_WIDTH)
        self.dropout1 = nn.Dropout(DROPOUT)

        self.layer_2 = nn.Linear(in_features=LAYER_WIDTH, out_features=LAYER_WIDTH)
        self.dropout2 = nn.Dropout(DROPOUT)

        self.layer_3 = nn.Linear(in_features=LAYER_WIDTH, out_features=1)

        self.relu = nn.ReLU()

    def forward(self, x):
        # Intersperse the ReLU activation function between layers
        return self.layer_3(self.dropout2(self.relu(self.layer_2(self.dropout1(self.relu(self.layer_1(x)))))))
    

HIGH_SCORE_THRESHOLD = 0.7  # Threshold to consider a point as high score
SMOOTHENED_THRESHOLD = 0.7 # this is defined by the training data - best F1 score was achieved with this threshold 
SMOOTHING_MODEL_STATE_DICT_PATH = '/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase/cryptobench_classifier.pt'
smoothing_model = CryptoBenchClassifier().to(DEVICE)
smoothing_model.load_state_dict(torch.load(SMOOTHING_MODEL_STATE_DICT_PATH, map_location=DEVICE), strict=True)

<All keys matched successfully>

In [3]:
OUTPUT_PATH = '/home/skrhakv/cryptoshow-analysis/src/E-regular-binding-site-predictor/evaluation/creatine-kinase'
pdb_id = '3b6r'
chain_id = 'B'

embedding_path = f'{OUTPUT_PATH}/{pdb_id}{chain_id}_embedding.npy'
coordinates_path = f'{OUTPUT_PATH}/{pdb_id}{chain_id}.npy'

prediction = eval_utils.compute_prediction(
    sequence,
    embedding_path,
    model,
    tokenizer
)

coordinates = np.load(coordinates_path)
clusters = eval_utils.compute_clusters(
        coordinates,
        prediction,
        decision_threshold=HIGH_SCORE_THRESHOLD,
        method='meanshift'
    )

distance_matrix = eval_utils.compute_distance_matrix(coordinates)

# enhance predicted pockets using the smoothing model
predicted_binding_sites = []
for cluster_label in np.unique(clusters):
    if cluster_label == -1:
        continue
    cluster_residue_indices = np.where(clusters == cluster_label)[0]
    embeddings, indices = eval_utils.process_single_sequence(pdb_id, chain_id, cluster_residue_indices, embedding_path, distance_matrix) 
    
    prediction = eval_utils.predict_single_sequence(embeddings, indices, smoothing_model)

    enhanced_residue_indices = np.concatenate((indices[prediction['predictions'] > SMOOTHENED_THRESHOLD], cluster_residue_indices))
    predicted_binding_sites.append(enhanced_residue_indices)

import pickle
with open(f'{OUTPUT_PATH}/{pdb_id}{chain_id}.pkl', 'wb') as f:
    pickle.dump(predicted_binding_sites, f)

predicted_binding_sites

[array([290,  63, 185, 314, 315, 329, 316, 317, 318, 319, 320, 321]),
 array([277, 278, 279,  52,  53,  54,  55,  56,  57,  63,  64, 195,  69,
        199, 198, 276,  22,  23,  65,  66,  67,  68]),
 array([286,  66, 199, 201, 276, 277, 278, 279, 280]),
 array([ 22,  23,  52,  55,  65,  66,  67,  68, 195, 196, 197, 198, 199,
        200,  53,  54,  56,  57]),
 array([276,  53, 195, 196, 198, 199, 202, 225, 200, 201]),
 array([276,  53,  54,  55,  56, 195, 196,  67, 200, 201, 198, 199]),
 array([286, 186, 222, 230, 124, 126, 232]),
 array([286, 126, 123, 124]),
 array([230, 222]),
 array([196, 201, 202, 222, 226, 230, 188, 225]),
 array([222, 225, 126, 230]),
 array([195, 198, 199, 200, 201, 202, 203, 225,  53, 196]),
 array([280, 124, 126, 286]),
 array([329, 286, 290, 122, 123, 124, 126, 288]),
 array([321, 318, 319, 315, 316, 317, 314]),
 array([278, 279, 280, 286, 335])]