In [1]:
#installs for mmseq
!pip install -q condacolab
import condacolab
condacolab.install()

✨🍰✨ Everything looks OK!


In [None]:
#mmseq
!conda install -c conda-forge -c bioconda mmseqs2

In [None]:
#installs
!pip install biopython
!pip install py3dmol
!pip install transformers
!pip install fair-esm
!pip install scikit-learn
!pip install torch-geometric

In [None]:
#imports
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import pickle
import torch
import esm
import numpy as np
import matplotlib.pyplot as plt
import random
import io
from google.colab import drive
from transformers import EsmModel, EsmTokenizer, EsmConfig, AutoTokenizer
from sklearn.metrics import roc_auc_score
from google.colab import drive
from Bio import SeqIO
from sklearn.model_selection import train_test_split
import os
from Bio.PDB import PDBParser
from torch_geometric.data import Data, DataLoader
from sklearn.preprocessing import LabelEncoder
from torch_geometric.nn import GATConv, global_max_pool
from torch.optim import Adam
import requests
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

#connect to drive
drive.mount('/content/drive')

In [7]:
#load in train df
df = pd.read_csv('/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/final_data.csv')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# **Initial Dataset Pre-Processing**

In [None]:
architecture_names = {
    (1, 10): "Mainly Alpha: Orthogonal Bundle",
    (1, 20): "Mainly Alpha: Up-down Bundle",
    (2, 30): "Mainly Beta: Roll",
    (2, 40): "Mainly Beta: Beta Barrel",
    (2, 60): "Mainly Beta: Sandwich",
    (3, 10): "Alpha Beta: Roll",
    (3, 20): "Alpha Beta: Alpha-Beta Barrel",
    (3, 30): "Alpha Beta: 2-Layer Sandwich",
    (3, 40): "Alpha Beta: 3-Layer(aba) Sandwich",
    (3, 90): "Alpha Beta: Alpha-Beta Complex"
}

In [None]:
def get_architecture_name(row):
    key = (row['class'], row['architecture'])
    return architecture_names.get(key, "Unknown")

#add names
df['architecture_domain'] = df.apply(get_architecture_name, axis=1)

# **Clustering For Homology**

In [8]:
#check for invalid amino acids

#define set {} of valid AAs
valid_AAs = {'A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'}

#define method to check a sequence for invalid characters
def contains_invalid_char(seq):
  unique_chars = set(seq) # set of all characters in the protein; unique_chars = {A, C} for protein="AAACCC"

  if unique_chars.issubset(valid_AAs):  #unique_chars = {A,C}, and {A,C} is a subset of valid_AAs
    return ''
  else: # e.g. unique_chars = {A,X}. {A,X} is not a subset of valid_AAs because X is not in valid_AAs
    return unique_chars.difference(valid_AAs) # e.g. {A,X} - valid_AAs = {X}

In [None]:
#apply our method contains_invalid_char to the sequence column
df['invalid_chars'] = df['sequences'].apply(contains_invalid_char)
#display rows where there's an invalid character
df[df['invalid_chars'].str.len()>0].sort_values(by='sequences')

In [10]:
#remove invalid rows if necessary
df = df[df['invalid_chars'].str.len()==0].reset_index(drop=True).drop(columns=['invalid_chars'])

In [12]:
#reset df
df.reset_index(drop=True, inplace=True)

In [15]:
#add an ID for clustering
df['id'] = [f'seq{i}' for i in range(len(df))]

In [17]:
#write the fasta file
with open('sequences.fasta', 'w') as f:
  for i in range(len(df)):
    id = df.loc[i,'id']
    seq = df.loc[i,'sequences']
    f.write(f'>{id}\n{seq}\n')

In [None]:
#clustering with mmseqs to solve homology issue
!mmseqs easy-cluster sequences.fasta clusterRes mmseqs_results --min-seq-id 0.2 -c 0.3 --cov-mode 0

In [None]:
#get clusters
fasta_sequences = SeqIO.parse(open('clusterRes_rep_seq.fasta'),'fasta')
id_list = []
seq_list = []

