In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.nn import Embedding
import tagme
from canlpy.helpers.ernie_helpers import load_name_to_QID,load_QID_to_eid,process_sentences

from canlpy.core.util.tokenization import BertTokenizer
from canlpy.core.models.ernie.model import ErnieForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
#import logging
#logging.basicConfig(level=logging.INFO)

device = 'cpu'

KNOWLEDGE_DIR = '../canlpy/knowledge/ernie/'
PRE_TRAINED_DIR = '../canlpy/pretrained_models/ernie/ernie_base/'

NAME_TO_QID_FILE = KNOWLEDGE_DIR+ 'entity_map.txt'
QID_TO_EID_FILE = KNOWLEDGE_DIR+ 'entity2id.txt'
EID_TO_VEC_FILE = PRE_TRAINED_DIR + 'entity2vec.pt'

In [3]:
# Load pre-trained model (weights)
model,_ = ErnieForMaskedLM.from_pretrained(PRE_TRAINED_DIR)
model.eval()
pass

In [5]:
#Suppose to predict hensen for idx 8
def eval_sentence(text_a,text_b,model,tokenizer,masked_indices):

    tokens_tensor,ents_tensor,ent_mask,segments_tensors = process_sentences(text_a,text_b,masked_indices,name_to_QID,QID_to_eid,eid_to_embeddings,tokenizer,device=device)

    # Predict all tokens
    with torch.no_grad():
        predictions = model(tokens_tensor, ents_tensor, ent_mask, segments_tensors)

        # confirm we were able to predict 'henson'
        for masked_index in masked_indices:
            predicted_index = torch.argmax(predictions[0, masked_index]).item()
            predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
            print(f"predicted_token for index {masked_index} is {predicted_token}")

#Load pre-trained model tokenizer (vocabulary)
#Special tokenizer for text and entities
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_DIR)

#Eg: 'Northern Ireland': 'Q26'
name_to_QID = load_name_to_QID(NAME_TO_QID_FILE)
#Eg: {'Q11456633': 4525438, 'Q8863973': 1628631}
QID_to_eid = load_QID_to_eid(QID_TO_EID_FILE)

eid_to_embeddings = torch.load(EID_TO_VEC_FILE)
#Creats a dictionnary of entity index->embeddings
eid_to_embeddings = Embedding.from_pretrained(eid_to_embeddings)

text_a = "Who was Jim Henson ? "
text_b = "Jim Henson was a puppeteer ."

#tokens_tensor,ents_tensor,ent_mask,segments_tensors = process_sentences(text_a,text_b,masked_indices,name_to_QID,QID_to_eid,tokenizer)
#tokens: ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '.', '[SEP]']
masked_indices = [8,11,12]#henson, puppet, ##eer
eval_sentence(text_a,text_b,model,tokenizer,masked_indices)

[101, 2040, 2001, 3958, 27227, 1029, 102, 3958, 103, 2001, 1037, 103, 103, 1012, 102]
predicted_token for index 8 is henson
predicted_token for index 11 is popular
predicted_token for index 12 is actor
