In [4]:
import os
import csv
import torch
import shutil
import sqlite3
from datetime import datetime

from utils.evaluation.process_mimic_db.utils import *
from utils.evaluation.process_mimic_db.process_tables import *

# Specify the path to the downloaded MIMIC III data
data_dir = 'mimic_table'
# Path to the generated mimic.db. No need to update.
out_dir = 'mimic_db'

# Generate five tables and the database with all admissions
# if os.path.exists(out_dir):
#     shutil.rmtree(out_dir)
# os.mkdir(out_dir)
'''
conn = sqlite3.connect(os.path.join(out_dir, 'mimic_all.db'))
build_demographic_table(data_dir, out_dir, conn)
build_diagnoses_table(data_dir, out_dir, conn)
build_procedures_table(data_dir, out_dir, conn)
build_prescriptions_table(data_dir, out_dir, conn)
build_lab_table(data_dir, out_dir, conn)
'''

'''
1. We did not emumerate all possible questions about MIMIC III.
MIMICSQL data is generated based on the patient information 
related to 100 randomly selected admissions.
2. The following codes are used for sampling the admissions 
from the large database. 
3. The parameter 'random_state=0' in line 41 will provide you 
the same set of sampled admissions and the same database as we used.
'''

print('Begin sampling ...')
# DEMOGRAPHIC
print('Processing DEMOGRAPHIC')
conn = sqlite3.connect(os.path.join(out_dir, 'mimic.db'))
data_demo = pandas.read_csv(os.path.join(out_dir, "DEMOGRAPHIC.csv"))
h_adm_list = [elem[0] for elem in torch.load('result-dxprx/p_sections')]
data_demo_sample = data_demo[data_demo['HADM_ID'].isin(h_adm_list)]
for k, v in data_demo_sample.dtypes.items():
    data_demo_sample[k] = data_demo_sample[k].apply(lambda x: x.lower() if type(x) == str else x)
print(len(data_demo_sample))
#data_demo_sample.to_sql('DEMOGRAPHIC', conn, if_exists='replace', index=False)
sampled_id = data_demo_sample['HADM_ID'].values


# DIAGNOSES
print('Processing DIAGNOSES')
data_input = pandas.read_csv(os.path.join(out_dir, "DIAGNOSES.csv"))
data_filter = []
cnt = 0
for itm in sampled_id:
    msg = 'HADM_ID=='+str(itm)
    data_filter.append(data_input.query(msg))
    cnt += 1
    show_progress(cnt, len(sampled_id))
data_out = pandas.concat(data_filter, ignore_index=True)
for k, v in data_out.dtypes.items():
    data_out[k] = data_out[k].apply(lambda x: x.lower() if type(x) == str else x)
data_out.to_sql('DIAGNOSES', conn, if_exists='replace', index=False)


# PROCEDURES
print('Processing PROCEDURES')
data_input = pandas.read_csv(os.path.join(out_dir, "PROCEDURES.csv"))
data_filter = []
cnt = 0
for itm in sampled_id:
    msg = 'HADM_ID=='+str(itm)
    data_filter.append(data_input.query(msg))
    cnt += 1
    show_progress(cnt, len(sampled_id))
data_out = pandas.concat(data_filter, ignore_index=True)
for k, v in data_out.dtypes.items():
    data_out[k] = data_out[k].apply(lambda x: x.lower() if type(x) == str else x)
data_out.to_sql('PROCEDURES', conn, if_exists='replace', index=False)

'''
# PRESCRIPTIONS
print('Processing PRESCRIPTIONS')
data_input = pandas.read_csv(os.path.join(out_dir, "PRESCRIPTIONS.csv"))
data_filter = []
cnt = 0
for itm in sampled_id:
    msg = 'HADM_ID=='+str(itm)
    data_filter.append(data_input.query(msg))
    cnt += 1
    show_progress(cnt, len(sampled_id))
data_out = pandas.concat(data_filter, ignore_index=True)
for k, v in data_out.dtypes.items():
    data_out[k] = data_out[k].apply(lambda x: x.lower() if type(x) == str else x)
data_out.to_sql('PRESCRIPTIONS', conn, if_exists='replace', index=False)

'''
'''
# LAB
print('Processing LAB')
data_input = pandas.read_csv(os.path.join(out_dir, "LAB.csv"))
data_filter = []
cnt = 0
for itm in sampled_id:
    msg = 'HADM_ID=='+str(itm)
    data_filter.append(data_input.query(msg))
    cnt += 1
    show_progress(cnt, len(sampled_id))
data_out = pandas.concat(data_filter, ignore_index=True)
for k, v in data_out.dtypes.items():
    data_out[k] = data_out[k].apply(lambda x: x.lower() if type(x) == str else x)
data_out.to_sql('LAB', conn, if_exists='replace', index=False)
'''

