# **1. Run Depth First Search on KG** 

Root is an admission node.

In [1]:
import random

ROOT_DIR = 'dxprx'
NUM_SPECIAL_TOKENS = 3
eval_size = 1000

## 1-1. Not unified abstract embedding

In [None]:
from tqdm import tqdm
import pickle
import torch
import spacy, scispacy
nlp = spacy.load("en_core_sci_sm")
import numpy as np
import os

def get_childs(subgraph, depth, heads):
    temp_seq = list()
    for head in heads:
        temp_seq += subgraph[depth][head]
    return temp_seq

# Build dictionaries from file
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(h,t):r for h,t,r in triples}
nodes = {x.split('\t')[0]:x.split('\t')[-1] for x in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
literals = {k:int(v)+NUM_SPECIAL_TOKENS for (k,v) in list(nodes.items()) if '^^' in k}
edges = {x.split()[0]:x.split()[1] for x in open(os.path.join(ROOT_DIR,'relation2id.txt')).read().splitlines()[1:]}

# Extract Admission Nodes & Literals
adm_node = list()
for node in list(nodes.items()):
    if 'hadm' in node[0]:
        adm_node.append(node[1])
        
# Initialize subgraph
subgraph_norel = [{node:list() for node in adm_node}]

# Depth First Search
print('start preprocessing')
level = 0
while len(triples)>0:
    queue = list()
    print('level:{}'.format(level))
    for triple in tqdm(triples):
        if triple[0] in subgraph_norel[level]:
            subgraph_norel[level][triple[0]].append(triple[1])
            flag = False
        else:
            flag = True
        if flag:
            queue.append(triple)
    print('{}/{}'.format(len(queue),len(triples)))
    new_head = list()
    for heads in list(subgraph_norel[level].values()):
        new_head+=heads
    subgraph_norel.append({k:list() for k in new_head})
    triples = queue
    level += 1
    if level > 30:
        break

# Build subgraph
subgraphs = dict()
max_len = 239
cnt = 0
for head in tqdm(list(subgraph_norel[0].keys())):
    depth=0
    seq = [head]
    heads = [head]
    while depth<level:
        heads = get_childs(subgraph_norel,depth,heads)
        seq += heads
        depth+=1
    if len(seq)>max_len:
        continue
    else:
        subgraphs[head]=[2]+[int(x)+NUM_SPECIAL_TOKENS for x in seq]+[0]*(max_len-len(seq))
 
# Align subgraph and note, remove unmathced samples
note_aid_pair = list()
f = torch.load(os.path.join(ROOT_DIR,'p_sections'))

for aid, note in torch.load(os.path.join(ROOT_DIR,'p_sections')):
    try:
        if nodes[f'</hadm_id/{aid}>'] in subgraphs:
            note_refined = {header.replace('"',''):' '.join([token.text for token in nlp(text.replace('"',''))]) for header, text in note.items()}
            note_aid_pair.append((nodes[f'</hadm_id/{aid}>'],note_refined))
    except:
        continue

print(len(note_aid_pair))
print(max(list(map(lambda x: len(x),list(subgraphs.values())))))
print('num_literals : {}'.format(len(literals.items())))

 12%|█▏        | 144515/1171955 [00:00<00:00, 1445141.04it/s]

start preprocessing
level:0


100%|██████████| 1171955/1171955 [00:00<00:00, 1452395.16it/s]


589909/1171955


 14%|█▍        | 82319/589909 [00:00<00:00, 823188.82it/s]

level:1


100%|██████████| 589909/589909 [00:00<00:00, 837489.66it/s]


7863/589909


100%|██████████| 7863/7863 [00:00<00:00, 876679.67it/s]
  9%|▊         | 2804/32696 [00:00<00:01, 28038.15it/s]

level:2
0/7863


100%|██████████| 32696/32696 [00:01<00:00, 30396.74it/s]


In [None]:
print('num_literals : {}'.format(len(literals.items())))
print('num_nodes : {}'.format(len(nodes.items())))

# **2. Build DB**

## 2-0. TVT split

In [None]:
# Build Input
if not os.path.exists('{}.db'.format(ROOT_DIR)):
    DB = {'train':[],'valid':[],'test':[]}
    for sample in tqdm(note_aid_pair):
        split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
        if (len(DB[split])>=eval_size) and (split in ['valid', 'test']):
            split = 'train'
        DB[split].append(sample)
    torch.save(DB,'{}.db'.format(ROOT_DIR))
else:
    DB = torch.load('{}.db'.format(ROOT_DIR))

## 2-1-(1). Masked Literal Prediction

In [None]:
task = '{}_NoKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    for head, note in tqdm(DB[split]):
        if head not in subgraphs:
            continue
        subgraph = subgraphs[head]
        inputs.append(subgraph)
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph[1:] if x!=0])
        num_edge_types = len(set(node2edge.values()))
        rc_index = list()
        head_indeces = [random.randint(1,num_nodes) for _ in range(int(0.1*num_nodes))]
        not_conn = 0
        for head_idx in head_indeces:
            for tail_idx in range(1,num_nodes+1):
                if (random.random()<0.25) or (head_idx==tail_idx):
                    continue
                idx_pair = (head_idx, tail_idx)
                node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
                inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
                if node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[node_pair],))
                    break
                elif inv_node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[inv_node_pair],))
                    break
                else:
                    if not_conn < 0.1*num_nodes/(num_edge_types+1):
                        rc_index.append(idx_pair+(num_edge_types,))
                        not_conn +=1
                        break
        rc_indeces.append(rc_index)
        notes.append(note)
    db = {'input':inputs,
                'label':labels,
                'label_mask':label_masks,
                'rc_index':rc_indeces,
                'text':notes}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

