In [3]:
import torch
from transformers import AutoTokenizer, EsmModel
import torch.nn as nn
import numpy as np
import finetuning_utils

# some constants
MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LENGTH = 1024
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
OUTPUT_SIZE = 1
DROPOUT = 0.25

# UPDATE THIS!
MODEL_PATH = '/home/skrhakv/nn-for-kamila/model.pt'

# define the model - if we do not define the model then the loading of the model will fail
class FinetuneESM(nn.Module):
    def __init__(self, esm_model: str) -> None:
        super().__init__()
        self.llm = EsmModel.from_pretrained(esm_model)
        self.dropout = nn.Dropout(DROPOUT)
        self.classifier = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)
        self.plDDT_regressor = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)
        self.distance_regressor = nn.Linear(self.llm.config.hidden_size, OUTPUT_SIZE)

    def forward(self, batch: dict[str, np.ndarray]) -> torch.Tensor:
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        token_embeddings = self.llm(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state
        
        return self.classifier(token_embeddings), self.plDDT_regressor(token_embeddings), self.distance_regressor(token_embeddings)

# load the model
model = torch.load(MODEL_PATH, weights_only=False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval()

KRAS_sequence = 'GMTEYKLVVVGACGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETSLLDILDTAGQEEYSAMRDQYMRTGEGFLLVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKSDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQGVDDAFYTLVREIRKHKEK'

# tokenize the sequence
tokenized_sequences = tokenizer(KRAS_sequence, max_length=MAX_LENGTH, padding='max_length', truncation=True)
tokenized_sequences = {k: torch.tensor([v]).to(DEVICE) for k,v in tokenized_sequences.items()}

# predict
output = model(tokenized_sequences)
output = output.flatten()

mask = (tokenized_sequences['attention_mask'] == 1).flatten()

output = torch.sigmoid(output[mask][1:-1]).detach().cpu().numpy()
print(output)


[0.0763  0.1235  0.2544  0.4612  0.5312  0.874   0.3777  0.8735  0.3633
 0.7983  0.9287  0.971   0.9937  0.992   0.9897  0.983   0.989   0.991
 0.9893  0.9688  0.808   0.8687  0.926   0.588   0.6685  0.614   0.4863
 0.887   0.9834  0.97    0.9585  0.953   0.984   0.966   0.9497  0.9766
 0.944   0.952   0.943   0.915   0.9316  0.8594  0.5786  0.424   0.3486
 0.2622  0.3398  0.3137  0.283   0.2893  0.2218  0.1947  0.4055  0.2644
 0.7085  0.721   0.961   0.9907  0.979   0.971   0.963   0.9683  0.792
 0.6714  0.7046  0.5254  0.5977  0.826   0.7437  0.615   0.8     0.8945
 0.8438  0.7935  0.849   0.889   0.5747  0.273   0.7227  0.1622  0.2527
 0.258   0.3875  0.3535  0.4724  0.458   0.3364  0.2769  0.3774  0.2866
 0.1803  0.315   0.3628  0.2186  0.3477  0.5493  0.519   0.5347  0.465
 0.7666  0.801   0.4226  0.783   0.9053  0.9326  0.579   0.395   0.3943
 0.3418  0.434   0.2654  0.3457  0.1453  0.1099  0.2424  0.21    0.9814
 0.9917  0.9473  0.9844  0.961   0.54    0.3667  0.3564  0.3423  0.