# Test Message-Passing Transformer

In [1]:
# model init
import torch
from transformers import BertTokenizer, BertConfig

from MPBert_model import MessagePassingBert
from utils import adj

# fix random seed for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# model configuration
model_name = 'bert-base-uncased'
# 2 x 57942 in dev set need to be trimmed or sampled
num_entities = 12605  # size of the output layer, i.e., maximum number of entities in the subgraph that are candidate answers
tokenizer = BertTokenizer.from_pretrained(model_name)
# regression task for matching predicate/entity label to the input question
config = BertConfig.from_pretrained(model_name, num_labels=1)

model = MessagePassingBert(config, num_entities, mp_layer=True)
# run model on the GPU
# model.cuda()

In [2]:
# test inference with a sample input, where input is a question and a predicate label along with the list of edges for this predicate
question1 = "When were Beatles founded?"
output = 1

# build input tensors
input_ids = torch.tensor([tokenizer.encode(question1)] * num_relations)  # Batch size num_relations
labels = torch.tensor([output]).unsqueeze(0)  # Batch size 1
indices, relation_mask = adj(adjacencies, num_entities, num_relations)
entities = torch.zeros(num_entities, 1)
entities[[0, 3]] = 1
print(relation_mask.shape)
print(input_ids.shape)

# run inference
outputs = model(input_ids, [indices, relation_mask, entities], labels=labels)
loss, logits = outputs[:2]
print(loss, logits)

torch.Size([52])
torch.Size([2, 7])
tensor(9.1537, grad_fn=<NllLossBackward>) tensor([ 0.0000,  0.2882, -0.1879,  ...,  0.0000,  0.0000,  0.0000],
       grad_fn=<SumBackward1>)


In [3]:
# train model
model.train()
outputs = model(input_ids, [indices, relation_mask, entities], labels=labels)
loss = outputs[0]
current_loss = loss.item()
print(current_loss)

8.525114059448242


# Prepare the Dataset

In [4]:
# load graph
from hdt import HDTDocument, TripleComponentRole
from settings import *

hdt_file = 'wikidata2018_09_11.hdt'
kg = HDTDocument(hdt_path+hdt_file)
namespace = 'predef-wikidata2018-09-all'
PREFIX_E = 'http://www.wikidata.org/entity/'

# prepare to retrieve all adjacent nodes including literals
predicates_ids = []
kg.configure_hops(1, predicates_ids, namespace, True, False)

# load all predicate labels
from predicates import properties

relationid2label = {}
for p in properties['results']['bindings']:
    _id = p['property']['value'].split('/')[-1]
    label = p['propertyLabel']['value']
    relationid2label[_id] = label

# print(relationid2label)

In [5]:
# load dataset
import json
from collections import Counter, defaultdict

train_conversations_path = './data/train_set/train_set_ALL.json'
dev_conversations_path = './data/dev_set/dev_set_ALL.json'


def lookup_predicate_labels(predicate_ids):
    p_labels_map = defaultdict(list)
    for p_id in predicate_ids:
        p_uri = kg.global_id_to_string(p_id, TripleComponentRole.PREDICATE)
        label = p_uri.split('/')[-1]
        if label in relationid2label:
            label = relationid2label[label]
        else:
            label = label.split('#')[-1]
        p_labels_map[label].append(p_id)
    return p_labels_map


def check_answer_in_subgraph(conversation, entity_ids):
    answer1 = conversation['questions'][0]['answer']
    # consider only answers which are entities
    if ('www.wikidata.org' in answer1):
        answer1_id = kg.string_to_global_id(PREFIX_E+answer1.split('/')[-1], TripleComponentRole.OBJECT)
        in_subgraph = answer1_id in entity_ids
        # consider only answer entities that are in the subgraph
        if in_subgraph:
            answer1_idx = entity_ids.index(answer1_id)
            return answer1_idx


