In [1]:
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
path = "/Users/sarwar/Documents/protein_class"
os.chdir(path)

os.getcwd()

'/Users/sarwar/Documents/protein_class'

In [3]:
# Load saved model checkpoint and tokenizer
checkpoint_dir = './results_8M_c20_tmp/training_results/checkpoint-8725'  # Adjust path if necessary

model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir)

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D", do_lower_case=False, model_max_length=238)

# Example user sequence
user_sequence = "LTDGLSNLVLGKKTIDASLLEELEMILLSADIGIEATQSILNNLSQQVARKSLSDPKALIDALKIEL"

# Tokenize the sequence
encoded_sequence = tokenizer(
    user_sequence,
    truncation=True,
    padding='max_length',
    return_tensors="pt"
)

In [4]:
# Check if CUDA (GPU) is available
#if torch.cuda.is_available():
#    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
#    print(f"Device Count: {torch.cuda.device_count()}")
#else:
#    print("GPU is not available. Using CPU.")

In [5]:
# Move data to device (GPU if available)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

In [6]:
# Check for MPS (Metal Performance Shaders) availability
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Use MPS for GPU
    print("Using Apple GPU (MPS backend).")
else:
    device = torch.device("cpu")  # Fallback to CPU
    print("Using CPU.")
    
model.to(device)

Using Apple GPU (MPS backend).


EsmForSequenceClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05,

In [7]:
encoded_sequence = {key: val.to(device) for key, val in encoded_sequence.items()}

In [8]:
encoded_sequence

{'input_ids': tensor([[ 0,  4, 11, 13,  6,  4,  8, 17,  4,  7,  4,  6, 15, 15, 11, 12, 13,  5,
           8,  4,  4,  9,  9,  4,  9, 20, 12,  4,  4,  8,  5, 13, 12,  6, 12,  9,
           5, 11, 16,  8, 12,  4, 17, 17,  4,  8, 16, 16,  7,  5, 10, 15,  8,  4,
           8, 13, 14, 15,  5,  4, 12, 13,  5,  4, 15, 12,  9,  4,  2,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
   

In [9]:
# Make predictions
model.eval()  # Set model to evaluation mode
with torch.no_grad():
    outputs = model(**encoded_sequence)
    logits = outputs.logits

In [10]:
# Extract the predicted class index
predicted_class_idx = torch.argmax(logits, dim=1).item()
predicted_class_idx

17

In [11]:
# Decode the class index to original label
from sklearn.preprocessing import LabelEncoder
import joblib

# Load the saved LabelEncoder
le = joblib.load("label_encoder_class20.pkl")

predicted_class_label = le.inverse_transform([predicted_class_idx])[0]

In [12]:
#print("All labels:", le.classes_)

In [13]:
# Print the result
print(f"Predicted class for the input sequence: {predicted_class_label}")


Predicted class for the input sequence: SRP54_N
