In [None]:
#Run this cell once to install pykg2vec
!git clone https://github.com/Sujit-O/pykg2vec.git
%cd pykg2vec/
!python setup.py install

In [None]:
#Generate tuples (Question, Token, SemType)
%cd ~/Desktop/covid-project

from tqdm.notebook import tqdm
import pandas as pd
s = pd.read_json('sample.json')

documents = s.iloc[0][0]['Document']['Utterances']
Metamap_Tokenizations = [] #To be used for final QA

def retrieve_tokens(SyntaxUnits):
    tokens = []
    for i in range(len(SyntaxUnits)):
        tokens.append(SyntaxUnits[i]['InputMatch'])
    return tokens

def retrieve_mappings(Mappings):
    mapped_semantic_types = []
    #No mappings found
    if len(Mappings) == 0:
        return [] #These words will get their embeddings from BERT
    else:
        candidates = Mappings[0]['MappingCandidates'] #Choosing Only top mappings
        for cnd in candidates:
            mapped_semantic_types.append([' '.join(cnd['MatchedWords']),cnd['SemTypes'][0]])
    return mapped_semantic_types
    
for doc in tqdm(documents):
    Phrases = doc['Phrases']
    Phrase_Tokenizations = []
    Mappings = []
    for ph in Phrases:
        Phrase_Tokenizations.append(retrieve_tokens(ph['SyntaxUnits']))
        Mappings.append(retrieve_mappings(ph['Mappings']))
    #Flattening the Lists
    Phrase_Tokenizations = [item for sublist in Phrase_Tokenizations for item in sublist]
    Mappings = [item for sublist in Mappings for item in sublist]
    #Creating the final list
    Metamap_Tokenizations.append((doc['UttText'], Phrase_Tokenizations, Mappings))

In [None]:
#Replacing each shorthand mapping with KG concept
import mysql.connector

mydb = mysql.connector.connect(host="localhost", user="root", password="Saptarshi123!", database="umls")
mycursor = mydb.cursor()

for i in range(len(Metamap_Tokenizations)):
    for j in range(len(Metamap_Tokenizations[i][2])):
        mycursor.execute("select STY_RL from SRDEF where ABR = '%s' " % Metamap_Tokenizations[i][2][j][1])
        Metamap_Tokenizations[i][2][j][1] = mycursor.fetchall()[0][0]

mycursor.close()

In [None]:
#Generating the KG triples (KGT)
from itertools import permutations

All_Mappings = [y for x in Metamap_Tokenizations for y in x[2]]
All_Concept_Pairs = permutations(All_Mappings, 2)

mydb = mysql.connector.connect(host="localhost", user="root", password="Saptarshi123!", database="umls")
mycursor = mydb.cursor()

KGT = set()

for term_pair in tqdm(All_Concept_Pairs):
    semantic_type1 = term_pair[0][1]
    semantic_type2 = term_pair[1][1]
    mycursor.execute("select RL from SRSTR where STY_RL1 = '%s' and STY_RL2 = '%s' " % (semantic_type1, semantic_type2))
    relation = mycursor.fetchall()
    if relation != []:
        KGT.add((semantic_type1, relation[0][0], semantic_type2))

mycursor.close()

In [None]:
#Creating Train/Validation/Test splits for training KGE's
import pandas as pd
import numpy as np
import os

#Converting set to pandas dataframe for easily split 
KGT = pd.DataFrame(KGT)

#Giving the KGT dataframe meaningful column names
KGT.rename(columns={0: "E1", 1: "Rel", 2: "E2"}, inplace=True)

#80/10/10 split
train, validation, test = np.split(KGT.sample(frac=1, random_state=42), [int(.8*len(KGT)), int(.9*len(KGT))])

#Saving datasets as .txt files
dataset_path = os.path.abspath('UMLS_KG')
np.savetxt(os.path.join(dataset_path, 'UMLS_KG-train.txt'), train.values, delimiter="\t", fmt="%s")
np.savetxt(os.path.join(dataset_path, 'UMLS_KG-valid.txt'), validation.values, delimiter="\t", fmt="%s")
np.savetxt(os.path.join(dataset_path, 'UMLS_KG-test.txt'), test.values, delimiter="\t", fmt="%s")

In [None]:
#Run this cell to execute pykg2vec programs
%cd ~/Desktop/covid-project/pykg2vec/scripts

In [None]:
#Tune KGE model
!python pykg2vec_tune.py -mn DistMult -ds UMLS_KG -dsp ~/Desktop/covid-project/UMLS_KG \
-hpf ~/Desktop/covid-project/UMLS_KG/hyperparams.yaml

In [None]:
#Train KGE
!python pykg2vec_train.py -mn DistMult -ds UMLS_KG -dsp ~/Desktop/covid-project/UMLS_KG \
-lr 0.01 -l1 True -k 768 -b 128 -l 1000 -mg 1.00 -opt "sgd" -s "bern" -ngr 1

In [None]:
#Converting KGE to BERT embeddings (Domain Term Encoding (DTE)) - part1 (generating associated triples)
#[Entity Expansion]
%cd ~/Desktop/covid-project/UMLS_KG/

import numpy as np
import pandas as pd
import pickle

#Mapping b/w entity and corresponding ID
with open('entity2idx.pkl', 'rb') as f:
    entity2id = pickle.load(f)

#Mapping b/w relation and corresponding ID
with open('relation2idx.pkl', 'rb') as f:
    relation2id = pickle.load(f)

