***Using our drug repurposing model: GDRnet***

In [1]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import openpyxl
import re

Defining the required functions ----------

In [2]:
def predict(net,dis_list): #list of diseases - in Disease::MESH:D###### format
  dis_batches,drug_dict = get_disease_batches(nodes_mapping,dis_list)
  dis_batches = torch.LongTensor(dis_batches)
  dictionaries_norm = []
  for i in range(len(dis_batches)):
    embed,logits = net(input_features.to(device),ax.to(device),a2x.to(device),dis_batches[i].to(device))
    probs = standardize(logits)
    dct_norm = dict ()
    for j in range(len(probs)):
      x = get_node_name(dis_batches[i,j,0].item())
      if (x in drug_dict.keys()) : 
        dct_norm[drug_dict[x][0]] = probs[j].item()
      else : 
        dct_norm[x] = probs[j].item()
    dictionaries_norm.append(dct_norm)
  return embed,dictionaries_norm

def load_variable(filename):
  return pickle.load(open(filename,'rb'))

def get_node_name(id):
  return list(nodes_mapping.keys())[list(nodes_mapping.values()).index(id)]

def get_node_id(name):
  return list(nodes_mapping.values())[list(nodes_mapping.keys()).index(name)]

def load_model_on_cpu(model,path):
  model.load_state_dict(torch.load(path,map_location=torch.device('cpu')))
  return model

def get_disease_batches(nodes_mapping,disease_list): #disease_id in the form like Disease::MESH..
  dct = get_drug_name_desc_dict()
  keys = list(nodes_mapping.keys())
  drugs = []
  batches = []
  for key in keys:
    if (re.search(r"Compound+",key,re.I)):
      if (key in dct.keys()):
        '''We can change the set of drugs here -- as in if wanna remove the withdrawn/experimental drugs''' 
        #a = dct[key][1].split(',')
        #if (not ((('experimental' in a) and (len(a)==1)) or 'withdrawn' in a)) :
        drugs.append(nodes_mapping[key])
  for disease in disease_list:
    disease_id = get_node_id(disease)
    batch = []
    for drug in drugs:
      batch.append((drug,disease_id))
    batches.append(batch)
  return batches,dct

def get_drug_name_desc_dict():
  #May need to change the path of Drug_details file accordingly
  sheet = openpyxl.load_workbook('/content/drive/My Drive/Using_DR_model/Drug_details.xlsx').active
  dct = dict ()
  for i in range(1,sheet.max_row+1):
    dct[sheet.cell(row=i,column=1).value] = (sheet.cell(row=i,column=2).value,sheet.cell(row=i,column=3).value,sheet.cell(row=i,column=4).value)
  return dct

def standardize(t):
  mean = torch.mean(t)
  stdev = torch.std(t)
  standard_t = (t-mean)/stdev
  return standard_t

def get_rank(dct,key):
  lst = sorted(dct.items(),key=lambda t:t[1])[::-1]
  for i in range(len(lst)):
    if (key==lst[i][0]):
      break
  return i+1

Model definition / Blue print -------

In [3]:
L_Relu = nn.LeakyReLU(0.2)
sig = nn.Sigmoid()
Relu = nn.ReLU()
tanh = nn.Tanh()

class GDRnet(nn.Module):
  def __init__(self):
    super(GDRnet, self).__init__()
    decoder_dim = 250
    input_dim = 400
    r = 3
    self.theta0 = nn.Linear(input_dim,decoder_dim) 
    self.theta1 = nn.Linear(input_dim,decoder_dim)
    self.theta2 = nn.Linear(input_dim,decoder_dim)
    self.combine1 = nn.Linear(decoder_dim*r,decoder_dim) 
    self.layer8 = nn.Linear(decoder_dim,decoder_dim)
    self.layer9 = nn.Linear(decoder_dim,decoder_dim) #not used 

  def decoder(self,t,batch): 
    self.t_new = torch.empty(len(batch)).to(device)
    for i in range(len(batch)):
      self.c = torch.dot(t[batch[i,0].item()],self.layer8(t[batch[i,1].item()])).to(device) #+torch.dot(t[batch[i,1].item()],self.layer9(t[batch[i,0].item()]))).to(device)
      self.t_new[i] = self.c
    return self.t_new

  def forward(self,X,ax,a2x,batch):
    t1 = tanh(self.theta0(X))
    t2 = tanh(self.theta1(ax))
    t3 = tanh(self.theta2(a2x))
    c = torch.cat((t1,t2,t3),dim=1)
    c = L_Relu(self.combine1(c))
    t1 = self.decoder(c,batch)
    return c,t1
  

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading our pre-trained model--------


In [5]:
#give a path for all these files
input_features = load_variable("/content/drive/My Drive/Using_DR_model/input_features.p") 
nodes_mapping = load_variable("/content/drive/My Drive/Using_DR_model/nodes_mapping.p")
A_tilda = load_variable("/content/drive/My Drive/Using_DR_model/A_tilda.p")
ax = A_tilda*np.array(input_features)
a2x = A_tilda*ax
a2x = torch.tensor(a2x,dtype=torch.float)
ax = torch.tensor(ax,dtype=torch.float)
empty_model = GDRnet()
net = load_model_on_cpu(empty_model,"/content/drive/My Drive/Using_DR_model/DR_model").to(device)

We give out a list of all the 4k diseases and 8k drugs in our dataset, on which our is model is trained. We can predict the drugs for any of these diseases.

In [6]:
embeddings,drugs = predict(net,["Disease::MESH:D008288"]) #give a list of diseases in the same form as in the "Disease_list.xlsx" 
#embeddings - Our 250 dimensional node embeddings for all the entities in our graph
'''drugs - here "drugs" will be a list of dictionaries (each dict for a disease you give) with every dict
following keys = drug names and values = corresponding scores'''

'drugs - here "drugs" will be a list of dictionaries (each dict for a disease you give) with every dict\nfollowing keys = drug names and values = corresponding scores'

In [7]:
#we can check the rank of any drug in our predicted list ----------
get_rank(drugs[0],"Chloroquine")

7

In [8]:
#top 30 predicted drugs for a disease, this list is of MESH::D008288 -- Malaria
sorted(drugs[0].items(),key=lambda t:t[1])[::-1][:30]

[('Tetracycline', 2.903021812438965),
 ('Clindamycin', 2.861628532409668),
 ('Doxycycline', 2.7223026752471924),
 ('Metronidazole', 2.70158052444458),
 ('Minocycline', 2.687775135040283),
 ('Ivermectin', 2.589362621307373),
 ('Chloroquine', 2.574557065963745),
 ('Rifapentine', 2.5660512447357178),
 ('Erythromycin', 2.538343906402588),
 ('Proguanil', 2.5356037616729736),
 ('Sulfadiazine', 2.5127780437469482),
 ('Dapsone', 2.502814769744873),
 ('Clarithromycin', 2.467073917388916),
 ('Rifabutin', 2.4631710052490234),
 ('Trimethoprim', 2.445632219314575),
 ('Primaquine', 2.4132375717163086),
 ('Praziquantel', 2.4030725955963135),
 ('Demeclocycline', 2.4014358520507812),
 ('Atovaquone', 2.381624221801758),
 ('Sulfamethoxazole', 2.368584632873535),
 ('Terbinafine', 2.365410566329956),
 ('Rifaximin', 2.363856554031372),
 ('Rifampicin', 2.3388354778289795),
 ('Loperamide', 2.2780063152313232),
 ('Hydroxychloroquine', 2.2549023628234863),
 ('Telithromycin', 2.239393472671509),
 ('Ketoconazole'