In [3]:
import argparse
import time

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

import utils
# from model_kvmem import ModelKvmem

torch.manual_seed(0)
np.random.seed(0)

ValueError: module functions cannot set METH_CLASS or METH_STATIC

In [None]:
start_time = time.time()
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help="GPU device ID. Use -1 for CPU training")
parser.add_argument('--epochs', type=int, default=100, help="Number of training epochs")
parser.add_argument('--hops', type=int, default=3, help="Number of hops")
parser.add_argument('--qemb', default='bert', choices=['kewer', 'blstatic', 'bldynamic', 'bert'],
                        help="How to embed question text. "
                             "kewer: mean of KEWER embeddings of tokens and linked entities, "
                             "bldynamic: Bi-LSTM embedding trained as part of the model, "
                             "blstatic: Static pre-trained Bi-LSTM embedding")
parser.add_argument('--baseline', default='baseline-3', help="Baseline method triples")
parser.add_argument('--savemodel', default='models/model-kvmem-3-3hops-bert.pt', help="Path to save the model")
parser.add_argument('--loadmodel', help='Load this model checkpoint before training')
args = parser.parse_args(args=[])
print(args)

In [None]:
if args.loadmodel:
    checkpoint = torch.load(args.loadmodel)
    print(checkpoint['args'])
else:
    checkpoint = None

In [None]:
kewer = utils.load_kewer()

In [None]:
if checkpoint:
    kvmem_triples = utils.load_kvmem_triples(checkpoint['args'].baseline)
else:
    kvmem_triples = utils.load_kvmem_triples(args.baseline)

In [None]:
def load_question_set(qblink_split, kvmem_triples, kewer):
    question_set = []
    for sequence in qblink_split:
        for question in ['q1', 'q2', 'q3']:
            question_id = str(sequence[question]['t_id'])
            target_entity = f"<http://dbpedia.org/resource/{sequence[question]['wiki_page']}>"
            if question_id in kvmem_triples:
                key_embeddings = []
                value_embeddings = []
                value_entities = set()

                for subj, pred, obj in kvmem_triples[question_id]:
                    if subj in kewer.wv and pred in kewer.wv and obj in kewer.wv:
                        key_embedding = kewer.wv[subj] + kewer.wv[pred]
                        key_embedding = key_embedding / np.linalg.norm(key_embedding)
                        key_embeddings.append(key_embedding)
                        value_embedding = kewer.wv[obj]
                        value_embeddings.append(value_embedding)
                        value_entities.add(obj)

                candidate_embeddings = []
                target_index = None
                i = 0
                for value_entity in value_entities:
                    candidate_embedding = kewer.wv[value_entity]
                    candidate_embedding = candidate_embedding / np.linalg.norm(candidate_embedding)
                    candidate_embeddings.append(candidate_embedding)
                    if value_entity == target_entity:
                        target_index = i
                    i += 1

                if target_index is not None:
                    question_text = sequence[question]['quetsion_text']
                    
                    question_set.append({
                        'key_embeddings': np.array(key_embeddings, dtype=np.float32),
                        'value_embeddings': np.array(value_embeddings, dtype=np.float32),
                        'candidate_embeddings': np.array(candidate_embeddings, dtype=np.float32),
                        'target_index': target_index,
                        'question_text': question_text
                    })
    return question_set

In [None]:
train_split = utils.load_qblink_split('train')
train_set = load_question_set(train_split, kvmem_triples, kewer)
print('Training examples:', len(train_set))
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)

dev_split = utils.load_qblink_split('dev')
dev_set = load_question_set(dev_split, kvmem_triples, kewer)
print('Dev examples:', len(dev_set))
dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False)

if args.gpu >= 0:
    device = torch.device('cuda:%d' % args.gpu)
else:
    device = torch.device('cpu')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline

class ModelKvmem(nn.Module):
    """Key-Value Memory Network from the paper 'Key-Value Memory Networks for Directly Reading Documents'"""

    def __init__(self, qemb: str, num_hops: int = 3, input_dim: int = 300, question_emb_dim: int = 768):
        """Initialize the model object.

        num_hops: number of hops H.
        input_dim: dimensionality of key and value embedding d.
        """
        super(ModelKvmem, self).__init__()
        self.qemb = qemb
        self.num_hops = num_hops
        self.input_dim = input_dim
        self.question_emb_dim = question_emb_dim
        
        self.BERT = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        for param in self.BERT.parameters():
            param.requires_grad = False
            
        self.W = nn.Linear(self.question_emb_dim, self.input_dim, bias=False)
        
        self.R = nn.ModuleList()
        for i in range(num_hops):
            self.R.append(nn.Linear(input_dim, input_dim, bias=False))

    def forward(self, question_embedding: torch.Tensor, key_embeddings: torch.Tensor,
                value_embeddings: torch.Tensor, candidate_embeddings: torch.Tensor) -> torch.Tensor:
        """Calculate logit scores for each candidate entity.

        question_embedding: 2D tensor of shape 1 x input_dim
        key_embeddings: 3D tensor of shape 1 x 'no. of triple memory slots' x input_dim
        value_embeddings: 3D tensor of shape 1 x 'no. of triple memory slots' x input_dim
        candidate_embeddings: 3D tensor of shape 1 x 'no. of candidate entities' x input_dim
        returns 2D tensor of shape 1 x 'no. of candidate_entities'
        """
        if self.qemb == 'kewer':
            q = question_embedding
        else:
            q = self.W(question_embedding)
        
        for R_j in self.R:
            p_hi = F.softmax(torch.sum(q * key_embeddings, -1, keepdim=True),
                             dim=-2)  # shape: 1 x no. of triple memory slots x 1
            o = torch.sum(p_hi * value_embeddings, dim=-2)  # shape: 1 x input_dim
            q = F.normalize(R_j(q + o), dim=-1)  # shape: 1 x input_dim
        candidate_embeddings_norm = F.normalize(candidate_embeddings,
                                                dim=-1)  # shape: 1 x no. of candidate_entities x input_dim
        entity_scores = torch.sum(q * candidate_embeddings_norm, -1)  # shape: 1 x no. of candidate_entities
        return entity_scores

In [None]:
def train(args, device, trainloader, devloader, checkpoint=None):
    train_samples = len(trainloader)
    dev_samples = len(devloader)

    if checkpoint:
        model = ModelKvmem(qemb=checkpoint['args'].qemb, num_hops=checkpoint['args'].hops)
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model = ModelKvmem(qemb=args.qemb, num_hops=args.hops)
    model = model.to(device)
    model.train()

    optimizer = optim.Adam(model.parameters())
    if checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_epoch = checkpoint['epoch']
        best_dev_loss = checkpoint['best_dev_loss']
        start_epoch = checkpoint['epoch'] + 1
    else:
        best_epoch = -1
        best_dev_loss = float('inf')
        start_epoch = 0

    criterion = nn.CrossEntropyLoss()

    for epoch in range(start_epoch, args.epochs):
        train_epoch_loss = 0.0
        model.train()
        epoch_start_time = time.time()
        for sample in trainloader:
            optimizer.zero_grad()

            question_embedding, key_embeddings, value_embeddings, candidate_embeddings, target = \
                sample['question_embedding'].to(device), sample['key_embeddings'].to(device), \
                sample['value_embeddings'].to(device), sample['candidate_embeddings'].to(device), \
                sample['target_index'].to(device)
            scores = model(question_embedding, key_embeddings, value_embeddings, candidate_embeddings)
            loss = criterion(scores, target)
            loss.backward()
            optimizer.step()

            train_epoch_loss += loss.item()

        dev_epoch_loss = 0.0
        model.eval()
        for sample in devloader:
            question_embedding, key_embeddings, value_embeddings, candidate_embeddings, target = \
                sample['question_embedding'].to(device), sample['key_embeddings'].to(device), \
                sample['value_embeddings'].to(device), sample['candidate_embeddings'].to(device), \
                sample['target_index'].to(device)
            scores = model(question_embedding, key_embeddings, value_embeddings, candidate_embeddings)
            loss = criterion(scores, target)
            dev_epoch_loss += loss.item()

        print(f'Epoch {epoch} train loss: {train_epoch_loss / train_samples:.4f}, ' +
              f'dev loss: {dev_epoch_loss / dev_samples:.4f}. Took {time.time() - epoch_start_time:.2f} seconds. '
              f'Total time: {(time.time() - start_time) / (60 * 60):.2f} hours.')
        if dev_epoch_loss / dev_samples < best_dev_loss:
            best_dev_loss = dev_epoch_loss / dev_samples
            best_epoch = epoch
            print(f'Saving model {args.savemodel}...')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_dev_loss': best_dev_loss,
                'args': args
            }, args.savemodel)
    print(f'Best dev loss {best_dev_loss} on was achieved on epoch {best_epoch}.')

In [None]:
train(args, device, train_loader, dev_loader, checkpoint)