# **1. Run Depth First Search on KG** 

Root is an admission node.

In [9]:
import random

ROOT_DIR = 'dxprx'
NUM_SPECIAL_TOKENS = 3
eval_size = 500

## 1-1. Not unified abstract embedding

In [2]:
from tqdm import tqdm
import pickle
import torch
import spacy, scispacy
nlp = spacy.load("en_core_sci_sm")
import numpy as np
import os

def get_childs(subgraph, depth, heads):
    temp_seq = list()
    for head in heads:
        temp_seq += subgraph[depth][head]
    return temp_seq

# Build dictionaries from file
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(h,t):r for h,t,r in triples}
nodes = {x.split('\t')[0]:x.split('\t')[-1] for x in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
literals = {k:int(v)+NUM_SPECIAL_TOKENS for (k,v) in list(nodes.items()) if '^^' in k}
edges = {x.split()[0]:x.split()[1] for x in open(os.path.join(ROOT_DIR,'relation2id.txt')).read().splitlines()[1:]}

# Extract Admission Nodes & Literals
adm_node = list()
for node in list(nodes.items()):
    if 'hadm' in node[0]:
        adm_node.append(node[1])
        
# Initialize subgraph
subgraph_norel = [{node:list() for node in adm_node}]

# Depth First Search
print('start preprocessing')
level = 0
while len(triples)>0:
    queue = list()
    print('level:{}'.format(level))
    for triple in tqdm(triples):
        if triple[0] in subgraph_norel[level]:
            subgraph_norel[level][triple[0]].append(triple[1])
            flag = False
        else:
            flag = True
        if flag:
            queue.append(triple)
    print('{}/{}'.format(len(queue),len(triples)))
    new_head = list()
    for heads in list(subgraph_norel[level].values()):
        new_head+=heads
    subgraph_norel.append({k:list() for k in new_head})
    triples = queue
    level += 1
    if level > 30:
        break

# Build subgraph
subgraphs = dict()
max_len = 239
cnt = 0
for head in tqdm(list(subgraph_norel[0].keys())):
    depth=0
    seq = [head]
    heads = [head]
    while depth<level:
        heads = get_childs(subgraph_norel,depth,heads)
        seq += heads
        depth+=1
    if len(seq)>max_len:
        continue
    else:
        subgraphs[head]=[2]+[int(x)+NUM_SPECIAL_TOKENS for x in seq]+[0]*(max_len-len(seq))


# Align subgraph and note, remove unmathced samples
#aid = [nodes['</hadm_id/{}>'.format(x)] for x in open(os.path.join(ROOT_DIR,'p_hadm_ids.txt')).read().splitlines() if (len(x)>0) and ('</hadm_id/{}>'.format(x) in nodes)]
#note = [x for x in torch.load(os.path.join(ROOT_DIR,'p_sections.txt'))if (len(x)>0)]
# Load and preprocess note
note_aid_pair = list()
f = torch.load(os.path.join(ROOT_DIR,'p_sections'))
#print(len(f))
#print(f[0])
for aid, note in torch.load(os.path.join(ROOT_DIR,'p_sections')):
    try:
        if nodes[f'</hadm_id/{aid}>'] in subgraphs:
            note_refined = {header.replace('"',''):' '.join([token.text for token in nlp(text.replace('"',''))]) for header, text in note.items()}
            note_aid_pair.append((nodes[f'</hadm_id/{aid}>'],note_refined))
    except:
        continue
#print('{}/{}'.format(len(aid),len(adm_node)))
#print(len(note))
print(len(note_aid_pair))
print(max(list(map(lambda x: len(x),list(subgraphs.values())))))
print('num_literals : {}'.format(len(literals.items())))

# Re-indexing nodes in current subgraph after filtering
new_nodes = list()
for head, note in note_aid_pair:
    new_nodes += subgraphs[head]
new_nodes = set(new_nodes)
old2new = dict()


 11%|█         | 126924/1171955 [00:00<00:00, 1269229.11it/s]

start preprocessing
level:0


100%|██████████| 1171955/1171955 [00:00<00:00, 1221176.73it/s]


589909/1171955


 13%|█▎        | 78429/589909 [00:00<00:00, 784277.66it/s]

level:1


100%|██████████| 589909/589909 [00:00<00:00, 800492.31it/s]


7863/589909


100%|██████████| 7863/7863 [00:00<00:00, 745559.23it/s]
  8%|▊         | 2766/32696 [00:00<00:01, 27658.51it/s]

level:2
0/7863


100%|██████████| 32696/32696 [00:01<00:00, 27073.52it/s]


32915
240
num_literals : 7863


In [3]:
print('num_literals : {}'.format(len(literals.items())))
print('num_nodes : {}'.format(len(nodes.items())))

num_literals : 7863
num_nodes : 630450


# **2. Build DB**

## 2-0. TVT split

In [10]:
# Build Input
if not os.path.exists('{}.db'.format(ROOT_DIR)):
    DB = {'train':[],'valid':[],'test':[]}
    for sample in tqdm(note_aid_pair):
        split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
        if (len(DB[split])>=eval_size) and (split in ['valid', 'test']):
            split = 'train'
        DB[split].append(sample)
    torch.save(DB,'{}.db'.format(ROOT_DIR))
else:
    DB = torch.load('{}.db'.format(ROOT_DIR))

## 2-1-(1). Masked Literal Prediction

In [11]:
task = '{}_NoKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    for head, note in tqdm(DB[split]):
        subgraph = subgraphs[head]
        inputs.append(subgraph)
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph if x!=0])
        rc_index = list()
        not_conn = 0
        while len(rc_index) < 0.1*num_nodes:
            idx_pair = (random.randint(0,num_nodes-1),random.randint(0,num_nodes-1))
            node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
            inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
            if (idx_pair[0] == idx_pair[1]):
                continue
            if node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[node_pair],))
            elif inv_node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[inv_node_pair],))
            else:
                if not_conn < 0.05*num_nodes:
                    rc_index.append(idx_pair+(len(set(node2edge.values())),))
                    not_conn +=1
        rc_indeces.append(rc_index)
        notes.append(note)
    db = {'input':inputs,
                'label':labels,
                'label_mask':label_masks,
                'rc_index':rc_indeces,
                'text':notes}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

 20%|██        | 2/10 [00:00<00:00, 12.49it/s]