for fasta in fasta_sequences:
  id, sequence = fasta.id, str(fasta.seq)

  id_list.append(id)
  seq_list.append(sequence)

cluster_reps = pd.DataFrame(
    data = {
        'representative id': id_list,
        'sequence': seq_list
    }
)

print('Total clusters: {}'.format(len(cluster_reps)))
cluster_reps.head()

In [None]:
#make sure all sequences are clusterred
clusters = pd.read_csv(f'clusterRes_cluster.tsv',sep='\t',header=None)
print('Total cluster members: {}'.format(len(clusters)))


print('All sequences were clustered: {}'.format(len(clusters)==len(df)))

clusters = clusters.rename(columns={
    0: 'representative id',
    1: 'member id'
})
clusters.head()

In [None]:
#make clusters df
clusters = pd.merge(clusters,
                    df.rename(columns={'id': 'member id',
                                            'Length':'length',
                                            'sequences':'sequence'}),
                    on='member id',how='left')
clusters

In [None]:
#grouping by representative id creates one row per cluster
#counting the member ID column for each cluster gives you the size of that cluster
sizedict = pd.DataFrame(clusters.groupby('representative id').count()['member id']).to_dict()['member id']

#include the size of each cluster
clusters['cluster_size'] = clusters['representative id'].apply(lambda x: sizedict[x])
#sort by cluster size
clusters = clusters.sort_values(by='cluster_size',ascending=False).reset_index(drop=True)

clusters

In [35]:
#prepare data for the random cluster split
X = list(clusters['representative id'])
y = ['']*len(X) # there are no target values for clusters - this array will effectively be blank

#fix the random state
rs=78

#perform the split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=rs) # 80-20 train-test

#split the dataframe describing the clusters into train and test
train_clusters = clusters[clusters['representative id'].isin(X_train)].sort_values(by=['cluster_size'],ascending=False).reset_index(drop=True)
test_clusters = clusters[clusters['representative id'].isin(X_test)].sort_values(by=['cluster_size'],ascending=False).reset_index(drop=True)

In [None]:
#get the train cluster representatives
train_cluster_reps = list(train_clusters['representative id'])
#use the clusters database to get the sequences of each protein in each cluster
train_sequences = list(clusters.loc[clusters['representative id'].isin(train_cluster_reps)]['sequence'])
#use the clusters database to get the flo values of each protein in each cluster
train_targets = list(clusters.loc[clusters['representative id'].isin(train_cluster_reps)]['architecture'])
print('train sequences: ', len(train_sequences))

test_cluster_reps = list(test_clusters['representative id'])
test_sequences = list(clusters.loc[clusters['representative id'].isin(test_cluster_reps)]['sequence'])
test_targets = list(clusters.loc[clusters['representative id'].isin(test_cluster_reps)]['architecture'])
print('test sequences: ', len(test_sequences))

# **Generating Structure Data**

In [None]:
#unzip the PDB files
!unzip '/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/pdb_share.zip' -d /content/pdb_files

In [None]:
#pdb directory
pdb_dir = '/content/pdb_files/pdb_share'

#store structure data
data_list = []

parser = PDBParser(QUIET=True)

pdb_files = os.listdir(pdb_dir)