print('Done!')

Begin sampling ...
Processing DEMOGRAPHIC


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_demo_sample[k] = data_demo_sample[k].apply(lambda x: x.lower() if type(x) == str else x)


32699
Processing DIAGNOSES
Processing PROCEDURES>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]100%
Done!>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]100%


In [6]:
import sys
sys.path.append('..')
sys.path.append('.')
import os
import pandas as pd
import sqlite3

from utils.schema_mimic import *
from utils.evaluation.utils import query

PJT_ROOT_PATH = './'
print('PJT_ROOT_PATH: ',PJT_ROOT_PATH)

if __name__ == '__main__':
    db_conn = sqlite3.connect(os.path.join(PJT_ROOT_PATH, 'mimic_db/mimic.db'))
    '''
    patient_cols = list(patient_demographic_dtype.keys())
    addmission_cols = list(hadm_demographic_dtype.keys())

    demographic = pd.read_sql_query("SELECT * FROM DEMOGRAPHIC", db_conn)
    demographic.info()
    demographic.head()

    patients_df = demographic.loc[:,patient_cols]
    patients_df.info()

    addmissions_df = demographic.loc[:,addmission_cols] # primary key: HADM_ID
    addmissions_df.info()
    '''
    diagenoses = pd.read_sql_query("SELECT * FROM DIAGNOSES", db_conn)
    diagenoses = diagenoses.reset_index().rename({'index': 'DIAGNOSES'}, axis=1)
    diagenoses.info()

    diagnoses_cols = list(diagnoses_dtype.keys())
    d_icd_dagnoses_cols = list(d_icd_diagnoses_dtype.keys())
    diagnoses_cols = ['ICD9_CODE' if c == 'DIAGNOSES_ICD9_CODE' else c for c in diagnoses_cols]
    d_icd_dagnoses_cols = ['ICD9_CODE' if c == 'DIAGNOSES_ICD9_CODE' else c for c in d_icd_dagnoses_cols]
    d_icd_dagnoses_cols = ['LONG_TITLE' if c == 'DIAGNOSES_LONG_TITLE' else c for c in d_icd_dagnoses_cols]
    #d_icd_dagnoses_cols = ['SHORT_TITLE' if c == 'DIAGNOSES_SHORT_TITLE' else c for c in d_icd_dagnoses_cols]

    diagenoses_df = diagenoses.loc[:, diagnoses_cols]
    diagenoses_df.info()

    d_icd_diagenoses_df = diagenoses.loc[:, d_icd_dagnoses_cols]
    d_icd_diagenoses_df.drop_duplicates(inplace=True)
    d_icd_diagenoses_df.reset_index(inplace=True, drop=True)
    d_icd_diagenoses_df.info()
    
    procedures = pd.read_sql_query("SELECT * FROM PROCEDURES", db_conn)
    procedures = procedures.reset_index().rename({'index': 'PROCEDURES'}, axis=1)
    procedures.info()

    procedures_cols = list(procedures_dtype.keys())
    d_icd_procedures_cols = list(d_icd_procedures_dtype.keys())
    procedures_cols = ['ICD9_CODE' if c == 'PROCEDURES_ICD9_CODE' else c for c in procedures_cols]
    d_icd_procedures_cols = ['ICD9_CODE' if c == 'PROCEDURES_ICD9_CODE' else c for c in d_icd_procedures_cols]
    d_icd_procedures_cols = ['LONG_TITLE' if c == 'PROCEDURES_LONG_TITLE' else c for c in d_icd_procedures_cols]
    #d_icd_procedures_cols = ['SHORT_TITLE' if c == 'PROCEDURES_SHORT_TITLE' else c for c in d_icd_procedures_cols]

    procedures_df = procedures.loc[:, procedures_cols]
    procedures_df.info()

    d_icd_procedures_df = procedures.loc[:, d_icd_procedures_cols]
    d_icd_procedures_df.drop_duplicates(inplace=True)
    d_icd_procedures_df.reset_index(inplace=True, drop=True)
    d_icd_procedures_df.info()
    '''
    lab_cols = list(lab_dtype.keys())
    d_labitem_cols = list(d_labitem_dtype.keys())

    lab = pd.read_sql_query("SELECT * FROM LAB", db_conn)
    lab = lab.reset_index().rename({'index': 'LAB'}, axis=1)
    lab.info()

    lab_df = lab.loc[:, lab_cols]
    lab_df.info()

    d_labitem_df = lab.loc[:, d_labitem_dtype]
    d_labitem_df.drop_duplicates(inplace=True)
    d_labitem_df.reset_index(inplace=True, drop=True)
    d_labitem_df.info()
    
    prescriptions_cols = list(prescriptions_dtype.keys())

    prescriptions = pd.read_sql_query("SELECT * FROM PRESCRIPTIONS", db_conn)
    prescriptions = prescriptions.reset_index().rename({'index': 'PRESCRIPTIONS'}, axis=1)
    prescriptions_df = prescriptions.loc[:, prescriptions_cols]
    prescriptions_df.info()
    '''
    conn = sqlite3.connect(os.path.join(PJT_ROOT_PATH , 'mimicsqlstar.db')) 
    '''
    patients_df.to_sql('PATIENTS', conn, if_exists='replace', index=False)
    addmissions_df.to_sql('ADMISSIONS', conn, if_exists='replace', index=False)
    '''
    diagenoses_df.to_sql('DIAGNOSES', conn, if_exists='replace', index=False)
    d_icd_diagenoses_df.to_sql('D_ICD_DIAGNOSES', conn, if_exists='replace', index=False)
    
    procedures_df.to_sql('PROCEDURES', conn, if_exists='replace', index=False)
    d_icd_procedures_df.to_sql('D_ICD_PROCEDURES', conn, if_exists='replace', index=False)
    '''
    prescriptions_df.to_sql('PRESCRIPTIONS', conn, if_exists='replace', index=False)
    
    lab_df.to_sql('LAB', conn, if_exists='replace', index=False)
    d_labitem_df.to_sql('D_LABITEM', conn, if_exists='replace', index=False)
    '''
    print(f'LOAD DB ...')

    db_file = os.path.join(PJT_ROOT_PATH,'mimicsqlstar.db')
    new_model = query(db_file)
    print('DONE')

