In [2]:
import torch
from torchtext.datasets import BABI20
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch.nn.init as I
import numpy as np
from torch.optim.lr_scheduler import StepLR

In [3]:
def dataloader(batch_size, memory_size, task, joint, tenK):
    train_iter, valid_iter, test_iter = BABI20.iters(
        batch_size=batch_size, memory_size=memory_size, task=task, joint=joint, tenK=tenK, device=torch.device("cpu"),
    shuffle=True)
    return train_iter, valid_iter, test_iter, train_iter.dataset.fields['query'].vocab

In [6]:
train_iter, valid_iter, test_iter, vocab = dataloader(64, 50, 6, True, False)

In [4]:
def print_story(stories):
    """
    function to print stories from padded sequence
    """
    for s in stories:
        if sum(s)>0:
            print(' '.join([vocab.itos[i] for i in s]))

In [7]:
count = 0
for _, batch in enumerate(train_iter, start=1):
    print_story(batch.story[0])
    print([vocab.itos[i] for i in batch.query[0]])
    print(vocab.itos[batch.answer[0]])
    break
count

The chest fits inside the suitcase <pad> <pad> <pad> <pad> <pad>
The box of chocolates fits inside the box <pad> <pad> <pad>
The container is bigger than the box <pad> <pad> <pad> <pad>
The box fits inside the suitcase <pad> <pad> <pad> <pad> <pad>
The suitcase is bigger than the chest <pad> <pad> <pad> <pad>
['Is', 'the', 'box', 'of', 'chocolates', 'bigger', 'than', 'the', 'suitcase', '<pad>', '<pad>']
no


0

In [8]:
class MemN2N(nn.Module):

    def __init__(self, params, vocab):
        super(MemN2N, self).__init__()
        self.input_size = len(vocab)
        self.embed_size = params['embed_size']
        self.memory_size = params['memory_size']
        self.num_hops = params['num_hops']
        self.use_bow = params['use_bow']
        self.use_lw = params['use_lw']
        self.use_ls = params['use_ls']
        self.vocab = vocab

        # create parameters according to different type of weight tying
        pad = self.vocab.stoi['<pad>']
        # First embedding for stories
        self.A = nn.ModuleList([nn.Embedding(self.input_size, self.embed_size, padding_idx=pad)])
        self.A[-1].weight.data.normal_(0, 0.1)
        # Second embedding for stories
        self.C = nn.ModuleList([nn.Embedding(self.input_size, self.embed_size, padding_idx=pad)])
        self.C[-1].weight.data.normal_(0, 0.1)
        if self.use_lw:
            for _ in range(1, self.num_hops):
                self.A.append(self.A[-1])
                self.C.append(self.C[-1])
            self.B = nn.Embedding(self.input_size, self.embed_size, padding_idx=pad)
            self.B.weight.data.normal_(0, 0.1)
            self.out = nn.Parameter(
                I.normal_(torch.empty(self.input_size, self.embed_size), 0, 0.1))
            self.H = nn.Linear(self.embed_size, self.embed_size)
            self.H.weight.data.normal_(0, 0.1)
        else:
            for _ in range(1, self.num_hops):
                self.A.append(self.C[-1])
                self.C.append(nn.Embedding(self.input_size, self.embed_size, padding_idx=pad))
                self.C[-1].weight.data.normal_(0, 0.1)
            self.B = self.A[0]
            self.out = self.C[-1].weight

        # temporal matrix
        self.TA = nn.Parameter(I.normal_(torch.empty(self.memory_size, self.embed_size), 0, 0.1))
        self.TC = nn.Parameter(I.normal_(torch.empty(self.memory_size, self.embed_size), 0, 0.1))

    def forward(self, story, query):
        sen_size = query.shape[-1]
        weights = self.compute_weights(sen_size)
        state = (self.B(query) * weights).sum(1)
        sen_size = story.shape[-1]
        weights = self.compute_weights(sen_size)
        for i in range(self.num_hops):
            memory = (self.A[i](story.view(-1, sen_size)) * weights).sum(1).view(
                *story.shape[:-1], -1)
            memory += self.TA
            output = (self.C[i](story.view(-1, sen_size)) * weights).sum(1).view(
                *story.shape[:-1], -1)
            output += self.TC
            probs = (memory @ state.unsqueeze(-1)).squeeze() # attention scores
            if not self.use_ls:
                probs = F.softmax(probs, dim=-1)
            response = (probs.unsqueeze(1) @ output).squeeze()
            if self.use_lw:
                state = self.H(response) + state
            else:
                state = response + state

        return F.log_softmax(F.linear(state, self.out), dim=-1)

    def compute_weights(self, J):
        # position encoding
        d = self.embed_size
        if self.use_bow:
            weights = torch.ones(J, d)
        else:
            func = lambda j, k: 1 - (j + 1) / J - (k + 1) / d * (1 - 2 * (j + 1) / J)    # 0-based indexing
            weights = torch.from_numpy(np.fromfunction(func, (J, d), dtype=np.float32))
        #return weights.cuda() if torch.cuda.is_available() else weights
        return weights


