In [1]:
import os
import math
import copy
import pickle

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import dgl

from dataset_construction.utils import preprocess_kg, create_split, process_disease_area_split, create_dgl_graph, evaluate_graph_construct, convert2str, data_download_wrapper

import warnings
warnings.filterwarnings("ignore")

In [3]:
kg_dir = '/data/pj20/txgnn/kg/'

df = pd.read_csv(kg_dir + 'kg.csv')
df = df[['x_type', 'x_id', 'relation', 'y_type', 'y_id']]

In [5]:
unique_relation = np.unique(df.relation.values)
unique_relation

array(['anatomy_anatomy', 'anatomy_protein_absent',
       'anatomy_protein_present', 'bioprocess_bioprocess',
       'bioprocess_protein', 'cellcomp_cellcomp', 'cellcomp_protein',
       'contraindication', 'disease_disease',
       'disease_phenotype_negative', 'disease_phenotype_positive',
       'disease_protein', 'drug_drug', 'drug_effect', 'drug_protein',
       'exposure_bioprocess', 'exposure_cellcomp', 'exposure_disease',
       'exposure_exposure', 'exposure_molfunc', 'exposure_protein',
       'indication', 'molfunc_molfunc', 'molfunc_protein',
       'off-label use', 'pathway_pathway', 'pathway_protein',
       'phenotype_phenotype', 'phenotype_protein', 'protein_protein'],
      dtype=object)

In [13]:
undirected_index = []

for i in tqdm(unique_relation):
    if ('_' in i) and (i.split('_')[0] == i.split('_')[1]):
        # homogeneous graph
        df_temp = df[df.relation == i]
        df_temp['check_string'] = df_temp.apply(lambda row: '_'.join(sorted([str(row['x_id']), str(row['y_id'])])), axis=1)
        undirected_index.append(df_temp.drop_duplicates('check_string').index.values.tolist())
    else:
        # undirected, 去重 (a->b 和 b->a 只记其中一个)
        d_off = df[df.relation == i]
        undirected_index.append(d_off[d_off.x_type == d_off.x_type.iloc[0]].index.values.tolist())

100%|██████████| 30/30 [00:32<00:00,  1.08s/it]


In [18]:
flat_list = [item for sublist in undirected_index for item in sublist]
df = df[df.index.isin(flat_list)]
unique_node_types = np.unique(np.append(np.unique(df.x_type.values), np.unique(df.y_type.values)))

In [19]:
unique_node_types

array(['anatomy', 'biological_process', 'cellular_component', 'disease',
       'drug', 'effect/phenotype', 'exposure', 'gene/protein',
       'molecular_function', 'pathway'], dtype=object)

In [22]:
df

Unnamed: 0,x_type,x_id,relation,y_type,y_id,x_idx,y_idx
0,gene/protein,9796.0,protein_protein,gene/protein,56992.0,,
1,gene/protein,7918.0,protein_protein,gene/protein,9240.0,,
2,gene/protein,8233.0,protein_protein,gene/protein,23548.0,,
3,gene/protein,4899.0,protein_protein,gene/protein,11253.0,,
4,gene/protein,5297.0,protein_protein,gene/protein,8601.0,,
...,...,...,...,...,...,...,...
5949711,disease,16982.0,disease_phenotype_positive,effect/phenotype,30448.0,,
5949712,disease,19314_19023_2726,disease_phenotype_positive,effect/phenotype,8069.0,,
5949713,disease,19314_19023_2726,disease_phenotype_positive,effect/phenotype,100495.0,,
5949714,disease,4747.0,disease_phenotype_positive,effect/phenotype,202.0,,


In [21]:
df['x_idx'] = np.nan
df['y_idx'] = np.nan
df['x_id'] = df.x_id.apply(lambda x: convert2str(x))
df['y_id'] = df.y_id.apply(lambda x: convert2str(x))


In [27]:
idx_map = {}
for i in tqdm(unique_node_types):
    names = np.unique(np.append(df[df.x_type == i]['x_id'].values, df[df.y_type == i]['y_id'].values))
    names2idx = dict(zip(names, list(range(len(names)))))
    df.loc[df.x_type == i, 'x_idx'] = df[df.x_type == i]['x_id'].apply(lambda x: names2idx[x])
    df.loc[df.y_type == i, 'y_idx'] = df[df.y_type == i]['y_id'].apply(lambda x: names2idx[x])
    idx_map[i] = names2idx

100%|██████████| 10/10 [00:18<00:00,  1.87s/it]


In [39]:
df_test = df[df.x_type == "gene/protein"]
df_test = df_test[df_test.x_idx == 1.0]

In [40]:
df_test

Unnamed: 0,x_type,x_id,relation,y_type,y_id,x_idx,y_idx
21007,gene/protein,10.0,protein_protein,gene/protein,8493.0,1.0,25662.0
25217,gene/protein,10.0,protein_protein,gene/protein,6790.0,1.0,22307.0
39248,gene/protein,10.0,protein_protein,gene/protein,6794.0,1.0,22310.0
49356,gene/protein,10.0,protein_protein,gene/protein,8930.0,1.0,26207.0
50145,gene/protein,10.0,protein_protein,gene/protein,25818.0,1.0,10753.0
...,...,...,...,...,...,...,...
4706870,gene/protein,10.0,anatomy_protein_present,anatomy,2108.0,1.0,5872.0
4706871,gene/protein,10.0,anatomy_protein_present,anatomy,2113.0,1.0,5878.0
4706872,gene/protein,10.0,anatomy_protein_present,anatomy,2114.0,1.0,5879.0
5376898,gene/protein,10.0,anatomy_protein_absent,anatomy,1377.0,1.0,2220.0