for pdb_file in tqdm(pdb_files, desc="Loading PDB files"):
    pdb_path = os.path.join(pdb_dir, pdb_file)
    if os.path.isfile(pdb_path):
        structure = parser.get_structure('protein', pdb_path)

        atom_coords = []
        atom_types = []
        edge_index = []

        for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        atom_coords.append(atom.coord.tolist())
                        atom_types.append(atom.element)

        num_atoms = len(atom_coords)
        edge_index = []

        if num_atoms == 0:
            continue

        for i in range(num_atoms):
            for j in range(i+1, num_atoms):
                distance = np.linalg.norm(np.array(atom_coords[i]) - np.array(atom_coords[j]))
                if distance < 5.0:  #using 5 Angstroms from literature
                    edge_index.append([i, j])
                    edge_index.append([j, i])

        x = torch.tensor(atom_coords, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

        identifier = os.path.splitext(pdb_file)[0]  #allows us to map back to cath_id

        data = Data(x=x, edge_index=edge_index, pdb_id=identifier)

        data_list.append(data)

In [None]:
#load in saved structure data
data_list = torch.load('/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/structure_data.pt')

In [73]:
#add structure data to df
structural_data_dict = {data.pdb_id: data for data in data_list}

df['structural_data'] = df['cath_id'].map(structural_data_dict)

# **Training GNN For Structure Embeddings**

In [74]:
label_encoder = LabelEncoder()
df['encoded_architecture'] = label_encoder.fit_transform(df['architecture'])

In [75]:
#dataset setup

class StructureDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        structure_data = self.df.iloc[idx]['structural_data']
        architecture = self.df.iloc[idx]['encoded_architecture']
        architecture = torch.tensor(architecture, dtype=torch.long)

        return {
            'Structure': structure_data,
            'Architecture': architecture
        }

In [77]:
#model (use classifier to train, only take embeddings from model)
class GATClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=8, dropout_rate=0.3):
        super(GATClassifier, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels // num_heads, heads=num_heads)
        self.conv2 = GATConv(hidden_channels, hidden_channels // num_heads, heads=num_heads)
        self.conv3 = GATConv(hidden_channels, hidden_channels // num_heads, heads=num_heads)

        self.dropout = nn.Dropout(p=dropout_rate)

        #fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_channels, 320),
            nn.ReLU(),
            self.dropout,
            nn.Linear(320, 128),
            nn.ReLU(),
            self.dropout,
            nn.Linear(128, out_channels)
        )

    def forward(self, data):
        #get graph data
        structure_data = data['Structure']
        x, edge_index, batch = structure_data.x, structure_data.edge_index, structure_data.batch

        #add skip connections
        x1 = F.relu(self.conv1(x, edge_index))
        x2 = F.relu(self.conv2(x1, edge_index)) + x1
        x3 = self.conv3(x2, edge_index) + x2

        #max pooling
        x = global_max_pool(x3, batch)

        #pass through fc layers
        out = self.fc_layers(x)

        return out, x

In [None]:
#initialize model, optimizer, loss function
structure_model = GATClassifier(in_channels=3, hidden_channels=640, out_channels=10, num_heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
dataset = StructureDataset(df)

#split data (only split for embeddings)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

#dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
num_epochs = 10
best_accuracy = 0.0  #keep track of accuracy
best_model_path = "/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/structure_model.pt"

#train
for epoch in range(num_epochs):
    structure_model.train()
    total_loss = 0
    print(f"Epoch {epoch + 1}/{num_epochs}:")

    for data in tqdm(train_loader, desc="Training Batches"):
        optimizer.zero_grad()

        structure_data = data['Structure'].to(device)
        architecture = data['Architecture'].view(-1).to(device)

        out, _ = structure_model({'Structure': structure_data})

        loss = criterion(out, architecture)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)

    #eval
    structure_model.eval()
    correct = 0

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing Batches")
            structure_data = data['Structure'].to(device)
            architecture = data['Architecture'].view(-1).to(device)

            out, _ = structure_model({'Structure': structure_data})

            pred = out.argmax(dim=1)
            correct += (pred == architecture).sum().item()

    accuracy = correct / len(test_loader.dataset)

    #results
    print(f"Epoch {epoch + 1} - Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.4f}")

    #save only best models
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(structure_model.state_dict(), best_model_path)
        print(f"Best model saved with accuracy: {best_accuracy:.4f}")

In [None]:
#get structure embeddings
structure_model.eval()

structure_model = structure_model.to(device)

structure_embeddings = {}

with torch.no_grad():
    for idx, row in tqdm(df.iterrows()):
        #get data
        structure_data = row['structural_data']
        sequence = row['sequences']

        structure_data = structure_data.to(device)

        input_data = {'Structure': structure_data}

        #get embeddings
        _, embedding = structure_model(input_data)

        structure_embeddings[sequence] = embedding

In [46]:
#save embeddings as pkl file
with open('/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/structure_embeddings.pkl', 'wb') as file:
  pickle.dump(structure_embeddings, file)

# **Generate ESM Sequence Embeddings**

In [None]:
#load esm model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

In [None]:
#run embeddings
sequence_embeddings = {}

for sequence in tqdm(df['sequences']):

    #standard esm calculations
    batch_labels, batch_strs, batch_tokens = batch_converter([("", sequence)])
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33].cpu()
    del batch_tokens

    avg_embedding = token_representations[0, 1 : batch_lens[0] - 1].mean(0)
    sequence_embeddings[sequence] = avg_embedding

100%|██████████| 6263/6263 [1:05:24<00:00,  1.60it/s]


In [None]:
#save embeddings as pkl file
with open('/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/cath_proteins_embeddings.pkl', 'wb') as file:
  pickle.dump(sequence_embeddings, file)

# **Handling Missing Amino Acids**

In [None]:
#find indices of missing amino acids based on pdb
def extract_residue_indices(pdb_file):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('', pdb_file)
    indices = set()
    for model in structure:
        for chain in model:
            for residue in chain.get_residues():
                indices.add(residue.id[1])
    print(f"Extracted indices from {pdb_file}: {sorted(indices)}")
    return indices

#check for gaps based on cath indices and pdb residue indices
def check_for_gaps(cath_indices, pdb_indices):
    missing_indices = []
    for start, end in cath_indices:
        missing = [index for index in range(start, end + 1) if index not in pdb_indices]
        if missing:
            print(f"Gap detected for range {start}-{end} with PDB indices: {sorted(pdb_indices)}")
            print(f"Missing indices: {missing}")
            missing_indices.extend(missing)
    return missing_indices

#directory of pdbs
pdb_dir = '/content/pdb_files/pdb_share'

results = []

for index, row in df.iterrows():
    pdb_id = row['pdb_id']
    cath_id = row['cath_id']

    #find cath indices
    cath_indices = row['cath_indices']

    if pd.isna(cath_indices):
        continue

    if isinstance(cath_indices, str):
        cath_indices = eval(cath_indices)
    elif not isinstance(cath_indices, list):
        raise ValueError(f"Unsupported format for CATH indices at row {index}: {cath_indices}")

    pdb_file = os.path.join(pdb_dir, cath_id)
    if os.path.exists(pdb_file):
        print(f"Processing PDB file: {pdb_file} for row {index}")
        pdb_indices = extract_residue_indices(pdb_file)
        missing_indices = check_for_gaps(cath_indices, pdb_indices)
        #add indices that are missing
        results.append((pdb_id, cath_id, missing_indices))

results_df = pd.DataFrame(results, columns=['pdb_id', 'cath_id', 'missing_indices'])

In [None]:
#go through pdb id/chain/subunit and find original sequence from either uniprot or rscb

def get_uniprot_id(pdb_id, polymer, chain_id=None):
    polymer_url = f"https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{polymer}"
    try:
        polymer_response = requests.get(polymer_url)
        polymer_response.raise_for_status()
        polymer_data = polymer_response.json()
        if chain_id:
            if chain_id in polymer_data['rcsb_polymer_entity_container_identifiers']['auth_asym_ids']:
                uniprot_id = polymer_data['rcsb_polymer_entity_container_identifiers']['reference_sequence_identifiers'][0]['database_accession']
                return uniprot_id
        else:
            uniprot_id = polymer_data['rcsb_polymer_entity_container_identifiers']['reference_sequence_identifiers'][0]['database_accession']
            return uniprot_id
    except requests.exceptions.RequestException as e:
        print(f"Error fetching UniProt ID for PDB ID {pdb_id}, polymer {polymer}: {e}")
        return None
    except (KeyError, IndexError):
        print(f"Error extracting UniProt ID from response for PDB ID {pdb_id}, polymer {polymer}")
        return None

def get_sequence_from_uniprot(uniprot_id):
    base_url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
    try:
        response = requests.get(base_url)
        response.raise_for_status()
        fasta_data = response.text
        sequence = ''.join(fasta_data.split('\n')[1:])
        return sequence
    except requests.exceptions.RequestException as e:
        print(f"Error fetching sequence for UniProt ID {uniprot_id}: {e}")
        return None

def get_sequence_from_pdb(pdb_id, polymer):
    polymer_url = f"https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{polymer}"
    try:
        response = requests.get(polymer_url)
        response.raise_for_status()
        data = response.json()
        sequence = data['entity_poly']['pdbx_seq_one_letter_code_can']
        return sequence
    except requests.exceptions.RequestException as e:
        print(f"Error fetching sequence from PDB for PDB ID {pdb_id}, polymer {polymer}: {e}")
        return None
    except KeyError:
        print(f"Error extracting sequence from PDB response for PDB ID {pdb_id}, polymer {polymer}")
        return None

def get_sequence(pdb_id, identifier):
    base_url = "https://data.rcsb.org/rest/v1/core/entry"
    entry_url = f"{base_url}/{pdb_id}"
    try:
        response = requests.get(entry_url)
        response.raise_for_status()
        data = response.json()
    except requests.exceptions.RequestException as e:
        print(f"Error fetching entry details for PDB ID {pdb_id}: {e}")
        return None
    except ValueError:
        print(f"Error decoding JSON response for PDB ID {pdb_id}")
        print(f"Response content: {response.text}")
        return None

    chains = []
    subunits = []
    try:
        for polymer in data['rcsb_entry_container_identifiers']['polymer_entity_ids']:
            polymer_url = f"https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{polymer}"
            polymer_response = requests.get(polymer_url)
            polymer_response.raise_for_status()
            polymer_data = polymer_response.json()
            chains.extend(polymer_data['rcsb_polymer_entity_container_identifiers']['auth_asym_ids'])
            subunits.append(str(polymer))
    except requests.exceptions.RequestException as e:
        print(f"Error fetching polymer entity details for PDB ID {pdb_id}: {e}")
        return None
    except ValueError:
        print(f"Error decoding JSON response for polymer entity of PDB ID {pdb_id}")
        return None

    if identifier in chains:
        for polymer in data['rcsb_entry_container_identifiers']['polymer_entity_ids']:
            uniprot_id = get_uniprot_id(pdb_id, polymer, chain_id=identifier)
            if uniprot_id:
                sequence = get_sequence_from_uniprot(uniprot_id)
                if sequence:
                    return sequence
            sequence = get_sequence_from_pdb(pdb_id, polymer)
            if sequence:
                return sequence

    elif identifier in subunits:
        for polymer in data['rcsb_entry_container_identifiers']['polymer_entity_ids']:
            if identifier == str(polymer):
                uniprot_id = get_uniprot_id(pdb_id, polymer)
                if uniprot_id:
                    sequence = get_sequence_from_uniprot(uniprot_id)
                    if sequence:
                        return sequence
                sequence = get_sequence_from_pdb(pdb_id, polymer)
                if sequence:
                    return sequence
    else:
        print(f"Identifier {identifier} not found in chains or subunits.")
        return None


In [None]:
#get full sequences
for index, row in tqdm(df_filtered.iterrows(), total=df_filtered.shape[0]):
    pdb_id = row['pdb_id']
    identifier = row['cath_id'][4]
    sequence = get_sequence(pdb_id, identifier)
    df_filtered.at[index, 'full_sequence'] = sequence

  5%|▌         | 141/2695 [05:49<1:41:07,  2.38s/it]

Error extracting UniProt ID from response for PDB ID 3hpa, polymer 1


  6%|▋         | 174/2695 [07:08<1:36:19,  2.29s/it]

Error extracting UniProt ID from response for PDB ID 3fgx, polymer 1


  7%|▋         | 185/2695 [07:34<1:36:12,  2.30s/it]

Error extracting UniProt ID from response for PDB ID 3kwl, polymer 1


 19%|█▊        | 502/2695 [20:48<1:24:17,  2.31s/it]

Error extracting UniProt ID from response for PDB ID 2o55, polymer 1


 19%|█▉        | 518/2695 [21:25<1:31:42,  2.53s/it]

Error fetching sequence for UniProt ID 70834870: 400 Client Error: Bad Request for url: https://rest.uniprot.org/uniprotkb/70834870.fasta


 19%|█▉        | 521/2695 [21:33<1:29:11,  2.46s/it]

Error extracting UniProt ID from response for PDB ID 2o57, polymer 1


 22%|██▏       | 590/2695 [24:19<1:24:28,  2.41s/it]

Error extracting UniProt ID from response for PDB ID 3dip, polymer 1


 22%|██▏       | 601/2695 [24:45<1:23:23,  2.39s/it]

Error extracting UniProt ID from response for PDB ID 5vis, polymer 1


 25%|██▍       | 670/2695 [27:35<1:32:22,  2.74s/it]

Error extracting UniProt ID from response for PDB ID 2r3s, polymer 1


 26%|██▌       | 699/2695 [28:45<1:18:08,  2.35s/it]

Error extracting UniProt ID from response for PDB ID 3m6j, polymer 1


 30%|██▉       | 802/2695 [32:53<1:13:31,  2.33s/it]

Error extracting UniProt ID from response for PDB ID 3eo6, polymer 1


 33%|███▎      | 882/2695 [36:14<1:14:06,  2.45s/it]

Error extracting UniProt ID from response for PDB ID 3vjf, polymer 1


 34%|███▎      | 905/2695 [37:09<1:14:36,  2.50s/it]

Error extracting UniProt ID from response for PDB ID 3kwl, polymer 1


 36%|███▋      | 978/2695 [40:05<1:06:30,  2.32s/it]

Error extracting UniProt ID from response for PDB ID 4eog, polymer 1


 37%|███▋      | 990/2695 [40:35<1:09:40,  2.45s/it]

Error extracting UniProt ID from response for PDB ID 2wb7, polymer 1


 38%|███▊      | 1023/2695 [41:55<1:05:07,  2.34s/it]

Error extracting UniProt ID from response for PDB ID 2o5n, polymer 1


 45%|████▌     | 1214/2695 [50:05<1:00:11,  2.44s/it]

Error extracting UniProt ID from response for PDB ID 3lye, polymer 1


 50%|████▉     | 1345/2695 [55:26<51:47,  2.30s/it]

Error extracting UniProt ID from response for PDB ID 3e0z, polymer 1


 59%|█████▉    | 1585/2695 [1:05:09<43:22,  2.34s/it]

Error extracting UniProt ID from response for PDB ID 3kwl, polymer 1


 67%|██████▋   | 1808/2695 [1:14:40<34:36,  2.34s/it]

Error extracting UniProt ID from response for PDB ID 3jrt, polymer 1


 72%|███████▏  | 1934/2695 [1:19:44<29:24,  2.32s/it]

Error extracting UniProt ID from response for PDB ID 3ako, polymer 1


 86%|████████▌ | 2316/2695 [1:35:33<16:06,  2.55s/it]

Error extracting UniProt ID from response for PDB ID 1qys, polymer 1


 87%|████████▋ | 2348/2695 [1:36:50<13:33,  2.35s/it]

Error extracting UniProt ID from response for PDB ID 2j7q, polymer 1


 93%|█████████▎| 2503/2695 [1:43:06<07:55,  2.47s/it]

Error extracting UniProt ID from response for PDB ID 3s9x, polymer 1


 95%|█████████▌| 2563/2695 [1:45:38<05:18,  2.41s/it]

Error extracting UniProt ID from response for PDB ID 3qvq, polymer 1


100%|█████████▉| 2682/2695 [1:50:33<00:33,  2.60s/it]

Identifier A not found in chains or subunits.


100%|██████████| 2695/2695 [1:51:04<00:00,  2.47s/it]


In [None]:
#get real substring from full sequence based on cath indices and use this as sequence instead

def extract_substring(row):
    start, end = row['cath_indices']
    return row['full_sequence'][start-1:end]

# **Training Sequence-Based Model**

In [None]:
#load in sequence embeddings
train_embeddings_path = '/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/cath_proteins_embeddings.pkl'

with open(train_embeddings_path, 'rb') as file:
    sequence_embeddings = pickle.load(file)

In [None]:
label_encoder = LabelEncoder()

#fit encoder on the architecture labels
df['encoded_architecture'] = label_encoder.fit_transform(df['architecture'])

In [None]:
#dataset set up

class SeqDataset(torch.utils.data.Dataset):
   def __init__(self, df, sequence_embeddings):
    super().__init__()
    self.sequence_embeddings = sequence_embeddings
    self.df = df

   def __len__(self):
    return len(self.df)

   def __getitem__(self, index):
    prot, architecture = self.df.loc[index][['sequences', 'encoded_architecture']]
    prot_embedding = self.sequence_embeddings[prot]
    architecture = torch.tensor(architecture, dtype=torch.long)


    return_dict = {
        "Protein": prot,
        "Sequence Input": prot_embedding,
        "Architecture": architecture
    }

    return return_dict

In [None]:
#self-attention sequence-based model

class SeqClassifier(nn.Module):
    def __init__(self, num_classes, embedding_dim=1280):
        super(SeqClassifier, self).__init__()
        #self-attention layer
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True)

        #fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(embedding_dim * 2, 1280),
            nn.ReLU(),
            nn.Linear(1280, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
            nn.ReLU(),
            nn.Linear(320, num_classes)
        )

    def forward(self, protein_embedding):
        #reshape protein embedding
        protein_embedding = protein_embedding.unsqueeze(1)

        #apply self-attention
        attention_output, _ = self.attention(protein_embedding, protein_embedding, protein_embedding)

        #concat original + attention output
        combined = torch.cat((protein_embedding.squeeze(1), attention_output.squeeze(1)), dim=1)

        #pass through layers
        output = self.fc_layers(combined)

        return output

