In [1]:
import torch
import torch_geometric
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions
                                            )
from graphein.protein.graphs import construct_graph
import networkx as nx
import numpy as np
from torch_geometric.data import Data
import os

  from .autonotebook import tqdm as notebook_tqdm
To use the Graphein submodule graphein.protein.features.sequence.embeddings, you need to install: biovec 
biovec cannot be installed via conda
To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d 
To do so, use the following command: conda install -c pytorch3d pytorch3d


In [2]:
import pandas as pd
import numpy as np
from math import ceil,sin,cos,sqrt,pi
class parser:
    def __init__(self,filepath,label_path):
        self.residue_feature=[]
        self.a3toa1={'GLN': 'Q', 'LYS': 'K', 'PRO': 'P', 'TYR': 'Y', 'GLU': 'E', 'THR': 'T', 'GLY': 'G', 'ILE': 'I', 'SER': 'S', 'HIS': 'H', 'ARG': 'R', 'ASP': 'D', 'MET': 'M', 'LEU': 'L', 'ASN': 'N', 'VAL': 'V', 'PHE': 'F', 'ALA': 'A', 'TRP': 'W', 'CYS': 'C', 'OCS': 'C', 'MSE': 'M', 'OCY': 'C', 'ACE': 'X', 'MLZ': 'K', 'FME': 'M', 'CXM': 'M', 'CME': 'C', 'KCX': 'K', 'PTR': 'Y', 'DDZ': 'A', 'LLP': 'K', 'MAA': 'A', 'MLY': 'K', 'CSO': 'C', 'CSS': 'C', 'PCA': 'Q', 'SEC': 'U', 'ALS': 'A', 'CSX': 'C', 'CSD': 'C', 'SEP': 'S', 'SCY': 'C', 'SNN': 'N', 'NEP': 'H', '0AF': 'W', 'NH2': 'X', 'PHD': 'D', 'KPI': 'K', 'HYP': 'P', 'SME': 'M', 'ZBZ': 'C', 'AYA': 'A', 'LP6': 'K', 'CAS': 'C', 'FOR': 'X', 'DAL': 'A', 'TPO': 'T', 'HIC': 'H', 'DAH': 'F', 'YCM': 'C', 'SNC': 'C', 'CS4': 'C', '2LT': 'Y', 'BCS': 'C', 'CYG': 'C', 'P1L': 'C', '2MR': 'R', 'BFD': 'D', 'MHS': 'H', 'GL3': 'G', 'MHO': 'M', 'XPC': 'X', 'B3E': 'E', 'XCP': 'X', 'CYD': 'C', 'NLE': 'L', 'ABA': 'A', 'TYI': 'Y', 'OMT': 'M', 'CSR': 'C', 'SMC': 'C', 'CGV': 'C', 'FGL': 'G', 'MSO': 'M', 'LYZ': 'K', 'PBF': 'F', 'AME': 'M', 'KYN': 'W', 'M3L': 'K', 'CSB': 'C', 'TRW': 'W', 'CSA': 'C', 'SC2': 'X', 'MEN': 'N', 'LED': 'L', 'SEB': 'S', 'TYS': 'Y', 'MLE': 'L', '9EV': 'X', 'SVY': 'S', 'MLI': 'X', 'PYL': 'O', 'CCS': 'C', 'ALY': 'K', 'APK': 'K', 'ACY': 'X', 'TY2': 'Y', 'DHA': 'S', 'TRN': 'W', 'ORN': 'A', 'ZAL': 'A', 'ALN': 'A', 'PXU': 'P', 'HS8': 'H', 'TNQ': 'W', 'PHI': 'F', 'QCS': 'C', 'OAS': 'S', 'TRQ': 'W', 'NIY': 'Y', 'DMG': 'X', 'MME': 'M', 'MIR': 'S', '4M9': 'X', 'PVO': 'X', 'TRO': 'W', 'IYR': 'Y', 'LCK': 'K', 'MYR': 'X', '4HH': 'S', 'CGN': 'X', '55I': 'F', 'UF0': 'S', 'AAR': 'R', 'GPL': 'K', 'PR7': 'P', '143': 'C', 'SAC': 'S', 'SCH': 'C', 'KYQ': 'K', 'OSE': 'S', 'KST': 'K', 'AGM': 'R', 'I2M': 'I', 'MGN': 'Q', 'CGU': 'E', 'LVN': 'V', 'FGP': 'S', 'CRO': 'X', 'CRQ': 'X', 'GYS': 'X', 'MDO': 'X', 'XYG': 'X', 'CR2': 'X', 'NRQ': 'X', 'KWS': 'X'}
        self._parse_design_pdb(filepath)
        self.feature_column=['fa_atr', 'fa_rep', 'fa_sol', 'fa_intra_rep', 'fa_intra_sol_xover4', 'lk_ball_wtd', 'fa_elec', 'pro_close', 'hbond_sr_bb', 'hbond_lr_bb', 'hbond_bb_sc', 'hbond_sc', 'dslf_fa13', 'omega', 'fa_dun', 'p_aa_pp', 'yhh_planarity', 'ref', 'rama_prepro']
        new_edge_funcs = {"edge_construction_functions": [add_peptide_bonds, add_hydrogen_bond_interactions,add_ionic_interactions,add_disulfide_interactions,add_cation_pi_interactions]}
        self.config = ProteinGraphConfig(**new_edge_funcs)
        self.filepath=filepath
        self.label={}
        self._get_label(label_path)
        self.name=os.path.basename(self.filepath)
    def _get_label(self,path):
        with open(path) as f:
            for x in f.readlines():
                self.label[x.split(',')[0]]=float(x.split(',')[1])
    def _parse_design_pdb(self,filepath):
        columns=['label', 'fa_atr', 'fa_rep', 'fa_sol', 'fa_intra_rep', 'fa_intra_sol_xover4', 'lk_ball_wtd', 'fa_elec', 'pro_close', 'hbond_sr_bb', 'hbond_lr_bb', 'hbond_bb_sc', 'hbond_sc', 'dslf_fa13', 'omega', 'fa_dun', 'p_aa_pp', 'yhh_planarity', 'ref', 'rama_prepro']
        with open(filepath) as g:
            lines = g.readlines()
            parse=False            
            for line in lines:
                if '#BEGIN_POSE_ENERGIES_TABLE' in line:
                    parse=True
                elif '#END_POSE_ENERGIES_TABLE' in line:
                    parse=False
                elif parse:
                    if 'label' in line or 'weights' in line:
                        continue
                    elif 'pose' in line:
                        self.pose=line.strip().split(' ')[1:]
                    else:
                        strip_line=line.strip().split()[:-1]
                        self.residue_feature.append([float(x) for x in strip_line[1:]])
    def get_pose(self):
        columns=['fa_atr', 'fa_rep', 'fa_sol', 'fa_intra_rep', 'fa_intra_sol_xover4', 'lk_ball_wtd', 'fa_elec', 'pro_close', 'hbond_sr_bb', 'hbond_lr_bb', 'hbond_bb_sc', 'hbond_sc', 'dslf_fa13', 'omega', 'fa_dun', 'p_aa_pp', 'yhh_planarity', 'ref', 'rama_prepro']
        return {x:float(self.pose[i]) for (i,x) in enumerate(columns)}
    def get_design_seq(self):
        design_residue=self.parsed_pdb[self.parsed_pdb['Design']]
        return ''.join(self.a3toa1[y] for y in [ x['atm_resname'].unique()[0] for (_,x) in design_residue[['atm_resname','atm_reseq']].groupby(by='atm_reseq')])
    def get_torch_data(self):
        g = construct_graph(config=self.config, pdb_path=self.filepath)
        edges=torch.from_numpy(np.array(nx.to_numpy_array(g).nonzero()))
        nodes=torch.from_numpy(np.array(self.residue_feature))
        data = Data(x=nodes, edge_index=edges,name=self.name,y=self.label[self.name],num_nodes=len(nodes))
        return data

