In [None]:
#Converting KGE to BERT embeddings (Domain Term Encoding (DTE)) - part3 (Passing KGE's through BERT) 
#[Creating DTE Lookup Table]

%cd ~/Desktop/CDQA-project/UMLS_KG/embeddings/distmult
from transformers import AutoModel
import torch
import os
import pickle

def Create_DTE_BERT_LookUp_Table(model_name):
    
    #Loading triple_list
    with open('expanded_entities.pkl', 'rb') as f:
        triple_list = pickle.load(f)

    model = AutoModel.from_pretrained(model_name)
    model.eval()

    matrix = model.get_input_embeddings() #BERT embeddings

    CLS_embedding = matrix(torch.LongTensor([101]))
    SEP_embedding = matrix(torch.LongTensor([102]))

    DTE_BERT_Matrix = {}

    with torch.no_grad():
        for seq in tqdm(triple_list):
            if seq == []: #There is no expansion of the entity
                continue

            outputs = model(inputs_embeds = torch.unsqueeze(\
                                            torch.cat(\
                                            (CLS_embedding,\
                                             torch.FloatTensor([x[1] for x in seq]), SEP_embedding)), dim=1))

            #Collecting all the embeddings for the current domain term in e[]
            e = []

            '''
            Starting at 1 & ending at (len -1) to a/c for [CLS] & [SEP].
            Step size is 3 since the required entity occurs in spaces of 3, according to the expansion scheme.
            '''
            for i in range(1, (len(seq) - 1), 3): 
                e.append(outputs[0][i])

            '''
            The BERT embedding for each entity will be the average of all its occurrences.
            *e provides all the elements of e (unpacking).
            '''
            DTE_BERT_Matrix[seq[0][0]] = torch.mean(torch.stack([*e], dim=0), dim=0)

    DTE_BERT_Lookup_Table = pd.DataFrame(list(DTE_BERT_Matrix.items()),columns = ['Term','Embedding'])
    DTE_BERT_Matrix.clear()
    return DTE_BERT_Lookup_Table

DTE_BERT_Lookup_Table = Create_DTE_BERT_LookUp_Table('bert-base-uncased')

In [None]:
#Extracting embeddings for non-domain terms. I'm simply using BERT's tokenizer for the nDT's.
#Creating question representations in this block.
from transformers import AutoTokenizer, AutoModel
import torch
import re

def custom_question_rep_gen():
    

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    question_representations = []
    all_entities = s['Term'].to_list()

    for tup in Metamap_Tokenizations:
        metamap_tokenized_question = tup[1]

        #Removing punctuations/spaces from domain-terms for easy comparison
        domain_terms = [re.sub(r'[\W\s]','',x[0]).lower() for x in tup[2]] 

        '''
        Note: is_split_into_words is not the same as pre-tokenized. BERT uses subwords tokenization.
        Thus, when the above is set to True, it simply tells the tokenizer to run BERT's scheme on the resp. words.
        '''
        encoded_input = tokenizer(metamap_tokenized_question, is_split_into_words=True, return_tensors="pt")

        model.eval()
        with torch.no_grad():
            outputs = model(**encoded_input)

        question_embeddings = []
        start_index = 1
        for word in metamap_tokenized_question:
            filtered_word = re.sub(r'\W','',word).lower()
            number_of_subwords = len(tokenizer(word)['input_ids']) - 2 #1 for CLS & 1 for SEP
            end_index = start_index + number_of_subwords

            '''
            This means that the filtered_word has to be a domain term which also has a KG expansion. If if does not,
            then simply use its BERT embeddings.
            '''
            if filtered_word in domain_terms: #Use DTE_BERT_Matrix
                mapped_concept = tup[2][domain_terms.index(filtered_word)][1]
                if mapped_concept in all_entities:
                    question_embeddings.append(s.query("Term==@mapped_concept")['Embedding'].values[0])
                else: #The DT doesn't have an expansion in the KG & so its BERT embeddings are used.
                    question_embeddings.append(outputs.last_hidden_state[0][start_index:end_index])
            else: #Use Regular BERT subword embeddings
                question_embeddings.append(outputs.last_hidden_state[0][start_index:end_index])

            start_index = end_index

        #In this way, I don't have to add the CLS & SEP embeddings during fine-tuning.
        final_representation = torch.unsqueeze(torch.cat((CLS_embedding,\
                                                          torch.cat([*question_embeddings]),\
                                                          SEP_embedding)), dim=1)

        question_representations.append(final_representation)