In [None]:
#installs
!pip install biopython
!pip install py3dmol
!pip install transformers
!pip install fair-esm
!pip install scikit-learn
!pip install torch-geometric

In [None]:
#imports
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import pickle
import torch
import esm
import numpy as np
import matplotlib.pyplot as plt
import random
import io
from google.colab import drive
from transformers import EsmModel, EsmTokenizer, EsmConfig, AutoTokenizer
from sklearn.metrics import roc_auc_score
from google.colab import drive
import os
from Bio.PDB import PDBParser
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
import torch
from torch_geometric.nn import GATConv, global_max_pool
from sklearn.metrics import precision_score, recall_score, accuracy_score

#connect to drive
drive.mount('/content/drive')

In [5]:
#load in test data (at least need columns named 'sequences' and 'architecture')
df = pd.read_csv('your test dataset path')

In [None]:
#load esm model
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()  # disables dropout for deterministic results

In [None]:
#run embeddings
sequence_embeddings = {}

for sequence in tqdm(df['sequences']):

    #standard esm calculations
    batch_labels, batch_strs, batch_tokens = batch_converter([("", sequence)])
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33].cpu()
    del batch_tokens

    avg_embedding = token_representations[0, 1 : batch_lens[0] - 1].mean(0)
    sequence_embeddings[sequence] = avg_embedding

In [11]:
label_encoder = LabelEncoder()

#fit encoder on the architecture labels
df['encoded_architecture'] = label_encoder.fit_transform(df['architecture'])

In [None]:
#unzip pdb files
!unzip 'path to your pdb files zipped' -d /content/pdb_files #you can change output path to your preference as well

In [None]:
#run this to get coordinate and edges data

pdb_dir = 'path to folder where all pdb files are stored' #assuming same format as what was given for training

data_list = []

parser = PDBParser(QUIET=True)

pdb_files = os.listdir(pdb_dir)