In [3]:
import torch
from torch_geometric.data import InMemoryDataset

class Dataset(InMemoryDataset):
    def __init__(self,design_list,save_dir,root='.',transform=None, pre_transform=None, pre_filter=None):
        self.design_list=design_list
        self.save_dir=save_dir
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
    @property
    def raw_file_names(self):
        return self.design_list
    @property
    def processed_file_names(self):
        return ['data.pt']
    @property
    def processed_dir(self):
        return self.save_dir
    def process(self):
        # Read data into huge `Data` list.
        data_list = []
        for name in self.design_list:
            if 'ipynb_checkpoints' in name:
                continue
            pars=parser('design/'+name,'ground-truth.txt')
            data_list.append(pars.get_torch_data())
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        print(self.processed_paths)
        torch.save((data, slices), self.processed_paths[0])

In [4]:
from sklearn.model_selection import train_test_split

In [5]:
train_list,test_list=train_test_split(os.listdir('design/'))

In [6]:
train_dataset=Dataset(train_list,'train/')
test_dataset=Dataset(test_list,'test/')

Processing...


['train/data.pt']


Done!
Processing...


['test/data.pt']


Done!


In [7]:
from torch_geometric.data import DataLoader

In [8]:
train_loder=DataLoader(train_dataset,batch_size=2)
test_loder=DataLoader(test_dataset)



In [18]:
from torch_geometric.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self,node_features,hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index, batch=data.x.float(),data.edge_index,data.batch
        
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(node_features=19,hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(19, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(64, 1, bias=True)
)


In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

def train():
    model.train()
    for data in train_loder:  # Iterate in batches over the training dataset.
        out = model(data)  # Perform a single forward pass.
        loss = criterion(out, data.y.view(-1,1))  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()
    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data)  
        pred = (out>=0.5)*1.0  # Use the class with highest probability.
        correct += int((pred == data.y.view(-1,1)).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loder)
    test_acc = test(test_loder)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 002, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 003, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 004, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 005, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 006, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 007, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 008, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 009, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 010, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 011, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 012, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 013, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 014, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 015, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 016, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 017, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 018, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 019, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 020, Train Acc: 0.8571, Test Acc: 1.0000
Epoch: 021, Train Acc: 0.8571, Test Acc:

KeyboardInterrupt: 