# **1. Run Depth First Search on KG (Node only)** 

Root is an admission node.

In [1]:
ROOT_DIR = 'px'

In [2]:
from tqdm import tqdm
import pickle
import torch
import numpy as np
import os

NUM_SPECIAL_TOKENS = 2

def get_childs(subgraph, depth, heads):
    temp_seq = list()
    for head in heads:
        temp_seq += subgraph[depth][head]
    return temp_seq

# Build dictionaries from file
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(h,t):r for h,t,r in triples}
nodes = {' '.join(x.split()[:-1]):x.split()[-1] for x in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
literals = {k:int(v)+NUM_SPECIAL_TOKENS for (k,v) in list(nodes.items()) if '^^' in k}
edges = {x.split()[0]:x.split()[1] for x in open(os.path.join(ROOT_DIR,'relation2id.txt')).read().splitlines()[1:]}

# Extract Admission Nodes & Literals
adm_node = list()
for node in list(nodes.items()):
    if 'hadm' in node[0]:
        adm_node.append(node[1])
        
# Initialize subgraph
subgraph_norel = [{node:list() for node in adm_node}]

# Depth First Search
print('start preprocessing')
level = 0
while len(triples)>0:
    queue = list()
    print('level:{}'.format(level))
    for triple in tqdm(triples):
        if triple[0] in subgraph_norel[level]:
            subgraph_norel[level][triple[0]].append(triple[1])
            flag = False
        else:
            flag = True
        if flag:
            queue.append(triple)
    print('{}/{}'.format(len(queue),len(triples)))
    new_head = list()
    for heads in list(subgraph_norel[level].values()):
        new_head+=heads
    subgraph_norel.append({k:list() for k in new_head})
    triples = queue
    level += 1
    if level > 30:
        break

# Build subgraph
subgraphs = dict()
max_len = 768
cnt = 0
for head in tqdm(list(subgraph_norel[0].keys())):
    depth=0
    seq = [head]
    heads = [head]
    while depth<level:
        heads = get_childs(subgraph_norel,depth,heads)
        seq += heads
        depth+=1
    if len(seq)>max_len:
        continue
    else:
        subgraphs[head]=[int(x)+NUM_SPECIAL_TOKENS for x in seq]+[0]*(max_len-len(seq))

# Align subgraph and note, remove unmathced samples
aid = [nodes['</hadm_id/{}>'.format(x)] for x in open(os.path.join(ROOT_DIR,'p_hadm_ids.txt')).read().splitlines() if (len(x)>0) and ('</hadm_id/{}>'.format(x) in nodes)]
note = [x for x in open(os.path.join(ROOT_DIR,'p_sections.txt')).read().splitlines() if (len(x)>0)]
note_aid_pair = [(x,y) for (x,y) in zip(aid,note) if x in subgraphs]
print('{}/{}'.format(len(aid),len(adm_node)))
print(len(note))
print(len(note_aid_pair))
print(max(list(map(lambda x: len(x),list(subgraphs.values())))))
print('num_literals : {}'.format(len(literals.items())))

# Re-indexing nodes in current subgraph after filtering
new_nodes = list()
for head, note in note_aid_pair:
    new_nodes += subgraphs[head]
new_nodes = set(new_nodes)
old2new = dict()


  0%|          | 75321/19499414 [00:00<00:25, 753208.92it/s]

start preprocessing
level:0


100%|██████████| 19499414/19499414 [00:14<00:00, 1373368.82it/s]


16561968/19499414


  0%|          | 40451/16561968 [00:00<00:40, 404502.67it/s]

level:1


100%|██████████| 16561968/16561968 [00:29<00:00, 553861.96it/s]


0/16561968


100%|██████████| 33425/33425 [00:08<00:00, 4089.93it/s]


33684/33425
33692
26168
768


In [4]:
print('num_literals : {}'.format(len(literals.items())))
print('num_nodes : {}'.format(len(nodes.items())))

num_literals : 9291
num_nodes : 3015342


# **2. Build DB**

## 2-0. TVT split

In [None]:
# Build Input
if not os.path.exists('{}.db'.format(ROOT_DIR)):
    DB = {'train':[],'valid':[],'test':[]}
    for sample in tqdm(note_aid_pair):
        split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
        if (len(split)>0.1*len(note_aid_pair)) and (split in ['valid', 'test']):
            split = 'train'
        elif (len(split)>0.8*len(note_aid_pair)) and (split in ['train']):
            split = np.random.choice(['valid','test'],p=[0.5,0.5])
        DB[split].append(sample)
    torch.save(DB,'{}.db'.format(ROOT_DIR))
else:
    DB = torch.load('{}.db'.format(ROOT_DIR))

## 2-1-(1). Masked Literal Prediction

In [None]:
task = '{}_NoKGenc'.format(ROOT_DIR)
if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    labels = list()
    label_masks = list()
    notes = list()
    for head, note in DB[split]:
        inputs.append(subgraphs[head])
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraphs[head])))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraphs[head])))
        notes.append(note)
    torch.save({'input':inputs,
                'label':labels,
                'label_mask':label_masks,
                'text':notes},
                '{}/db'.format(os.path.join(task,split)))