[valid] set size : 10


100%|██████████| 10/10 [00:00<00:00, 13.19it/s]
 20%|██        | 2/10 [00:00<00:00, 17.06it/s]

[test] set size : 10


100%|██████████| 10/10 [00:00<00:00, 15.99it/s]


**Sanity Check**

In [16]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([id2entity[x] for x in v[IDX][1:] if x!=0])
    elif k=='rc_index':
        print([(id2entity[db['input'][IDX][h]],id2entity[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

input:
['</hadm_id/129414>', '</diagnoses/406609>', '</diagnoses/406605>', '</diagnoses/406610>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406612>', '</diagnoses/406616>', '</diagnoses/406608>', '</diagnoses/406615>', '</diagnoses/406614>', '</diagnoses/406611>', '</diagnoses/406607>', '</diagnoses_icd9_code/2662>', '</diagnoses_icd9_code/30000>', '</diagnoses_icd9_code/486>', '</diagnoses_icd9_code/79029>', '</diagnoses_icd9_code/4019>', '</diagnoses_icd9_code/27652>', '</diagnoses_icd9_code/2859>', '</diagnoses_icd9_code/49390>', '</diagnoses_icd9_code/2761>', '</diagnoses_icd9_code/32723>', '</diagnoses_icd9_code/27800>', '</diagnoses_icd9_code/30500>', '"other b-complex deficiencies"', '"anxiety state, unspecified"', '"pneumonia, organism unspecified"', '"other abnormal glucose"', '"unspecified essential hypertension"', '"hypovolemia"', '"anemia, unspecified"', '"asthma, unspecified type, unspecified"', '"hyposmolality and/or hyponatremia"', '"obstructive sleep ap

## 2-1-(3). Masked Literal Prediction, Graph Enc, UniKGenc

In [18]:
task = '{}_UniKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        subgraph = subgraphs[head]
        # Append input
        inputs.append(subgraph)
        # Append label
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.eye(len(subgraph))
        for head_idx, head in enumerate(subgraph):
            for tail_idx, tail in enumerate(subgraph):
                if head_idx>tail_idx:
                    continue
                elif (head==0) or (tail==0):
                    continue
                else:
                    if (head,tail) in node2edge:
                        mask[(head_idx, tail_idx)]=1.0
                        mask[(tail_idx, head_idx)]=1.0
        masks.append(mask)
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph if x!=0])
        rc_index = list()
        not_conn = 0
        while len(rc_index) < 0.1*num_nodes:
            idx_pair = (random.randint(0,num_nodes-1),random.randint(0,num_nodes-1))
            node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
            inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
            if (idx_pair[0] == idx_pair[1]):
                continue
            if node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[node_pair],))
            elif inv_node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[inv_node_pair],))
            else:
                if not_conn < 0.05*num_nodes:
                    rc_index.append(idx_pair+(len(set(node2edge.values())),))
                    not_conn +=1
        rc_indeces.append(rc_index)
        # Append note for text encoder
        notes.append(note)
            
    db = {'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'text':notes,
                'rc_index':rc_indeces}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

 10%|█         | 1/10 [00:00<00:00,  9.59it/s]

