In [1]:
import os
import math
import copy
import pickle

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import dgl

from dataset_construction.utils import preprocess_kg, create_split, process_disease_area_split, create_dgl_graph, evaluate_graph_construct, convert2str, data_download_wrapper

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch


In [2]:
# kg_dir = '/data/pj20/txgnn/kg/'
kg_dir = './dataset_construction/'

df = pd.read_csv(kg_dir + 'KG.csv')
df = df[['x_type', 'x_id', 'relation', 'y_type', 'y_id']]

In [3]:
def categorize_values(series):
    # Define the categories based on deciles
    quantiles = series.quantile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])

    # Create a categorical variable based on the quantiles
    categories = pd.cut(series, bins=pd.unique([-float('inf'), *quantiles, float('inf')]),
                        labels=range(1, len(pd.unique([-float('inf'), *quantiles, float('inf')]))),
                        duplicates='drop')

    return categories

# Group the data by 'relation' and 'y_type', and apply the categorization function
value_rows = df['y_type'] == 'value'
df.loc[value_rows, 'y_id'] = df[value_rows].groupby(['relation', 'y_type'])['y_id'].transform(categorize_values)

In [4]:
unique_relation = np.unique(df.relation.values)
unique_relation

array(['closematch', 'contraindication', 'cooccurence_molecule_disease',
       'cooccurence_molecule_gene/protein',
       'cooccurence_molecule_molecule', 'covalent_unit_count',
       'defined_atom_stereo_count', 'defined_bond_stereo_count',
       'drug_drug', 'drug_effect', 'drug_protein', 'exact_mass',
       'has_component', 'has_isotopologue', 'has_parent',
       'has_same_connectivity', 'has_stereoisomer',
       'hydrogen_bond_acceptor_count', 'hydrogen_bond_donor_count',
       'in_pathway', 'indication', 'isotope_atom_count',
       'molecular_weight', 'mono_isotopic_weight', 'neighbor_2d',
       'neighbor_3d', 'non-hydrogen_atom_count', 'off-label use',
       'rotatable_bond_count', 'structure_complexity', 'tautomer_count',
       'to_drug', 'total_formal_charge', 'tpsa', 'type',
       'undefined_atom_stereo_count', 'undefined_bond_stereo_count',
       'xlogp3', 'xlogp3-aa'], dtype=object)

In [5]:
undirected_index = []

for i in tqdm(unique_relation):
    if ('_' in i) and (i.split('_')[0] == i.split('_')[1]):
        # homogeneous graph
        df_temp = df[df.relation == i]
        df_temp['check_string'] = df_temp.apply(lambda row: '_'.join(sorted([str(row['x_id']), str(row['y_id'])])), axis=1)
        undirected_index.append(df_temp.drop_duplicates('check_string').index.values.tolist())
    else:
        # undirected, 去重 (a->b 和 b->a 只记其中一个)
        d_off = df[df.relation == i]
        undirected_index.append(d_off[d_off.x_type == d_off.x_type.iloc[0]].index.values.tolist())

100%|██████████| 39/39 [00:12<00:00,  3.18it/s]


In [6]:
flat_list = [item for sublist in undirected_index for item in sublist]
df = df[df.index.isin(flat_list)]
unique_node_types = np.unique(np.append(np.unique(df.x_type.values), np.unique(df.y_type.values)))

In [7]:
unique_node_types

array(['disease', 'drug', 'effect/phenotype', 'gene/protein', 'molecule',
       'pathway', 'value'], dtype=object)

In [8]:
df

Unnamed: 0,x_type,x_id,relation,y_type,y_id
0,molecule,23978.0,drug_protein,gene/protein,F8
1,molecule,23978.0,drug_protein,gene/protein,F5
2,molecule,977.0,drug_protein,gene/protein,HBA2
3,molecule,82153.0,drug_protein,gene/protein,SERPINA6
4,molecule,5311000.0,drug_protein,gene/protein,SERPINA6
...,...,...,...,...,...
2526886,molecule,439153.0,in_pathway,pathway,PWID1228018
2526887,molecule,644102.0,in_pathway,pathway,PWID1239777
2526888,molecule,668.0,in_pathway,pathway,PWID1324467
2526889,molecule,644102.0,in_pathway,pathway,PWID1234075


