In [22]:
import read_pdb_file_20200302 as readpb
import os
import json
import random
from json import loads, dumps
from ast import literal_eval
import pandas
import torch
import torch.nn as nn
import dgllife
from dgllife.utils import CanonicalAtomFeaturizer
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Pad
import torch.nn.functional as F

In [2]:
mapping_atom2int = {'C':0, 'D':1, 'N':2, 'O':3, 'P':4, 'S':5}

In [3]:
# Import Protein Data Base Files
'''
Contains atoms and locations that make up the structure of a protein
'''
proteins_dict = {}
for protein in os.listdir("./cs5242_project_data/pdbs/"):
    if protein.startswith("."):
        pass
    else:
        x,y,z,group,atom = readpb.read_pdb(os.path.join("cs5242_project_data/pdbs/",protein))
        atom = [mapping_atom2int[i] for i in atom]
    temp = {}
    for i in range(len(x)):
        temp[i] = {"coords":[x[i],y[i],z[i]], 'group':group[i], "atom":atom[i]}
        proteins_dict[protein[:4]] = temp

In [4]:
# proteins_dict['102D']
def gen_amino_list(protein):
    
    list_of_groups = [i['group'] for i in protein.values()]
    # average = np.average(list_of_coords, axis=0)
    # print(average)
    set_of_groups = set(list_of_groups)
    set_of_groups  = dict.fromkeys(list_of_groups, 0)
    for k in set_of_groups:
        set_of_groups[k] = {'atom':[i['atom'] for i in protein.values() if i['group'] == k], 
                            'coords':[i['coords'] for i in protein.values() if i['group'] == k]}
        
        set_of_groups[k]['centroid'] = list(np.average(set_of_groups[k]['coords'],0))
    return set_of_groups

# set_of_amino = gen_amino_list(proteins_dict['102D'])

In [119]:
# Produce full list of amino-acid group to embedding map
# Code takes a long time to run, instead a json object is saved for loading

# full_amino_list = []

# for i in proteins_dict:
#     amino_group = gen_amino_list(proteins_dict[i])
    
#     for k in amino_group.values():
        
#         if k['atom'] not in full_amino_list:
#             full_amino_list.append(k['atom'])
            
# amino_db = {}
# for i,item in enumerate(full_amino_list):
#     amino_db[tuple(item)] = i
    
# # save: convert each tuple key to a string before saving as json object
# s = dumps({str(k): v for k, v in amino_db.items()})
  
# # Writing
# with open("amino_db.json", "w") as outfile:
#     outfile.write(s)

In [9]:
# (i) load json object

with open('amino_db.json', 'r') as openfile:
    json_object = json.load(openfile)

# (ii) convert loaded keys from string back to tuple
amino_db = {literal_eval(k): v for k, v in json_object.items()}

In [11]:
# Load centroids
'''
Location of ligand binding site on protein
'''
centroid_df = pandas.read_csv('./cs5242_project_data/centroids.csv')
# null_df = centroid_df.isnull()
# row_with_NA = null_df.any(axis=1)
# centroid_df[row_with_NA]

# Remove NA - PID 1NDE
centroid_df = centroid_df.dropna()
centroid_df = centroid_df.reset_index(drop=True)
centroid_df.head()

Unnamed: 0,PID,x,y,z
0,102D,9.819391,24.178348,71.561739
1,110M,35.189667,6.802667,12.175667
2,112M,34.8922,7.174,12.4984
3,11BA,-14.688256,14.944487,0.193744
4,11BG,5.319879,55.114576,66.171818


In [12]:
### JOIN PROTEIN AND CENTROID
proteins_with_binding = {}
for i, protein in enumerate(centroid_df['PID']):
    proteins_with_binding[protein] = {'centroid':[centroid_df.loc[i,'x'],centroid_df.loc[i,"y"],centroid_df.loc[i,"z"]]
                                     }
proteins_with_binding['102D']

