In [1]:
import argparse
import time

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

import utils

In [2]:
start_time = time.time()
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help="GPU device ID. Use -1 for CPU training")
parser.add_argument('--epochs', type=int, default=100, help="Number of training epochs")
parser.add_argument('--hidden', type=str, default='20,10', help="Sized of hidden layers, comma-separated")
utils.add_bool_arg(parser, 'same-w', True)  # use the same matrix W for all features
parser.add_argument('--interaction', default='mult', choices=['mult', 'add', 'dot'],
                    help="Interaction function to use")
parser.add_argument('--qemb', default='kewer', choices=['kewer', 'blstatic', 'bldynamic'],
                    help="How to embed question text. "
                         "kewer: mean of KEWER embeddings of tokens and linked entities, "
                         "bldynamic: Bi-LSTM embedding trained as part of the model, "
                         "blstatic: Static pre-trained Bi-LSTM embedding")
parser.add_argument('--features-mask', nargs='+', type=int, help="Features mask for feature ablation study")
parser.add_argument('--savemodel', default='models/model-mult-same.pt', help="Path to save the model")
parser.add_argument('--loadmodel', help='Load this model checkpoint before training')
parser.add_argument('--baseline', default='baseline-4', help="Baseline method triples")
args = parser.parse_args(args=[])
print(args)

Namespace(baseline='baseline-4', epochs=100, features_mask=None, gpu=0, hidden='20,10', interaction='mult', loadmodel=None, qemb='kewer', same_w=True, savemodel='models/model-mult-same.pt')


In [3]:
feature_inputs = utils.load_feature_inputs()
kewer = utils.load_kewer()

In [4]:
word_probs = None
question_entities = None
train_question_embeddings = None
# dev_question_embeddings = None

In [5]:
word_probs = utils.load_word_probs()
question_entities = utils.load_question_entities()

In [6]:
def load_question_set(args, qblink_split, overlap_features, feature_inputs, question_embeddings, kewer, word_probs,
                      question_entities):
    question_set = []
    for sequence in qblink_split:
        for question in ['q1', 'q2', 'q3']:
            question_id = str(sequence[question]['t_id'])
            question_text = sequence[question]['quetsion_text']
            target_entity = f"<http://dbpedia.org/resource/{sequence[question]['wiki_page']}>"
            if question_id in overlap_features:

                if question == 'q1':
                    previous_answer = None
                elif question == 'q2':
                    previous_answer = f"<http://dbpedia.org/resource/{sequence['q1']['wiki_page']}>"
                elif question == 'q3':
                    previous_answer = f"<http://dbpedia.org/resource/{sequence['q2']['wiki_page']}>"
                if previous_answer is not None and previous_answer in kewer.wv:
                    previous_answer_embedding = kewer.wv[previous_answer].copy()
                else:
                    previous_answer_embedding = np.zeros(kewer.wv.vector_size, dtype=np.float32)

                overlap_feature_array = []
                feature_input_arrays = {
                    'p': [],
                    'lit': [],
                    'cat': [],
                    'ent': [],
                    's': []
                }
                for i, (entity, entity_overlap_features) in enumerate(overlap_features[question_id].items()):
                    assert entity in feature_inputs
                    overlap_feature_array.append(entity_overlap_features)
                    for feature_type in ['p', 'lit', 'cat', 'ent']:
                        if feature_inputs[entity]['counts'][feature_type] > 0:
                            feature_input_arrays[feature_type].append(
                                feature_inputs[entity]['feature_inputs'][feature_type] /
                                feature_inputs[entity]['counts'][feature_type])
                        else:
                            assert (feature_inputs[entity]['feature_inputs'][feature_type] == 0).all()
                            feature_input_arrays[feature_type].append(
                                feature_inputs[entity]['feature_inputs'][feature_type])
                    feature_input_arrays['s'].append((feature_inputs[entity]['feature_inputs']['lit'] +
                                                      feature_inputs[entity]['feature_inputs']['cat'] +
                                                      feature_inputs[entity]['feature_inputs']['ent']) / (
                                                             feature_inputs[entity]['counts']['lit'] +
                                                             feature_inputs[entity]['counts']['cat'] +
                                                             feature_inputs[entity]['counts']['ent']))
                    if entity == target_entity:
                        target_index = i

                question_set_item = {
                    'overlap_features': np.array(overlap_feature_array, dtype=np.float32),
                    'p_inputs': np.array(feature_input_arrays['p'], dtype=np.float32),
                    'lit_inputs': np.array(feature_input_arrays['lit'], dtype=np.float32),
                    'cat_inputs': np.array(feature_input_arrays['cat'], dtype=np.float32),
                    'ent_inputs': np.array(feature_input_arrays['ent'], dtype=np.float32),
                    's_inputs': np.array(feature_input_arrays['s'], dtype=np.float32),
                    'previous_answer_embedding': previous_answer_embedding,
                    'target_index': target_index
                }
                if args.qemb == 'kewer':
                    question_set_item['question_embedding'] = utils.embed_question(question_text, kewer.wv, word_probs,
                                                                                   question_entities[question_id])
                elif args.qemb == 'blstatic':
                    question_set_item['question_embedding'] = utils.get_question_embedding(question_embeddings,
                                                                                           int(question_id))
                elif args.qemb == 'bldynamic':
                    question_set_item['question'] = question_text

                question_set.append(question_set_item)
    return question_set

