In [2]:
import sys, os
import pandas as pd
import numpy as np
import json,pickle
from collections import OrderedDict
from rdkit import Chem
from rdkit.Chem import MolFromSmiles
import networkx as nx

from models.gat_gcn import GAT_GCN
from sklearn import model_selection, preprocessing, metrics, decomposition
import matplotlib.pyplot as plt

from random import shuffle
import torch
import torch.nn as nn
from utils1 import *
if torch.cuda.is_available():  
    device = "cuda:4"
    print("cuda:4")
else:  
    device = "cpu" 



cuda:4


In [3]:
def webserver(drug1,drug2,cell_name):

    def atom_features(atom):
        return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
                        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) +
                        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) +
                        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) +
                        [atom.GetIsAromatic()])

    def one_of_k_encoding(x, allowable_set):
        if x not in allowable_set:
            raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
        return list(map(lambda s: x == s, allowable_set))

    def one_of_k_encoding_unk(x, allowable_set):
        """Maps inputs not in the allowable set to the last element."""
        if x not in allowable_set:
            x = allowable_set[-1]
        return list(map(lambda s: x == s, allowable_set))

    def smile_to_graph(smile):
        mol = Chem.MolFromSmiles(smile)

        c_size = mol.GetNumAtoms()

        features = []
        for atom in mol.GetAtoms():
            feature = atom_features(atom)
            features.append( feature / sum(feature) )

        edges = []
        for bond in mol.GetBonds():
            edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
        g = nx.Graph(edges).to_directed()
        edge_index = []
        for e1, e2 in g.edges:
            edge_index.append([e1, e2])

        return c_size, features, edge_index
    
    #Read Gene expression of cell-line
    def cell_exp(cell_line):
        
        cell_line_names= []
        cell_names_df = pd.read_csv('basic_info/cell_line_names.csv')
        cell_line_names += list( cell_names_df["x"] )
        Gene_exp= []
        cell_df = pd.read_csv('final_data/cell_line.csv')
        #cell_df['cell_line']=cell_names_df['x']
        Gene_exp = cell_df.values.tolist()

        cell_dic = {}
        for i, gene_exp in enumerate(Gene_exp):
            cell_dic[cell_line_names[i]]=gene_exp
    
        x = cell_dic[cell_line]
        #print(x)
        #print(len(x))
        cell_lst = []
        cell_lst.append(x)
        #x = np.array(x)
        #x = x[1:]
        #print(x)
        #print(x.shape)
        #x=x.replace('[','')
        #x=x.replace(']','')
        #x=x.replace(cell_line,'')
        #x=[float(y.strip()) for y in x.split(',')]
        return cell_lst
    
    # Read the Smiles and convert it to graph
    def drug_graph(drug1,drug2):
        Drug_name = []
        Drug_smiles=[]
        df = pd.read_csv('smiles.csv')
        Drug_name += list( df["Name"] )
        Drug_smiles += list( df["smile"] )

        Drug_smiles = set(Drug_smiles)
        smile_graph = {}
        smiles_dic = {}
        for i,smile in enumerate(Drug_smiles):
            g = smile_to_graph(smile)
            smile_graph[smile] = g
            smiles_dic[Drug_name[i]] = smile
        drug_1_gr = smiles_dic[drug1]
        drug_2_gr = smiles_dic[drug1]
        d1_lst=[]
        d2_lst=[]
        d1_lst.append(drug_1_gr)
        d2_lst.append(drug_2_gr)
        return d1_lst, d2_lst,smile_graph
    

    def predicting(model, device, loader1,loader2):
        model.eval()
        total_preds = torch.Tensor()
        total_labels = torch.Tensor()
        #print('Make prediction for {} samples...'.format(len(loader1.dataset)))
        with torch.no_grad():
            for data1,data2 in zip(loader1,loader2):
                data1 = data1.to(device)
                data2 = data2.to(device)
                output = model(data1,data2)
                total_preds = torch.cat((total_preds, output.cpu()), 0)
                #total_labels = torch.cat((total_labels, data1.y.view(-1, 1).cpu()), 0)
        return total_preds.numpy().flatten() #total_labels.numpy().flatten()

    d_1,d_2, smile_graph = drug_graph(drug1,drug2) 
    cell_line_exp = cell_exp(cell_name)
    dataset="human"
    test_data1 = TestbedDataset(root='data', dataset="human1", xd=d_1, xt=cell_line_exp,smile_graph=smile_graph)
    test_data2 = TestbedDataset(root='data', dataset="human2", xd=d_2, xt=cell_line_exp,smile_graph=smile_graph)
    test_loader1 = DataLoader(test_data1, batch_size=128, shuffle=False)
    test_loader2 = DataLoader(test_data2, batch_size=128, shuffle=False)
    model = GAT_GCN()
    model = model.to(device)
    model_st = "GCNNet"
    model_file_name = 'model_' + model_st + '_' + str(0) +  '.model'
    model.load_state_dict(torch.load(model_file_name))
    prediction = predicting(model, device, test_loader1,test_loader2) 
    
    return prediction


drug_1= "5-FU"
drug_2= "ABT-888"
cell_name= "A2058"
prediction = webserver(drug_1,drug_2,cell_name)
print(prediction)

Pre-processed data data/processed/human1.pt not found, doing pre-processing...
Converting SMILES to graph: 1/1
Graph construction done. Saving to file.
Pre-processed data data/processed/human2.pt not found, doing pre-processing...
Converting SMILES to graph: 1/1
Graph construction done. Saving to file.
[1.2952782]


In [None]:

"A2058","A2780","A375","A427","CAOV3","COLO320DM","DLD1","EFM192B","ES2","HCT116","HT144","HT29","KPL1","LNCAP","LOVO","MDAMB436","MSTO","NCIH1650","NCIH2122","NCIH23","NCIH460","NCIH520","OCUBM","OV90","OVCAR3","PA1","RKO","RPMI7951","SKMEL30","SKMES1","SKOV3","SW620","SW837","T47D","UACC62","UWB1289","UWB1289BRCA1","VCAP","ZR751"