In [9]:
df['x_idx'] = np.nan
df['y_idx'] = np.nan
df['x_id'] = df.x_id.apply(lambda x: convert2str(x))
df['y_id'] = df.y_id.apply(lambda x: convert2str(x))


In [10]:
unique_node_types

array(['disease', 'drug', 'effect/phenotype', 'gene/protein', 'molecule',
       'pathway', 'value'], dtype=object)

In [11]:
idx_map = {}
for i in tqdm(unique_node_types):
    names = np.unique(np.append(df[df.x_type == i]['x_id'].values, df[df.y_type == i]['y_id'].values))
    names2idx = dict(zip(names, list(range(len(names)))))
    df.loc[df.x_type == i, 'x_idx'] = df[df.x_type == i]['x_id'].apply(lambda x: names2idx[x])
    df.loc[df.y_type == i, 'y_idx'] = df[df.y_type == i]['y_id'].apply(lambda x: names2idx[x])
    idx_map[i] = names2idx

100%|██████████| 7/7 [00:12<00:00,  1.76s/it]


In [12]:
idx_map[
    'value'
]

{'1.0': 0,
 '10.0': 1,
 '2.0': 2,
 '3.0': 3,
 '4.0': 4,
 '5.0': 5,
 '6.0': 6,
 '7.0': 7,
 '8.0': 8,
 '9.0': 9}

In [13]:
len(idx_map['disease']), len(idx_map['molecule'])

(4173, 88811)

In [14]:
df_test = df[df.x_type == "molecule"]
df_test = df_test[df_test.x_idx == 1.0]

In [16]:
df_train = pd.read_csv('./dataset_construction/train.csv')
df_valid = pd.read_csv('./dataset_construction/valid.csv')
df_test = pd.read_csv('./dataset_construction/test.csv')

In [16]:
## create_dgl_graph
unique_graph = df_train[['x_type', 'relation', 'y_type']].drop_duplicates()
unique_graph.values