In [42]:
df_train = pd.read_csv('/data/pj20/txgnn/kg/full_graph_42/train.csv')
df_valid = pd.read_csv('/data/pj20/txgnn/kg/full_graph_42/valid.csv')
df_test = pd.read_csv('/data/pj20/txgnn/kg/full_graph_42/test.csv')

In [46]:
## create_dgl_graph
unique_graph = df_train[['x_type', 'relation', 'y_type']].drop_duplicates()
unique_graph


Unnamed: 0,x_type,relation,y_type
0,gene/protein,protein_protein,gene/protein
305021,drug,drug_protein,gene/protein
329391,drug,contraindication,disease
358532,drug,indication,disease
367451,drug,off-label use,disease
369891,drug,drug_drug,drug
1639389,gene/protein,phenotype_protein,effect/phenotype
1642553,effect/phenotype,phenotype_phenotype,effect/phenotype
1660352,disease,disease_phenotype_negative,effect/phenotype
1661485,disease,disease_phenotype_positive,effect/phenotype


In [54]:
DGL_input = {}
for i in unique_graph.values:
    o = df_train[df_train.relation == i[1]][['x_idx', 'y_idx']].values.T
    DGL_input[tuple(i)] = (o[0].astype(int), o[1].astype(int))

In [61]:
temp = dict(df.groupby('x_type')['x_idx'].max())
temp2 = dict(df.groupby('y_type')['y_idx'].max())
temp['effect/phenotype'] = 0.0

output = {}
for d in (temp, temp2):
    for k, v in d.items():
        output.setdefault(k, float('-inf'))
        output[k] = max(output[k], v)


In [63]:
DGL_input

{('gene/protein',
  'protein_protein',
  'gene/protein'): (array([27422, 23886, 24822, ..., 23283, 18775,  7948]), array([19536, 26764, 10205, ...,  8412, 15799, 25835])),
 ('drug',
  'drug_protein',
  'gene/protein'): (array([5810, 5810, 5819, ..., 7705, 6020, 6020]), array([ 9202,  9200, 13100, ..., 18963, 14238, 14254])),
 ('drug',
  'contraindication',
  'disease'): (array([3448, 3448, 7534, ...,   73,  335, 1178]), array([12675,  1569, 12675, ..., 10809,   311,  1484])),
 ('drug',
  'indication',
  'disease'): (array([ 478,  478, 6458, ...,   34,   34,  948]), array([12675,  1569, 12675, ...,  8492, 16381, 14069])),
 ('drug',
  'off-label use',
  'disease'): (array([884, 884, 868, ..., 396, 396, 187]), array([12675,  1569, 12675, ..., 14964,  2546,  2546])),
 ('drug',
  'drug_drug',
  'drug'): (array([   0,    4,    5, ..., 5797, 5797, 5797]), array([3979, 3979, 3979, ..., 2642, 6960, 7627])),
 ('gene/protein',
  'phenotype_protein',
  'effect/phenotype'): (array([    0,  8793,  9

In [62]:
output

{'anatomy': 14032.0,
 'biological_process': 28641.0,
 'cellular_component': 4175.0,
 'disease': 17079.0,
 'drug': 7956.0,
 'effect/phenotype': 15310.0,
 'exposure': 817.0,
 'gene/protein': 27609.0,
 'molecular_function': 11168.0,
 'pathway': 2515.0}

In [64]:
g = dgl.heterograph(DGL_input, num_nodes_dict={i: int(output[i])+1 for i in output.keys()})

In [68]:
import torch

node_dict = {}
edge_dict = {}
for ntype in g.ntypes:
    node_dict[ntype] = len(node_dict)
for etype in g.etypes:
    edge_dict[etype] = len(edge_dict)
    g.edges[etype].data['id'] = torch.ones(g.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] 

In [79]:
g.edges[g.etypes[2]].data['id']

tensor([2, 2, 2,  ..., 2, 2, 2])

In [80]:
from dataset_construction import TxData

TxData = TxData(data_folder_path = '/data/pj20/txgnn/kg/')
TxData.prepare_split(split = 'full_graph', seed = 42, no_kg = False)

Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....
Creating DGL graph....
Done!


In [84]:
from txgnn import TxGNN

TxGNN = TxGNN(data = TxData, 
              weight_bias_track = False,
              proj_name = 'TxGNN',
              exp_name = 'TxGNN'
              )


In [None]:
TxGNN.model_initialize(n_hid = 100, 
                      n_inp = 100, 
                      n_out = 100, 
                      proto = True,
                      proto_num = 3,
                      attention = False,
                      sim_measure = 'all_nodes_profile',
                      bert_measure = 'disease_name',
                      agg_measure = 'rarity',
                      num_walks = 200,
                      walk_mode = 'bit',
                      path_length = 2)

In [77]:
edge_dict[g.etypes[1]]

1