**Sanity Check**

In [None]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([id2entity[x] for x in v[IDX][1:] if x!=0])
    elif k=='rc_index':
        print([(id2entity[db['input'][IDX][h]],id2entity[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

## 2-1-(3). Masked Literal Prediction, Graph Enc, UniKGenc

In [None]:
task = '{}_UniKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        if head not in subgraphs:
            continue
        subgraph = subgraphs[head]
        # Append input
        inputs.append(subgraph)
        # Append label
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.eye(len(subgraph))
        for head_idx, head in enumerate(subgraph):
            for tail_idx, tail in enumerate(subgraph):
                if head_idx>tail_idx:
                    continue
                elif (head==0) or (tail==0):
                    continue
                else:
                    if (head,tail) in node2edge:
                        mask[(head_idx, tail_idx)]=1.0
                        mask[(tail_idx, head_idx)]=1.0
        masks.append(mask)
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph[1:] if x!=0])
        num_edge_types = len(set(node2edge.values()))
        rc_index = list()
        head_indeces = [random.randint(1,num_nodes) for _ in range(int(0.1*num_nodes))]
        not_conn = 0
        for head_idx in head_indeces:
            for tail_idx in range(1,num_nodes+1):
                if (random.random()<0.25) or (head_idx==tail_idx):
                    continue
                idx_pair = (head_idx, tail_idx)
                node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
                inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
                if node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[node_pair],))
                    break
                elif inv_node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[inv_node_pair],))
                    break
                else:
                    if not_conn < 0.1*num_nodes/(num_edge_types+1):
                        rc_index.append(idx_pair+(num_edge_types,))
                        not_conn +=1
                        break
        rc_indeces.append(rc_index)
        # Append note for text encoder
        notes.append(note)
            
    db = {'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'text':notes,
                'rc_index':rc_indeces}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

**Sanity Check**

In [None]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([id2entity[x] for x in v[IDX][1:] if x!=0])
    elif k=='mask':
        print(v[IDX][1])
    elif k=='rc_index':
        print([(id2entity[db['input'][IDX][h]],id2entity[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

## 2-2-(1). Unified Abstract Embedding, NoKGenc

In [None]:
task = '{}_UnifiedNoKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
#literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
#torch.save(literal_id2label,'{}/id2label'.format(task))

# Abstract Node Unification
node2uninode = {k:k for k in range(NUM_SPECIAL_TOKENS)}
if ROOT_DIR == 'px':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'prescript':4,'icustay':5}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'prescript' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['prescript']
        elif 'icustay' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['icustay']
        else:
            raise ValueError()

elif ROOT_DIR == 'dxprx':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'diagnoses_icd9_code':4,'diagnoses':5,'procedures_icd9_code':6,'procedures':7}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)     
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'diagnoses_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses_icd9_code']
        elif 'diagnoses' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses']
        elif 'procedures_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures_icd9_code']
        elif 'procedures' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures']
        else:
            raise ValueError()

torch.save(unified_node,'{}/unified_node'.format(task))
torch.save(node2uninode,'{}/node2uninode'.format(task))
print('# Unified nodes : {}'.format(len(unified_node)))
            