[valid] set size : 10


100%|██████████| 10/10 [00:00<00:00, 11.82it/s]
 20%|██        | 2/10 [00:00<00:00, 15.35it/s]

[test] set size : 10


100%|██████████| 10/10 [00:00<00:00, 14.58it/s]


**Sanity Check**

In [21]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([id2entity[x] for x in v[IDX][1:] if x!=0])
    elif k=='mask':
        print(v[IDX][1])
    elif k=='rc_index':
        print([(id2entity[db['input'][IDX][h]],id2entity[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

input:
['</hadm_id/129414>', '</diagnoses/406609>', '</diagnoses/406605>', '</diagnoses/406610>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406612>', '</diagnoses/406616>', '</diagnoses/406608>', '</diagnoses/406615>', '</diagnoses/406614>', '</diagnoses/406611>', '</diagnoses/406607>', '</diagnoses_icd9_code/2662>', '</diagnoses_icd9_code/30000>', '</diagnoses_icd9_code/486>', '</diagnoses_icd9_code/79029>', '</diagnoses_icd9_code/4019>', '</diagnoses_icd9_code/27652>', '</diagnoses_icd9_code/2859>', '</diagnoses_icd9_code/49390>', '</diagnoses_icd9_code/2761>', '</diagnoses_icd9_code/32723>', '</diagnoses_icd9_code/27800>', '</diagnoses_icd9_code/30500>', '"other b-complex deficiencies"', '"anxiety state, unspecified"', '"pneumonia, organism unspecified"', '"other abnormal glucose"', '"unspecified essential hypertension"', '"hypovolemia"', '"anemia, unspecified"', '"asthma, unspecified type, unspecified"', '"hyposmolality and/or hyponatremia"', '"obstructive sleep ap

## 2-2-(1). Unified Abstract Embedding, NoKGenc

In [22]:
task = '{}_UnifiedNoKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
#literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
#torch.save(literal_id2label,'{}/id2label'.format(task))

# Abstract Node Unification
node2uninode = {k:k for k in range(NUM_SPECIAL_TOKENS)}
if ROOT_DIR == 'px':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'prescript':4,'icustay':5}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'prescript' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['prescript']
        elif 'icustay' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['icustay']
        else:
            raise ValueError()

elif ROOT_DIR == 'dxprx':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'diagnoses_icd9_code':4,'diagnoses':5,'procedures_icd9_code':6,'procedures':7}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)     
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'diagnoses_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses_icd9_code']
        elif 'diagnoses' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses']
        elif 'procedures_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures_icd9_code']
        elif 'procedures' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures']
        else:
            raise ValueError()

torch.save(unified_node,'{}/unified_node'.format(task))
torch.save(node2uninode,'{}/node2uninode'.format(task))
print('# Unified nodes : {}'.format(len(unified_node)))
            
for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    for head, note in tqdm(DB[split]):
        subgraph = subgraphs[head]
        inputs.append(list(map(lambda x: node2uninode[x],subgraph)))
        labels.append(list(map(lambda x: node2uninode[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph if x!=0])
        rc_index = list()
        not_conn = 0
        while len(rc_index) < 0.1*num_nodes:
            idx_pair = (random.randint(0,num_nodes-1),random.randint(0,num_nodes-1))
            node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
            inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
            if (idx_pair[0] == idx_pair[1]):
                continue
            if node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[node_pair],))
            elif inv_node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[inv_node_pair],))
            else:
                if not_conn < 0.05*num_nodes:
                    rc_index.append(idx_pair+(len(set(node2edge.values())),))
                    not_conn +=1
        rc_indeces.append(rc_index)
        notes.append(note)
    db = {'input':inputs,
                'label':labels,
                'label_mask':label_masks,
                'rc_index':rc_indeces,
                'text':notes}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

# Unified nodes : 7871
[valid] set size : 10
[test] set size : 10


**Sanity Check**