array([['molecule', 'drug_protein', 'gene/protein'],
       ['molecule', 'contraindication', 'disease'],
       ['molecule', 'indication', 'disease'],
       ['molecule', 'off-label use', 'disease'],
       ['molecule', 'drug_drug', 'molecule'],
       ['molecule', 'drug_effect', 'effect/phenotype'],
       ['molecule', 'defined_atom_stereo_count', 'value'],
       ['molecule', 'exact_mass', 'value'],
       ['molecule', 'rotatable_bond_count', 'value'],
       ['molecule', 'defined_bond_stereo_count', 'value'],
       ['molecule', 'hydrogen_bond_acceptor_count', 'value'],
       ['molecule', 'mono_isotopic_weight', 'value'],
       ['molecule', 'structure_complexity', 'value'],
       ['molecule', 'isotope_atom_count', 'value'],
       ['molecule', 'hydrogen_bond_donor_count', 'value'],
       ['molecule', 'total_formal_charge', 'value'],
       ['molecule', 'non-hydrogen_atom_count', 'value'],
       ['molecule', 'covalent_unit_count', 'value'],
       ['molecule', 'undefined_atom_st

In [17]:
DGL_input = {}
for i in unique_graph.values:
    o = df_train[(df_train.x_type == i[0]) & (df_train.relation == i[1]) & (df_train.y_type == i[2])][['x_idx', 'y_idx']].values.T
    triple_type = tuple(i)
    if triple_type[1] == 'cooccurence' or triple_type[1] == 'rev_cooccurence':
        triple_type = (triple_type[0], triple_type[1] + '_' + triple_type[0] + '_' + triple_type[2], triple_type[2])
    DGL_input[triple_type] = (o[0].astype(int), o[1].astype(int))

In [18]:
output = {k: len(v) for k, v in idx_map.items()}

In [19]:
output

{'disease': 4173,
 'drug': 40543,
 'effect/phenotype': 990,
 'gene/protein': 13091,
 'molecule': 88811,
 'pathway': 35763,
 'value': 10}

In [20]:
num_nodes_dict = {i: int(output[i])+3 for i in output.keys()}

for k, v in DGL_input.items():
    h_type = k[0]
    t_type = k[2]
    h_array = v[0]
    t_array = v[1]
    h_max = max(h_array)
    t_max = max(t_array)
    if h_max > num_nodes_dict[h_type]:
        print('h_type: ', h_type)
        print(k)
        print(h_max)
    if t_max > num_nodes_dict[t_type]:
        print('t_type', t_type)
        print(k)
        print(t_max)


## issue: 共用了cooccurrence

t_type value
('molecule', 'defined_atom_stereo_count', 'value')
15766
t_type value
('molecule', 'exact_mass', 'value')
16212
t_type value
('molecule', 'rotatable_bond_count', 'value')
16037
t_type value
('molecule', 'defined_bond_stereo_count', 'value')
15766
t_type value
('molecule', 'hydrogen_bond_acceptor_count', 'value')
16081
t_type value
('molecule', 'mono_isotopic_weight', 'value')
16211
t_type value
('molecule', 'structure_complexity', 'value')
16218
t_type value
('molecule', 'isotope_atom_count', 'value')
13891
t_type value
('molecule', 'hydrogen_bond_donor_count', 'value')
15766
t_type value
('molecule', 'total_formal_charge', 'value')
10675
t_type value
('molecule', 'non-hydrogen_atom_count', 'value')
16133
t_type value
('molecule', 'covalent_unit_count', 'value')
15766
t_type value
('molecule', 'undefined_atom_stereo_count', 'value')
15766
t_type value
('molecule', 'molecular_weight', 'value')
16204
t_type value
('molecule', 'undefined_bond_stereo_count', 'value')
15220
t_t

In [21]:
{i: int(output[i])+1 for i in output.keys()}

{'disease': 4174,
 'drug': 40544,
 'effect/phenotype': 991,
 'gene/protein': 13092,
 'molecule': 88812,
 'pathway': 35764,
 'value': 11}

In [22]:
g = dgl.heterograph(DGL_input, num_nodes_dict={i: int(output[i])+1 for i in output.keys()})

DGLError: The given number of nodes of node type molecule must be larger than the max ID in the data, but got 11 and 15766.

In [None]:
import torch

node_dict = {}
edge_dict = {}
for ntype in g.ntypes:
    node_dict[ntype] = len(node_dict)
for etype in g.etypes:
    edge_dict[etype] = len(edge_dict)
    g.edges[etype].data['id'] = torch.ones(g.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] 

In [None]:
g.edges[g.etypes[2]].data['id']

tensor([2, 2, 2,  ..., 2, 2, 2])

In [None]:
from dataset_construction import TxData

TxData = TxData(data_folder_path = '/data/pj20/txgnn/kg/')
TxData.prepare_split(split = 'full_graph', seed = 42, no_kg = False)

Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....
Creating DGL graph....


KeyboardInterrupt: 

In [None]:
from txgnn import TxGNN

TxGNN = TxGNN(data = TxData, 
              weight_bias_track = False,
              proj_name = 'TxGNN',
              exp_name = 'TxGNN'
              )


In [None]:
TxGNN.model_initialize(n_hid = 100, 
                      n_inp = 100, 
                      n_out = 100, 
                      proto = True,
                      proto_num = 3,
                      attention = False,
                      sim_measure = 'all_nodes_profile',
                      bert_measure = 'disease_name',
                      agg_measure = 'rarity',
                      num_walks = 200,
                      walk_mode = 'bit',
                      path_length = 2)

In [None]:
edge_dict[g.etypes[1]]

1