{'centroid': [9.819391304, 24.17834783, 71.56173913]}

In [14]:
# Generate Protein Embeddings
def protein2numeric(protein_id, max_len):    # Get numeric list of amino acids
    amino_list = gen_amino_list(proteins_dict[protein_id])
    
    # sort numeric by distance
    neighbors = []
    distances = []
    centroid = proteins_with_binding[protein_id]['centroid']
    for i in amino_list.values():
        numeric = amino_db.get(tuple(i['atom']))
        dist = np.linalg.norm(np.array(i['centroid'])-np.array(centroid))
        distances.append((numeric, dist))
    distances.sort(key=lambda tup:tup[1])
    
    for i in range(len(distances)):
        if i < max_len:
            neighbors.append(distances[i][0])
    if len(neighbors)<max_len:
        list = [0] * (max_len-len(neighbors))
        neighbors += list
        
    return torch.Tensor(neighbors)
    # return numeric tensor
# protein2numeric('102D')

In [15]:
# Load ligand
'''
Loads ligand ID and SMILES information 
'''
ligand_df = pandas.read_csv('cs5242_project_data/ligand.csv')
ligand2class = {}
for i,lid in enumerate(ligand_df['LID']):
    ligand2class[lid] = i

In [16]:
CHAR_SMI_SET = {"(": 1, ".": 2, "0": 3, "2": 4, "4": 5, "6": 6, "8": 7, "@": 8,
                "B": 9, "D": 10, "F": 11, "H": 12, "L": 13, "N": 14, "P": 15, "R": 16,
                "T": 17, "V": 18, "Z": 19, "\\": 20, "b": 21, "d": 22, "f": 23, "h": 24,
                "l": 25, "n": 26, "r": 27, "t": 28, "#": 29, "%": 30, ")": 31, "+": 32,
                "-": 33, "/": 34, "1": 35, "3": 36, "5": 37, "7": 38, "9": 39, "=": 40,
                "A": 41, "C": 42, "E": 43, "G": 44, "I": 45, "K": 46, "M": 47, "O": 48,
                "S": 49, "U": 50, "W": 51, "Y": 52, "[": 53, "]": 54, "a": 55, "c": 56,
                "e": 57, "g": 58, "i": 59, "m": 60, "o": 61, "s": 62, "u": 63, "y": 64,
                "p":65, ' ':66}

max_len = 50

# def ligand2numeric(smiles):
#     mol = Chem.MolFromSmiles(smiles)
#     atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='feat')
#     num = atom_featurizer(mol)['feat']
#     # layer_pad = Pad([0,0,0,num.shape[1]-num.shape[0]])
#     num = num.view(-1)
#     num = F.pad(num, (0, 6734 - len(num)), 'constant', 0)
#     return num

# list_of_ligands = [ligand2numeric(smiles) for smiles in ligand_df['Smiles']]

def ligand2numeric(line, max_len):
    X = np.zeros(max_len)
    for i, ch in enumerate(line[:max_len]):
        X[i] = CHAR_SMI_SET[ch] 

    return torch.Tensor(X)

# ligand2numeric(ligand_df.loc[2,'Smiles'], max_smi_len)

In [18]:
# Load Pair
'''
Loads Information of which ligand binds to which protein
'''
pair_df = pandas.read_csv('cs5242_project_data/pair.csv')
pairs_lookup = {}
for i in range(len(pair_df)):
    pairs_lookup[pair_df.loc[i,'LID']] = pair_df.loc[i,'PID'] 

In [19]:
# Merge Ligand Smiles to PID
df = pandas.merge(centroid_df, pair_df, on = 'PID')
df = pandas.merge(df, ligand_df, on = 'LID')

# labels = np.array(df['LID'])
# zeros_matrix = np.zeros((labels.size, labels.max()+1))
# zeros_matrix[np.arange(labels.size),labels] = 1

