<a href="https://colab.research.google.com/github/soulofshadow/KELM_for_UMLS/blob/main/KELM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## library and dataset

In [1]:
from google.colab import drive
drive.mount("/content/drive")

# Here is the path of the root dir of this folder in your google drive
path="/content/drive/My Drive/Colab_Notebooks/KELM"


import os
import sys
os.chdir(path)
sys.path.append(path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import os
import numpy as np
import pandas as pd
import json
import random
import copy
from tqdm import tqdm
tqdm.pandas()

dataset_path = os.path.join(path, "data")

#### Load text datasets

In [3]:
#get the text dataset

'''
Notice the dataset load here are in the shape of
{ 'text': the original sentence
  'entity': a list of entities extracted using MetaMap
}

each entity is a dict with {'cui': the cui identifier of this entity,
                            'name': the name of this entity,
                            'type': the type of this entity,
                            'pos': the position of this entiy in the sentence
                                for multiple apparence in the sentence, here only list one instance,
                            'triger': the substring in the sentence which recognised as this entity
                            }
'''

text_file = 'sentences.csv'
text_path = os.path.join(dataset_path, text_file)

sentences = pd.read_csv(text_path)
sentences = sentences.drop(['Unnamed: 0'], axis = 1)

# sentences = pd.read_csv(text_path, header=None, sep="\t")
# sentences.columns = ['idk','text']
# sentences = sentences.drop(['idk'], axis = 1)

In [4]:
sentences.head()

Unnamed: 0,text,entity
0,The hepatic ultrastructural aspect and the hep...,"[{'cui': 'C0041623', 'name': 'Ultrastructure',..."
1,The group G streptococcus may be a more common...,"[{'cui': 'C0014118', 'name': 'Endocarditis', '..."
2,Effects of a 24 hour fast were studied in 21 o...,"[{'cui': 'C0028754', 'name': 'Obesity', 'type'..."
3,The present study has been undertaken to elici...,"[{'cui': 'C0342895', 'name': 'Fish-Eye Disease..."
4,Selected androgenic and nonandrogenic steroids...,"[{'cui': 'C0002844', 'name': 'Androgens', 'typ..."


In [5]:
len(sentences)

50000

In [6]:
#idk why when I write the csv it turn None to nan, so here I transfer it back
sentences['triplet'] = None
#sentences['triplet'] = sentences['triplets'].where(sentences['triplet'].notnull(), None)
#the entity default by pandas is string, turn to list of dict
sentences['entity'] = sentences['entity'].apply(eval)

#### if not have triplets.csv file, build triplets from umls

In [None]:
#get the umls
#from the umls, MRREL.RRF file we can build the triplets
from utils.load_umls import UMLS
umls = UMLS(dataset_path)

8751471it [02:22, 61570.64it/s]


cui count: 3695485
str2cui count: 13396819
MRCONSO count: 6131827


25369590it [02:33, 165162.42it/s]


rel count: 18702888


4010842it [00:18, 216076.72it/s]

sty count: 3695485





In [None]:
#build the triplets
triplets = []

for rel in umls.rel:
    triplet = rel.strip().split("\t")
    if len(triplet) != 4:
        continue;
    
    sub = triplet[0]
    obj= triplet[1]
    relation = triplet[3]

    tri = [sub, obj, relation]
    triplets.append(tri)

print("triplets count:", len(triplets))

triplets count: 11110308


In [None]:
triplets = pd.DataFrame(triplets)

In [None]:
triplets.columns = ['subject', 'object', 'relation']
triplets.drop_duplicates()

triplets.to_csv('triplets.csv')

#### Load triplets

In [7]:
#get triplets
triplet_file = 'triplets.csv'
triplet_path = os.path.join(dataset_path, triplet_file)

umls_triplets = pd.read_csv(triplet_path)
umls_triplets = umls_triplets.drop(['Unnamed: 0'], axis = 1)

In [8]:
umls_triplets.head()

Unnamed: 0,subject,object,relation
0,C2347441,C4762419,has_ingredient
1,C0022877,C0803531,has_class
2,C0301042,C3160584,has_inactive_ingredient
3,C0027530,C1953956,has_system
4,C0040300,C1507501,has_system


## Align

In [9]:
#transfer from str CUI to int CUI
#like "C1022345" to 1022345

def tran_to_index(x):
    return int(x[1:])

In [10]:
#transfer it to dict of shape
'''
{(subject, object): relation}
'''

umls_triplets['subject'] = umls_triplets['subject'].map(tran_to_index)
umls_triplets['object'] = umls_triplets['object'].map(tran_to_index)

dict_triplets = dict(zip(zip(umls_triplets['subject'], umls_triplets['object']), umls_triplets['relation']))

In [11]:
umls_triplets.head()

Unnamed: 0,subject,object,relation
0,2347441,4762419,has_ingredient
1,22877,803531,has_class
2,301042,3160584,has_inactive_ingredient
3,27530,1953956,has_system
4,40300,1507501,has_system


In [12]:
# Define the relations_map
umls_relations = set(umls_triplets['relation'])
relations_map = {i: relation for i, relation in enumerate(umls_relations)}
inverse_map = {v:k for k,v in relations_map.items()}

print("There are total {} kinds of relationship".format(len(umls_relations)))

There are total 608 kinds of relationship


In [13]:
'''
This step is main for Align text and triplets,
'''

for index, item in tqdm(sentences.iterrows()):

    #column 1 {sentence} we dont need for here
    entities = copy.deepcopy(item['entity']) #column 2, use copy cause we dont want to change the original data
    triplets = item['triplet'] #column 3

    if triplets is not None:
        continue;
    
    #also do str CUI to int CUI for entity of text
    for entity in entities:
        if not isinstance(entity['cui'], int):
            entity['cui'] = tran_to_index(entity['cui'])
    
    #start to align text with tripltes of umls
    aligned_triplets = []
    for entity1 in entities:
        for entity2 in entities:
            # we use a pair of entitis to detect whether they have a relation in the UMLS triplets
            if entity1 != entity2:
                if (entity1['cui'], entity2['cui']) in dict_triplets.keys():
                    #we find one, we use this triplet to label this row
                    #instead of store cui, we stored the string name of this entity
                    rel = dict_triplets[(entity1['cui'], entity2['cui'])]
                    
                    triplet = (entity1['name'], entity2['name'], rel)
                    aligned_triplets.append(triplet)
    #Align
    item['triplet'] = aligned_triplets

50000it [07:11, 115.89it/s]


In [14]:
sentences.head()

Unnamed: 0,text,entity,triplet
0,The hepatic ultrastructural aspect and the hep...,"[{'cui': 'C0041623', 'name': 'Ultrastructure',...","[(Anemia, Hemolytic, Chronic hemolytic anemia,..."
1,The group G streptococcus may be a more common...,"[{'cui': 'C0014118', 'name': 'Endocarditis', '...",[]
2,Effects of a 24 hour fast were studied in 21 o...,"[{'cui': 'C0028754', 'name': 'Obesity', 'type'...","[(Insulin, Insulin [EPC], has_structural_class..."
3,The present study has been undertaken to elici...,"[{'cui': 'C0342895', 'name': 'Fish-Eye Disease...","[(Calcium, Calcium [EPC], parent_of), (Rattus,..."
4,Selected androgenic and nonandrogenic steroids...,"[{'cui': 'C0002844', 'name': 'Androgens', 'typ...","[(Androgens, Etiocholanolone, isa), (Androgens..."


## Extract Subgraph

#### Build rel_pairs

In [15]:
'''
Here is for build rel_pairs as dict
{(rel1, rel2) : count}
'''
rel_pairs = {} 

def count_rel_pairs(triplets):
    #for better match speed, map it to index
    rels = [inverse_map[triplet[2]] for triplet in triplets]

    length = len(rels)
    for i in range(length):
        rel_i = rels[i]
        if i == length - 1:
            break;
        for j in range(i+1, length):
            rel_j = rels[j]
            if ((rel_i, rel_j) in rel_pairs.keys() or (rel_j, rel_i) in rel_pairs.keys()) and (rel_i != rel_j):
                if (rel_i, rel_j) in rel_pairs.keys():
                    rel_pairs[(rel_i, rel_j)] += 1
                else:
                    rel_pairs[(rel_j, rel_i)] += 1
            else:
                rel_pairs[(rel_i, rel_j)] = 1

In [98]:
# it is as dict of {cui(int) : name(string)}
dict_entities = {}

# it is a list of tuples as (subject cui, object cui, relation index)
# this is for store value to extract subgraph more efficiency
sentences_triplets = []

for index, item in tqdm(sentences.iterrows()):

    entities = copy.deepcopy(item['entity']) #column 2
    triplets = item['triplet'] #column 3

    if triplets is None:
        continue
    
    temp_dict = {}
    for entity in entities:
        if not isinstance(entity['cui'], int):
            entity['cui'] = tran_to_index(entity['cui'])

        dict_entities[entity['cui']] = entity['name']
        temp_dict[entity['name']] = entity['cui']

    for triplet in triplets:
        entity1 = temp_dict[triplet[0]]
        entity2 = temp_dict[triplet[1]]
        rel = inverse_map[triplet[2]]

        sentences_triplets.append((entity1, entity2, rel))

    count_rel_pairs(triplets)

50000it [01:17, 641.48it/s]


In [99]:
sentences_triplets = list(set(sentences_triplets))

In [100]:
len(sentences_triplets)

57154

In [101]:
len(dict_entities)

93221

In [102]:
import heapq

#Trun to a dict of 
#key as rel
#value as heapq of tuple (count, rel_2)
# rel_pairs = sorted(rel_pairs.items(), key=lambda x:x[1], reverse=True)

dict_of_maxheap = {}

for pairs, count in rel_pairs.items():
    rel_i = pairs[0]
    rel_j = pairs[1]

    if rel_i not in dict_of_maxheap.keys():
        heap = []
        heapq.heappush(heap, (-count, rel_j))
        dict_of_maxheap[rel_i] = heap
    else:
        heapq.heappush(dict_of_maxheap[rel_i], (-count, rel_j))
    
    if rel_j not in dict_of_maxheap.keys():
        heap = []
        heapq.heappush(heap, (-count, rel_i))
        dict_of_maxheap[rel_j] = heap
    else:
        heapq.heappush(dict_of_maxheap[rel_j], (-count, rel_i))

#### extract subgraph

In [103]:
DEPTH = 5

def map_to_string(triplet):

    sub = dict_entities[triplet[0]]
    obj = dict_entities[triplet[1]]
    rel = relations_map[triplet[2]]
    return (sub, obj, rel)

def get_triplet(rel_j, R):

    for triplet in R:
        if triplet[2] == rel_j:
            return triplet
    return None

In [104]:
all_triplet_sets = []

for cui, name in dict_entities.items():

    retrievl = [] # R
    for triplet in sentences_triplets:
        if cui == triplet[0]:
            retrievl.append(triplet)
    
    #cause Each entity subgraph consists of a maximum of five triples
    #so if the triplet in the whole KG of this entity <= 5, we don't search for rel_pairs
    #just add, and pass
    if len(retrievl) <= 5:
        if len(retrievl) != 0:
            #transfer from int to string of name 
            triplet_set = []
            for r in retrievl:
                triplet_set.append(map_to_string(r))
                sentences_triplets.remove(r)
            all_triplet_sets.append(triplet_set)
        continue;
    
    #if there are more than 5, we need to select those Rel by the rel_pair of the order of count
    while retrievl:
        triplet_set = []

        triplet_random = random.choice(retrievl)
        rel_random = triplet_random[2]
        triplet_set.append(map_to_string(triplet_random))

        retrievl.remove(triplet_random)
        sentences_triplets.remove(triplet_random)

        for i in range(2, DEPTH):
            maxheap = dict_of_maxheap[rel_random]
            while maxheap:
                max_element = heapq.heappop(maxheap)
                flag = 0
                for triplet in retrievl:
                    if triplet[2] == max_element[1]:
                        rel_random = max_element[1]
                        triplet_set.append(map_to_string(triplet))
                        retrievl.remove(triplet)
                        sentences_triplets.remove(triplet)
                        flag = 1
                        break;
                if flag == 1:
                    break;

        if len(triplet_set) != 0:
            all_triplet_sets.append(triplet_set)
        #append (entity1, entity2, relation)


In [105]:
len(all_triplet_sets)

43581

In [106]:
all_triplet_sets[0]

[('Anemia, Hemolytic', 'Hemolytic-Uremic Syndrome', 'use'),
 ('Anemia, Hemolytic', 'Autoimmune hemolytic anemia', 'isa'),
 ('Anemia, Hemolytic', 'Chronic hemolytic anemia', 'isa'),
 ('Anemia, Hemolytic', 'Coombs positive hemolytic anemia', 'isa'),
 ('Anemia, Hemolytic', 'Lupus Erythematosus, Systemic', 'has_manifestation')]

## Save aligned dataset and subgraph dataset

In [108]:
sentences.to_csv(os.path.join(dataset_path, 'aligned_dataset.csv'))

In [121]:
all_triplet_sets = [tuple(x) for x in all_triplet_sets]

pd_all_triplet_sets = pd.DataFrame({'subgraph': all_triplet_sets})

In [123]:
pd_all_triplet_sets.head()

Unnamed: 0,subgraph
0,"((Anemia, Hemolytic, Hemolytic-Uremic Syndrome..."
1,"((Aplastic Anemia, Anemia, related_to), (Aplas..."
2,"((Aplastic Anemia, Aplastic bone marrow, mappe..."
3,"((ITGA2B wt Allele, Homo sapiens, organism_has..."
4,"((Bilirubin, bilirubin glucuronate, mapped_to)..."


In [126]:
pd_all_triplet_sets.to_csv(os.path.join(dataset_path, 'subgraph_dataset.csv'))

## Fine-Tune T5