In [26]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
uninode2name = {v:k.split('^^')[0] for k,v in unified_node.items()}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([uninode2name[x] for x in v[IDX] if x!=0])
    elif k=='rc_index':
        print([(uninode2name[db['input'][IDX][h]],uninode2name[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

input:
['CLS', 'hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', '"other b-complex deficiencies"', '"anxiety state, unspecified"', '"pneumonia, organism unspecified"', '"other abnormal glucose"', '"unspecified essential hypertension"', '"hypovolemia"', '"anemia, unspecified"', '"asthma, unspecified type, unspecified"', '"hyposmolality and/or hyponatremia"', '"obstructive sleep apnea (adult)(pediatric)"', '"obesity, unspecified"', '"alcohol abuse, unspecified"']
label:
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100

## 2-2-(2). Unified Abstract Embedding, UniKGenc

In [27]:
task = '{}_UnifiedUniKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):int(r) for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

# Abstract Node Unification
node2uninode = {k:k for k in range(NUM_SPECIAL_TOKENS)}
if ROOT_DIR == 'px':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'prescript':4,'icustay':5}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'prescript' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['prescript']
        elif 'icustay' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['icustay']
        else:
            raise ValueError()

elif ROOT_DIR == 'dxprx':
    unified_node= {'PAD':0,'MASK':1,'CLS':2,'hadm':3,'diagnoses_icd9_code':4,'diagnoses':5,'procedures_icd9_code':6,'procedures':7}
    for key in nodes:
        if key in literals:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = len(unified_node)
            unified_node[key]=len(unified_node)     
        elif 'hadm' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['hadm']
        elif 'diagnoses_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses_icd9_code']
        elif 'diagnoses' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['diagnoses']
        elif 'procedures_icd9_code' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures_icd9_code']
        elif 'procedures' in key:
            node2uninode[int(nodes[key])+NUM_SPECIAL_TOKENS] = unified_node['procedures']
        else:
            raise ValueError()
            
torch.save(unified_node,'{}/unified_node'.format(task))
torch.save(node2uninode,'{}/node2uninode'.format(task))
print('# Unified nodes : {}'.format(len(unified_node)))

for split in DB:
    ## Debugging Purpose
    #if split == 'train':
    #    continue
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    rc_indeces = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        subgraph = subgraphs[head]
        # Append input
        inputs.append(list(map(lambda x: node2uninode[x],subgraph)))
        # Append label
        labels.append(list(map(lambda x: node2uninode[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.eye(len(subgraph))
        for head_idx, head in enumerate(subgraph):
            for tail_idx, tail in enumerate(subgraph):
                if head_idx>tail_idx:
                    continue
                elif (head==0) or (tail==0):
                    continue
                else:
                    if (head,tail) in node2edge:
                        mask[(head_idx, tail_idx)]=1.0
                        mask[(tail_idx, head_idx)]=1.0
        # Add RC index for sample
        num_nodes = sum([1 for x in subgraph if x!=0])
        rc_index = list()
        not_conn = 0
        while len(rc_index) < 0.1*num_nodes:
            idx_pair = (random.randint(0,num_nodes-1),random.randint(0,num_nodes-1))
            node_pair = (subgraph[idx_pair[0]], subgraph[idx_pair[1]])
            inv_node_pair = (subgraph[idx_pair[1]], subgraph[idx_pair[0]])
            if (idx_pair[0] == idx_pair[1]):
                continue
            if node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[node_pair],))
            elif inv_node_pair in node2edge:
                rc_index.append(idx_pair+(node2edge[inv_node_pair],))
            else:
                if not_conn < 0.05*num_nodes:
                    rc_index.append(idx_pair+(len(set(node2edge.values())),))
                    not_conn +=1
        rc_indeces.append(rc_index)
        masks.append(mask)
        notes.append(note)
    db = {'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'rc_index':rc_indeces,
                'text':notes}
    torch.save(db,'{}/db'.format(os.path.join(task,split)))

 10%|█         | 1/10 [00:00<00:00,  9.88it/s]

# Unified nodes : 7871
[valid] set size : 10


100%|██████████| 10/10 [00:00<00:00, 12.09it/s]
 20%|██        | 2/10 [00:00<00:00, 15.56it/s]

[test] set size : 10


100%|██████████| 10/10 [00:00<00:00, 14.37it/s]


**Sanity Check**

In [28]:
IDX = 1
id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
uninode2name = {v:k.split('^^')[0] for k,v in unified_node.items()}
for k, v in db.items():
    print(f'{k}:')
    if k=='input':
        print([uninode2name[x] for x in v[IDX] if x!=0])
    elif k=='mask':
        print(v[IDX][1])
    elif k=='rc_index':
        print([(uninode2name[db['input'][IDX][h]],uninode2name[db['input'][IDX][t]],r) for h,t,r in v[IDX]])
    else:
        print(v[IDX])

input:
['CLS', 'hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', 'diagnoses_icd9_code', '"other b-complex deficiencies"', '"anxiety state, unspecified"', '"pneumonia, organism unspecified"', '"other abnormal glucose"', '"unspecified essential hypertension"', '"hypovolemia"', '"anemia, unspecified"', '"asthma, unspecified type, unspecified"', '"hyposmolality and/or hyponatremia"', '"obstructive sleep apnea (adult)(pediatric)"', '"obesity, unspecified"', '"alcohol abuse, unspecified"']
mask:
tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

---

## 2-3. Literal Bucket Prediction _(ongoing..)_

In [None]:
task = 'masked_literal_prediction'
if not os.path.isdir(task):
    os.mkdir(task)
# Build Input
DB = {'train':[],'valid':[],'test':[]}
for sample in note_aid_pair:
    split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
    if (len(split)>0.1*len(note_aid_pair)) and (split in ['valid', 'test']):
        split = 'train'
    elif (len(split)>0.8*len(note_aid_pair)) and (split in ['train']):
        split = np.random.choice(['valid','test'],p=[0.5,0.5])
    DB[split].append(sample)

# Load Bucket ID
literalID2bucketID = torch.load('literalID2bucketID')

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    torch.save(literal_id2label,'{}/id2label'.format(os.path.join(task,split)))
    torch.save([note for (head,note) in DB[split]],'{}/note'.format(os.path.join(task,split)))
    torch.save({'input':[subgraphs[head] for (head,note) in DB[split]],
                'mask':[(~np.isin(np.array(subgraphs[head]),list(literals.values()))).astype(np.int64).tolist() for (head,note) in DB[split]],
                'label':[list(map(lambda x: literalID2bucketID[x] if x in literalID2bucketID else -100,subgraphs[head])) for (head,note) in DB[split]]},
               '{}/{}/kg_norel'.format(task,split))

## 2-4. Contrastive Learning _(ongoing..)_

In [None]:
def negative_sampling(input, mask):

task = 'masked_literal_prediction'
if not os.path.isdir(task):
    os.mkdir(task)
# Build Input
DB = {'train':[],'valid':[],'test':[]}
for sample in note_aid_pair:
    split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
    if (len(split)>0.1*len(note_aid_pair)) and (split in ['valid', 'test']):
        split = 'train'
    elif (len(split)>0.8*len(note_aid_pair)) and (split in ['train']):
        split = np.random.choice(['valid','test'],p=[0.5,0.5])
    DB[split].append(sample)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    torch.save(literal_id2label,'{}/id2label'.format(os.path.join(task,split)))
    torch.save([note for (head,note) in DB[split]],'{}/note'.format(os.path.join(task,split)))
    torch.save({'input':[subgraphs[head] for (head,note) in DB[split]],
                'mask':[(~np.isin(np.array(subgraphs[head]),list(literals.values()))).astype(np.int64).tolist() for (head,note) in DB[split]],
                'label':[list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraphs[head])) for (head,note) in DB[split]]},
               '{}/{}/kg_norel'.format(task,split))

**Supp 1. Save DB in torch.tensor format**

In [None]:
# Only for DB in tensor form
# Get id sequence of notes

print(subgraphs)
tensorized_subgraphs = torch.LongTensor([x for x in subgraphs])
print(max_len)
print(len(subgraphs))
print(tensorized_subgraphs[0,:20])
print('Saving...')
torch.save(tensorized_subgraphs,'subgraph_norel')
print('Done')

**Supp 2. Check input for debugging purpose**

In [None]:
split = 'test'
input = subgraphs[DB[split][0][0]]
mask = (~np.isin(np.array(subgraphs[DB[split][0][0]]),list(literals.values()))).astype(np.int64).tolist() 
label = list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraphs[DB[split][0][0]])) 
print(input)
print(mask)
print(label)
#print(literals)
print(len(literals))
print(list(literals.items())[-1])
print(list(literal_id2label.items())[-1])
#list(literals.values())[:5]