PJT_ROOT_PATH:  ./
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 434178 entries, 0 to 434177
Data columns (total 6 columns):
 #   Column       Non-Null Count   Dtype 
---  ------       --------------   ----- 
 0   DIAGNOSES    434178 non-null  int64 
 1   SUBJECT_ID   434178 non-null  int64 
 2   HADM_ID      434178 non-null  int64 
 3   ICD9_CODE    434178 non-null  object
 4   SHORT_TITLE  434178 non-null  object
 5   LONG_TITLE   434178 non-null  object
dtypes: int64(3), object(3)
memory usage: 19.9+ MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 434178 entries, 0 to 434177
Data columns (total 4 columns):
 #   Column      Non-Null Count   Dtype 
---  ------      --------------   ----- 
 0   SUBJECT_ID  434178 non-null  int64 
 1   DIAGNOSES   434178 non-null  int64 
 2   HADM_ID     434178 non-null  int64 
 3   ICD9_CODE   434178 non-null  object
dtypes: int64(3), object(1)
memory usage: 13.3+ MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6035 entries, 0 to 6034
Dat

In [7]:
import sys
import os
import gc
sys.path.append('..')
sys.path.append('.')
from rdflib import Graph, URIRef
import sqlite3
import pandas as pd
from rdflib import Literal
from tqdm import tqdm
from utils.kg_complex_schema import addmissions_dtype, patients_dtype, procedures_dtype, prescriptions_dtype,\
    diagnoses_dtype, lab_dtype, d_icd_procedures_dtype, d_icd_diagnoses_dtype, d_labitem_dtype

