In [4]:
import pandas as pd
import numpy as np
import networkx as nx

import pickle

import torch
from torch.utils.data import DataLoader

from torch_geometric.nn.models import MetaPath2Vec
from torch_geometric.data import HeteroData

from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch

from gnn.linkpred_model import LinkPredModel, train, test

from sklearn.metrics import roc_auc_score, roc_curve, f1_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt

In [5]:
torch_device = 'cpu'
print('Using device:', torch_device)

Using device: cpu


Specify which dataset is used, 1 (original knowledge graph) or 2 (restructured knowledge graph)

In [7]:
dataset_nr = 1
assert dataset_nr == 1, 2

# Load nodes and edges

Load the nodes

In [8]:
nodes = pd.read_csv(f'output/indexed_nodes_{dataset_nr}.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,WormBase:WBGene00000389,ORTH,cdc-25.4,5
1,1,ZP:0018675,DISO,right side lateral plate mesoderm mislocalised...,1
2,2,ZFIN:ZDB-GENE-040426-1197,ORTH,tbc1d5,5
3,3,5,DRUG,(S)-nicardipine,2
4,4,RGD:3443,ORTH,Ptk2,5
...,...,...,...,...,...
10029,10029,MP:0009763,DISO,increased sensitivity to induced morbidity/mor...,1
10030,10030,MP:0011057,DISO,absent brain ependyma motile cilia,1
10031,10031,MP:0001412,DISO,excessive scratching,1
10032,10032,WBPhenotype:0004023,DISO,frequency of body bend variant,1


In [9]:
nodes.shape

(10034, 5)

In [10]:
node_semantics = nodes[['semantic', 'semantic_id']].drop_duplicates().set_index('semantic_id').to_dict()
node_semantics_dict = node_semantics['semantic']
node_semantics_dict

{5: 'ORTH',
 1: 'DISO',
 2: 'DRUG',
 4: 'GENO',
 7: 'VARI',
 3: 'GENE',
 0: 'ANAT',
 6: 'PHYS'}

Load the edges

In [11]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges.replace({'class_head': node_semantics_dict, 'class_tail': node_semantics_dict}, inplace=True)
edges['relation'].fillna('na', inplace=True)
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,ZFIN:ZDB-GENE-050626-112,myl4,ORTH,5279,in orthology relationship with,FlyBase:FBgn0085464,CG34435,ORTH,6825,0
1,ZFIN:ZDB-GENE-050626-112,myl4,ORTH,5279,in orthology relationship with,HGNC:7585,MYL4,GENE,27,0
2,ZFIN:ZDB-GENE-050626-112,myl4,ORTH,5279,in orthology relationship with,FlyBase:FBgn0002772,Mlc1,ORTH,8901,0
3,ZFIN:ZDB-GENE-050626-112,myl4,ORTH,5279,in orthology relationship with,NCBIGene:396472,MYL4,GENE,9508,0
4,ZFIN:ZDB-GENE-050626-112,myl4,ORTH,5279,in 1 to 1 orthology relationship with,ENSEMBL:ENSECAG00000020967,ENSEMBL:ENSECAG00000020967,ORTH,8807,1
...,...,...,...,...,...,...,...,...,...,...
82908,4810,ibrutinib,DRUG,1618,targets,HGNC:11283,SRC,GENE,3279,14
82909,522,carvedilol,DRUG,184,targets,HGNC:620,APP,GENE,547,14
82910,OMIM:300377.0013,"DMD, EX18DEL",DISO,2822,is allele of,HGNC:2928,DMD,GENE,6612,16
82911,Coriell:GM05113,NIGMS-GM05113,GENO,8105,has role in modeling,MONDO:0010679,Duchenne muscular dystrophy,DISO,6315,15


# Generate data formats

In [12]:
metapath_df = edges[['class_head', 'relation', 'class_tail']].drop_duplicates().reset_index().drop(columns=['index'])
metapath_df

Unnamed: 0,class_head,relation,class_tail
0,ORTH,in orthology relationship with,ORTH
1,ORTH,in orthology relationship with,GENE
2,ORTH,in 1 to 1 orthology relationship with,ORTH
3,ORTH,in 1 to 1 orthology relationship with,GENO
4,ORTH,expressed in,ANAT
...,...,...,...
66,GENE,is marker for,DISO
67,DRUG,na,DISO
68,ORTH,has affected feature,ORTH
69,ORTH,contributes to condition,DISO


In [13]:
metapaths = list()

metapath_array = metapath_df.to_records(index=False)
metapath_array = list(metapath_array)
for metapath_triplet in metapath_array:
    triplet = tuple(metapath_triplet)
    metapaths.append(triplet)
    
metapaths

[('ORTH', 'in orthology relationship with', 'ORTH'),
 ('ORTH', 'in orthology relationship with', 'GENE'),
 ('ORTH', 'in 1 to 1 orthology relationship with', 'ORTH'),
 ('ORTH', 'in 1 to 1 orthology relationship with', 'GENO'),
 ('ORTH', 'expressed in', 'ANAT'),
 ('ORTH', 'in 1 to 1 orthology relationship with', 'GENE'),
 ('ORTH', 'is part of', 'PHYS'),
 ('ORTH', 'has phenotype', 'DISO'),
 ('ORTH', 'enables', 'PHYS'),
 ('ORTH', 'interacts with', 'ORTH'),
 ('ORTH', 'involved in', 'PHYS'),
 ('ORTH', 'colocalizes with', 'ORTH'),
 ('ORTH', 'in orthology relationship with', 'GENO'),
 ('GENE', 'has phenotype', 'DISO'),
 ('GENE', 'is part of', 'PHYS'),
 ('GENE', 'expressed in', 'ANAT'),
 ('GENE', 'enables', 'PHYS'),
 ('GENE', 'interacts with', 'GENE'),
 ('GENE', 'involved in', 'PHYS'),
 ('GENE', 'in orthology relationship with', 'ORTH'),
 ('GENE', 'in 1 to 1 orthology relationship with', 'ORTH'),
 ('GENE', 'interacts with', 'ORTH'),
 ('GENE', 'in 1 to 1 orthology relationship with', 'GENO'),
 (

In [14]:
data = HeteroData()

for metapath in metapaths:
    src_node_type, rel_type, dst_node_type = metapath
    metapath_edges = edges.loc[(edges['class_head'] == src_node_type) & (edges['relation'] == rel_type) & (edges['class_tail'] == dst_node_type)]
    metapath_edge_head = metapath_edges['index_head'].values.tolist()
    metapath_edge_tail = metapath_edges['index_tail'].values.tolist()
    
    metapath_edge_index = [metapath_edge_head, metapath_edge_tail]
    metapath_edge_index = torch.LongTensor(metapath_edge_index)
    
    data[src_node_type, rel_type, dst_node_type].edge_index = metapath_edge_index

data.edge_index_dict

{('ORTH',
  'in orthology relationship with',
  'ORTH'): tensor([[5279, 5279, 5279,  ..., 8274, 8274, 8683],
         [6825, 8901,  753,  ..., 7960, 4612, 3455]]),
 ('ORTH',
  'in orthology relationship with',
  'GENE'): tensor([[5279, 5279, 1869,  ..., 9121, 6818, 5960],
         [  27, 9508, 6980,  ..., 2331, 2971, 7581]]),
 ('ORTH',
  'in 1 to 1 orthology relationship with',
  'ORTH'): tensor([[5279, 5279, 5279,  ..., 8683, 8683, 8683],
         [8807, 6449,  904,  ..., 6576, 7251, 8175]]),
 ('ORTH',
  'in 1 to 1 orthology relationship with',
  'GENO'): tensor([[ 5279,  5279,  8727,  9633,  9633,  7771,  7771,  3230,  9038,  8022,
           8022,  7964,  3545,  3545,  9377,  9913,  9913,  5613,  1935,  1935,
           1857,  7119,  1496,  2298,  4919,  8596,  8596,  4474,  4474,  1836,
           1836,  7874,  7874,   333,  7853,  7853,  9796,  9283,  9283,  8418,
           4207,  4207,  2140,  1013,  1013,  8323,  8323,  5628,  6103,   859,
           9917,  2137,  2137,  1900, 

In [15]:
node_types = list(node_semantics_dict.values())

for node_type in node_types:
    all_nodes = nodes.loc[(nodes['semantic'] == node_type)]
    node_index_list = all_nodes['index_id'].values.tolist()
    node_index = torch.LongTensor(node_index_list)
    
    data[node_type].y_index = node_index
    
data

HeteroData(
  [1mORTH[0m={ y_index=[2880] },
  [1mDISO[0m={ y_index=[5146] },
  [1mDRUG[0m={ y_index=[202] },
  [1mGENO[0m={ y_index=[409] },
  [1mVARI[0m={ y_index=[1125] },
  [1mGENE[0m={ y_index=[202] },
  [1mANAT[0m={ y_index=[20] },
  [1mPHYS[0m={ y_index=[50] },
  [1m(ORTH, in orthology relationship with, ORTH)[0m={ edge_index=[2, 18418] },
  [1m(ORTH, in orthology relationship with, GENE)[0m={ edge_index=[2, 1058] },
  [1m(ORTH, in 1 to 1 orthology relationship with, ORTH)[0m={ edge_index=[2, 24676] },
  [1m(ORTH, in 1 to 1 orthology relationship with, GENO)[0m={ edge_index=[2, 303] },
  [1m(ORTH, expressed in, ANAT)[0m={ edge_index=[2, 412] },
  [1m(ORTH, in 1 to 1 orthology relationship with, GENE)[0m={ edge_index=[2, 2225] },
  [1m(ORTH, is part of, PHYS)[0m={ edge_index=[2, 990] },
  [1m(ORTH, has phenotype, DISO)[0m={ edge_index=[2, 11299] },
  [1m(ORTH, enables, PHYS)[0m={ edge_index=[2, 566] },
  [1m(ORTH, interacts with, ORTH)[0m={ edg

# Train MetaPath2Vec Model

In [16]:
model = MetaPath2Vec(data.edge_index_dict, embedding_dim=128,
                     metapath=metapaths, walk_length=50, context_size=7,
                     walks_per_node=5, num_negative_samples=5,
                     sparse=True).to(torch_device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=6)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

In [17]:
def train(epoch, log_steps=10):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(torch_device), neg_rw.to(torch_device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

@torch.no_grad()
def get_embedding(node_type):
    model.eval()
    z = model(node_type, batch=data.y_index_dict[node_type]).detach().numpy()
    return z

In [18]:
for epoch in range(1, 6):
    train(epoch)
    emb = get_embedding('ORTH')
    print(f'Epoch: {epoch}, Embedding: {emb.shape}\n{emb}')

Epoch: 1, Step: 00010/79, Loss: 5.5011
Epoch: 1, Step: 00020/79, Loss: 5.2528
Epoch: 1, Step: 00030/79, Loss: 5.0107
Epoch: 1, Step: 00040/79, Loss: 4.7622
Epoch: 1, Step: 00050/79, Loss: 4.5138
Epoch: 1, Step: 00060/79, Loss: 4.2774
Epoch: 1, Step: 00070/79, Loss: 4.0482
Epoch: 1, Embedding: (2880, 128)
[[ 0.71128476 -0.9895405  -0.10159381 ... -1.0421782   1.7574248
  -1.092581  ]
 [ 0.87899595  1.2059216   0.97764945 ...  0.4893869  -0.09185988
   0.02957912]
 [-0.22075064  1.468028    1.7073915  ... -1.130685    2.9690204
  -0.9671127 ]
 ...
 [-0.78040725 -0.9223599  -0.24881792 ... -1.9574789  -0.5536764
   1.006404  ]
 [ 0.36810803  0.6349618   1.3197838  ... -0.13620941  0.10044321
   1.082444  ]
 [ 0.2056579  -0.7279908  -1.1933768  ... -0.98145723  0.47630414
  -0.35820094]]
Epoch: 2, Step: 00010/79, Loss: 3.6031
Epoch: 2, Step: 00020/79, Loss: 3.4004
Epoch: 2, Step: 00030/79, Loss: 3.2070
Epoch: 2, Step: 00040/79, Loss: 3.0174
Epoch: 2, Step: 00050/79, Loss: 2.8479
Epoch: 2, 

# Collect Node Embeddings from Trained Model

In [23]:
data_node_types = data.collect('y_index')
index_emb_dict = {}

for data_node_type in list(data_node_types.keys()):
    indices = np.array(data[data_node_type].y_index) 
    emb = get_embedding(data_node_type)
    
    for index, embedding in zip(indices, emb):
        index_emb_dict[index] = embedding

In [24]:
metapath2vec_embedding = pd.DataFrame.from_dict(index_emb_dict, orient='index').sort_index()
metapath2vec_embedding

Unnamed: 0,0
0,"[0.11802676, -0.5340533, -0.023817837, -0.2197..."
1,"[0.49198022, 0.12364149, -0.055373993, -0.2258..."
2,"[0.57802266, 0.7560725, 0.6505928, 0.06400763,..."
3,"[0.2976379, 0.9546674, 0.019457819, 1.5472721,..."
4,"[-0.23311405, 0.99747473, 1.1609797, -0.263719..."
...,...
10029,"[0.43762165, -0.9704756, 0.48259053, 0.3617342..."
10030,"[-0.3976929, -0.5542564, -0.51212436, -0.11746..."
10031,"[-0.5442572, -0.060033336, 0.2244823, -0.41929..."
10032,"[0.039930098, 0.64827967, 0.03239383, -0.12917..."


In [21]:
metapath2vec_embedding.to_csv(f'output/metapath2vec_embedding_{dataset_nr}.csv', index=False)