## 2-1-(2). Masked Literal Prediction, Graph Enc, Multi-relational

In [None]:
task = '{}_MultiKGenc'.format(ROOT_DIR)
UniDirectional = False
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+2,int(t)+2):r for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        subgraph = subgraphs[head]
        # Append input
        inputs.append(subgraph)
        # Append label
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.stack([torch.eye(len(subgraph)) for _ in range(len(edges))],dim=2)
        if UniDirectional:
            for head_idx, head in enumerate(subgraph):
                for tail_idx, tail in enumerate(subgraph):
                    if head_idx>tail_idx:
                        continue
                    else:
                        if (head,tail) in node2edge:
                            mask[(head_idx, tail_idx, node2edge[(head,tail)])]=1.0
        else:
            for head_idx, head in enumerate(subgraph):
                for tail_idx, tail in enumerate(subgraph):
                    if head_idx>tail_idx:
                        continue
                    elif (head==0) or (tail==0):
                        continue
                    else:
                        if (head,tail) in node2edge:
                            mask[(head_idx, tail_idx, int(node2edge[(head,tail)]))]=1.0
                            mask[(tail_idx, head_idx, int(node2edge[(head,tail)]))]=1.0
        masks.append(mask)
        notes.append(note)
            
    torch.save({'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'text':notes},
                '{}/db'.format(os.path.join(task,split)))

## 2-1-(3). Masked Literal Prediction, Graph Enc, Homogeneous

In [None]:
task = '{}_UniKGenc'.format(ROOT_DIR)
triples = [x.split() for x in open(os.path.join(ROOT_DIR,'train2id.txt')).read().splitlines()[1:]]
node2edge = {(int(h)+2,int(t)+2):r for h,t,r in triples}

if not os.path.isdir(task):
    os.mkdir(task)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}
torch.save(literal_id2label,'{}/id2label'.format(task))

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    inputs = list()
    masks = list()
    labels = list()
    label_masks = list()
    notes = list()
    
    for head, note in tqdm(DB[split],total=len(DB[split])):
        subgraph = subgraphs[head]
        # Append input
        inputs.append(subgraph)
        # Append label
        labels.append(list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraph)))
        label_masks.append(list(map(lambda x: 1 if x in literal_id2label else 0,subgraph)))
        # Append attention mask for graph encoder
        mask =  torch.eye(len(subgraph))
        for head_idx, head in enumerate(subgraph):
            for tail_idx, tail in enumerate(subgraph):
                if head_idx>tail_idx:
                    continue
                elif (head==0) or (tail==0):
                    continue
                else:
                    if (head,tail) in node2edge:
                        mask[(head_idx, tail_idx)]=1.0
                        mask[(tail_idx, head_idx)]=1.0
        masks.append(mask)
        notes.append(note)
            
    torch.save({'input':inputs,
                'mask':masks,
                'label':labels,
                'label_mask':label_masks,
                'text':notes},
                '{}/db'.format(os.path.join(task,split)))

