# Test Message-Passing Transformer

In [61]:
# model init
import torch
from transformers import BertTokenizer, BertConfig

from MPBert_model import MessagePassingBert


# fix random seed for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# model configuration
model_name = 'bert-base-uncased'
num_labels = 1
# TODO pass as the model parameter to init
num_entities = 12605  # size of the output layer, i.e., maximum number of entities in the subgraph that are candidate answers 

tokenizer = BertTokenizer.from_pretrained(model_name)

config = BertConfig.from_pretrained(model_name, num_labels=num_labels)
model = MessagePassingBert(config, num_entities)

In [62]:
# test inference with a sample input, where input is a question and a predicate label along with the list of edges for this predicate
question1 = "Hello, my dog is cute"
adjacencies = [[(0, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1), (12, 1), (13, 1), (14, 1), (15, 1), (16, 1), (17, 1), (18, 1), (19, 1), (20, 1), (21, 1), (22, 1), (23, 1), (24, 1)]]
output = 1

# build input tensors
input_ids = torch.tensor(tokenizer.encode(question1)).unsqueeze(0)  # Batch size 1
graph = torch.tensor(adjacencies).unsqueeze(0)  # Batch size 1
labels = torch.tensor([output]).unsqueeze(0)  # Batch size 1

# run inference
outputs = model(input_ids, graph, labels=labels)
loss, logits = outputs[:2]
print(loss, logits)

tensor(1.5585, grad_fn=<MseLossBackward>) tensor([[-0.2484]], grad_fn=<AddmmBackward>)


In [68]:
# train model
model.train()
outputs = model(input_ids, graph, labels=labels)
loss = outputs[0]
current_loss = loss.item()
print(current_loss)

0.5900720357894897


# Prepare the Dataset

In [70]:
# load dataset
import json
conversations_path = './data/train_set/train_set_ALL.json'

with open(conversations_path, "r") as data:
    conversations = json.load(data)
print("%d conversations loaded"%len(conversations))

# load graph
from hdt import HDTDocument, TripleComponentRole
from settings import *

hdt_file = 'wikidata2018_09_11.hdt'
kg = HDTDocument(hdt_path+hdt_file)
namespace = 'predef-wikidata2018-09-all'
PREFIX_E = 'http://www.wikidata.org/entity/'

# prepare to retrieve all adjacent nodes including literals
predicates_ids = []
kg.configure_hops(1, predicates_ids, namespace, True, False)

# load all predicate labels
from predicates import properties

relationid2label = {}
for p in properties['results']['bindings']:
    _id = p['property']['value'].split('/')[-1]
    label = p['propertyLabel']['value']
    relationid2label[_id] = label

# print(relationid2label)

6720 conversations loaded


In [81]:
from collections import Counter, defaultdict


def lookup_predicate_labels(predicate_ids):
    p_labels_map = defaultdict(list)
    for p_id in predicate_ids:
        p_uri = kg.global_id_to_string(p_id, TripleComponentRole.PREDICATE)
        label = p_uri.split('/')[-1]
        if label in relationid2label:
            label = relationid2label[label]
        else:
            label = label.split('#')[-1]
        p_labels_map[label].append(p_id)
    return p_labels_map


answers_in_subgraph = Counter()

def check_answer_in_subgraph(conversation, subgraph):
    answer1 = conversation['questions'][0]['answer']
    # consider only answers which are entities
    if ('www.wikidata.org' in answer1):
        answer1_id = kg.string_to_global_id(PREFIX_E+answer1.split('/')[-1], TripleComponentRole.OBJECT)
        in_subgraph = answer1_id in entity_ids
        answers_in_subgraph.update([in_subgraph])
        # consider only answer entities that are in the subgraph
        if in_subgraph:
            answer1_idx = entity_ids.index(answer1_id)
            return answer1_idx
        

# TODO tokenize and concatenate two text inputs with BERT tokenizer
def encode_text_pair(text_a, text_b, tokenizer):
    print(text_a)
    tokens_a = tokenizer.tokenize(text_a)
    tokens_b = tokenizer.tokenize(text_b)
    tokens = [tokenizer.cls_token] + tokens_a + [tokenizer.sep_token] + tokens_b + [tokenizer.sep_token]
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    return input_ids


max_triples = 50000000
offset = 0

# collect only samples where the answer is entity and it is adjacent to the seed entity
dataset = []
labels = []
for conversation in conversations:
    question1 = conversation['questions'][0]['question']
    # use oracle for the correct initial entity
    seed_entity = conversation['seed_entity'].split('/')[-1]
    seed_entity_id = kg.string_to_global_id(PREFIX_E+seed_entity, TripleComponentRole.OBJECT)
    
    # retrieve all adjacent nodes including literals
    subgraph1 = kg.compute_hops([seed_entity_id], max_triples, offset)
    entity_ids, predicate_ids, adjacencies = subgraph1
    assert len(predicate_ids) == len(adjacencies)

    # check that the answer is in the subgraph
    answer1_idx = check_answer_in_subgraph(conversation, entity_ids)
    if answer1_idx:
        # get labels for all candidate predicates
        p_labels_map = lookup_predicate_labels(predicate_ids)
        
        # TODO pad the graph to the maximum size in the number of entities
        
        # create one-hot vector for the correct answer to the input question
        correct_answer_vector = [0] * num_entities
        correct_answer_vector[answer1_idx] = 1
        
        # create a sample for each predicate label separately
        for p_label, p_ids in p_labels_map.items():
            data = [p_label, p_ids]
            # get adjacencies only for the predicates sharing the same label
            selected_adjacencies = []
            for p_id in p_ids:
                p_id_idx = predicate_ids.index(p_id)
                selected_adjacencies.append(adjacencies[p_id_idx])
                # concatenate question with a predicate label
                question_predicate = encode_text_pair(question1, p_label, BertTokenizer)
                dataset.append([question_predicate, subgraph1])
                labels.append(correct_answer_vector)


print(answers_in_subgraph)
print("Compiled dataset with %d samples"%len(dataset))

Which author wrote the novel 1Q84?


TypeError: tokenize() missing 1 required positional argument: 'text'