In [None]:
import torch
from torch.nn import Embedding
import tagme
from canlpy.helpers.ernie_helpers import load_name_to_QID,load_QID_to_eid,process_sentences
from canlpy.helpers.cokebert_helpers import load_k_v_queryR
import pickle

from canlpy.core.components.tokenization import BertTokenizer
from canlpy.core.models.cokebert.model import CokeBertForMaskedLM

device = 'cpu'

KNOWLEDGE_DIR = '../canlpy/knowledge/cokebert/kg_embed/'
GRAPH_NEIGHBOR_DIR = '../canlpy/knowledge/cokebert/load_data_n/'
PRE_TRAINED_DIR = '../canlpy/pretrained_models/cokebert'

NAME_TO_QID_FILE = KNOWLEDGE_DIR + 'entity_map.txt'
QID_TO_EID_FILE = KNOWLEDGE_DIR + 'entity2id.txt'
EID_TO_VEC_FILE = KNOWLEDGE_DIR + 'entity2vec.vec'
REL_TO_VEC_FILE = KNOWLEDGE_DIR + 'relation2vec.vec'

ENT_AND_NEIGHBOR_FILE = GRAPH_NEIGHBOR_DIR + 'e1_e2_list_2D_Tensor.pkl'
ENT_AND_RELATION_FILE = GRAPH_NEIGHBOR_DIR + 'e1_r_list_2D_Tensor.pkl'
ENT_AND_INOUT_FILE = GRAPH_NEIGHBOR_DIR + 'e1_outORin_list_2D_Tensor.pkl'

In [None]:
model, _ = CokeBertForMaskedLM.from_pretrained(PRE_TRAINED_DIR)
model.eval()

In [None]:
with open(ENT_AND_NEIGHBOR_FILE, 'rb') as f:
    ent_to_neighbors = pickle.load(f)

with open(ENT_AND_RELATION_FILE, 'rb') as f:
    ent_to_relations = pickle.load(f)

with open(ENT_AND_INOUT_FILE, 'rb') as f:
    ent_to_outORin = pickle.load(f)

In [None]:
#Load pre-trained model tokenizer (vocabulary)
#Special tokenizer for text and entities
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_DIR)

#Eg: 'Northern Ireland': 'Q26'
name_to_QID = load_name_to_QID(NAME_TO_QID_FILE)
#Eg: {'Q11456633': 4525438, 'Q8863973': 1628631}
QID_to_eid = load_QID_to_eid(QID_TO_EID_FILE)

#eid_to_embeddings = torch.load(EID_TO_VEC_FILE)
vecs = []
vecs.append([0]*100)
with open(EID_TO_VEC_FILE, 'r') as fin:
    for line in fin:
        vec = line.strip().split('\t')
        vec = [float(x) for x in vec]
        vecs.append(vec)
embed_ent = torch.FloatTensor(vecs)
ent_embed = torch.nn.Embedding.from_pretrained(embed_ent)
#Creats a dictionnary of entity index->embeddings

vecs = []
vecs.append([0]*100) # CLS
with open(REL_TO_VEC_FILE, 'r') as fin:
#with open("kg_embed/relation2vec.del", 'r') as fin:
    for line in fin:
        vec = line.strip().split('\t')
        vec = [float(x) for x in vec]
        vecs.append(vec)
r_embed = torch.FloatTensor(vecs)

In [None]:
def eval_sentence(text_a,text_b,model,tokenizer,masked_indices):

    tokens_tensor,ents_tensor,ent_mask,segments_tensors = process_sentences(text_a,text_b,masked_indices,name_to_QID,QID_to_eid,ent_embed,tokenizer,device=device)

    # Predict all tokens
    with torch.no_grad():
        #ents_tensor = ents_tensor+1
        k_1, v_1, k_2, v_2 = load_k_v_queryR((ents_tensor + 1).to(torch.long), ent_to_neighbors, ent_to_relations, ent_to_outORin, embed_ent, r_embed, device)
        predictions = model(tokens_tensor, ents_tensor, ent_mask, segments_tensors, k_v_s=[(k_1, v_1), (k_2, v_2)])

        # confirm we were able to predict 'henson'
        for masked_index in masked_indices:
            predicted_index = torch.argmax(predictions[0, masked_index]).item()
            predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
            print(f"predicted_token for index {masked_index} is {predicted_token}")

In [None]:
text_a = "Who was Jim Henson ? "
text_b = "Jim Henson was a puppeteer ."

#tokens_tensor,ents_tensor,ent_mask,segments_tensors = process_sentences(text_a,text_b,masked_indices,name_to_QID,QID_to_eid,tokenizer)
#tokens: ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '.', '[SEP]']
masked_indices = [8,11,12]#henson, puppet, ##eer
eval_sentence(text_a,text_b,model,tokenizer,masked_indices)