# MANNs for bAbI

In [290]:
%matplotlib inline
from __future__ import division, print_function
import os
import random
import re

import torch
from torch import FloatTensor, LongTensor
from torch.nn import functional as F, Embedding, Linear, Module, Parameter
from torch.autograd import Variable

# Data

In [263]:
class Statement(object):
    def __init__(self, slot, text):
        self.slot = slot
        self.text_strs = text
        x.text = Variable(LongTensor([d[i] for i in self.text]))
    
    def torchify(self, d):
        text = Variable(LongTensor([d[i] for i in self.text]))
        return Statement(self.slot, text)
        
    def __repr__(self):
        return 'S[{0.slot:02d}](text={0.text})'.format(self)


class Query(object):
    def __init__(self, slot, text, answer, refs):
        self.slot = slot
        self.text = text
        self.answer = answer
        self.refs = refs

    def torchify(self, d):
        text = Variable(LongTensor([d[i] for i in self.text]))
        answer = Variable(LongTensor([d[i] for i in self.answer]))
        return Query(self.slot, text, answer, list(self.refs))
        
    def __repr__(self):
        return 'Q[{0.slot:02d}](text={0.text}, answer={0.answer}, refs={0.refs})'.format(self)


class TaskData(object):
    def __init__(self, task, data_dir='babi-tasks_1-20_v1-2/en/'):
        self.task = task
        def match_one_file(p):
            ms = [f for f in os.listdir(data_dir) if re.match(p, f)]
            if len(ms) != 1:
                raise ValueError('{} matched the wrong number of items: {}'.format(p, ms))
            return os.path.join(data_dir, ms[0])

        self.train_file = match_one_file(r'qa{}_.*_train.txt'.format(task))
        with open(self.train_file) as fp:
            self.train = self.parse(fp)

        self.test_file = match_one_file(r'qa{}_.*_test.txt'.format(task))
        with open(self.test_file) as fp:
            self.test = self.parse(fp)
    
    @staticmethod
    def parse(lines):
        prev_slot = None
        episodes = [[]]
        for line in lines:
            slot, text = line.lower().strip().split(' ', 1)
            slot = int(slot)
            if '\t' in text:
                text, answer, refs = [re.findall(r'\w+', s) for s in text.split('\t')]
                refs = [int(ref) for ref in refs]
                item = Query(slot, text, answer, refs)
            else:
                text = re.findall(r'\w+', text)
                item = Statement(slot, text)
            if prev_slot is None or slot >  prev_slot:
                episodes[-1].append(item)
            else:
                episodes.append([item])
            prev_slot = slot
        return episodes
    
    def __repr__(self):
        return 'TaskData(task={}, num_train={}, num_test={})'.format(self.task, len(self.train), len(self.test))


In [350]:
data = TaskData(6)
print(data)
print('  train_file =', data.train_file)
print('  test_file  =', data.test_file)
print()
for x in data.train[0]:
    print(x)

TaskData(task=6, num_train=200, num_test=200)
  train_file = babi-tasks_1-20_v1-2/en/qa6_yes-no-questions_train.txt
  test_file  = babi-tasks_1-20_v1-2/en/qa6_yes-no-questions_test.txt