**Sanity Check**

In [None]:
IDX = 0
id2entity = {int(line.split('\t')[1])+2:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(ROOT_DIR,'entity2id.txt')).read().splitlines()[1:]}
print(notes[IDX])
print([id2entity[x] for x in inputs[IDX] if x!=0])
print(labels[IDX])
print(label_masks[IDX])
print(masks[IDX].shape)
print(~torch.tensor(inputs[IDX]).eq(0))

## 2-2. Literal Bucket Prediction _(ongoing..)_

In [None]:
task = 'masked_literal_prediction'
if not os.path.isdir(task):
    os.mkdir(task)
# Build Input
DB = {'train':[],'valid':[],'test':[]}
for sample in note_aid_pair:
    split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
    if (len(split)>0.1*len(note_aid_pair)) and (split in ['valid', 'test']):
        split = 'train'
    elif (len(split)>0.8*len(note_aid_pair)) and (split in ['train']):
        split = np.random.choice(['valid','test'],p=[0.5,0.5])
    DB[split].append(sample)

# Load Bucket ID
literalID2bucketID = torch.load('literalID2bucketID')

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    torch.save(literal_id2label,'{}/id2label'.format(os.path.join(task,split)))
    torch.save([note for (head,note) in DB[split]],'{}/note'.format(os.path.join(task,split)))
    torch.save({'input':[subgraphs[head] for (head,note) in DB[split]],
                'mask':[(~np.isin(np.array(subgraphs[head]),list(literals.values()))).astype(np.int64).tolist() for (head,note) in DB[split]],
                'label':[list(map(lambda x: literalID2bucketID[x] if x in literalID2bucketID else -100,subgraphs[head])) for (head,note) in DB[split]]},
               '{}/{}/kg_norel'.format(task,split))

## 2-3. Contrastive Learning _(ongoing..)_

In [None]:
def negative_sampling(input, mask):

task = 'masked_literal_prediction'
if not os.path.isdir(task):
    os.mkdir(task)
# Build Input
DB = {'train':[],'valid':[],'test':[]}
for sample in note_aid_pair:
    split = np.random.choice(list(DB.keys()),p=[0.8,0.1,0.1])
    if (len(split)>0.1*len(note_aid_pair)) and (split in ['valid', 'test']):
        split = 'train'
    elif (len(split)>0.8*len(note_aid_pair)) and (split in ['train']):
        split = np.random.choice(['valid','test'],p=[0.5,0.5])
    DB[split].append(sample)

# Re-index literals for labeling
literal_id2label = {k:v for (v,k) in enumerate(list(literals.values()))}

for split in DB:
    print('[{}] set size : {}'.format(split, len(DB[split])))
    if not os.path.isdir(os.path.join(task,split)):
        os.mkdir(os.path.join(task,split))
    torch.save(literal_id2label,'{}/id2label'.format(os.path.join(task,split)))
    torch.save([note for (head,note) in DB[split]],'{}/note'.format(os.path.join(task,split)))
    torch.save({'input':[subgraphs[head] for (head,note) in DB[split]],
                'mask':[(~np.isin(np.array(subgraphs[head]),list(literals.values()))).astype(np.int64).tolist() for (head,note) in DB[split]],
                'label':[list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraphs[head])) for (head,note) in DB[split]]},
               '{}/{}/kg_norel'.format(task,split))

**Supp 1. Save DB in torch.tensor format**

In [None]:
# Only for DB in tensor form
# Get id sequence of notes

print(subgraphs)
tensorized_subgraphs = torch.LongTensor([x for x in subgraphs])
print(max_len)
print(len(subgraphs))
print(tensorized_subgraphs[0,:20])
print('Saving...')
torch.save(tensorized_subgraphs,'subgraph_norel')
print('Done')