for pdb_file in tqdm(pdb_files, desc="Loading PDB files"):
    pdb_path = os.path.join(pdb_dir, pdb_file)
    if os.path.isfile(pdb_path):
        structure = parser.get_structure('protein', pdb_path)

        atom_coords = []
        atom_types = []
        edge_index = []

        for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        atom_coords.append(atom.coord.tolist())
                        atom_types.append(atom.element)

        num_atoms = len(atom_coords)
        edge_index = []

        if num_atoms == 0:
            continue

        for i in range(num_atoms):
            for j in range(i+1, num_atoms):
                distance = np.linalg.norm(np.array(atom_coords[i]) - np.array(atom_coords[j]))
                if distance < 5.0: #using 5 Angstroms from literature
                    edge_index.append([i, j])
                    edge_index.append([j, i])

        x = torch.tensor(atom_coords, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

        identifier = os.path.splitext(pdb_file)[0] #allows us to map back to cath_id

        data = Data(x=x, edge_index=edge_index, pdb_id=identifier)

        data_list.append(data)

In [15]:
structural_data_dict = {data.pdb_id: data for data in data_list}

#store the structural Data
df['structural_data'] = df['cath_id'].map(structural_data_dict)

In [16]:
#model to get structure embeddings (use classifier to train, only take embeddings from model)
class GATClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=8, dropout_rate=0.3):
        super(GATClassifier, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels // num_heads, heads=num_heads)
        self.conv2 = GATConv(hidden_channels, hidden_channels // num_heads, heads=num_heads)
        self.conv3 = GATConv(hidden_channels, hidden_channels // num_heads, heads=num_heads)

        self.dropout = nn.Dropout(p=dropout_rate)

        #fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_channels, 320),
            nn.ReLU(),
            self.dropout,
            nn.Linear(320, 128),
            nn.ReLU(),
            self.dropout,
            nn.Linear(128, out_channels)
        )

    def forward(self, data):
        #get graph data
        structure_data = data['Structure']
        x, edge_index, batch = structure_data.x, structure_data.edge_index, structure_data.batch

        #add skip connections
        x1 = F.relu(self.conv1(x, edge_index))
        x2 = F.relu(self.conv2(x1, edge_index)) + x1
        x3 = self.conv3(x2, edge_index) + x2

        #max pooling
        x = global_max_pool(x3, batch)

        #pass through fc layers
        out = self.fc_layers(x)

        return out, x

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
#initalize model for embeddings
struct_model = GATClassifier(in_channels=3, hidden_channels=640, out_channels=10, num_heads=8).to(device)

In [None]:
#load model path (you may need to change model path depending on how you are running the code)
struct_path = 'checkpoints/structure_model.pt'
struct_model.load_state_dict(torch.load(struct_path, map_location=device))

In [None]:
#store structure embeddings (similar to what was done w sequence)

struct_model.eval()

struct_model = struct_model.to(device)

structure_embeddings = {}

with torch.no_grad():
    for idx, row in tqdm(df.iterrows()):
        #get structure and sequence data
        structure_data = row['structural_data']
        sequence = row['sequences']

        structure_data = structure_data.to(device)

        input_data = {'Structure': structure_data}

        #get embeddings only
        _, embedding = struct_model(input_data)

        structure_embeddings[sequence] = embedding

In [42]:
#dataset class setup

class SeqStructDataset(torch.utils.data.Dataset):
   def __init__(self, df, sequence_embeddings, structure_embeddings):
    super().__init__()
    self.sequence_embeddings = sequence_embeddings
    self.structure_embeddings = structure_embeddings
    self.df = df

   def __len__(self):
    return len(self.df)

   def __getitem__(self, index):
    prot, architecture = self.df.loc[index][['sequences', 'encoded_architecture']]
    prot_embedding = self.sequence_embeddings[prot]
    structure_embedding = self.structure_embeddings[prot].squeeze(0)
    architecture = torch.tensor(architecture, dtype=torch.long)


    return_dict = {
        "Protein": prot,
        "Sequence Input": prot_embedding,
        "Structure Input": structure_embedding,
        "Architecture": architecture
    }

    return return_dict

In [43]:
#sequence + structure attention model

class SeqStructureClassifier(nn.Module):
    def __init__(self, num_classes, sequence_embedding_dim=1280, structure_embedding_dim=640):
        super(SeqStructureClassifier, self).__init__()

        #self-attention layers for sequence and structure
        self.sequence_attention = nn.MultiheadAttention(embed_dim=sequence_embedding_dim, num_heads=4, batch_first=True)
        self.structure_attention = nn.MultiheadAttention(embed_dim=structure_embedding_dim, num_heads=4, batch_first=True)

        #fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear((sequence_embedding_dim + structure_embedding_dim) * 2, 1280),
            nn.ReLU(),
            nn.Linear(1280, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
            nn.ReLU(),
            nn.Linear(320, num_classes)
        )

    def forward(self, sequence_embedding, structure_embedding):

        #reshape embeddings
        sequence_embedding = sequence_embedding.unsqueeze(1)
        structure_embedding = structure_embedding.unsqueeze(1)

        #apply self-attention to sequence and structure embeddings
        sequence_attention_output, _ = self.sequence_attention(sequence_embedding, sequence_embedding, sequence_embedding)
        structure_attention_output, _ = self.structure_attention(structure_embedding, structure_embedding, structure_embedding)

        #concat original + attention embeddings
        sequence_combined = torch.cat((sequence_embedding.squeeze(1), sequence_attention_output.squeeze(1)), dim=1)
        structure_combined = torch.cat((structure_embedding.squeeze(1), structure_attention_output.squeeze(1)), dim=1)

        #concatenate sequence and structure embeddings
        combined = torch.cat((sequence_combined, structure_combined), dim=1)

        #pass through fc layers
        output = self.fc_layers(combined)

        return output

In [44]:
#setup
test_dataset = SeqStructDataset(df, sequence_embeddings, structure_embeddings)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
#initalize model
final_model = SeqStructureClassifier(num_classes=10)

#best model ((you may need to change model path depending on how you are running the code))
model_path = 'checkpoints/seq_struct.pt'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

final_model.load_state_dict(torch.load(model_path, map_location=device))
final_model.to(device)
final_model.eval()

In [None]:
#predict on test set

all_predictions = []
all_labels = []
all_sequences = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        sequence_inputs = batch["Sequence Input"].to(device)
        structure_inputs = batch["Structure Input"].to(device)
        labels = batch["Architecture"].to(device)
        sequences = batch["Protein"]

        #get predictions
        outputs = final_model(sequence_inputs, structure_inputs)
        predictions = torch.argmax(outputs, dim=1)

        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_sequences.extend(sequences)

#decode predictions
decoded_predictions = label_encoder.inverse_transform(all_predictions)
decoded_labels = label_encoder.inverse_transform(all_labels)

#metrics
accuracy = accuracy_score(decoded_labels, decoded_predictions)
precision = precision_score(decoded_labels, decoded_predictions, average='weighted')
recall = recall_score(decoded_labels, decoded_predictions, average='weighted')

for seq, pred, true_label in zip(all_sequences, decoded_predictions, decoded_labels):
    print(f"Sequence: {seq}")
    print(f"Predicted Class: {pred}")
    print(f"True Class: {true_label}\n")

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")