In [1]:
import pandas as pd
import numpy as np

import torch

from torch_geometric.nn.models import MetaPath2Vec
from torch_geometric.data import HeteroData

In [2]:
torch_device = 'cpu'
print('Using device:', torch_device)

Using device: cpu


Specify which dataset is used, 1 (original knowledge graph) or 2 (restructured knowledge graph)

In [3]:
dataset_nr = 2
assert dataset_nr == 1 or 2

# Load nodes and edges

Load the nodes

In [4]:
nodes = pd.read_csv(f'output/indexed_nodes_{dataset_nr}.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,MP:0004187,phenotype,cardia bifida,9
1,1,ZP:0100138,phenotype,muscle tendon junction myotome increased amoun...,9
2,2,MGI:1346525,gene,Sgcd,5
3,3,OMIM:300377.0044,variant,"DMD, LYS770TER",11
4,4,ZP:0002210,phenotype,posterior lateral line neuromast primordium mi...,9
...,...,...,...,...,...
10270,10270,ZP:0014934,phenotype,atrioventricular valve development process qua...,9
10271,10271,ENSEMBL:ENSCAFG00000011207,gene,ENSEMBL:ENSCAFG00000011207,5
10272,10272,ENSEMBL:ENSXETG00000039922,gene,ENSEMBL:ENSXETG00000039922,5
10273,10273,ENSEMBL:ENSACAG00000010058,gene,ENSEMBL:ENSACAG00000010058,5


In [5]:
nodes.shape

(10275, 5)

In [6]:
node_semantics = nodes[['semantic', 'semantic_id']].drop_duplicates().set_index('semantic_id').to_dict()
node_semantics_dict = node_semantics['semantic']
node_semantics_dict

{9: 'phenotype',
 5: 'gene',
 11: 'variant',
 6: 'gene product',
 4: 'drug',
 10: 'taxon',
 0: 'biological artifact',
 7: 'genotype',
 1: 'biological process',
 3: 'disease',
 2: 'cellular component',
 8: 'molecular function'}

Load the edges

In [7]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00003929,pat-2,5,1542,0
1,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00006789,unc-54,5,6544,0
2,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSSSCG00000015555,LAMC1,5,9268,1
3,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ZFIN:ZDB-GENE-021226-3,lamc1,5,5387,1
4,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSOANG00000001050,ENSEMBL:ENSOANG00000001050,5,2204,1
...,...,...,...,...,...,...,...,...,...,...
85987,458,scopolamine butylbromide,4,5945,targets,P11229,Muscarinic acetylcholine receptor M1,6,5919,17
85988,OMIM:300377.0080,"DMD, IVS62, A-G, -285",11,1578,is allele of,HGNC:2928,DMD,5,3310,15
85989,5297,dacomitinib,4,8798,targets,P12931,Proto-oncogene tyrosine-protein kinase Src,6,2379,17
85990,ClinVarVariant:981988,NC_000023.11:g.(31875374_31929595)_(31968515_3...,11,8189,has affected feature,HGNC:2928,DMD,5,3310,11


In [8]:
edges.replace({'class_head': node_semantics_dict, 'class_tail': node_semantics_dict}, inplace=True)
edges['relation'].fillna('na', inplace=True)
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,WormBase:WBGene00006787,unc-52,gene,304,interacts with,WormBase:WBGene00003929,pat-2,gene,1542,0
1,WormBase:WBGene00006787,unc-52,gene,304,interacts with,WormBase:WBGene00006789,unc-54,gene,6544,0
2,WormBase:WBGene00006787,unc-52,gene,304,in orthology relationship with,ENSEMBL:ENSSSCG00000015555,LAMC1,gene,9268,1
3,WormBase:WBGene00006787,unc-52,gene,304,in orthology relationship with,ZFIN:ZDB-GENE-021226-3,lamc1,gene,5387,1
4,WormBase:WBGene00006787,unc-52,gene,304,in orthology relationship with,ENSEMBL:ENSOANG00000001050,ENSEMBL:ENSOANG00000001050,gene,2204,1
...,...,...,...,...,...,...,...,...,...,...
85987,458,scopolamine butylbromide,drug,5945,targets,P11229,Muscarinic acetylcholine receptor M1,gene product,5919,17
85988,OMIM:300377.0080,"DMD, IVS62, A-G, -285",variant,1578,is allele of,HGNC:2928,DMD,gene,3310,15
85989,5297,dacomitinib,drug,8798,targets,P12931,Proto-oncogene tyrosine-protein kinase Src,gene product,2379,17
85990,ClinVarVariant:981988,NC_000023.11:g.(31875374_31929595)_(31968515_3...,variant,8189,has affected feature,HGNC:2928,DMD,gene,3310,11


Below edge shuffle is necessary to avoid any errors with dataloader. TODO

In [9]:
if dataset_nr == 1:
    node_type_first = 'ORTH'
else:
    node_type_first = 'genotype'

matching_indices = edges.index[edges['class_head'] == node_type_first].tolist()

index_1 = 0
index_2 = matching_indices[0]
print('Switch rows with indices', index_1, 'and', index_2)

temp_row = edges.iloc[index_1].copy()
edges.iloc[index_1] = edges.iloc[index_2]
edges.iloc[index_2] = temp_row
        
edges

Switch rows with indices 0 and 898


Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,MGI:3621491,Dmd<mdx>/Dmd<mdx> [C57BL/10ScSn-Dmd<mdx>/J],genotype,5796,causes condition,MP:0011966,abnormal auditory brainstem response waveform ...,phenotype,8236,2
1,WormBase:WBGene00006787,unc-52,gene,304,interacts with,WormBase:WBGene00006789,unc-54,gene,6544,0
2,WormBase:WBGene00006787,unc-52,gene,304,in orthology relationship with,ENSEMBL:ENSSSCG00000015555,LAMC1,gene,9268,1
3,WormBase:WBGene00006787,unc-52,gene,304,in orthology relationship with,ZFIN:ZDB-GENE-021226-3,lamc1,gene,5387,1
4,WormBase:WBGene00006787,unc-52,gene,304,in orthology relationship with,ENSEMBL:ENSOANG00000001050,ENSEMBL:ENSOANG00000001050,gene,2204,1
...,...,...,...,...,...,...,...,...,...,...
85987,458,scopolamine butylbromide,drug,5945,targets,P11229,Muscarinic acetylcholine receptor M1,gene product,5919,17
85988,OMIM:300377.0080,"DMD, IVS62, A-G, -285",variant,1578,is allele of,HGNC:2928,DMD,gene,3310,15
85989,5297,dacomitinib,drug,8798,targets,P12931,Proto-oncogene tyrosine-protein kinase Src,gene product,2379,17
85990,ClinVarVariant:981988,NC_000023.11:g.(31875374_31929595)_(31968515_3...,variant,8189,has affected feature,HGNC:2928,DMD,gene,3310,11


In [10]:
edges.iloc[index_2]

head          WormBase:WBGene00006787
label_head                     unc-52
class_head                       gene
index_head                        304
relation               interacts with
tail          WormBase:WBGene00003929
label_tail                      pat-2
class_tail                       gene
index_tail                       1542
type                                0
Name: 898, dtype: object

# Generate data formats

In [11]:
metapath_df = edges[['class_head', 'relation', 'class_tail']].drop_duplicates().reset_index().drop(columns=['index'])
metapath_df

Unnamed: 0,class_head,relation,class_tail
0,genotype,causes condition,phenotype
1,gene,interacts with,gene
2,gene,in orthology relationship with,gene
3,gene,causes condition,phenotype
4,gene,involved in,biological process
5,gene,found in,taxon
6,gene,is part of,cellular component
7,gene,enables,molecular function
8,gene,colocalizes with,gene
9,gene,has role in modeling,disease


In [12]:
metapaths = list()

metapath_array = metapath_df.to_records(index=False)
metapath_array = list(metapath_array)
for metapath_triplet in metapath_array:
    triplet = tuple(metapath_triplet)
    metapaths.append(triplet)
    
metapaths

[('genotype', 'causes condition', 'phenotype'),
 ('gene', 'interacts with', 'gene'),
 ('gene', 'in orthology relationship with', 'gene'),
 ('gene', 'causes condition', 'phenotype'),
 ('gene', 'involved in', 'biological process'),
 ('gene', 'found in', 'taxon'),
 ('gene', 'is part of', 'cellular component'),
 ('gene', 'enables', 'molecular function'),
 ('gene', 'colocalizes with', 'gene'),
 ('gene', 'has role in modeling', 'disease'),
 ('gene', 'contributes to condition', 'phenotype'),
 ('genotype', 'has role in modeling', 'disease'),
 ('genotype', 'expresses gene', 'gene'),
 ('variant', 'has affected feature', 'gene'),
 ('biological artifact', 'has role in modeling', 'disease'),
 ('biological artifact', 'is of', 'taxon'),
 ('gene', 'contributes to condition', 'disease'),
 ('variant', 'causes condition', 'disease'),
 ('variant', 'likely causes condition', 'disease'),
 ('variant', 'causes condition', 'phenotype'),
 ('variant', 'is variant in', 'genotype'),
 ('variant', 'has role in model

In [13]:
len(metapaths)

37

In [14]:
data = HeteroData() # Data object describing a heterogeneous graph holding multiple node and edge types

all_nodes = set()   # To check which nodes are included in the data
edges_count = 0

for metapath in metapaths:
    src_node_type, rel_type, dst_node_type = metapath
    metapath_edges = edges.loc[(edges['class_head'] == src_node_type) & (edges['relation'] == rel_type) & (edges['class_tail'] == dst_node_type)]
    
    edges_count += metapath_edges.shape[0]
    
    metapath_edge_head = metapath_edges['index_head'].values.tolist()
    metapath_edge_tail = metapath_edges['index_tail'].values.tolist()
    
    all_nodes.update(metapath_edge_head)
    all_nodes.update(metapath_edge_tail)
    
    metapath_edge_index = [metapath_edge_head, metapath_edge_tail]
    metapath_edge_index = torch.LongTensor(metapath_edge_index)
    
    data[src_node_type, rel_type, dst_node_type].edge_index = metapath_edge_index

print(edges_count)

85992


In [15]:
data.edge_index_dict

{('genotype',
  'causes condition',
  'phenotype'): tensor([[ 5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,   472,   472,   472,   472,   472,   472,   472,   472,
            780,   780,   780,   780,   780,   780,   780,   780,   780,   780,
            780,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,
           5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,
           5821,  5821,  5821,  8308,  8308,  8308,  8308,  8308,  8308,  8308,
           8308,  8308,  8308,  8308,  8308,  8308,  8308,  1168,  1168,  1168,
           1168,  1168,  1168,  1168,  1168,  1168,  1168,  1168,  8448,  8448,
           8448,  8448,  8448,  8448,  8448,  8448,  8448,  8448,  84

In [16]:
len(all_nodes)

10275

In [17]:
max(all_nodes)

10274

In [18]:
node_types = list(node_semantics_dict.values())
node_types

['phenotype',
 'gene',
 'variant',
 'gene product',
 'drug',
 'taxon',
 'biological artifact',
 'genotype',
 'biological process',
 'disease',
 'cellular component',
 'molecular function']

In [19]:
node_count = 0

for node_type in node_types:
    all_nodes = nodes.loc[nodes['semantic'] == node_type]
    
    node_count += all_nodes.shape[0]
    
    node_index_list = all_nodes['index_id'].values.tolist()
    node_index = torch.LongTensor(node_index_list)
    
    data[node_type].y_index = node_index
    data[node_type].num_nodes = len(node_index_list)

node_count

10275

In [20]:
data

HeteroData(
  [1mphenotype[0m={
    y_index=[5311],
    num_nodes=5311
  },
  [1mgene[0m={
    y_index=[3163],
    num_nodes=3163
  },
  [1mvariant[0m={
    y_index=[1277],
    num_nodes=1277
  },
  [1mgene product[0m={
    y_index=[38],
    num_nodes=38
  },
  [1mdrug[0m={
    y_index=[291],
    num_nodes=291
  },
  [1mtaxon[0m={
    y_index=[26],
    num_nodes=26
  },
  [1mbiological artifact[0m={
    y_index=[71],
    num_nodes=71
  },
  [1mgenotype[0m={
    y_index=[36],
    num_nodes=36
  },
  [1mbiological process[0m={
    y_index=[24],
    num_nodes=24
  },
  [1mdisease[0m={
    y_index=[12],
    num_nodes=12
  },
  [1mcellular component[0m={
    y_index=[17],
    num_nodes=17
  },
  [1mmolecular function[0m={
    y_index=[9],
    num_nodes=9
  },
  [1m(genotype, causes condition, phenotype)[0m={ edge_index=[2, 460] },
  [1m(gene, interacts with, gene)[0m={ edge_index=[2, 7813] },
  [1m(gene, in orthology relationship with, gene)[0m={ edge_index=[2,

In [21]:
data.num_nodes_dict

{'phenotype': 5311,
 'gene': 3163,
 'variant': 1277,
 'gene product': 38,
 'drug': 291,
 'taxon': 26,
 'biological artifact': 71,
 'genotype': 36,
 'biological process': 24,
 'disease': 12,
 'cellular component': 17,
 'molecular function': 9}

In [22]:
data.num_nodes

10275

Below the method of creating the `num_nodes_dict` for the given `edge_index_dict` as used in the `MetaPatch2Vec` class. \
Source: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/metapath2vec.html#MetaPath2Vec.

In [35]:
num_nodes_dict = {}
for keys, edge_index in data.edge_index_dict.items():
    key = keys[0]   # head node type per metapath
    N = int(edge_index[0].max() + 1)    # head node of edge, maximum index + 1
    num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))    # get maximum index found for node type

    key = keys[-1]  # tail node type per metapath
    N = int(edge_index[1].max() + 1)    # tail node of edge, maximum index + 1
    num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))    # get maximum index found for node type
        
print(num_nodes_dict)

{'genotype': 10264, 'phenotype': 10275, 'gene': 10274, 'biological process': 9624, 'taxon': 10185, 'cellular component': 9480, 'molecular function': 9380, 'disease': 9796, 'variant': 10265, 'biological artifact': 10256, 'drug': 10251, 'gene product': 10003}


# Train MetaPath2Vec Model

In [36]:
data.edge_index_dict

{('genotype',
  'causes condition',
  'phenotype'): tensor([[ 5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,  5796,
           5796,  5796,   472,   472,   472,   472,   472,   472,   472,   472,
            780,   780,   780,   780,   780,   780,   780,   780,   780,   780,
            780,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,
           5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,  5821,
           5821,  5821,  5821,  8308,  8308,  8308,  8308,  8308,  8308,  8308,
           8308,  8308,  8308,  8308,  8308,  8308,  8308,  1168,  1168,  1168,
           1168,  1168,  1168,  1168,  1168,  1168,  1168,  1168,  8448,  8448,
           8448,  8448,  8448,  8448,  8448,  8448,  8448,  8448,  84

In [37]:
if dataset_nr == 1:
    epochs = 10
    walks_per_node = 5
    embedding_dim = 32
    context_size = 7
    walk_length = 50
    lr = 0.004935371639375551
else:
    epochs = 5
    walks_per_node = 2
    walk_length = 35
    context_size = 7
    embedding_dim = 64
    lr = 0.010631597403622543

model = MetaPath2Vec(data.edge_index_dict, embedding_dim=embedding_dim,
                     metapath=metapaths, walk_length=walk_length, context_size=context_size,
                     walks_per_node=walks_per_node, num_negative_samples=5,
                     sparse=True).to(torch_device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=8)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=lr)

In [38]:
def train(epoch, log_steps=10):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(torch_device), neg_rw.to(torch_device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

@torch.no_grad()
def get_embedding(node_type):
    model.eval()
    z = model(node_type, batch=data.y_index_dict[node_type]).detach().numpy()
    return z

In [39]:
for epoch in range(0, epochs):
    train(epoch)
    emb = get_embedding(node_semantics_dict[0])
    print(f'Epoch: {epoch}, Embedding: {emb.shape}\n{emb}')

Epoch: 0, Step: 00010/81, Loss: 3.6279
Epoch: 0, Step: 00020/81, Loss: 3.5152
Epoch: 0, Step: 00030/81, Loss: 3.3965
Epoch: 0, Step: 00040/81, Loss: 3.2833
Epoch: 0, Step: 00050/81, Loss: 3.1596
Epoch: 0, Step: 00060/81, Loss: 3.0397
Epoch: 0, Step: 00070/81, Loss: 2.9221
Epoch: 0, Step: 00080/81, Loss: 2.8069
Epoch: 0, Embedding: (71, 64)
[[-1.0707939  -1.5289599   0.76332736 ...  1.2211771   1.1570678
  -0.8539459 ]
 [ 0.7744395  -0.11110125 -0.49265254 ... -0.42844203  0.9727086
  -0.12002938]
 [-0.71456265  0.7723284   0.34724465 ...  0.6114247  -0.16503572
  -1.3067291 ]
 ...
 [ 0.74517983  1.2139555  -0.03283018 ...  0.22727582  1.6879668
   0.47272813]
 [-0.29942197 -2.2087553   0.40289217 ...  0.48066884 -1.4345055
   0.11740419]
 [ 1.2846439   2.6278389  -1.6315192  ... -1.602601   -0.7813209
   1.1551247 ]]
Epoch: 1, Step: 00010/81, Loss: 2.6748
Epoch: 1, Step: 00020/81, Loss: 2.5820
Epoch: 1, Step: 00030/81, Loss: 2.4791
Epoch: 1, Step: 00040/81, Loss: 2.3802
Epoch: 1, Step:

# Collect Node Embeddings from Trained Model

In [40]:
data_node_types = data.collect('y_index')
index_emb_dict = {}

for data_node_type in list(data_node_types.keys()):
    indices = np.array(data[data_node_type].y_index) 
    emb = get_embedding(data_node_type)
    
    for index, embedding in zip(indices, emb):
        index_emb_dict[index] = embedding

In [41]:
metapath2vec_embedding = pd.DataFrame.from_dict(index_emb_dict, orient='index').sort_index()
metapath2vec_embedding

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,54,55,56,57,58,59,60,61,62,63
0,-0.180537,0.170517,0.156036,0.059366,0.039686,0.770157,0.296236,-0.108639,-0.123995,-0.222156,...,-0.251923,-0.354369,-0.153597,-0.234457,-0.295879,-0.117370,0.188819,-0.246503,-0.223934,-0.326248
1,0.166972,0.479243,0.090037,0.415786,0.310327,0.180114,-0.316574,0.276626,0.642796,-0.170525,...,-0.572297,-0.267137,0.309202,0.403249,-0.329871,0.049839,-0.138035,-0.216177,-0.160851,0.138391
2,-0.118998,0.247875,-0.003196,-0.127593,0.217997,0.522361,-0.334698,-0.403308,-0.264037,-0.605411,...,0.261394,-0.060682,0.162056,0.231253,-0.287371,0.414680,0.607859,0.153862,-0.066283,0.258656
3,-0.117873,-0.638776,0.715720,-1.488765,1.139481,0.537108,-1.016577,-0.472169,-0.714616,0.245649,...,0.796758,-0.834914,-0.050803,0.204769,0.394310,0.166139,-0.316913,-0.519357,0.405025,-0.726510
4,-0.116238,-0.431725,-0.701060,-0.101868,-0.320842,0.756190,-0.296119,-0.160058,-0.027608,0.225885,...,0.039230,-0.414010,0.310657,0.359817,0.028379,-0.505937,-0.037625,-0.142428,0.314041,0.522135
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10270,0.146269,0.514914,-0.287683,0.211831,-0.964015,0.776288,0.281142,0.405423,-0.070531,0.078691,...,-0.179447,-0.044464,-0.145146,0.409061,-0.201076,-0.300863,-0.069612,-0.265975,0.219710,-0.149174
10271,-0.092208,-0.288510,-0.047735,0.977450,0.672737,-0.565982,-0.017390,0.195413,0.605879,-0.072216,...,-0.488707,0.382078,0.397737,-0.086307,0.633217,0.029278,0.169720,-0.153595,-0.340175,-0.012609
10272,1.015316,0.785766,0.356682,0.097771,0.243979,-0.025777,-0.220744,0.416585,0.051846,0.344864,...,-0.252412,0.053219,0.737996,-0.710615,-0.385023,0.260013,0.060839,-0.666456,0.258599,0.410538
10273,1.484033,-0.250599,-0.031319,-0.283997,0.425242,-0.080822,0.223526,0.071684,-0.320865,0.709390,...,-0.394961,-0.266571,-0.500837,0.021376,0.343995,-0.494114,1.222481,-0.028218,-0.052168,0.549186


In [42]:
metapath2vec_embedding.to_csv(f'output/metapath2vec_embedding_{dataset_nr}.csv', index=False)