PJT_ROOT_PATH = './'
print('PJT_ROOT_PATH: ', PJT_ROOT_PATH)

domain = ''


def isNoneNan(val):
    if val is None:
        return True

    if (type(val) == str) and (val.lower() in ['none', 'nan']):
        return True

    if val != val:
        return True

    return False


def clean_text(val):
    if type(val) == str:
        val = val.replace("\\", ' ')
    return val


def wrap2uri(obj, literal_type):
    obj = obj.lower()
    if literal_type == 'entity':
        return URIRef(obj)

    elif literal_type == 'relation':
        return URIRef(obj)

    else:
        return Literal(clean_text(obj), datatype=literal_type)


def table2triples(knowgraph,df, parent_col, subject_col, col_types):
    #triples = []
    for col_name, _ in tqdm(col_types.items()):

        if col_name == parent_col:
            # triples += [(wrap2uri(f'{domain}/{col_name}/{sub}', col_types[parent_col]),
            #              wrap2uri(f'{domain}/{subject_col}', 'relation'),
            #              wrap2uri(f'{domain}/{subject_col}/{obj}', col_types[subject_col]))
            #             for (sub, obj) in zip(df[col_name], df[subject_col])]
            for (sub, obj) in zip(df[col_name], df[subject_col]):
                knowgraph.add((wrap2uri(f'{domain}/{col_name}/{sub}', col_types[parent_col]),
                             wrap2uri(f'{domain}/{subject_col}', 'relation'),
                             wrap2uri(f'{domain}/{subject_col}/{obj}', col_types[subject_col]))
                            )
            continue

        if col_name == subject_col:
            continue
        for (sub, obj) in zip(df[subject_col], df[col_name]):
            if not isNoneNan(obj):
                knowgraph.add(
                            (wrap2uri(f'{domain}/{subject_col}/{sub}', col_types[subject_col]),
                             wrap2uri(f'{domain}/{col_name}', 'relation'),
                             wrap2uri(f'{domain}/{col_name}/{obj}' if col_types[col_name] == 'entity' else f'{obj}',
                                      col_types[col_name]))
                             )

    return knowgraph