In [None]:
#setup
from torch.utils.data import DataLoader, random_split
import torch.optim as optim

dataset = SeqDataset(df, sequence_embeddings)

#match train and val sequences from earlier clustering
train_df = df[df['sequences'].isin(train_sequences)]
val_df = df[df['sequences'].isin(test_sequences)]

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)


train_dataset = SeqDataset(train_df, sequence_embeddings)
val_dataset = SeqDataset(val_df, sequence_embeddings)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


#model, loss function, and optimizer
seq_model = SeqClassifier(num_classes=10)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seq_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(seq_model.parameters(), lr=0.001)

In [None]:
#train and eval loop

best_val_loss = float('inf')
num_epochs = 10
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    #train
    seq_model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        optimizer.zero_grad()

        protein_inputs = batch["Sequence Input"].to(device)
        labels = batch["Architecture"].to(device)

        outputs = seq_model(protein_inputs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}')

    #val
    seq_model.eval()
    val_predictions = []
    val_labels = []
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}"):
            protein_inputs = batch["Sequence Input"].to(device)
            labels = batch["Architecture"].to(device)

            outputs = seq_model(protein_inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            val_predictions.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')

    #precision, recall, f1, and accuracy
    val_predictions = np.array(val_predictions)
    val_labels = np.array(val_labels)

    precision = precision_score(val_labels, val_predictions, average='macro')
    recall = recall_score(val_labels, val_predictions, average='macro')
    f1 = f1_score(val_labels, val_predictions, average='macro')
    accuracy = accuracy_score(val_labels, val_predictions)

    print(f'Epoch [{epoch+1}/{num_epochs}], Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, Accuracy: {accuracy:.4f}')

    #save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(seq_model.state_dict(), '/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/seq.pt')
        print(f'Model saved at epoch {epoch+1} with validation loss {val_loss:.4f}')

#plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.show()

# **Training Seq+Struct Model**

In [None]:
structure_embeddings_path = '/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/structure_embeddings.pkl'

with open(structure_embeddings_path, 'rb') as file:
    structure_embeddings = pickle.load(file)

In [61]:
#dataset class setup

class SeqStructDataset(torch.utils.data.Dataset):
   def __init__(self, df, sequence_embeddings, structure_embeddings):
    super().__init__()
    self.sequence_embeddings = sequence_embeddings
    self.structure_embeddings = structure_embeddings
    self.df = df

   def __len__(self):
    return len(self.df)

   def __getitem__(self, index):
    prot, architecture = self.df.loc[index][['sequences', 'encoded_architecture']]
    prot_embedding = self.sequence_embeddings[prot]
    structure_embedding = self.structure_embeddings[prot].squeeze(0)
    architecture = torch.tensor(architecture, dtype=torch.long)


    return_dict = {
        "Protein": prot,
        "Sequence Input": prot_embedding,
        "Structure Input": structure_embedding,
        "Architecture": architecture
    }

    return return_dict

In [62]:
#sequence + structure attention model

class SeqStructureClassifier(nn.Module):
    def __init__(self, num_classes, sequence_embedding_dim=1280, structure_embedding_dim=640):
        super(SeqStructureClassifier, self).__init__()

        #self-attention layers for sequence and structure
        self.sequence_attention = nn.MultiheadAttention(embed_dim=sequence_embedding_dim, num_heads=4, batch_first=True)
        self.structure_attention = nn.MultiheadAttention(embed_dim=structure_embedding_dim, num_heads=4, batch_first=True)

        #fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear((sequence_embedding_dim + structure_embedding_dim) * 2, 1280),
            nn.ReLU(),
            nn.Linear(1280, 640),
            nn.ReLU(),
            nn.Linear(640, 320),
            nn.ReLU(),
            nn.Linear(320, num_classes)
        )

    def forward(self, sequence_embedding, structure_embedding):

        #reshape embeddings
        sequence_embedding = sequence_embedding.unsqueeze(1)
        structure_embedding = structure_embedding.unsqueeze(1)

        #apply self-attention to sequence and structure embeddings
        sequence_attention_output, _ = self.sequence_attention(sequence_embedding, sequence_embedding, sequence_embedding)
        structure_attention_output, _ = self.structure_attention(structure_embedding, structure_embedding, structure_embedding)

        #concat original + attention embeddings
        sequence_combined = torch.cat((sequence_embedding.squeeze(1), sequence_attention_output.squeeze(1)), dim=1)
        structure_combined = torch.cat((structure_embedding.squeeze(1), structure_attention_output.squeeze(1)), dim=1)

        #concatenate sequence and structure embeddings
        combined = torch.cat((sequence_combined, structure_combined), dim=1)

        #pass through fc layers
        output = self.fc_layers(combined)

        return output