In [20]:
torch.manual_seed(0)
np.random.seed(0)
range_index = range(len(df))
random_index = np.random.permutation(range_index)
training_idx, valid_idx, test_idx = random_index[0:int(0.7*len(df))], random_index[int(0.7*len(df)):int(0.85*len(df))], random_index[int(0.85*len(df)):]

In [233]:
# # Generate list of ligand embeddings
# ligand_embeddings = [ligand2numeric(ligand_df.loc[k,'Smiles'], max_len) for k in range(len(ligand_df))]

# # Generate list of protein embeddings
# protein_embeddings = [protein2numeric(k, max_len) for k in proteins_with_binding]

In [116]:
### Generate a function to generate negative samples 
def random_sample(no_of_samples=no_of_samples):
    random_ligands = np.random.randint(0, len(ligand_df), no_of_samples)
    random_proteins = random.sample(list(proteins_with_binding), no_of_samples)
    wrong_tensor = torch.zeros(size=(no_of_samples,max_len*2))
    for i,item in enumerate(zip(random_ligands,random_proteins)):
        wrong_tensor[i,:] = torch.cat((ligand2numeric(ligand_df.loc[item[0],'Smiles'], max_len), protein2numeric(item[1], max_len)))
        
    label = torch.zeros(size=(no_of_samples,))
    for i,item in enumerate(zip(random_ligands,random_proteins)):
        if item[0] in pairs_lookup:
            if pairs_lookup[item[0]] == item[1]:
                label[i] = 1
        else:
            pass
    return wrong_tensor, label

In [102]:
traindf = df.loc[training_idx,:]
np.array(traindf['PID'])

array(['1OF1', '1UUM', '1PKE', ..., '2XFH', '2I0Y', '2VVT'], dtype=object)

In [106]:
def correct_sample(df, idx):
    traindf = df.loc[training_idx,:]
    correct_ligands = np.array(traindf['Smiles'])
    correct_proteins = np.array(traindf['PID'])
    correct_tensor = torch.zeros(size=(len(traindf),max_len*2))
    for i,item in enumerate(zip(correct_ligands,correct_proteins)):
        correct_tensor[i,:] = torch.cat((ligand2numeric(item[0], max_len), protein2numeric(item[1], max_len)))
        
    label = torch.ones(size=(len(traindf),))

    return correct_tensor, label

In [127]:
# DEFINE MODEL PARAMETERS
no_of_samples = 1000
max_len = 50
num_epochs = 10
num_classes = 3424
batch_size = 8
learning_rate = 0.001

In [128]:
# DEFINE DATASET AND DATALOADERS

class PLI_Dataset(Dataset):
    """Returns protein, ligand, and output sequence."""

    def __init__(self, dataframe, split):
        """
        Args:
            dataframe - df file
            split - train, valid, or test
        """
        self.dataframe = dataframe
        
        if split == 'train':
            correct_tensor, label = correct_sample(self.dataframe, training_idx)
            wrong_tensor, wlabel = random_sample(no_of_samples=no_of_samples)
            self.split = torch.cat((correct_tensor,wrong_tensor))
            self.label = torch.cat((label,wlabel))
            
        elif split == 'valid':
            correct_tensor, label = correct_sample(self.dataframe, valid_idx)
            wrong_tensor, wlabel = random_sample(no_of_samples=no_of_samples)
            self.split = torch.cat((correct_tensor,wrong_tensor))
            self.label = torch.cat((label,wlabel))
            
        elif split == 'test': 
            correct_tensor, label = correct_sample(self.dataframe, test_idx)
            wrong_tensor, wlabel = random_sample(no_of_samples=no_of_samples)
            self.split = torch.cat((correct_tensor,wrong_tensor))
            self.label = torch.cat((label,wlabel))

    def __len__(self):
        return len(self.split)

    def __getitem__(self, idx):

        embedding = self.split[idx]
        output = self.label[idx]

        return embedding, output 

