### General

In [1]:
import os
import sys

import re
import numpy as np
import pandas as pd
import json
import random
import copy
import h5py
from tqdm import tqdm

In [2]:
path = '/Users/soulofshadow/Downloads/Project'

os.chdir(path)
sys.path.append(path)

### Load data

In [44]:
dataset = pd.read_csv(path + '/data/text_to_entity.csv')

In [45]:
dataset['entity'] = dataset['entity'].apply(eval)

In [None]:
'''
for read triplet from UMLS.RRF, just need once to generate umls_triplet.csv file for later use

from utils.load_umls import UMLS
umls = UMLS(path + '/data')

#build the triplets
umls_triplet = []

for rel in umls.rel:
    triplet = rel.strip().split("\t")
    if len(triplet) != 4:
        continue;

    head = triplet[0]
    tail= triplet[1]
    relation = triplet[3]

    tri = [head, tail, relation]
    umls_triplet.append(tri)

print("triplets count:", len(umls_triplet))

umls_triplet = pd.DataFrame(umls_triplet)
umls_triplet.columns = ['head', 'tail', 'relation']
umls_triplet.drop_duplicates()

umls_triplet.to_csv(path + '/data/umls_triplet.csv', index=False)

'''

In [None]:
#transfer from str CUI to int CUI
#like "C1022345" to 1022345
def tran_to_index(x):
    if not isinstance(x, int):
        return int(x[1:])
    else:
        return x


#read umls relations
umls_triplets = pd.read_csv(path + '/data/build_tri/umls_triplet.csv')


umls_triplets['head'] = umls_triplets['head'].map(tran_to_index)
umls_triplets['tail'] = umls_triplets['tail'].map(tran_to_index)

dict_triplets = dict(zip(zip(umls_triplets['head'], umls_triplets['tail']), umls_triplets['relation']))

print("triplets count:", len(umls_triplets))

In [46]:
umls_triplets.head()

Unnamed: 0,head,tail,relation
0,934536,505381,inverse_isa
1,11008,943468,has_component
2,60617,533932,mapped_to
3,1442351,486234,has_system
4,2700007,1514011,is_abnormal_cell_of_disease


In [47]:
# Define the relations_map
umls_relations = set(umls_triplets['relation'])
relations_map = {i: relation for i, relation in enumerate(umls_relations)}
relations_r_map = {v:k for k,v in relations_map.items()}

print("There are total {} kinds of relationship".format(len(umls_relations)))

There are total 608 kinds of relationship


### Align

In [106]:
dataset['triplet'] = None

build_entities = {}
build_triplets = set()
build_cui_triplets = set()

for index, item in tqdm(dataset.iterrows()):

    entities = copy.deepcopy(item['entity'])

    for entity in entities:
        entity['cui'] = tran_to_index(entity['cui'])
        if entity['cui'] not in build_entities.keys():
            build_entities[entity['cui']] = entity['preferred_name']

    #start to align text with tripltes of umls
    aligned_triplets = []
    for entity1 in entities:
        for entity2 in entities:
            # we use a pair of entitis to detect whether they have a relation in the UMLS triplets
            if entity1['cui'] != entity2['cui']:
                if (entity1['cui'], entity2['cui']) in dict_triplets.keys():
                    #we find one, we use this triplet to label this row
                    #instead of store cui, we stored the string name of this entity
                    rel = dict_triplets[(entity1['cui'], entity2['cui'])]

                    cui_triplet = [entity1['cui'], relations_r_map[rel], entity2['cui']]
                    triplet = [entity1['preferred_name'], rel, entity2['preferred_name']]

                    build_cui_triplets.add(tuple(cui_triplet))
                    build_triplets.add(tuple(triplet))
                    aligned_triplets.append(triplet)
    #Align
    item['triplet'] = aligned_triplets

27it [00:01, 21.65it/s]


In [89]:
#save result

with open(path + "/data/build_tri/build_entities.json", "w") as json_file:
    json.dump(build_entities, json_file)

with open(path + "/data/build_tri/build_triplets.json", "w") as json_file:
    json.dump(list(build_triplets), json_file)

with open(path + "/data/build_tri/build_cui_triplets.json", "w") as json_file:
    json.dump(list(build_cui_triplets), json_file)


In [None]:
#get triplets 

triplets = [' '.join(triplet.split('_')) for triplet in build_triplets]

with open(path + "/data/triplets.json", "w") as json_file:
    json.dump(triplets, json_file)

### extract subgraph

In [108]:
#read entity and triplet

with open(path + "/data/build_tri/build_entities.json", "r") as json_file:
    build_entities = json.load(json_file)
    build_entities = {int(key):value for key, value in build_entities.items()}

with open(path + "/data/build_tri/build_triplets.json", "r") as json_file:
    build_triplets = json.load(json_file)

with open(path + "/data/build_tri/build_cui_triplets.json", "r") as json_file:
    build_cui_triplets = json.load(json_file)