In [9]:
def train(train_iter, model, optimizer, epochs, max_clip, valid_iter=None):
    total_loss = 0
    valid_data = list(valid_iter)
    valid_loss = None
    next_epoch_to_report = 5
    pad = model.vocab.stoi['<pad>']

    for _, batch in enumerate(train_iter, start=1):
        story = batch.story
        query = batch.query
        answer = batch.answer

        optimizer.zero_grad()
        outputs = model(story, query)
        loss = F.nll_loss(outputs, answer.view(-1), ignore_index=pad, reduction='sum')
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_clip)
        optimizer.step()
        total_loss += loss.item()
        # linear start
        if model.use_ls:
            loss = 0
            for k, batch in enumerate(valid_data, start=1):
                story = batch.story
                query = batch.query
                answer = batch.answer
                outputs = model(story, query)
                loss += F.nll_loss(outputs, answer.view(-1), ignore_index=pad, reduction='sum').item()
            loss = loss / k
            if valid_loss and valid_loss <= loss:
                model.use_ls = False
            else:
                valid_loss = loss

        if train_iter.epoch == next_epoch_to_report:
            print("#! epoch {:d} average batch loss: {:5.4f}".format(
                int(train_iter.epoch), total_loss / len(train_iter)))
            next_epoch_to_report += 5
        if int(train_iter.epoch) == train_iter.epoch:
            total_loss = 0
        if train_iter.epoch == epochs:
            print("Done!")
            break
            
def eval(test_iter, model):
    total_error = 0

    for k, batch in enumerate(test_iter, start=1):
        story = batch.story
        query = batch.query
        answer = batch.answer
        outputs = model(story, query)
        _, outputs = torch.max(outputs, -1)
        total_error += torch.mean((outputs != answer.view(-1)).float()).item()
    print("#! average error: {:5.1f}".format(total_error / k * 100))

In [10]:
def train(train_iter, model, optimizer, epochs, max_clip, valid_iter=None):
    total_loss = 0
    next_epoch_to_report = 5
    pad = model.vocab.stoi['<pad>']
    scheduler = StepLR(optimizer, step_size=25, gamma=0.5)
    for epoch in range(epochs):
        scheduler.step()
        epoch_loss = 0
        training_data_size = 0
        for _, batch in enumerate(train_iter, start=1):
            training_data_size += batch.answer.shape[0]
            story = batch.story
            query = batch.query
            answer = batch.answer

            optimizer.zero_grad()
            outputs = model(story, query)
            loss = F.nll_loss(outputs, answer.view(-1), ignore_index=pad, reduction='sum')
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_clip)
            optimizer.step()
            epoch_loss += loss.item()

In [25]:
params = {'embed_size' : 20,
          'memory_size' : 50,
          'num_hops' : 3,
          'use_bow' : False,
          'use_lw' : True,
          'use_ls' : True}

model = MemN2N(params, vocab)
optimizer = optim.Adam(model.parameters(), 0.01)
train(train_iter, model, optimizer, 20, 40, valid_iter)