**Supp 3. Run Depth First Search on KG (Node & Relation)**

Root is an admission node.

In [None]:
from tqdm import tqdm
import pickle
import torch
import numpy as np
import os

NUM_SPECIAL_TOKENS = 2

def get_childs_withrel(subgraph, depth, heads,node2edge):
    temp_seq = list()
    temp_heads = list()
    for head in heads:
        node_set = [(head,tail) for tail in subgraph[depth][head]]
        for node_pair in node_set:
            temp_seq += ['r'+node2edge[node_pair],node_pair[1]]
        temp_heads += subgraph[depth][head]
    return temp_seq, temp_heads

triples = [x.split() for x in open('train2id.txt').read().splitlines()[1:]]
node2edge = {(h,t):r for h,t,r in triples}
nodes = {' '.join(x.split()[:-1]):x.split()[-1] for x in open('entity2id.txt').read().splitlines()[1:]}
literals = {k:int(v)+NUM_SPECIAL_TOKENS for (k,v) in list(nodes.items()) if '^^' in node[0]}
edges = {x.split()[0]:x.split()[1] for x in open('relation2id.txt').read().splitlines()[1:]}

# Extract Admission Nodes & Literals
adm_node = list()
for node in list(nodes.items()):
    if 'hadm' in node[0]:
        adm_node.append(node[1])   
        