for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    for head, note in tqdm(DB[split]):
        if head not in subgraphs:
            continue
        subgraph = subgraphs[head]
        inputs.append(list(map(lambda x: node2uninode[x],subgraph)))
        labels.append(list(map(lambda x: node2uninode[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph[1:] if x!=0])
        num_edge_types = len(set(node2edge.values()))
        rc_index = list()
        head_indeces = [random.randint(1,num_nodes) for _ in range(int(0.1*num_nodes))]
        not_conn = 0
        for head_idx in head_indeces:
            for tail_idx in range(1,num_nodes+1):
                if (random.random()<0.25) or (head_idx==tail_idx):
                    continue
                idx_pair = (head_idx, tail_idx)
                node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
                inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
                if node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[node_pair],))
                    break
                elif inv_node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[inv_node_pair],))
                    break
                else:
                    if not_conn < 0.1*num_nodes/(num_edge_types+1):
                        rc_index.append(idx_pair+(num_edge_types,))
                        not_conn +=1
                        break
        rc_indeces.append(rc_index)
        notes.append(note)
    db = {'input':inputs,
                'label':labels,
                'label_mask':label_masks,
                'rc_index':rc_indeces,
                'text':notes}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

**Sanity Check**

In [None]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
uninode2name = {v:k.split('^^')[0] for k,v in unified_node.items()}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([uninode2name[x] for x in v[IDX] if x!=0])
    elif k=='rc_index':
        print([(uninode2name[db['input'][IDX][h]],uninode2name[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

## 2-2-(2). Unified Abstract Embedding, UniKGenc

In [None]:
task = '{}_UnifiedUniKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

# Abstract Node Unification
node2uninode = {k:k for k in range(NUM_SPECIAL_TOKENS)}
if ROOT_DIR == 'px':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'prescript':4,'icustay':5}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'prescript' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['prescript']
        elif 'icustay' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['icustay']
        else:
            raise ValueError()

elif ROOT_DIR == 'dxprx':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'diagnoses_icd9_code':4,'diagnoses':5,'procedures_icd9_code':6,'procedures':7}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)     
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'diagnoses_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses_icd9_code']
        elif 'diagnoses' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses']
        elif 'procedures_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures_icd9_code']
        elif 'procedures' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures']
        else:
            raise ValueError()
            
torch.save(unified_node,'{}/unified_node'.format(task))
torch.save(node2uninode,'{}/node2uninode'.format(task))
print('# Unified nodes : {}'.format(len(unified_node)))

for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        if head not in subgraphs:
            continue
        subgraph = subgraphs[head]
        # Append input
        inputs.append(list(map(lambda x: node2uninode[x],subgraph)))
        # Append label
        labels.append(list(map(lambda x: node2uninode[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.eye(len(subgraph))
        for head_idx, head in enumerate(subgraph):
            for tail_idx, tail in enumerate(subgraph):
                if head_idx>tail_idx:
                    continue
                elif (head==0) or (tail==0):
                    continue
                else:
                    if (head,tail) in node2edge:
                        mask[(head_idx, tail_idx)]=1.0
                        mask[(tail_idx, head_idx)]=1.0
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph[1:] if x!=0])
        num_edge_types = len(set(node2edge.values()))
        rc_index = list()
        head_indeces = [random.randint(1,num_nodes) for _ in range(int(0.1*num_nodes))]
        not_conn = 0
        for head_idx in head_indeces:
            for tail_idx in range(1,num_nodes+1):
                if (random.random()<0.25) or (head_idx==tail_idx):
                    continue
                idx_pair = (head_idx, tail_idx)
                node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
                inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
                if node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[node_pair],))
                    break
                elif inv_node_pair in node2edge:
                    rc_index.append(idx_pair+(node2edge[inv_node_pair],))
                    break
                else:
                    if not_conn < 0.1*num_nodes/(num_edge_types+1):
                        rc_index.append(idx_pair+(num_edge_types,))
                        not_conn +=1
                        break
        rc_indeces.append(rc_index)
        masks.append(mask)
        notes.append(note)
    db = {'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'rc_index':rc_indeces,
                'text':notes}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

**Sanity Check**

In [None]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
uninode2name = {v:k.split('^^')[0] for k,v in unified_node.items()}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([uninode2name[x] for x in v[IDX] if x!=0])
    elif k=='mask':
        print(v[IDX][1])
    elif k=='rc_index':
        print([(uninode2name[db['input'][IDX][h]],uninode2name[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])