In [1]:
#Extracting embeddings for non-domain terms. I'm simply using BERT's tokenizer for the nDT's.
#Creating question representations in this block.
from transformers import AutoTokenizer, AutoModel
import torch
import re
import os
import pandas as pd

DTE_BERT_Lookup_Table = pd.read_pickle(os.path.join(os.path.abspath('UMLS_KG'), 'embeddings/distmult/DTE_to_BERT.pkl'))

Metamap_Tokenizations = pd.read_pickle('Metamap_Tokenizations.pkl')

model_name = 'bert-base-uncased'

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

BERT_embeddings = model.get_input_embeddings()

CLS_embedding = BERT_embeddings(torch.LongTensor([101]))
SEP_embedding = BERT_embeddings(torch.LongTensor([102]))

all_entities = DTE_BERT_Lookup_Table['Term'].to_list()

def custom_question_rep_gen(ques):
    
    def clean_term(word):
        return re.sub(r'[\W\s]', '', word).lower()

    tup = Metamap_Tokenizations.query("Question==@ques")

    metamap_tokenized_question = tup['Tokenization'][0]

    #Removing punctuations/spaces from domain-terms for easy comparison
    mappings = tup['Mappings'][0]
    for i,x in enumerate(mappings):
        mappings[i][0] = clean_term(x[0])

    domain_terms = [x[0] for x in mappings]

    question_embeddings = []
    for word in metamap_tokenized_question:
        '''
        This is done to easily check if the current word is a DT or not since DT form of the same word 
        are obtained bit differently.
        '''
        filtered_word = clean_term(word)

        '''
        This means that the filtered_word has to be a domain term which also has a KG expansion. If if does not,
        then use its BERT embeddings.
        '''

        if filtered_word in domain_terms: #Use DTE_BERT_Matrix
            mapped_concept = mappings[domain_terms.index(filtered_word)][1]
            if mapped_concept in all_entities: 
                question_embeddings.append(DTE_BERT_Lookup_Table.query("Term==@mapped_concept")['Embedding'].values[0])
            
        #The mapped_concept doesn't have an expansion in the KG or the term isn't a DT. Thus, its BERT embeddings are used.
        else:
            subword_indices = tokenizer(word)['input_ids'][1:-1] #Take all tokens between [CLS] & [SEP]
            for index in subword_indices:
                question_embeddings.append(BERT_embeddings(torch.LongTensor([index])))

    #In this way, I don't have to add the CLS & SEP embeddings during fine-tuning.
    final_representation = torch.unsqueeze(torch.cat((CLS_embedding,\
                                                      torch.cat([*question_embeddings]),\
                                                      SEP_embedding)), dim=1)

    return final_representation

custom_question_embeddings = custom_question_rep_gen('What is the main cause of HIV-1 infection in children? ')