def triple_gen(current_entity):
    results = KGT.query("E1==@current_entity")
    connected_entities = results.E2.to_list()
    outgoing_relations = results.Rel.to_list()
    a = []
    for i in range(len(results)):
        a.append([current_entity, outgoing_relations[i], connected_entities[i]])
    a = [y for x in a for y in x]
    return a

triple_list = []
for entity in tqdm(entity2id.keys()):
    triple_list.append(triple_gen(entity))

In [None]:
#Converting KGE to BERT embeddings (Domain Term Encoding (DTE)) - part2 (each KG item -> (KG item, KGE))
#KGE located here
%cd ~/Desktop/covid-project/UMLS_KG/embeddings/distmult

ent_embeddings = pd.read_csv('ent_embedding.tsv', sep='\t', header=None)
rel_embeddings = pd.read_csv('rel_embedding.tsv', sep='\t', header=None)

'''
Associating each item in the triple list with respective embeddings. 
This is done to create an easy Domain Term BERT embedding matrix.
'''
for TL in tqdm(triple_list):
    if TL == []:
        continue
    i = 0
    for index, item in enumerate(TL):
        if (i%3 == 0) or (i%3 == 2): #This item is an entity
            TL[index] = (item, ent_embeddings.iloc[entity2id[item]].to_numpy())
        else: #This item is a relation
            TL[index] = (item, rel_embeddings.iloc[relation2id[item]].to_numpy())
        i += 1

In [None]:
#Converting KGE to BERT embeddings (Domain Term Encoding (DTE)) - part3 (Passing KGE's through BERT) 
#[Creating DTE Lookup Table]
from transformers import BertModel
import torch

model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

matrix = model.get_input_embeddings() #BERT embeddings

CLS_embedding = matrix(torch.LongTensor([101]))
SEP_embedding = matrix(torch.LongTensor([102]))

DTE_BERT_Matrix = {}

with torch.no_grad():
    for seq in tqdm(triple_list):
        if seq == []: #There is no expansion of the entity
            continue
        
        outputs = model(inputs_embeds = torch.unsqueeze(\
                                        torch.cat(\
                                        (CLS_embedding,\
                                         torch.FloatTensor([x[1] for x in seq]), SEP_embedding)), dim=1))
        
        #Collecting all the embeddings for the current domain term in e[]
        e = []
        
        '''
        Starting at 1 & ending at (len -1) to a/c for [CLS] & [SEP].
        Step size is 3 since the required entity occurs in spaces of 3, according to the expansion scheme.
        '''
        for i in range(1, (len(seq) - 1), 3): 
            e.append(outputs[0][i])
        
        '''
        The BERT embedding for each entity will be the average of all its occurrences.
        *e provides all the elements of e (unpacking).
        '''
        DTE_BERT_Matrix[seq[0][0]] = torch.mean(torch.stack([*e], dim=0), dim=0)

'''
Saving DTE_BERT embeddings to a lookup table (dataframe) & clearing DTE_BERT_Matrix.
Dataframes allow quicker lookups
'''
s = pd.DataFrame(list(DTE_BERT_Matrix.items()),columns = ['Term','Embedding'])
s.to_csv('DTE_BERT_Matrix.csv')
DTE_BERT_Matrix.clear()

In [None]:
#Extracting embeddings for non-domain terms. I'm simply using BERT's tokenizer for the nDT's.
#Creating question representations in this block.
from transformers import BertTokenizer, BertModel
import torch
import re

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
question_representations = []
all_available_entity_embeddings = s['Term'].to_list()

for tup in Metamap_Tokenizations:
    metamap_tokenized_question = tup[1]
    
    #Removing punctuations/spaces from domain-terms for easy comparison
    domain_terms = [re.sub(r'[\W\s]','',x[0]).lower() for x in tup[2]] 

    '''
    Note: is_split_into_words is not the same as pre-tokenized. BERT uses subwords tokenization.
    Thus, when the above is set to True, it simply tells the tokenizer to run BERT's scheme on the resp. words.
    '''
    encoded_input = tokenizer(metamap_tokenized_question, is_split_into_words=True, return_tensors="pt")
    
    model.eval()
    with torch.no_grad():
        outputs = model(**encoded_input)
    
    question_embeddings = []
    start_index = 1
    for word in metamap_tokenized_question:
        filtered_word = re.sub(r'\W','',word).lower()
        number_of_subwords = len(tokenizer(word)['input_ids']) - 2 #1 for CLS & 1 for SEP
        end_index = start_index + number_of_subwords
        
        '''
        This means that the filtered_word has to be a domain term which also has a KG expansion. If if does not,
        then simply use its BERT embeddings.
        '''
        if filtered_word in domain_terms: #Use DTE_BERT_Matrix
            mapped_concept = tup[2][domain_terms.index(filtered_word)][1]
            if mapped_concept in all_available_entity_embeddings:
                question_embeddings.append(s.query("Term==@mapped_concept")['Embedding'].values[0])
            else: #The DT doesn't have an expansion in the KG & so its BERT embeddings are used.
                question_embeddings.append(outputs.last_hidden_state[0][start_index:end_index])
        else: #Use Regular BERT subword embeddings
            question_embeddings.append(outputs.last_hidden_state[0][start_index:end_index])

        start_index = end_index
        
    #In this way, I don't have to add the CLS & SEP embeddings during fine-tuning.
    final_representation = torch.unsqueeze(torch.cat((CLS_embedding,\
                                                      torch.cat([*question_embeddings]),\
                                                      SEP_embedding)), dim=1)
    
    question_representations.append(final_representation)

#Saving the question vectors to disk
with open('question_representation.data', 'wb') as filehandle:
    pickle.dump(question_representations, filehandle)