def prepare_dataset(conversations_path, n_limit=100):
    with open(conversations_path, "r") as data:
        conversations = json.load(data)
    print("%d conversations loaded"%len(conversations))
    
    max_triples = 50000000
    offset = 0

    # collect only samples where the answer is entity and it is adjacent to the seed entity
    train_dataset = []

    graph_sizes = []
    max_n_edges = 2409 # max size of the graph allowed in the number of edges
    if n_limit:
        conversations = conversations[:n_limit]
    for conversation in conversations:
        question1 = conversation['questions'][0]['question']
        # use oracle for the correct initial entity
        seed_entity = conversation['seed_entity'].split('/')[-1]
        seed_entity_id = kg.string_to_global_id(PREFIX_E+seed_entity, TripleComponentRole.OBJECT)

        # retrieve all adjacent nodes including literals
        subgraph = kg.compute_hops([seed_entity_id], max_triples, offset)
        entity_ids, predicate_ids, adjacencies = subgraph
        
        if not len(entity_ids) <= num_entities:
#             print(len(entity_ids))
            continue  # skip samples with large subgraphs
        assert len(predicate_ids) == len(adjacencies)
#         print("conversation")
        # check that the answer is in the subgraph
        answer1_idx = check_answer_in_subgraph(conversation, entity_ids)
        if answer1_idx:
            # activate seed entity
            entities = torch.zeros(num_entities, 1)
            entities[[entity_ids.index(seed_entity_id)]] = 1
            
            # get labels for all candidate predicates
            p_labels_map = lookup_predicate_labels(predicate_ids)

            # create a batch of samples for each predicate label separately
            input_ids = []
            attention_masks = []
            token_type_ids = []
            A = []

            for p_label, p_ids in p_labels_map.items():

                # encode a text pair of the question with a predicate label
                encoded_dict = tokenizer.encode_plus(question1, p_label,
                                                     add_special_tokens=True,
                                                     max_length=64,
                                                     pad_to_max_length=True,
                                                     return_attention_mask=True,
                                                     return_token_type_ids=True)
                input_ids.append(encoded_dict['input_ids'])
                token_type_ids.append(encoded_dict['token_type_ids'])
                attention_masks.append(encoded_dict['attention_mask'])

                # get adjacencies only for the predicates sharing the same label
                selected_adjacencies = []
                for p_id in p_ids:
                    p_id_idx = predicate_ids.index(p_id)
                    # add all edges together
                    for edge in adjacencies[p_id_idx]:
                        if edge not in selected_adjacencies:
                            selected_adjacencies.append(edge)
                A.append(selected_adjacencies)

            # create a single graph per example for all predicates
            indices, relation_mask = adj(A, num_entities, num_relations)

            train_dataset.append([torch.tensor(input_ids),
                                  torch.tensor(token_type_ids),
                                  torch.tensor(attention_masks),
                                  [indices, relation_mask, entities],
                                  torch.tensor([answer1_idx])])

    print("Compiled dataset with %d samples" % len(train_dataset))
    return train_dataset

train_dataset = prepare_dataset(train_conversations_path)
valid_dataset = prepare_dataset(dev_conversations_path)

6720 conversations loaded
Compiled dataset with 70 samples
2240 conversations loaded
Compiled dataset with 67 samples


In [6]:
# training setup
from transformers import get_linear_schedule_with_warmup, AdamW

epochs = 4
total_steps = len(train_dataset) * epochs

optimizer = AdamW(model.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8
                 )
# learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)


In [7]:
# train model
import random
import numpy as np

# use CPU to train the model
device = torch.device("cpu")

print("%d training examples"%(len(train_dataset)))
print("%d validation examples"%(len(valid_dataset)))

for epoch_i in range(0, epochs):
    
    # ========================================
    #               Training
    # ========================================
    
    # Perform one full pass over the training set.

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')
    
    # reset the total loss for this epoch
    total_train_loss = 0
    
    # put the model into training mode
    model.train()
    
    # for each sample of training data input as a batch of size 1
    for step, batch in enumerate(train_dataset):
        b_input_ids = batch[0].to(device)
        b_token_mask = batch[1].to(device)
        b_input_mask = batch[2].to(device)
        b_graphs = [tensor.to(device) for tensor in batch[3]]
        b_labels = batch[4].to(device)