if __name__ == '__main__':
    db_conn = sqlite3.connect(os.path.join(PJT_ROOT_PATH, 'mimicsqlstar.db'))
    '''
    patients = pd.read_sql_query("SELECT * FROM PATIENTS", db_conn)
    patients.info()

    admissions = pd.read_sql_query("SELECT * FROM ADMISSIONS", db_conn)
    admissions.info()
    '''
    diagnoses = pd.read_sql_query("SELECT * FROM DIAGNOSES", db_conn)
    diagnoses = diagnoses.rename({'ICD9_CODE': 'DIAGNOSES_ICD9_CODE'}, axis=1)
    diagnoses.info()

    d_icd_diagnoses = pd.read_sql_query("SELECT * FROM D_ICD_DIAGNOSES", db_conn)
    d_icd_diagnoses = d_icd_diagnoses.rename({'ICD9_CODE': 'DIAGNOSES_ICD9_CODE',
                                              #'SHORT_TITLE': 'DIAGNOSES_SHORT_TITLE',
                                              'LONG_TITLE': 'DIAGNOSES_LONG_TITLE'}, axis=1)
    d_icd_diagnoses.info()
    
    procedures = pd.read_sql_query("SELECT * FROM PROCEDURES", db_conn)
    procedures = procedures.rename({'ICD9_CODE': 'PROCEDURES_ICD9_CODE'}, axis=1)
    procedures.info()

    d_icd_procedures = pd.read_sql_query("SELECT * FROM D_ICD_PROCEDURES", db_conn)
    d_icd_procedures = d_icd_procedures.rename({'ICD9_CODE': 'PROCEDURES_ICD9_CODE',
                                                #'SHORT_TITLE': 'PROCEDURES_SHORT_TITLE',
                                                'LONG_TITLE': 'PROCEDURES_LONG_TITLE'}, axis=1)
    d_icd_procedures.info()
    '''
    prescriptions = pd.read_sql_query("SELECT * FROM PRESCRIPTIONS", db_conn)
    prescriptions['ICUSTAY_ID'] = prescriptions['ICUSTAY_ID'].apply(lambda x: str(x) if x == x else None)
    prescriptions.info()
    
    lab = pd.read_sql_query("SELECT * FROM LAB", db_conn)
    lab.info()

    d_labitem = pd.read_sql_query("SELECT * FROM D_LABITEM", db_conn)
    d_labitem.info()
    '''
    kg = Graph()
    #triples = []
    '''
    kg = table2triples(kg,patients, parent_col='', subject_col='SUBJECT_ID', col_types=patients_dtype)
    print('# total triples : {}'.format(len(kg)))
    # print(triples[:5])
    #print(triples[-5:])
    #print(len(triples))

    kg = table2triples(kg,admissions, parent_col='SUBJECT_ID', subject_col='HADM_ID',
                             col_types=addmissions_dtype)
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])
    
    '''
    kg = table2triples(kg,diagnoses, parent_col='HADM_ID', subject_col='DIAGNOSES', col_types=diagnoses_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])


    kg = table2triples(kg,d_icd_diagnoses, parent_col='', subject_col='DIAGNOSES_ICD9_CODE',
                             col_types=d_icd_diagnoses_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])

    
    kg = table2triples(kg,procedures, parent_col='HADM_ID', subject_col='PROCEDURES', col_types=procedures_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])


    kg = table2triples(kg,d_icd_procedures, parent_col='', subject_col='PROCEDURES_ICD9_CODE',
                             col_types=d_icd_procedures_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])
    #print(len(triples))
    '''
    kg = table2triples(kg,prescriptions, parent_col='HADM_ID', subject_col='PRESCRIPTIONS',
                             col_types=prescriptions_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])
    #print(len(triples))
    
    kg = table2triples(kg,lab, parent_col='HADM_ID', subject_col='LAB', col_types=lab_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])
    #print(len(triples))

    kg = table2triples(kg,d_labitem, parent_col='', subject_col='ITEMID', col_types=d_labitem_dtype)
    # print(triples[:5])
    print('# total triples : {}'.format(len(kg)))
    # print(triples[-5:])
    #print(len(triples))
'''
    q = """select * where { ?subject_id </gender> "f"^^<http://www.w3.org/2001/XMLSchema#string> }"""
    print(f"TEST QEURY... {q})")
    qres = kg.query(q)
    print("-" * 50)
    for res in qres:
        val = '|'
        for t in res:
            val += str(t.toPython()) + '|\t\t|'
        print(val[:-1])
    print()
    
    print('SAVE KG ...')
    kg.serialize('./mimic_sparqlstar_kg.xml', format='xml')
    print('SAVE DONE')
    print('LOAD TEST ...')
    kg = Graph()
    kg.parse('./mimic_sparqlstar_kg.xml', format='xml', publicID='/')

    print(len(kg))
    for i, t in enumerate(kg):
        print(i, t)
        if i == 5:
            break

    print('LOAD DONE')

PJT_ROOT_PATH:  ./
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 434178 entries, 0 to 434177
Data columns (total 4 columns):
 #   Column               Non-Null Count   Dtype 
---  ------               --------------   ----- 
 0   SUBJECT_ID           434178 non-null  int64 
 1   DIAGNOSES            434178 non-null  int64 
 2   HADM_ID              434178 non-null  int64 
 3   DIAGNOSES_ICD9_CODE  434178 non-null  object
