In [3]:
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
import functools

sys.path.append('/home/skrhakv/cryptic-nn/src/models')
import baseline_utils
import finetuning_utils
from finetuning_utils import MultitaskFinetunedEsmModel

MODEL = 'baseline-model'
MODEL_PATH = f'/home/skrhakv/cryptic-nn/src/models/train-models/{MODEL}.pt'
MAX_LENGTH = 1024

DATASET = 'cryptobench'
DATA_PATH = f'/home/skrhakv/cryptic-nn/data/{DATASET}'
ESM_EMBEDDINGS_PATH = f'{DATA_PATH}/embeddings'
ESM_MODEL_NAME = 'facebook/esm2_t36_3B_UR50D'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Xs_test_apo, Ys_test_apo = baseline_utils.process_sequence_dataset(f'/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', [ESM_EMBEDDINGS_PATH])
baseline_model = torch.load(MODEL_PATH, weights_only=False).to(device)


In [26]:
for protein_id, embedding in Xs_test_apo.items():
    embedding = torch.tensor(embedding, dtype=torch.float32).to(device)
    prediction = baseline_model(embedding).squeeze(1).detach().cpu().numpy()
    np.save(f'/home/skrhakv/cryptic-nn/src/models/predict/predictions/baseline/{protein_id}.npy', prediction)
    np.save(f'/home/skrhakv/cryptic-nn/src/models/predict/ground-truth/baseline/{protein_id}.npy', Ys_test_apo[protein_id])


In [None]:
MODEL = 'multitask-finetuned-model-with-ligysis'
MODEL_PATH = f'/home/skrhakv/cryptic-nn/src/models/train-models/{MODEL}.pt'

finetuned_model = torch.load(MODEL_PATH, weights_only=False).to(device)
tokenizer = AutoTokenizer.from_pretrained(ESM_MODEL_NAME)

val_dataset = finetuning_utils.process_sequence_dataset('/home/skrhakv/cryptic-nn/data/cryptobench/test.txt', tokenizer, load_ids=True)

partial_collate_fn = functools.partial(finetuning_utils.collate_fn, tokenizer=tokenizer)
finetuned_model.eval()


In [36]:
def predict(model, tokenized_sequences):
    tokenized_sequences = {k: torch.tensor([v]).to(device) for k,v in tokenized_sequences.items()}
    output, _, _ = model(tokenized_sequences)
    output = output.flatten()

    mask = (tokenized_sequences['attention_mask'] == 1).flatten()
    return torch.sigmoid(output[mask][1:-1]).detach().cpu().numpy()

for i in val_dataset:
    protein_id = i['ids'][0]
    del i['ids']  # Remove 'ids' to avoid passing it to the model
    prediction = predict(finetuned_model, i)

    baseline_prediction = np.load(f'/home/skrhakv/cryptic-nn/src/models/predict/predictions/baseline/{protein_id}.npy')
    assert baseline_prediction.shape == prediction.shape, f"Shape mismatch for {protein_id}: {baseline_prediction.shape} vs {prediction.shape}"
    np.save(f'/home/skrhakv/cryptic-nn/src/models/predict/predictions/baseline/{protein_id}.npy', prediction)
    np.save(f'/home/skrhakv/cryptic-nn/src/models/predict/ground-truth/baseline/{protein_id}.npy', i['labels'])