#         print(b_input_ids.shape)
#         print(b_labels.shape)
        model.zero_grad()
        # forward pass
        loss, logits = model(b_input_ids,
                             b_graphs,
                             token_type_ids=b_token_mask,
                             attention_mask=b_input_mask,
                             labels=b_labels)
#         print(loss.item())
        # accumulate the training loss over all of the batches
        total_train_loss += loss.item()

        # backward pass
        loss.backward()
        
        # clip gradient to prevent exploding
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # update parameters
        optimizer.step()
        scheduler.step()
    
    # training epoch is over here
    
    # calculate average loss over all the batches
    avg_train_loss = total_train_loss / len(train_dataset) 
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    
    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set.

    print("")
    print("Running Validation...")
    
    # put the model in evaluation mode
    model.eval()
    
    total_eval_loss = 0
        
    # evaluate data for one epoch
    for step, batch in enumerate(valid_dataset):
        
        b_input_ids = batch[0].to(device)
        b_token_mask = batch[1].to(device)
        b_input_mask = batch[2].to(device)
        b_graphs = [tensor.to(device) for tensor in batch[3]]
        b_labels = batch[4].to(device)
        
        with torch.no_grad():
            # forward pass
            loss, logits = model(b_input_ids,
                                 b_graphs,
                                 token_type_ids=b_token_mask,
                                 attention_mask=b_input_mask,
                                 labels=b_labels)
#             print(loss.item())
            # accumulate validation loss
            total_eval_loss += loss.item()
    
    avg_val_loss = total_eval_loss / len(valid_dataset)
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))


70 training examples
67 validation examples

Training...
  Average training loss: 8.31

Running Validation...
  Validation Loss: 8.95

Training...
  Average training loss: 7.98

Running Validation...
  Validation Loss: 7.74

Training...
  Average training loss: 6.56

Running Validation...
  Validation Loss: 6.13

Training...
  Average training loss: 5.40

Running Validation...
  Validation Loss: 6.37


In [36]:
# save and load model
import os

output_dir = './saved_models/1_1stquestion'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
print("Saving model to %s" % output_dir)

model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# torch.save(args, os.path.join(output_dir, 'training_args.bin'))


# Load a trained model and vocabulary that you have fine-tuned
# model = MessagePassingBert.from_pretrained(output_dir)
# tokenizer = BertTokenizer.from_pretrained(output_dir)

# Copy the model to the GPU.
# model.to(device)

Saving model to ./saved_models/1_1stquestion


('./saved_models/1_1stquestion/vocab.txt',
 './saved_models/1_1stquestion/special_tokens_map.json',
 './saved_models/1_1stquestion/added_tokens.json')

In [35]:
# run inference and evaluate performance on train and dev sets with the target QA metrics

def run_inference(model, dataset):
    # put model in evaluation mode
    model.eval()
    
    # TODO add MRR
    p1s = []  # measure accuracy of the top answer: P@1
    for batch in dataset:
        b_input_ids = batch[0].to(device)
        b_token_mask = batch[1].to(device)
        b_input_mask = batch[2].to(device)
        b_graphs = [tensor.to(device) for tensor in batch[3]]
        b_labels = batch[4].to(device)
        
        with torch.no_grad():
            # forward pass
            loss, logits = model(b_input_ids,
                                 b_graphs,
                                 token_type_ids=b_token_mask,
                                 attention_mask=b_input_mask,
                                 labels=b_labels)
            predicted_label = np.argmax(logits.numpy()).flatten()[0]
#             print(predicted_label)
            true_label = b_labels.numpy()[0]
#             print(true_label)
            p1 = int(predicted_label == true_label)
#             print(p1)
            p1s.append(p1)
    
    return p1s

        
p1s = run_inference(model, train_dataset)
print("Train set P@1: %.2f" % np.mean(p1s))

p1s = run_inference(model, valid_dataset)
print("Dev set P@1: %.2f" % np.mean(p1s))

Train set P@1: 0.27
Dev set P@1: 0.24