dtypes: int64(3), object(1)
memory usage: 13.3+ MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6035 entries, 0 to 6034
Data columns (total 2 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   DIAGNOSES_ICD9_CODE   6035 non-null   object
 1   DIAGNOSES_LONG_TITLE  6035 non-null   object
dtypes: object(2)
memory usage: 94.4+ KB


  0%|          | 0/3 [00:00<?, ?it/s]

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 147868 entries, 0 to 147867
Data columns (total 4 columns):
 #   Column                Non-Null Count   Dtype
---  ------                --------------   -----
 0   SUBJECT_ID            147868 non-null  int64
 1   PROCEDURES            147868 non-null  int64
 2   HADM_ID               147868 non-null  int64
 3   PROCEDURES_ICD9_CODE  147868 non-null  int64
dtypes: int64(4)
memory usage: 4.5 MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1828 entries, 0 to 1827
Data columns (total 2 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   PROCEDURES_ICD9_CODE   1828 non-null   int64 
 1   PROCEDURES_LONG_TITLE  1828 non-null   object
dtypes: int64(1), object(1)
memory usage: 28.7+ KB


100%|██████████| 3/3 [00:22<00:00,  7.56s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

# total triples : 868356


100%|██████████| 2/2 [00:00<00:00,  7.89it/s]
  0%|          | 0/3 [00:00<?, ?it/s]

# total triples : 874391


100%|██████████| 3/3 [00:07<00:00,  2.65s/it]
100%|██████████| 2/2 [00:00<00:00, 23.39it/s]


# total triples : 1170127
# total triples : 1171955
TEST QEURY... select * where { ?subject_id </gender> "f"^^<http://www.w3.org/2001/XMLSchema#string> })
--------------------------------------------------

SAVE KG ...
SAVE DONE
LOAD TEST ...
1171955
0 (rdflib.term.URIRef('/diagnoses/45841'), rdflib.term.URIRef('/diagnoses_icd9_code'), rdflib.term.URIRef('/diagnoses_icd9_code/58381'))
1 (rdflib.term.URIRef('/diagnoses/421157'), rdflib.term.URIRef('/diagnoses_icd9_code'), rdflib.term.URIRef('/diagnoses_icd9_code/v4611'))
2 (rdflib.term.URIRef('/diagnoses/156394'), rdflib.term.URIRef('/diagnoses_icd9_code'), rdflib.term.URIRef('/diagnoses_icd9_code/4240'))
3 (rdflib.term.URIRef('/hadm_id/176260'), rdflib.term.URIRef('/diagnoses'), rdflib.term.URIRef('/diagnoses/195136'))
4 (rdflib.term.URIRef('/hadm_id/137446'), rdflib.term.URIRef('/diagnoses'), rdflib.term.URIRef('/diagnoses/109397'))
5 (rdflib.term.URIRef('/diagnoses/134834'), rdflib.term.URIRef('/diagnoses_icd9_code'), rdflib.term.URI

In [10]:
from rdflib import Graph, URIRef
from tqdm import tqdm
import pickle

def build_dict(triples, nodes, edges):
    h, r, t = triples
    #for (h,r,t) in triples:
    if h not in nodes:
        nodes[h]=1
    else:
        nodes[h]+=1
    if t not in nodes:
        nodes[t]=1
    else:
        nodes[t]+=1
    if r not in edges:
        edges[r]=1
    else:
        edges[r]+=1
    return nodes, edges

# triple 확인
nodes = dict()
edges = dict()
for triple in tqdm(kg):
    triples = [x.n3() for x in triple]
    #print(triples)
    nodes, edges = build_dict(triples, nodes, edges)
#matching = [s for s in tqdm(list(nodes.keys())) if "hadm_id" in list(nodes.keys())]
print(len(nodes))
print(len(edges))

f = open('node_dict','w')
g = open('edge_dict','w')
for node in list(nodes.keys()):
    f.write('{}\n'.format(node))
for edge in list(edges.keys()):
    g.write('{}\n'.format(edge))

100%|██████████| 1171955/1171955 [00:18<00:00, 62337.87it/s]


630450
6


In [11]:
# Make file & Build look-up table
train2id = open('train2id.txt','w')
train2id.write(str(len(kg))+'\n')
node2id = open('entity2id.txt','w')
node_lookup = {k:v for (v,k) in enumerate(nodes)}
node2id.write(str(len(nodes))+'\n')
edge2id = open('relation2id.txt','w')
edge_lookup = {k:v for (v,k) in enumerate(edges)}
edge2id.write(str(len(edges))+'\n')

# Build Node lookup
for (node, idx) in list(node_lookup.items()):
    node2id.write('{}\t{}\n'.format(node, idx))

# Build Edge lookup
for (edge, idx) in list(edge_lookup.items()):
    edge2id.write('{}\t{}\n'.format(edge, idx))

# Actual triple to id triple
for triple in tqdm(kg):
    triples = [x.n3() for x in triple]
    train2id.write('{}\t{}\t{}\n'.format(node_lookup[triples[0]],node_lookup[triples[2]],edge_lookup[triples[1]]))

100%|██████████| 1171955/1171955 [00:18<00:00, 61906.69it/s]