Loss for epoch 0: 4.265222820070055
Loss for epoch 1: 0.6998750078413222
Loss for epoch 2: 0.6788641730414496
Loss for epoch 3: 0.46520428297254773
Loss for epoch 4: 0.39745666207207575
Loss for epoch 5: 0.372959529876709
Loss for epoch 6: 0.36126250563727486
Loss for epoch 7: 0.35352439075046116
Loss for epoch 8: 0.35828187582227916
Loss for epoch 9: 0.35167537763383655
Loss for epoch 10: 0.34342841127183704
Loss for epoch 11: 0.3414072438346015
Loss for epoch 12: 0.34585595565372046
Loss for epoch 13: 0.3344229125976563
Loss for epoch 14: 0.3314963086446126
Loss for epoch 15: 0.3279499698215061
Loss for epoch 16: 0.33579450289408364
Loss for epoch 17: 0.3312346231672499
Loss for epoch 18: 0.33089186032613116
Loss for epoch 19: 0.33001532713572185


In [26]:
for _, batch in enumerate(test_iter, start=1):
    vocab = test_iter.dataset.fields['query'].vocab
    with torch.no_grad():
        story = batch.story
        query = batch.query
        answer = batch.answer
        outputs = model(story, query)
        ex = 33
        print_story(story[ex])
        print(' '.join([vocab.itos[i] for i in query[ex]]))
        print(vocab.itos[answer[ex]])
        o = np.argmax(outputs[ex])
        print(vocab.itos[o])
        break

Daniel travelled to the kitchen <pad>
Daniel moved to the bedroom <pad>
Mary went to the garden <pad>
John took the football there <pad>
Is Daniel in the bathroom
no
no


In [27]:
acc = []
for _, batch in enumerate(test_iter, start=1):
    vocab = test_iter.dataset.fields['query'].vocab
    with torch.no_grad():
        story = batch.story
        query = batch.query
        answer = batch.answer
        outputs = model(story, query)
        o = np.argmax(outputs, axis=1)
        acc += (list(np.array(answer.reshape(-1,)==o)))
np.mean(acc)

0.828

In [11]:
params = {'embed_size' : 30,
          'memory_size' : 50,
          'num_hops' : 4,
          'use_bow' : False,
          'use_lw' : True,
          'use_ls' : False}

for task in range(1,21):
    train_iter, valid_iter, test_iter, vocab = dataloader(64, 50, task, False, True)
    model = MemN2N(params, vocab)
    optimizer = optim.Adam(model.parameters(), 0.05)
    train(train_iter, model, optimizer, 20, 40, valid_iter)
    acc = []
    with torch.no_grad():
        for _, batch in enumerate(test_iter, start=1):
            with torch.no_grad():
                story = batch.story
                query = batch.query
                answer = batch.answer
                outputs = model(story, query)
                o = np.argmax(outputs, axis=1)
                acc += (list(np.array(answer.reshape(-1,)==o)))
        print(f"Test Accuracy for Task {task}: {np.mean(acc)*100}%")
    np.mean(acc)



Test Accuracy for Task 1: 100.0%
Test Accuracy for Task 2: 14.499999999999998%
Test Accuracy for Task 3: 16.6%
Test Accuracy for Task 4: 61.5%
Test Accuracy for Task 5: 78.7%
Test Accuracy for Task 6: 49.3%
Test Accuracy for Task 7: 47.699999999999996%
Test Accuracy for Task 8: 85.8%
Test Accuracy for Task 9: 63.800000000000004%
Test Accuracy for Task 10: 44.3%
Test Accuracy for Task 11: 33.1%
Test Accuracy for Task 12: 77.2%
Test Accuracy for Task 13: 94.39999999999999%
Test Accuracy for Task 14: 20.1%
Test Accuracy for Task 15: 22.1%
Test Accuracy for Task 16: 25.6%
Test Accuracy for Task 17: 54.1%
Test Accuracy for Task 18: 65.4%
Test Accuracy for Task 19: 8.1%
Test Accuracy for Task 20: 90.9%
