# MANNs for bAbI

In [1]:
%matplotlib inline
from __future__ import division, print_function
import os
import random
import re

import torch
from torch import FloatTensor, LongTensor
from torch.nn import functional as F, Embedding, Linear, Module, Parameter
from torch.autograd import Variable

# Data

In [4]:
class Statement(object):
    def __init__(self, slot, text):
        self.slot = slot
        self.text = text
    
    def torchify(self, d):
        text = Variable(LongTensor([d[i] for i in self.text]))
        return Statement(self.slot, text)
        
    def __repr__(self):
        return 'S[{0.slot:02d}](text={0.text})'.format(self)


class Query(object):
    def __init__(self, slot, text, answer, refs):
        self.slot = slot
        self.text = text
        self.answer = answer
        self.refs = refs

    def torchify(self, d):
        text = Variable(LongTensor([d[i] for i in self.text]))
        answer = Variable(LongTensor([d[i] for i in self.answer]))
        return Query(self.slot, text, answer, list(self.refs))
        
    def __repr__(self):
        return 'Q[{0.slot:02d}](text={0.text}, answer={0.answer}, refs={0.refs})'.format(self)


class TaskData(object):
    def __init__(self, task, data_dir='babi-tasks_1-20_v1-2/en/'):
        self.task = task
        def match_one_file(p):
            ms = [f for f in os.listdir(data_dir) if re.match(p, f)]
            if len(ms) != 1:
                raise ValueError('{} matched the wrong number of items: {}'.format(p, ms))
            return os.path.join(data_dir, ms[0])

        self.train_file = match_one_file(r'qa{}_.*_train.txt'.format(task))
        with open(self.train_file) as fp:
            self.train = self.parse(fp)

        self.test_file = match_one_file(r'qa{}_.*_test.txt'.format(task))
        with open(self.test_file) as fp:
            self.test = self.parse(fp)
    
    @staticmethod
    def parse(lines):
        prev_slot = None
        episodes = [[]]
        for line in lines:
            slot, text = line.lower().strip().split(' ', 1)
            slot = int(slot)
            if '\t' in text:
                text, answer, refs = [re.findall(r'\w+', s) for s in text.split('\t')]
                refs = [int(ref) for ref in refs]
                item = Query(slot, text, answer, refs)
            else:
                text = re.findall(r'\w+', text)
                item = Statement(slot, text)
            if prev_slot is None or slot >  prev_slot:
                episodes[-1].append(item)
            else:
                episodes.append([item])
            prev_slot = slot
        return episodes
    
    def __repr__(self):
        return 'TaskData(task={}, num_train={}, num_test={})'.format(self.task, len(self.train), len(self.test))


In [5]:
data = TaskData(1)
print(data)
print('  train_file =', data.train_file)
print('  test_file  =', data.test_file)
print()
for x in data.train[0]:
    print(x)

TaskData(task=1, num_train=200, num_test=200)
  train_file = babi-tasks_1-20_v1-2/en/qa1_single-supporting-fact_train.txt
  test_file  = babi-tasks_1-20_v1-2/en/qa1_single-supporting-fact_test.txt

