In [None]:
#installs
!pip install biopython
!pip install py3dmol
!pip install transformers
!pip install fair-esm
!pip install scikit-learn

In [None]:
#imports
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import pickle
import torch
import esm
import numpy as np
import matplotlib.pyplot as plt
import random
import io
from google.colab import drive
from transformers import EsmModel, EsmTokenizer, EsmConfig, AutoTokenizer
from sklearn.metrics import roc_auc_score
from google.colab import drive

#connect to drive
drive.mount('/content/drive')

In [None]:
#load ESM-2 model
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()  # disables dropout for deterministic results

In [None]:
#load in test data (at least need columns named 'sequences' and 'architecture')
df = pd.read_csv('your test data path')

In [None]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()

#fit encoder on the architecture labels
df['encoded_architecture'] = label_encoder.fit_transform(df['architecture'])

In [None]:
#dataset class setup

class ProtDataset(torch.utils.data.Dataset):
    def __init__(self, df, esm_model, batch_converter, alphabet):
        super().__init__()
        self.df = df
        self.esm_model = esm_model
        self.batch_converter = batch_converter
        self.alphabet = alphabet

    def __len__(self):
        return len(self.df)

    def _compute_embedding(self, sequence):
      #typical esm embedding calculations
        batch_labels, batch_strs, batch_tokens = self.batch_converter([("", sequence)])
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)

        with torch.no_grad():
            results = self.esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
        token_representations = results["representations"][33].cpu()

        avg_embedding = token_representations[0, 1:batch_lens[0] - 1].mean(0)
        return avg_embedding

    def __getitem__(self, index):
        prot, architecture = self.df.loc[index][['sequences', 'encoded_architecture']]

        prot_embedding = self._compute_embedding(prot)
        architecture = torch.tensor(architecture, dtype=torch.long)

        return_dict = {
            "Protein": prot,
            "Protein Input": prot_embedding,
            "Architecture": architecture
        }

        return return_dict

In [None]:
#setup
test_dataset = ProtDataset(df, esm_model, batch_converter, alphabet)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
#self-attention sequence-based model

class SeqClassifier(nn.Module):
    def __init__(self, num_classes, embedding_dim=1280):
        super(SeqClassifier, self).__init__()
        #self-attention layer
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True)

        #fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(embedding_dim * 2, 1280),
            nn.ReLU(),
            nn.Linear(1280, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
            nn.ReLU(),
            nn.Linear(320, num_classes)
        )

    def forward(self, protein_embedding):
        #reshape protein embedding
        protein_embedding = protein_embedding.unsqueeze(1)

        #apply self-attention
        attention_output, _ = self.attention(protein_embedding, protein_embedding, protein_embedding)

        #concat original + attention output
        combined = torch.cat((protein_embedding.squeeze(1), attention_output.squeeze(1)), dim=1)

        #pass through layers
        output = self.fc_layers(combined)

        return output

In [None]:
#initalize model
model = SeqClassifier(num_classes=10)

#best model (you may need to change model path depending on how you are running the code)
model_path = 'checkpoints/seq.pt'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

In [None]:
import torch
from sklearn.metrics import precision_score, recall_score, accuracy_score
from tqdm import tqdm

#store predictions
all_predictions = []
all_labels = []
all_sequences = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        sequence_inputs = batch["Protein Input"].to(device)
        labels = batch["Architecture"].to(device)
        sequences = batch["Protein"]

        outputs = model(sequence_inputs)
        predictions = torch.argmax(outputs, dim=1)

        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_sequences.extend(sequences)

#decode back
decoded_predictions = label_encoder.inverse_transform(all_predictions)
decoded_labels = label_encoder.inverse_transform(all_labels)

#metrics
accuracy = accuracy_score(decoded_labels, decoded_predictions)
precision = precision_score(decoded_labels, decoded_predictions, average='weighted')
recall = recall_score(decoded_labels, decoded_predictions, average='weighted')


for seq, pred, true_label in zip(all_sequences, decoded_predictions, decoded_labels):
    print(f"Sequence: {seq}")
    print(f"Predicted Class: {pred}")
    print(f"True Class: {true_label}\n")

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

Testing: 100%|██████████| 3/3 [00:53<00:00, 17.67s/it]

Accuracy: 0.8077
Precision: 0.8191
Recall: 0.8077
Sequence: IPFTIKLKTCLKMCIQRLRYAQEKQQAIAKQSRRQVAQLLLTNKEQKAHYRVETLIHDDIHIELLEILELYCELLLARVQVINDISTEEQLVKEHMDDGINEAIRSLIYAILFVDEVKELSQLKDLMAWKINVEFVNGVIADHIDVPEKIIKKCSPSVPKEELVDLYLKEIAKTYDVPYSKLENSL
Predicted Class: 20
True Class: 20

Sequence: TMFYGKGDVYVFRTYANPLKGLKQIPESNFTEKHNTIFGMNAKVALKGEQLLTSFTEGDNSLVVATDSMKNFIQRHAASYEGATLEGFLQYVCEAFLAKYSHLDAVRLEAKEYAFDDIQVGTDKGVVTSDLVFRKSRNEYATATVEVARTASGTEVVEQASGIADIQLIKVSGSSFYGYIIDEYTTLAEATDRPLYIFLNIGWAYENQDDAKGDNPANYVAAEQVRDIAASVFHTLDNKSIQHLIYHIGLTILDRFPQLTEVNFGTNNRTWDTVVEGTDGFKGAVFTEPRPPFGFQGFSVHQEDLAREKASANSEYVAL
Predicted Class: 10
True Class: 10

Sequence: PERFIRKRWAQIAKAMDAKKFGVIVSIKKGQLRLAEAKRIVKLLKKHGREARLIVMNDVNYHKLEGFPFEAYVVVACPRVPLDDYGAWRKPVLTPKEVEILL
Predicted Class: 40
True Class: 40

Sequence: SIDVPVQTLTVEAGNGLQLQLTKKNNDLVIVRFFGSVSNIQKGWNMSGTWVDRPFRPAAVQSLVGHFAGRDTSFHIDINPNGSITWWGANIDKTPIATRGNGSYFIK
Predicted Class: 60
True Class: 60

Sequence: QKIHEEFGAVFDQLIAEQTGEKKEVADLSMGDLLLHTS