In [69]:
#setup

#match train and val sequences from earlier clustering
train_df = df[df['sequences'].isin(train_sequences)]
val_df = df[df['sequences'].isin(test_sequences)]

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

train_dataset = SeqStructDataset(train_df, sequence_embeddings, structure_embeddings)
val_dataset = SeqStructDataset(val_df, sequence_embeddings, structure_embeddings)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


#model, loss function, and optimizer
combined_model = SeqStructureClassifier(num_classes=10)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
combined_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.0001)

In [None]:
#train + eval

best_val_loss = float('inf')
num_epochs = 10
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    #train
    combined_model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        optimizer.zero_grad()

        sequence_inputs = batch["Sequence Input"].to(device)
        structure_inputs = batch["Structure Input"].to(device)
        labels = batch["Architecture"].to(device)

        outputs = combined_model(sequence_inputs, structure_inputs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}')

    #val
    combined_model.eval()
    val_predictions = []
    val_labels = []
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}"):
            sequence_inputs = batch["Sequence Input"].to(device)
            structure_inputs = batch["Structure Input"].to(device)
            labels = batch["Architecture"].to(device)

            outputs = combined_model(sequence_inputs, structure_inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            val_predictions.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')

    #precision, recall, f1, and accuracy
    val_predictions = np.array(val_predictions)
    val_labels = np.array(val_labels)

    precision = precision_score(val_labels, val_predictions, average='macro')
    recall = recall_score(val_labels, val_predictions, average='macro')
    f1 = f1_score(val_labels, val_predictions, average='macro')
    accuracy = accuracy_score(val_labels, val_predictions)

    print(f'Epoch [{epoch+1}/{num_epochs}], Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, Accuracy: {accuracy:.4f}')

    #save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(combined_model.state_dict(), '/content/drive/MyDrive/Challenge (Rishab)/ml_hands_on_challenge/seq_struct.pt')
        print(f'Model saved at epoch {epoch+1} with validation loss {val_loss:.4f}')

#plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.show()