**Supp 2. Check input for debugging purpose**

In [None]:
split = 'test'
input = subgraphs[DB[split][0][0]]
mask = (~np.isin(np.array(subgraphs[DB[split][0][0]]),list(literals.values()))).astype(np.int64).tolist() 
label = list(map(lambda x: literal_id2label[x] if x in literal_id2label else -100,subgraphs[DB[split][0][0]])) 
print(input)
print(mask)
print(label)
#print(literals)
print(len(literals))
print(list(literals.items())[-1])
print(list(literal_id2label.items())[-1])
#list(literals.values())[:5]

**Supp 3. Run Depth First Search on KG (Node & Relation)**

Root is an admission node.

In [None]:
from tqdm import tqdm
import pickle
import torch
import numpy as np
import os

NUM_SPECIAL_TOKENS = 2

def get_childs_withrel(subgraph, depth, heads,node2edge):
    temp_seq = list()
    temp_heads = list()
    for head in heads:
        node_set = [(head,tail) for tail in subgraph[depth][head]]
        for node_pair in node_set:
            temp_seq += ['r'+node2edge[node_pair],node_pair[1]]
        temp_heads += subgraph[depth][head]
    return temp_seq, temp_heads

triples = [x.split() for x in open('train2id.txt').read().splitlines()[1:]]
node2edge = {(h,t):r for h,t,r in triples}
nodes = {' '.join(x.split()[:-1]):x.split()[-1] for x in open('entity2id.txt').read().splitlines()[1:]}
literals = {k:int(v)+NUM_SPECIAL_TOKENS for (k,v) in list(nodes.items()) if '^^' in node[0]}
edges = {x.split()[0]:x.split()[1] for x in open('relation2id.txt').read().splitlines()[1:]}

# Extract Admission Nodes & Literals
adm_node = list()
for node in list(nodes.items()):
    if 'hadm' in node[0]:
        adm_node.append(node[1])   
        
# Initialize subgraph
subgraph_norel = [{node:list() for node in adm_node}]

#subgraph_rel = dict(adm_node)
#node_dict = list(subgraph_norel.keys())

# Depth First Search
print('start preprocessing')
level = 0
while len(triples)>0:
    queue = list()
    print('level:{}'.format(level))
    for triple in tqdm(triples):
        if triple[0] in subgraph_norel[level]:
            subgraph_norel[level][triple[0]].append(triple[1])
            flag = False
        else:
            flag = True
        if flag:
            queue.append(triple)
    print('{}/{}'.format(len(queue),len(triples)))
    new_head = list()
    for heads in list(subgraph_norel[level].values()):
        new_head+=heads
    subgraph_norel.append({k:list() for k in new_head})
    triples = queue
    level += 1

# Build subgraph
subgraphs = list()
max_len = 0
for head in tqdm(list(subgraph_norel[0].keys())):
    depth=0
    seq = [head]
    heads = [head]
    while depth<level:
        seqs, heads = get_childs_withrel(subgraph_norel,depth,heads,node2edge)
        seq += seqs
        depth+=1
    subgraphs.append(list(map(lambda x: int(x)+NUM_SPECIAL_TOKENS if 'r' not in x else -(int(x.split('r')[-1])+1),seq)))
    if len(seq)>max_len:
        max_len = len(seq)

# Align subgraph and note
aid = [nodes['</hadm_id/{}>'.format(x)] for x in open('p_hadm_ids.txt').read().splitlines() if (len(x)>0) and ('</hadm_id/{}>'.format(x) in nodes)]
note = [x for x in open('p_sections.txt').read().splitlines() if (len(x)>0)]
note_aid_pair = [(x,y) for (x,y) in zip(aid,note) if x in subgraphs]
print('{}/{}'.format(len(aid),len(adm_node)))
print(len(note))
print(len(note_aid_pair))