In [80]:
'''
Here is for build rel_pairs as dict
{(rel1, rel2) : count}
'''
rel_pairs = {}

def count_rel_pairs(triplets):
    #for better match speed, map it to index
    rels = [relations_r_map[triplet[1]] for triplet in triplets]

    length = len(rels)
    for i in range(length):
        rel_i = rels[i]
        if i == length - 1:
            break;
        for j in range(i+1, length):
            rel_j = rels[j]
            if ((rel_i, rel_j) in rel_pairs.keys() or (rel_j, rel_i) in rel_pairs.keys()) and (rel_i != rel_j):
                if (rel_i, rel_j) in rel_pairs.keys():
                    rel_pairs[(rel_i, rel_j)] += 1
                else:
                    rel_pairs[(rel_j, rel_i)] += 1
            else:
                rel_pairs[(rel_i, rel_j)] = 1

In [81]:
for index, item in tqdm(dataset.iterrows()):

    triplets = item['triplet'] #column 3
    if triplets is None:
        continue

    count_rel_pairs(triplets)

27it [00:00, 9733.24it/s]


In [82]:
import heapq

#Trun to a dict of
#key as rel
#value as heapq of tuple (count, rel_2)
# rel_pairs = sorted(rel_pairs.items(), key=lambda x:x[1], reverse=True)

dict_of_maxheap = {}

for pairs, count in rel_pairs.items():
    rel_i = pairs[0]
    rel_j = pairs[1]

    if rel_i not in dict_of_maxheap.keys():
        heap = []
        heapq.heappush(heap, (-count, rel_j))
        dict_of_maxheap[rel_i] = heap
    else:
        heapq.heappush(dict_of_maxheap[rel_i], (-count, rel_j))

    if rel_j not in dict_of_maxheap.keys():
        heap = []
        heapq.heappush(heap, (-count, rel_i))
        dict_of_maxheap[rel_j] = heap
    else:
        heapq.heappush(dict_of_maxheap[rel_j], (-count, rel_i))

In [83]:
DEPTH = 5

def map_to_string(triplet):

    head = build_entities[triplet[0]]
    tail = build_entities[triplet[2]]
    rel = relations_map[triplet[1]]
    return (head, rel, tail)

In [109]:
all_entities = []
all_triplet_sets = []

for cui, name in build_entities.items():

    retrievl = [] # R
    for triplet in build_cui_triplets:
        if cui == triplet[0]:
            retrievl.append(triplet)

    #cause Each entity subgraph consists of a maximum of five triples
    #so if the triplet in the whole KG of this entity <= 5, we don't search for rel_pairs
    #just add, and pass
    if len(retrievl) <= 5:
        if len(retrievl) != 0:
            #transfer from int to string of name
            triplet_set = []
            for r in retrievl:
                triplet_set.append(map_to_string(r))
                build_cui_triplets.remove(r)
            all_entities.append(name)
            all_triplet_sets.append(triplet_set)
        continue;

    #if there are more than 5, we need to select those Rel by the rel_pair of the order of count
    while retrievl:
        triplet_set = []

        triplet_random = random.choice(retrievl)
        rel_random = triplet_random[1]
        triplet_set.append(map_to_string(triplet_random))

        retrievl.remove(triplet_random)
        build_cui_triplets.remove(triplet_random)

        for i in range(2, DEPTH):
            maxheap = dict_of_maxheap[rel_random]
            while maxheap:
                max_element = heapq.heappop(maxheap)
                flag = 0
                for triplet in retrievl:
                    if triplet[1] == max_element[1]:
                        rel_random = max_element[1]
                        triplet_set.append(map_to_string(triplet))
                        retrievl.remove(triplet)
                        build_cui_triplets.remove(triplet)
                        flag = 1
                        break;
                if flag == 1:
                    break;

        if len(triplet_set) != 0:
            all_entities.append(name)
            all_triplet_sets.append(triplet_set)

In [111]:
#save result
combine = {'entity':all_entities,
           'subgraph':all_triplet_sets
        }

entity_subgraph = pd.DataFrame(combine)
entity_subgraph.to_csv('data/entity_to_subgraph.csv', index_label=False)

### save triplet memory

In [None]:
with open(path + "/data/build_tri/build_entities.json", "r") as json_file:
    entities = json.load(json_file)

with open(path + "/data/build_tri/build_triplets.json", "r") as json_file:
    triplets = json.load(json_file)

In [None]:
sentences = []
for key, ent in tqdm(entities.items()):
    related_triplet = [x for x in triplets if x[0] == ent]
    if related_triplet:
        sentence = ""
        temp = 0
        for triplet in related_triplet:
            if temp == 0:
                sentence = ' '.join(triplet)
                temp = 1
            else:
                sentence += ', '
                sentence += ' '.join(triplet[1:])
        sentences.append([ent, sentence])

In [None]:
t_df = pd.DataFrame(columns=['entity', 'serialized_sentence'], data=sentences)
t_df.to_csv('data/memories/triplets_memory.csv',index_label=False)