In [129]:
datasets = {"train_set" : PLI_Dataset(dataframe = df, split = 'train'), 
            "val_set" : PLI_Dataset(dataframe = df, split = 'valid'),
            "test_set" : PLI_Dataset(dataframe = df, split = 'train')
           }
dataloaders = {"train_loader": torch.utils.data.DataLoader(dataset=datasets["train_set"], batch_size=8,shuffle=True,drop_last=True),
               "val_loader": torch.utils.data.DataLoader(dataset=datasets["val_set"], batch_size=8,shuffle=True,drop_last=True),
               "test_loader": torch.utils.data.DataLoader(dataset=datasets["test_set"], batch_size=8,shuffle=True,drop_last=True)
              }

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

KeyboardInterrupt: 

In [None]:
embedding, output = next(iter(dataloaders['train_loader']))

In [53]:
embedding.shape

torch.Size([8, 21, 100])

In [54]:
output.shape

torch.Size([8, 21])

In [61]:
class MLPNet(nn.Module):
    def __init__(self, num_classes=num_classes, device='cpu'):
        super(MLPNet, self).__init__()
        self.device = device

        # DEFINE NN MODULES
        self.relu = nn.ReLU()
        self.batchnorm1 = nn.BatchNorm1d(32)
        self.batchnorm2 = nn.BatchNorm1d(32)
        self.flatten = nn.Flatten()
        self.fc1_1 = nn.Linear(in_features=100, out_features=32)
        self.fc2_1 = nn.Linear(in_features=32, out_features=32)
        self.classifier = nn.Linear(in_features=32, out_features=2)

    
    def forward(self, x):

    # DEFINE FORWARD FUNCTIONS
    
        x_out = self.batchnorm1(self.fc1_1(x))
        x_out = self.relu(x_out)
        x_out = self.batchnorm2(self.fc2_1(x_out))
        x_out = self.relu(x_out)
        out = self.classifier(x_out)

        return out
    
    def inference(self, PID, PDBS_DIR, centroid, ligands):

        p = self.process_PDB(PID, PDBS_DIR).to(self.device)
        c = self.process_coord(centroid).to(self.device)
        l = self.batch_process_SMILE(ligands).to(self.device)
        return self.forward(p, c, l)

In [63]:
model = MLPNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(dataloaders['train_loader'])
for epoch in range(num_epochs):
    for i, (embedding, target) in enumerate(dataloaders['train_loader']):
        embedding = embedding.reshape((8*21,100))
        target = target.reshape(-1)
        
        embedding = embedding.to(device)
        target = target.to(device)
        # Forward pass
        output = model(embedding)
        loss = criterion(output, target)
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            
    # model.eval()  
    # with torch.no_grad():
    #     correct = 0
    #     total = 0
    #     for i, (protein, ligand, target) in enumerate(dataloaders['val_loader']):
    #         ligand = ligand.to(device)
    #         protein = protein.to(device)
    #         target = target.to(device)
    #         output = model(protein, ligand)
    #         _, predicted = torch.topk(output.data, 10, 1)
    #         total += target.size(0)
    #         for i in range(target.size(0)):
    #             if target[i] in predicted[i]:
    #                 correct += 1
    #             else:
    #                 pass
    #     print('Val accuracy on validation set: {} %'.format(100 * correct / total))
        
# # Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')

ValueError: Target size (torch.Size([168])) must be the same as input size (torch.Size([168, 2]))

In [None]:
# Test the model
model.eval()  
with torch.no_grad():
    correct = 0
    total = 0
    for i, (protein, ligand, target) in enumerate(dataloaders['val_loader']):
        ligand = ligand.to(device)
        protein = protein.to(device)
        target = target.to(device)
        output = model(protein, ligand)
        _, predicted = torch.topk(output.data, 10, 1)
        total += target.size(0)
        for i in range(target.size(0)):
            if target[i] in predicted[i]:
                correct += 1
            else:
                pass
    print(total)
    print('Val accuracy on validation set: {} %'.format(100 * correct / total))