S[01](text=['mary', 'moved', 'to', 'the', 'bathroom'])
S[02](text=['sandra', 'journeyed', 'to', 'the', 'bedroom'])
Q[03](text=['is', 'sandra', 'in', 'the', 'hallway'], answer=['no'], refs=[2])
S[04](text=['mary', 'went', 'back', 'to', 'the', 'bedroom'])
S[05](text=['daniel', 'went', 'back', 'to', 'the', 'hallway'])
Q[06](text=['is', 'daniel', 'in', 'the', 'bathroom'], answer=['no'], refs=[5])
S[07](text=['sandra', 'went', 'to', 'the', 'kitchen'])
S[08](text=['daniel', 'went', 'back', 'to', 'the', 'bathroom'])
Q[09](text=['is', 'daniel', 'in', 'the', 'office'], answer=['no'], refs=[8])
S[10](text=['daniel', 'picked', 'up', 'the', 'football', 'there'])
S[11](text=['daniel', 'went', 'to', 'the', 'bedroom'])
Q[12](text=['is', 'daniel', 'in', 'the', 'bedroom'], answer=['yes'], refs=[11])
S[13](text=['john', 

In [351]:
word2id = {}
id2word = {}
for episode in data.train + data.test:
    for item in episode:
        for word in item.text + (item.answer if isinstance(item, Query) else []):
            try:
                i = word2id[word]
            except KeyError:
                word2id[word] = len(word2id)
                id2word[word2id[word]] = word
for i, word in sorted(id2word.items()):
    print(i, '=>', word)

0 => mary
1 => moved
2 => to
3 => the
4 => bathroom
5 => sandra
6 => journeyed
7 => bedroom
8 => is
9 => in
10 => hallway
11 => no
12 => went
13 => back
14 => daniel
15 => kitchen
16 => office
17 => picked
18 => up
19 => football
20 => there
21 => yes
22 => john
23 => travelled
24 => garden
25 => got
26 => apple
27 => put
28 => down
29 => grabbed
30 => left
31 => dropped
32 => took
33 => milk
34 => discarded


# Neural Net

In [352]:
class LSTM(torch.nn.Module):
    def __init__(self, n_in, n_out):
        super(LSTM, self).__init__()
        self.w = Parameter(FloatTensor(n_in + n_out, 4 * n_out))
        self.b = Parameter(FloatTensor(1, 4 * n_out))
        self.h0 = Parameter(FloatTensor(1, n_out))
        self.c0 = Parameter(FloatTensor(1, n_out))
        self.reset_parameters()
    
    def reset_parameters(self):
        self.w.data.normal_(0, 0.01)
        self.b.data.zero_()
        self.h0.data.zero_()
        self.c0.data.zero_()
    
    @property
    def initial_state(self):
        return self.h0, self.c0
    
    def observable(self, s):
        return s[0]
    
    def step(self, x, s=None):
        h_prev, c_prev = s or self.initial_state
        z = torch.cat([x, h_prev], 1)
        a = z.mm(self.w) + self.b.expand(x.size(0), self.b.size(1))
        i, f, o, g = torch.chunk(a, 4, dim=1)
        i = F.sigmoid(i)
        f = F.sigmoid(f)
        o = F.sigmoid(o)
        g = F.tanh(g)
        c = f * c_prev + i * g
        h = o * c
        return h, c


class RegNet(torch.nn.Module):
    def __init__(self, n_words, embed_dim, controller_hid_dim, register_dim, n_registers):
        super(RegNet, self).__init__()
        self.n_words = n_words
        self.embed_dim = embed_dim
        self.controller_hid_dim = controller_hid_dim
        self.register_dim = register_dim
        self.n_registers = n_registers
        
        self.embedding = Embedding(n_words, embed_dim)
        self.add = Linear(embed_dim, register_dim)
        self.erase = Linear(embed_dim, register_dim)
        self.controller = LSTM(embed_dim, controller_hid_dim)
        self.attend = Linear(controller_hid_dim, n_registers)
        self.read_key = Linear(embed_dim, register_dim)
        self.output = Linear(register_dim, n_words)
    
    @property
    def initial_state(self):
        memory = [Variable(torch.zeros(1, self.register_dim)) for i in range(self.n_registers)]
        return (memory, self.controller.initial_state)
    
    def embed(self, x):
        return self.embedding(x).sum(0)
    
    def update(self, x, s=None):
        m, r = s or self.initial_state
        z = self.embed(x)
        a = F.sigmoid(self.add(z))
        e = F.sigmoid(self.erase(z))
        r = self.controller.step(z, r)
        g = F.softmax(self.attend(self.controller.observable(r)))
        g = [gi.expand(a.size()) for gi in g.chunk(self.n_registers, dim=1)]
        m = [g[i] * a + (1 - g[i]) * m[i] for i in range(self.n_registers)]
        return m, r
    
    def scores(self, x, s=None):
        m, _ = s or self.initial_state
        z = self.embed(x)
        k = self.read_key(z)
        k_norm = k.norm()
        g = [m[i].dot(k) / (m[i].norm() * k_norm) for i in range(self.n_registers)]
        g = [g[i].unsqueeze(0).expand(m[i].size()) for i in range(self.n_registers)]
        h = sum(g[i] * m[i] for i in range(self.n_registers))
        return self.output(h)
        
    def unfold(self, xs):
        s = self.initial_state
        ss = []
        ps = []
        for x in xs:
            if isinstance(x, Statement):
                s = self.update(x.text, s)
                ss.append(s)
                ps.append(None)
            elif isinstance(x, Query):
                p = self.scores(x.text, s)
                ss.append(s)
                ps.append(p)
            else:
                raise TypeError('expected Statement or Query, got {}'.format(type(x)))
        return ps, ss

    def is_output(self, p, x):
        if p is None:
            assert isinstance(x, Statement)
            return False
        assert isinstance(x, Query)
        return True
    
    def answer_cost(self, p, x):
        return F.cross_entropy(p, x.answer)
    
    def cost(self, xs):
        ps, _ = self.unfold(xs)
        return sum(self.answer_cost(p, x) 
                   for p, x in zip(ps, xs) 
                       if self.is_output(p, x))
    
    def predict(self, xs):
        ps, _ = self.unfold(xs)
        return [p.data.numpy().argmax(1) 
                if self.is_output(p, x) 
                else None 
                for p, x in zip(ps, xs)]


In [353]:
regnet = RegNet(len(word2id), 20, 15, 20, 5)

# Eval

In [358]:
data_train = [[x.torchify(word2id) for x in xs] for xs in data.train]
print('Number of samples:', len(data_train))
while True:
    random.shuffle(data_train)
    sum_nll = 0.0
    for xs in data_train:
        regnet.zero_grad()
        nll = regnet.cost(xs)
        nll.backward()
        for param in regnet.parameters():
            param.data.add_(-0.005, param.grad.data)
        sum_nll += nll.data.numpy()[0]
    print(sum_nll)

Number of samples: 200
665.530073643
670.551208138
667.196122169
662.009292603
665.453834414
666.770536423
663.650207877
661.660168767
660.056868196
658.267350435
656.219919324
657.227235556
653.741602302
651.653306127
648.709418893
647.69943881
648.588648677
648.852978706
645.459496737
642.948979735
644.346433163
638.017903447
637.239938259
634.41603446
631.183526516
627.824265003
626.683178067
626.420316815
622.437878966
621.611734509
617.344587088
613.819654465
612.776342273
608.654447794
607.068612933
601.205011845
595.352134705
592.224013209
588.018076777
579.454022169
576.218397677
574.99883306
570.005260229
563.462118983
558.412848175
555.055948973
544.094163418
543.726655543
536.603663445
531.339917421
527.623206377
517.883247137
519.240021586
506.177681446
498.757984102
501.606267273
493.918481052
484.60589993
474.38744086
467.902221322
467.738693118
464.850767568
452.331591189
446.38367337
441.085629642
436.164363235
425.328442663
419.085003018
416.797692537
410.601911604
403

KeyboardInterrupt: 

In [359]:
for i, xs in enumerate(data.test):
    if i >= 2:
        break
    xs_vars = [x.torchify(word2id) for x in xs]
    ps = regnet.predict(xs_vars)
    assert(len(ps) == len(xs))
    for x, p in zip(xs, ps):
        if p is None:
            print(x, None)
        else:
            print(x)
            print((x.answer, [word2id[w] for w in x.answer]),
                  ([id2word[j] for j in p.tolist()], p))

S[01](text=['mary', 'got', 'the', 'milk', 'there']) None
S[02](text=['john', 'moved', 'to', 'the', 'bedroom']) None
Q[03](text=['is', 'john', 'in', 'the', 'kitchen'], answer=['no'], refs=[2])
(['no'], [11]) (['no'], array([11]))
S[04](text=['mary', 'discarded', 'the', 'milk']) None
S[05](text=['john', 'went', 'to', 'the', 'garden']) None
Q[06](text=['is', 'john', 'in', 'the', 'kitchen'], answer=['no'], refs=[5])
(['no'], [11]) (['yes'], array([21]))
S[07](text=['daniel', 'moved', 'to', 'the', 'bedroom']) None
S[08](text=['daniel', 'went', 'to', 'the', 'garden']) None
Q[09](text=['is', 'john', 'in', 'the', 'garden'], answer=['yes'], refs=[5])
(['yes'], [21]) (['yes'], array([21]))
S[10](text=['daniel', 'travelled', 'to', 'the', 'bathroom']) None
S[11](text=['sandra', 'travelled', 'to', 'the', 'bedroom']) None
Q[12](text=['is', 'daniel', 'in', 'the', 'bathroom'], answer=['yes'], refs=[10])
(['yes'], [21]) (['no'], array([11]))
S[13](text=['mary', 'took', 'the', 'football', 'there']) None

In [342]:
def compute_error(ds):
    errors = 0.0
    total = 0.0
    for xs in ds:
        xs_vars = [x.torchify(word2id) for x in xs]
        ps = regnet.predict(xs_vars)
        assert(len(ps) == len(xs))
        for x, p in zip(xs, ps):
            if p is not None:
                if x.answer != [id2word[i] for i in p.tolist()]:
                    errors += 1
                total += 1
    print(errors, total)
    return errors, total

In [362]:
compute_error(data.test)

281.0 1000.0


(281.0, 1000.0)