In [7]:
dev_split = utils.load_qblink_split('dev')
dev_overlap_features = utils.load_overlap_features('dev')
dev_set = load_question_set(args, dev_split, dev_overlap_features, feature_inputs, dev_question_embeddings, kewer,
                            word_probs, question_entities)
print('Dev examples:', len(dev_set))

Dev examples: 1111


In [13]:
len(dev_overlap_features)

1111

In [9]:
def load_question_set(qblink_split, kvmem_triples, kewer, word_probs, question_entities):
    question_set = []
    for sequence in qblink_split:
        for question in ['q1', 'q2', 'q3']:
            question_id = str(sequence[question]['t_id'])
            target_entity = f"<http://dbpedia.org/resource/{sequence[question]['wiki_page']}>"
            if question_id in kvmem_triples:
                key_embeddings = []
                value_embeddings = []
                value_entities = set()

                for subj, pred, obj in kvmem_triples[question_id]:
                    if subj in kewer.wv and pred in kewer.wv and obj in kewer.wv:
                        key_embedding = kewer.wv[subj] + kewer.wv[pred]
                        key_embedding = key_embedding / np.linalg.norm(key_embedding)
                        key_embeddings.append(key_embedding)
                        value_embedding = kewer.wv[obj]
                        value_embeddings.append(value_embedding)
                        value_entities.add(obj)

                candidate_embeddings = []
                target_index = None
                i = 0
                
                for value_entity in value_entities:
                    candidate_embedding = kewer.wv[value_entity]
                    candidate_embedding = candidate_embedding / np.linalg.norm(candidate_embedding)
                    candidate_embeddings.append(candidate_embedding)
                    if value_entity == target_entity:
                        target_index = i
                    i += 1

                if target_index is not None:
                    question_text = sequence[question]['quetsion_text']
                    question_embedding = utils.embed_question(question_text, kewer.wv, word_probs,
                                                              question_entities[question_id])
                    question_set.append({
                        'question_embedding': np.array(question_embedding, dtype=np.float32),
                        'key_embeddings': np.array(key_embeddings, dtype=np.float32),
                        'value_embeddings': np.array(value_embeddings, dtype=np.float32),
                        'candidate_embeddings': np.array(candidate_embeddings, dtype=np.float32),
                        'target_index': target_indexzzz
                    })
    return question_set

In [11]:
kvmem_triples = utils.load_kvmem_triples(args.baseline)

In [12]:
dev_set = load_question_set(dev_split, kvmem_triples, kewer, word_probs, question_entities)
print('Dev examples:', len(dev_set))

Dev examples: 839
