## DOWNLOADING THE FILES(if required)

In [None]:
! pip install gdown

In [None]:
! gdown --folder https://drive.google.com/drive/folders/1g4jlcdFpcDyQAvvkYhugSFKy_5w7qfpM

In [None]:
! mv /content/Sem8project/* /content/

## INSTALL AND IMPORT

In [None]:
!pip install ogb

In [None]:
! pip install rdkit-pypi
! pip install deepchem
from rdkit.Chem import MACCSkeys
from rdkit import Chem

In [None]:
!pip install dgl

In [None]:
import dgl
import torch as th
import numpy as np
import pandas as pd

In [None]:
import re, io, gzip
import xml.etree.ElementTree as ET
import collections
import requests
import pandas
import json

In [None]:
import numpy as np
import torch
import dgl
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

In [None]:
import random
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

## Mapping the drugids to integers and vice versa (same for the common genes and diseases)

In [None]:
csvFile = pandas.read_csv('./drug_id_list.csv')

print(len(csvFile))
print(csvFile)

In [None]:
drug_id_list = csvFile['DrugID']
print(drug_id_list[0])

In [None]:
drugbank_id_to_graph_node_id_map = {}
graph_node_id_to_drugbank_id_map = {}


# to maintain the count of nodes in the graph
count = 0 

for x in drug_id_list:
    drugbank_id_to_graph_node_id_map[x] = count
    graph_node_id_to_drugbank_id_map[count] = x
    count = count + 1

In [None]:
print(len(drugbank_id_to_graph_node_id_map))
print(len(graph_node_id_to_drugbank_id_map))

In [None]:
csvFile = pd.read_csv('common_genes.csv')

In [None]:
print(csvFile.head())
common_gene_ids = csvFile["CommonGenes"]
print(common_gene_ids[0])
# common_gene_ids.append("NDUFA9")
print(len(csvFile))
print(csvFile)

In [None]:
gene_to_graph_node_id_map = {}
graph_node_id_to_gene_map = {}


# to maintain the count of nodes in the graph
count = 0 

for x in common_gene_ids:
    gene_to_graph_node_id_map[x] = count
    graph_node_id_to_gene_map[count] = x
    count = count + 1

In [None]:
print(len(gene_to_graph_node_id_map))
print(len(graph_node_id_to_gene_map))

In [None]:
csvFile = pd.read_csv('disease.csv')
print(csvFile.head())

In [None]:
print(csvFile.head())
disease_name = csvFile["disease"]
print(disease_name[0])
# common_gene_ids.append("NDUFA9")
print(len(csvFile))
print(csvFile)

In [None]:
disease_to_graph_node_id_map = {}
graph_node_id_to_disease_map = {}


# to maintain the count of nodes in the graph
count = 0 

for x in disease_name:
    disease_to_graph_node_id_map[x] = count
    graph_node_id_to_disease_map[count] = x
    count = count + 1

In [None]:
print(len(disease_to_graph_node_id_map))
print(len(graph_node_id_to_disease_map))

## Drug encoding

In [None]:
csvFile = pandas.read_csv('./drugIDandSMILES.csv')
print(len(csvFile))
print(csvFile)
drug_id = csvFile['Drug ID']
smiles = csvFile['SMILES']
print(len(smiles))
drug_id_to_smiles_id_map = {}
for i in range(len(drug_id)):
  drug_id_to_smiles_id_map[drug_id[i]] = smiles[i]

In [None]:
print(len(drugbank_id_to_graph_node_id_map))
print(len(graph_node_id_to_drugbank_id_map))
print(drugbank_id_to_graph_node_id_map.keys())

In [None]:
# from rdkit.Chem import MACCSkeys
# from rdkit import Chem

mol = Chem.MolFromSmiles(smiles[0])
print(list(MACCSkeys.GenMACCSKeys(mol)))
MACCSkeys_length = len(list(MACCSkeys.GenMACCSKeys(mol)))
print(MACCSkeys_length)

In [None]:
print(len(drugbank_id_to_graph_node_id_map))
print(len(graph_node_id_to_drugbank_id_map))

In [None]:
%%time
# CREATING THE NODE EMBEDDINGS

drug_features = torch.zeros(len(graph_node_id_to_drugbank_id_map), MACCSkeys_length)
# LIST TO CAPTURE THE MOLECULES THAT ARE NOT BEING ABLE TO ENCODED AS MACCS_KEYS
molecules_with_no_MACCS_keys = []

for i in range(len(graph_node_id_to_drugbank_id_map)):
    drug_id = graph_node_id_to_drugbank_id_map[i]
    # print(drug_id)
    # break
    if drug_id in drug_id_to_smiles_id_map.keys():
        mol = Chem.MolFromSmiles(drug_id_to_smiles_id_map[drug_id])

        if not mol:
          continue

      # break
        try:
            MACCSkeys__ = list(MACCSkeys.GenMACCSKeys(mol))
            drug_features[i] = torch.Tensor(MACCSkeys__)
        except:
            # print(drug_id[i])
            molecules_with_no_MACCS_keys.append(drug_id[i])      
            # break

    # print(i)

In [None]:
print(drug_features)
print(drug_features.shape)

## Common genes network 

In [None]:
#%%
import pandas as pd
import networkx as nx

df = pd.read_csv("gene-interact-gene.csv")
print(df.head())

In [None]:
# %%
edge_keys = df.columns.to_list()[2:]
feats = ["feat_" + str(x) for x in range(len(edge_keys))]
#%%
df.columns = ["node1","node2"] + feats
#%%
G = nx.from_pandas_edgelist(df,"node1","node2", edge_attr=feats, create_using=nx.MultiDiGraph())
# nx.write_gml(G,"temp.gml")
# %%
import dgl
g = dgl.from_networkx(G, edge_attrs=feats)

#%%
ndataList = []
import dgl.function as fn
for i,f in enumerate(feats):
    ndata_scheme = "nfeat_" + str(i)
    ndataList.append(ndata_scheme)
    g.update_all(fn.copy_e(f,'m'),fn.sum('m',ndata_scheme))
    



In [None]:
print(ndataList)

In [None]:
print(g.ndata)
print(g.edata)
print(common_gene_ids)
print(G.nodes)
print(G.nodes.data())
g_converted = dgl.to_networkx(g, node_attrs= ndataList)
print(g_converted)
print(G.nodes, len(G.nodes))
print(g_converted.nodes, len(g_converted.nodes))
print(g_converted.nodes(data=True))
# print(g_converted.nodes.data())

In [None]:
mapping = {}
for i,e in enumerate(list(G.nodes())):
  mapping[i] = e
mapping

In [None]:
g_converted = nx.relabel_nodes(g_converted, mapping)

In [None]:
print(g_converted.nodes(data = True))

In [None]:
print(len(common_gene_ids))

In [None]:
common_gene_data = []
common_len = len(g_converted.nodes(data=True)['A2M'])
print(common_len)
# print(common_gene_ids)
for x in common_gene_ids: 
  # print(x)
  this_data = []
  if x in g_converted.nodes:
    # print()
    for t in g_converted.nodes[x].keys():
      # print(g_converted.nodes[x][t])
      this_data.append(g_converted.nodes[x][t])
  # print(this_data)
  common_gene_data.append(this_data)
  # break

In [None]:
# print(common_gene_data.size())
common_gene_data_tensor = torch.Tensor(common_gene_data)
print(common_gene_data_tensor)
print(common_gene_data_tensor.shape)


In [None]:
gene_features = common_gene_data_tensor

In [None]:
print(gene_features, gene_features.dtype, gene_features.shape)

In [None]:
n = np.zeros(shape=(1474,167-9)) 
temp = torch.tensor(n, dtype=torch.float32)
print(temp.shape)

In [None]:
new_tensor = torch.cat((gene_features, temp), 1)
print(new_tensor, new_tensor.shape)

In [None]:
gene_features = new_tensor

In [None]:
print(gene_features, gene_features.dtype, gene_features.shape)
print(drug_features, drug_features.dtype, drug_features.shape)

## Forming the hetero graph

In [None]:
drug_drug = pd.read_csv('/content/drug-interact-drug.csv')
drug_gene = pd.read_csv('/content/drug-interact-gene.csv')
gene_gene = pd.read_csv('/content/gene-interact-gene.csv')
drug_disease = pd.read_csv('/content/drug-treat-disease.csv')

# THE MAPS:
 
# drugbank_id_to_graph_node_id_map
# graph_node_id_to_drugbank_id_map
# gene_to_graph_node_id_map
# graph_node_id_to_gene_map
# disease_to_graph_node_id_map
# graph_node_id_to_disease_map

graph_node_id_drugs = list(graph_node_id_to_drugbank_id_map.keys())
graph_node_id_genes = list(graph_node_id_to_gene_map.keys())
graph_node_id_diseases = list(graph_node_id_to_disease_map.keys())

In [None]:
print(drug_drug.head())
print(drug_gene.head())
print(gene_gene.head())
print(drug_disease.head())

In [None]:
print(drug_id_list)
drug_id_list = set(list(drug_id_list))

In [None]:
print(common_gene_ids)
common_gene_ids = set(list(common_gene_ids))

In [None]:
print(disease_name)
disease_name = set(list(disease_name))

In [None]:
## EXTRACT THOSE INTERACTIONS COMMON WITH THE LIST FOR EACH OF THE DRUGS AND GENES AND DISEASES

In [None]:
# drug drug
drug_drug_interactions = drug_drug[drug_drug["Drug1 ID"].isin(drug_id_list)
              & drug_drug["Drug2 ID"].isin(drug_id_list)]
drug_drug_interactions_drug1 = drug_drug_interactions["Drug1 ID"]
drug_drug_interactions_drug2 = drug_drug_interactions["Drug2 ID"]
print(len(drug_drug_interactions_drug1), len(drug_drug_interactions_drug2))

In [None]:
# drug gene
drug_gene_interactions = drug_gene[drug_gene["drugbank_id"].isin(drug_id_list)
              & drug_gene["gene_symbol"].isin(common_gene_ids)]
drug_gene_interactions_drug1 = drug_gene_interactions["drugbank_id"]
drug_gene_interactions_gene1 = drug_gene_interactions["gene_symbol"]
print(len(drug_gene_interactions_drug1), len(drug_gene_interactions_gene1))

In [None]:
# gene gene
gene_gene_interactions = gene_gene[gene_gene["node1"].isin(common_gene_ids)
              & gene_gene["node2"].isin(common_gene_ids)]
gene_gene_interactions_gene1 = gene_gene_interactions["node1"]
gene_gene_interactions_gene2 = gene_gene_interactions["node2"]

gene_gene_interactions_egde_automated_textmining = gene_gene_interactions['automated_textmining']

print(len(gene_gene_interactions_gene1), len(gene_gene_interactions_gene2))

In [None]:
# drug disease
drug_disease_interactions = drug_disease[drug_disease["drugbank_id"].isin(drug_id_list)
              & drug_disease["disease"].isin(disease_name)]
drug_disease_interactions_drug1 = drug_disease_interactions["drugbank_id"]
drug_disease_interactions_disease1 = drug_disease_interactions["disease"]
print(len(drug_disease_interactions_drug1), len(drug_disease_interactions_disease1))

In [None]:
def map_drugbank_id_to_graph_node_id(x):
  return drugbank_id_to_graph_node_id_map[x]

def map_gene_id_to_graph_node_id(x):
  return gene_to_graph_node_id_map[x]

def map_disease_id_to_graph_node_id(x):
  return disease_to_graph_node_id_map[x]


In [None]:
graph_data = {
    ('drug', 'drug_drug', 'drug'): (th.tensor(list(map(map_drugbank_id_to_graph_node_id, drug_drug_interactions_drug1))),
                                    th.tensor(list(map(map_drugbank_id_to_graph_node_id, drug_drug_interactions_drug2)))),
    # ('drug', 'drug_disease', 'disease'): (th.tensor(list(map(map_drugbank_id_to_graph_node_id, drug_disease_interactions_drug1))),
    #                                 th.tensor(list(map(map_disease_id_to_graph_node_id, drug_disease_interactions_disease1)))),
    ('drug', 'drug_gene', 'gene'): (th.tensor(list(map(map_drugbank_id_to_graph_node_id, drug_gene_interactions_drug1))),
                                    th.tensor(list(map(map_gene_id_to_graph_node_id, drug_gene_interactions_gene1)))),
    ('gene', 'gene_gene', 'gene'): (th.tensor(list(map(map_gene_id_to_graph_node_id, gene_gene_interactions_gene1))),
                                    th.tensor(list(map(map_gene_id_to_graph_node_id, gene_gene_interactions_gene2))))
}
g = dgl.heterograph(graph_data)

In [None]:
print(g)

In [None]:

g.edges['gene_gene'].data['automated_textmining'] = th.tensor(list(gene_gene_interactions_egde_automated_textmining)).float()


In [None]:
print(g.edges['gene_gene'].data['automated_textmining'])

In [None]:
print(g)

In [None]:
print(g.num_nodes('drug'))
print(g.num_nodes('gene'))
# print(g.num_nodes('disease'))
print(g.edges['gene_gene'])

In [None]:
print(g.num_edges(('drug', 'drug_gene', 'gene')))
print(g.num_edges('drug_gene'))

In [None]:
n_features = 167
n_drug_features = n_features
n_gene_features = n_features
n_disease_features = n_features

In [None]:
g.nodes['drug'].data['feature'] = drug_features
g.nodes['gene'].data['feature'] = gene_features
# g.nodes['disease'].data['feature'] = torch.randn(g.num_nodes('disease'), n_disease_features)


In [None]:

g.edges['drug_gene'].data['label'] = torch.randint(1, g.num_edges('drug_gene'), (g.num_edges('drug_gene'),)).float()


In [None]:
# https://docs.dgl.ai/en/0.6.x/tutorials/basics/5_hetero.html
# https://docs.dgl.ai/en/0.6.x/guide/training-link.html
# https://docs.dgl.ai/en/0.6.x/tutorials/blitz/4_link_predict.html

## Model, Training and Metrics

In [None]:
# Heterograph Conv model

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv3 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(2*hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        # print(h.keys())
        # print(h["drug"].shape, h["gene"].shape)
        h2 = h
        # print(h2["drug"].shape, h2["gene"].shape)
        h3 = self.conv2(graph, h2)
        h3["drug"] = torch.cat((h2["drug"],h3["drug"]), 1)
        h3["gene"] = torch.cat((h2["gene"],h3["gene"]), 1)
        h3 = {k: F.relu(v) for k, v in h3.items()}
        # print(h3["drug.shape"])
        # h3 = torch.cat(h2,h3)
        # print(h3["drug"].shape, h3["gene"].shape)
        # h3 = h2 + h3
        h4 = self.conv3(graph, h3)
        h4["drug"] = torch.cat((h4["drug"],h3["drug"]), 1)
        h4["gene"] = torch.cat((h4["gene"],h3["gene"]), 1)
        # print(h4["drug"].shape, h4["gene"].shape)
        # print(h4)
        return h4


class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):        
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']


def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})


class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()

    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)


In [None]:


def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()


### 1 forward pass of the model

In [None]:
set_seed()
output_features = 5
model = Model(n_features, n_features*2 , output_features, g.etypes)
drug_feats = g.nodes['drug'].data['feature']
gene_feats = g.nodes['gene'].data['feature']
# disease_feats = g.nodes['disease'].data['feature']
# node_features = {'drug': drug_feats, 'gene': gene_feats, 'disease': disease_feats}
node_features = {'drug': drug_feats, 'gene': gene_feats}
opt = torch.optim.Adam(model.parameters())

negative_graph = construct_negative_graph(g, output_features, ('drug', 'drug_gene', 'gene'))
pos_score, neg_score = model(g, negative_graph, node_features, ('drug', 'drug_gene', 'gene'))
loss = compute_loss(pos_score, neg_score)
print(loss)

### AUC

In [None]:
from sklearn.metrics import roc_auc_score
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).cpu().detach().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

In [None]:
print(compute_auc(pos_score, neg_score))

### HITS@K

In [None]:
print(pos_score.shape, neg_score.shape, torch.flatten(pos_score).shape, torch.flatten(neg_score).shape)

In [None]:
from ogb.linkproppred import Evaluator
evaluator = Evaluator(name='ogbl-ddi')
K = 10000
evaluator.K = K
hits = evaluator.eval({
    'y_pred_pos': torch.flatten(pos_score),
    'y_pred_neg': torch.flatten(neg_score),
})[f'hits@{K}']

In [None]:
print(hits)

### Training of the model

In [None]:
output_features = 5
model = Model(n_features, n_features*2 , output_features, g.etypes)
drug_feats = g.nodes['drug'].data['feature']
gene_feats = g.nodes['gene'].data['feature']
# disease_feats = g.nodes['disease'].data['feature']
# node_features = {'drug': drug_feats, 'gene': gene_feats, 'disease': disease_feats}
node_features = {'drug': drug_feats, 'gene': gene_feats}
opt = torch.optim.Adam(model.parameters())

## change the number of epochs to 501 
for epoch in range(501):
    negative_graph = construct_negative_graph(g, output_features, ('drug', 'drug_gene', 'gene'))
    pos_score, neg_score = model(g, negative_graph, node_features, ('drug', 'drug_gene', 'gene'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch%10 == 0:
        print(loss.item())

### AUC

In [None]:
from sklearn.metrics import roc_auc_score
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).cpu().detach().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

In [None]:
print(compute_auc(pos_score, neg_score))

### HITS@K

In [None]:
print(pos_score.shape, neg_score.shape, torch.flatten(pos_score).shape, torch.flatten(neg_score).shape)

In [None]:
from ogb.linkproppred import Evaluator
evaluator = Evaluator(name='ogbl-ddi')
K = 10000
evaluator.K = K
hits = evaluator.eval({
    'y_pred_pos': torch.flatten(pos_score),
    'y_pred_neg': torch.flatten(neg_score),
})[f'hits@{K}']

In [None]:
print(hits)

## Making predictions for the drug-gene interactions

In [None]:
pos_score, neg_score = model(g, negative_graph, node_features, ('drug', 'drug_gene', 'gene'))

print(pos_score.shape, neg_score.shape)
print(node_features)

In [None]:
print(g)

In [None]:
def predict_score_between(drug_node_id_in_graph, gene_node_id_in_graph):
    test_drug = drug_features[drug_node_id_in_graph]
    # print(test_drug.shape)
    test_drug = torch.reshape(test_drug, (1, -1))
    test_gene = gene_features[gene_node_id_in_graph]
    # print(test_gene.shape)
    test_gene = torch.reshape(test_gene, (1,-1))
    # print(negative_graph)
    negative_graph.add_edges(torch.tensor([drug_node_id_in_graph]), torch.tensor([gene_node_id_in_graph]), etype = "drug_gene")
    pos_score, neg_score = model(g, negative_graph, node_features, ('drug', 'drug_gene', 'gene'))
    # print(pos_score.shape, neg_score.shape)
    # print(node_features)
    print(neg_score[-1])


### If no edge originally present between a gene and a drug

In [None]:
drug_node_id_in_graph = 6
gene_node_id_in_graph = 0

print(graph_node_id_to_drugbank_id_map[drug_node_id_in_graph], graph_node_id_to_gene_map[gene_node_id_in_graph])

print(g.has_edges_between(drug_node_id_in_graph, gene_node_id_in_graph, etype='drug_gene'))

print(negative_graph.has_edges_between(drug_node_id_in_graph, gene_node_id_in_graph, etype='drug_gene'))


In [None]:
predict_score_between(drug_node_id_in_graph, gene_node_id_in_graph)

### If edge originally present in graph

In [None]:
drug_node_id_in_graph = 4402
gene_node_id_in_graph = 1009

print(graph_node_id_to_drugbank_id_map[drug_node_id_in_graph], graph_node_id_to_gene_map[gene_node_id_in_graph])

print(g.has_edges_between(drug_node_id_in_graph, gene_node_id_in_graph, etype='drug_gene'))

print(negative_graph.has_edges_between(drug_node_id_in_graph, gene_node_id_in_graph, etype='drug_gene'))


In [None]:
predict_score_between(drug_node_id_in_graph, gene_node_id_in_graph)