S[01](text=['mary', 'moved', 'to', 'the', 'bathroom'])
S[02](text=['john', 'went', 'to', 'the', 'hallway'])
Q[03](text=['where', 'is', 'mary'], answer=['bathroom'], refs=[1])
S[04](text=['daniel', 'went', 'back', 'to', 'the', 'hallway'])
S[05](text=['sandra', 'moved', 'to', 'the', 'garden'])
Q[06](text=['where', 'is', 'daniel'], answer=['hallway'], refs=[4])
S[07](text=['john', 'moved', 'to', 'the', 'office'])
S[08](text=['sandra', 'journeyed', 'to', 'the', 'bathroom'])
Q[09](text=['where', 'is', 'daniel'], answer=['hallway'], refs=[4])
S[10](text=['mary', 'moved', 'to', 'the', 'hallway'])
S[11](text=['daniel', 'travelled', 'to', 'the', 'office'])
Q[12](text=['where', 'is', 'daniel'], answer=['office'], refs=[11])
S[13](text=['john', 'went', 'back', 'to', 'the', 'garden'])
S[14](text=['john'

In [22]:
word2id = {}
id2word = {}
for episode in data.train + data.test:
    for item in episode:
        for word in item.text + (item.answer if isinstance(item, Query) else []):
            try:
                i = word2id[word]
            except KeyError:
                word2id[word] = len(word2id)
                id2word[word2id[word]] = word
print('Vocab')
for i, word in sorted(id2word.items()):
    print(i, '=>', word)

Vocab
0 => mary
1 => moved
2 => to
3 => the
4 => bathroom
5 => john
6 => went
7 => hallway
8 => where
9 => is
10 => daniel
11 => back
12 => sandra
13 => garden
14 => office
15 => journeyed
16 => travelled
17 => bedroom
18 => kitchen


# Register Network

### Update
$\begin{align*}
    v &= Ax_t \\
    z &= \sigma(W_a v + b_a) \\
    s_t &= \text{lstm}(v, s_{t-1}) \\
    g &= \text{softmax}(W_g s_t^h + b_g) \\
    m_{i,t} &= g_i \cdot z + (1 - g_i) * m_{i,t-1}
\end{align*}$

### Answer
$\begin{align*}
    v &= Aq_t \\
    k &= \sigma(W_k v + b_k) \\
    g_i &= \frac{k^T m_{i,t}}{||k||\cdot||m_{i,t-1}||} \\
    h_t &= \sum_i g_i \cdot m_{i,t-1}
\end{align*}$

In [16]:
class LSTM(torch.nn.Module):
    def __init__(self, n_in, n_out):
        super(LSTM, self).__init__()
        self.w = Parameter(FloatTensor(n_in + n_out, 4 * n_out))
        self.b = Parameter(FloatTensor(1, 4 * n_out))
        self.h0 = Parameter(FloatTensor(1, n_out))
        self.c0 = Parameter(FloatTensor(1, n_out))
        self.reset_parameters()
    
    def reset_parameters(self):
        self.w.data.normal_(0, 0.01)
        self.b.data.zero_()
        self.h0.data.zero_()
        self.c0.data.zero_()
    
    @property
    def initial_state(self):
        return self.h0, self.c0
    
    def observable(self, s):
        return s[0]
    
    def step(self, x, s=None):
        h_prev, c_prev = s or self.initial_state
        z = torch.cat([x, h_prev], 1)
        a = z.mm(self.w) + self.b.expand(x.size(0), self.b.size(1))
        i, f, o, g = torch.chunk(a, 4, dim=1)
        i = F.sigmoid(i)
        f = F.sigmoid(f)
        o = F.sigmoid(o)
        g = F.tanh(g)
        c = f * c_prev + i * g
        h = o * c
        return h, c


class RegNet(torch.nn.Module):
    def __init__(self, n_words, embed_dim, controller_hid_dim, register_dim, n_registers):
        super(RegNet, self).__init__()
        self.n_words = n_words
        self.embed_dim = embed_dim
        self.controller_hid_dim = controller_hid_dim
        self.register_dim = register_dim
        self.n_registers = n_registers
        
        self.embedding = Embedding(n_words, embed_dim)
        self.candidate = Linear(embed_dim, register_dim)
        self.controller = LSTM(embed_dim, controller_hid_dim)
        self.attend = Linear(controller_hid_dim, n_registers)
        self.read_key = Linear(embed_dim, register_dim)
        self.output = Linear(register_dim, n_words)
    
    @property
    def initial_state(self):
        memory = [Variable(torch.zeros(1, self.register_dim)) for i in range(self.n_registers)]
        return (memory, self.controller.initial_state)
    
    def embed(self, x):
        return self.embedding(x).sum(0)
    
    def update(self, x, state=None):
        m, s = state or self.initial_state
        v = self.embed(x)
        z = F.sigmoid(self.candidate(v))
        s = self.controller.step(v, s)
        g = F.softmax(self.attend(self.controller.observable(s)))
        g = [gi.expand(z.size()) for gi in g.chunk(self.n_registers, dim=1)]
        m = [g[i] * z + (1 - g[i]) * m[i] for i in range(self.n_registers)]
        return m, s
    
    def scores(self, x, s=None):
        m, _ = s or self.initial_state
        v = self.embed(x)
        k = self.read_key(v)
        k_norm = k.norm()
        g = [m[i].dot(k) / (m[i].norm() * k_norm) for i in range(self.n_registers)]
        g = [g[i].unsqueeze(0).expand(m[i].size()) for i in range(self.n_registers)]
        h = sum(g[i] * m[i] for i in range(self.n_registers))
        return self.output(h)
        
    def unfold(self, xs):
        s = self.initial_state
        ss = []
        ps = []
        for x in xs:
            if isinstance(x, Statement):
                s = self.update(x.text, s)
                ss.append(s)
                ps.append(None)
            elif isinstance(x, Query):
                p = self.scores(x.text, s)
                ss.append(s)
                ps.append(p)
            else:
                raise TypeError('expected Statement or Query, got {}'.format(type(x)))
        return ps, ss

    def is_output(self, p, x):
        if p is None:
            assert isinstance(x, Statement)
            return False
        assert isinstance(x, Query)
        return True
    
    def answer_cost(self, p, x):
        return F.cross_entropy(p, x.answer)
    
    def cost(self, xs):
        ps, _ = self.unfold(xs)
        return sum(self.answer_cost(p, x) 
                   for p, x in zip(ps, xs) 
                       if self.is_output(p, x))
    
    def predict(self, xs):
        ps, _ = self.unfold(xs)
        return [p.data.numpy().argmax(1) 
                if self.is_output(p, x) 
                else None 
                for p, x in zip(ps, xs)]


In [17]:
regnet = RegNet(len(word2id), 20, 15, 20, 5)

# Eval

In [18]:
data_train = [[x.torchify(word2id) for x in xs] for xs in data.train]
print('Number of samples:', len(data_train))
while True:
    random.shuffle(data_train)
    sum_nll = 0.0
    for xs in data_train:
        regnet.zero_grad()
        nll = regnet.cost(xs)
        nll.backward()
        for param in regnet.parameters():
            param.data.add_(-0.005, param.grad.data)
        sum_nll += nll.data.numpy()[0]
    print(sum_nll)

Number of samples: 200
2019.81519651
1688.50849533
1505.36831737
1394.07288146
1304.93501854
1245.51178932
1211.32668853
1182.7722156
1160.56629539
1135.86292756
1123.11780107
1104.13347006
1092.08082724
1101.46472955
1093.65085626
1083.36482334
1087.56319308
1074.16418105
1069.72428298
1058.95768261
1062.40102416
1059.79730749
1045.54559696
1044.84016967
1013.67597759
974.118311048
920.633042693
878.022991657
838.997856259
807.104522824
778.487491727
748.365551949
719.048949778
676.36352545
646.833854616
626.400753856
588.101940274
538.134019911
498.6312747
467.582735419
441.136835396
418.767855883
396.122953534
379.045443714
367.836013973
356.14786762
347.692258805
340.22368452
335.874804467
329.50463146
324.056122273
316.782067686
312.514947042
309.623082504
305.949044272
302.717185125
298.643917307
297.692042604
291.593903109
286.751051918
271.134983495
234.063230202
200.039684802
180.252600722
164.541784376
146.013142951
126.340037674
108.903952651
94.7894595414
85.0233711153
76.6

KeyboardInterrupt: 

In [19]:
for i, xs in enumerate(data.test):
    if i >= 2:
        break
    xs_vars = [x.torchify(word2id) for x in xs]
    ps = regnet.predict(xs_vars)
    assert(len(ps) == len(xs))
    for x, p in zip(xs, ps):
        if p is None:
            print(x, None)
        else:
            print(x)
            print((x.answer, [word2id[w] for w in x.answer]),
                  ([id2word[j] for j in p.tolist()], p))

S[01](text=['john', 'travelled', 'to', 'the', 'hallway']) None
S[02](text=['mary', 'journeyed', 'to', 'the', 'bathroom']) None
Q[03](text=['where', 'is', 'john'], answer=['hallway'], refs=[1])
(['hallway'], [7]) (['hallway'], array([7]))
S[04](text=['daniel', 'went', 'back', 'to', 'the', 'bathroom']) None
S[05](text=['john', 'moved', 'to', 'the', 'bedroom']) None
Q[06](text=['where', 'is', 'mary'], answer=['bathroom'], refs=[2])
(['bathroom'], [4]) (['bathroom'], array([4]))
S[07](text=['john', 'went', 'to', 'the', 'hallway']) None
S[08](text=['sandra', 'journeyed', 'to', 'the', 'kitchen']) None
Q[09](text=['where', 'is', 'sandra'], answer=['kitchen'], refs=[8])
(['kitchen'], [18]) (['kitchen'], array([18]))
S[10](text=['sandra', 'travelled', 'to', 'the', 'hallway']) None
S[11](text=['john', 'went', 'to', 'the', 'garden']) None
Q[12](text=['where', 'is', 'sandra'], answer=['hallway'], refs=[10])
(['hallway'], [7]) (['hallway'], array([7]))
S[13](text=['sandra', 'went', 'back', 'to', 't

In [20]:
def compute_error(ds):
    errors = 0.0
    total = 0.0
    for xs in ds:
        xs_vars = [x.torchify(word2id) for x in xs]
        ps = regnet.predict(xs_vars)
        assert(len(ps) == len(xs))
        for x, p in zip(xs, ps):
            if p is not None:
                if x.answer != [id2word[i] for i in p.tolist()]:
                    errors += 1
                total += 1
    print(errors, total)
    return errors, total

In [21]:
compute_error(data.test)

0.0 1000.0


(0.0, 1000.0)