# Initialize subgraph
subgraph_norel = [{node:list() for node in adm_node}]

#subgraph_rel = dict(adm_node)
#node_dict = list(subgraph_norel.keys())

# Depth First Search
print('start preprocessing')
level = 0
while len(triples)>0:
    queue = list()
    print('level:{}'.format(level))
    for triple in tqdm(triples):
        if triple[0] in subgraph_norel[level]:
            subgraph_norel[level][triple[0]].append(triple[1])
            flag = False
        else:
            flag = True
        if flag:
            queue.append(triple)
    print('{}/{}'.format(len(queue),len(triples)))
    new_head = list()
    for heads in list(subgraph_norel[level].values()):
        new_head+=heads
    subgraph_norel.append({k:list() for k in new_head})
    triples = queue
    level += 1

# Build subgraph
subgraphs = list()
max_len = 0
for head in tqdm(list(subgraph_norel[0].keys())):
    depth=0
    seq = [head]
    heads = [head]
    while depth<level:
        seqs, heads = get_childs_withrel(subgraph_norel,depth,heads,node2edge)
        seq += seqs
        depth+=1
    subgraphs.append(list(map(lambda x: int(x)+NUM_SPECIAL_TOKENS if 'r' not in x else -(int(x.split('r')[-1])+1),seq)))
    if len(seq)>max_len:
        max_len = len(seq)

# Align subgraph and note
aid = [nodes['</hadm_id/{}>'.format(x)] for x in open('p_hadm_ids.txt').read().splitlines() if (len(x)>0) and ('</hadm_id/{}>'.format(x) in nodes)]
note = [x for x in open('p_sections.txt').read().splitlines() if (len(x)>0)]
note_aid_pair = [(x,y) for (x,y) in zip(aid,note) if x in subgraphs]
print('{}/{}'.format(len(aid),len(adm_node)))
print(len(note))
print(len(note_aid_pair))

## 2-1-(2). Masked Literal Prediction, Graph Enc, MultiKGenc

In [None]:
task = '{}_MultiKGenc'.format(ROOT_DIR)
UniDirectional = False
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+NUM_SPECIAL_TOKENS,int(t)+NUM_SPECIAL_TOKENS):r for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        subgraph = subgraphs[head]
        # Append input
        inputs.append(subgraph)
        # Append label
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.stack([torch.eye(len(subgraph)) for _ in range(len(edges))],dim=2)
        if UniDirectional:
            for head_idx, head in enumerate(subgraph):
                for tail_idx, tail in enumerate(subgraph):
                    if head_idx>tail_idx:
                        continue
                    else:
                        if (head,tail) in node2edge:
                            mask[(head_idx, tail_idx, node2edge[(head,tail)])]=1.0
        else:
            for head_idx, head in enumerate(subgraph):
                for tail_idx, tail in enumerate(subgraph):
                    if head_idx>tail_idx:
                        continue
                    elif (head==0) or (tail==0):
                        continue
                    else:
                        if (head,tail) in node2edge:
                            mask[(head_idx, tail_idx, int(node2edge[(head,tail)]))]=1.0
                            mask[(tail_idx, head_idx, int(node2edge[(head,tail)]))]=1.0
        masks.append(mask)
        notes.append(note)
            
    torch.save({'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'text':notes},
                '{}/db'.format(os.path.join(task,split)))