# MANNs for bAbI

In [1]:
%matplotlib inline
from __future__ import division, print_function
import os
import random
import re
import numpy as np

import torch
from torch import FloatTensor, LongTensor
from torch.nn import functional as F, Embedding, Linear, Module, Parameter
from torch.autograd import Variable

# Data

In [2]:
class Statement(object):
    def __init__(self, slot, text):
        self.slot = slot
        self.text = text
    
    def torchify(self, d):
        text = Variable(LongTensor([d[i] for i in self.text]))
        return Statement(self.slot, text)
        
    def __repr__(self):
        return 'S[{0.slot:02d}](text={0.text})'.format(self)


class Query(object):
    def __init__(self, slot, text, answer, refs):
        self.slot = slot
        self.text = text
        self.answer = answer
        self.refs = refs

    def torchify(self, d):
        text = Variable(LongTensor([d[i] for i in self.text]))
        answer = Variable(LongTensor([d[i] for i in self.answer]))
        return Query(self.slot, text, answer, list(self.refs))
        
    def __repr__(self):
        return 'Q[{0.slot:02d}](text={0.text}, answer={0.answer}, refs={0.refs})'.format(self)


class TaskData(object):
    def __init__(self, task, data_dir='babi-tasks_1-20_v1-2/en/'):
        self.task = task
        def match_one_file(p):
            ms = [f for f in os.listdir(data_dir) if re.match(p, f)]
            if len(ms) != 1:
                raise ValueError('{} matched the wrong number of items: {}'.format(p, ms))
            return os.path.join(data_dir, ms[0])

        self.train_file = match_one_file(r'qa{}_.*_train.txt'.format(task))
        with open(self.train_file) as fp:
            self.train = self.parse(fp)

        self.test_file = match_one_file(r'qa{}_.*_test.txt'.format(task))
        with open(self.test_file) as fp:
            self.test = self.parse(fp)
    
    @staticmethod
    def parse(lines):
        prev_slot = None
        episodes = [[]]
        for line in lines:
            slot, text = line.lower().strip().split(' ', 1)
            slot = int(slot)
            if '\t' in text:
                text, answer, refs = [re.findall(r'\w+', s) for s in text.split('\t')]
                refs = [int(ref) for ref in refs]
                item = Query(slot, text, answer, refs)
            else:
                text = re.findall(r'\w+', text)
                item = Statement(slot, text)
            if prev_slot is None or slot >  prev_slot:
                episodes[-1].append(item)
            else:
                episodes.append([item])
            prev_slot = slot
        return episodes
    
    def __repr__(self):
        return 'TaskData(task={}, num_train={}, num_test={})'.format(self.task, len(self.train), len(self.test))


In [20]:
data = TaskData(16)
print(data)
print('  train_file =', data.train_file)
print('  test_file  =', data.test_file)
print()
for x in data.train[0]:
    print(x)

TaskData(task=16, num_train=1000, num_test=1000)
  train_file = babi-tasks_1-20_v1-2/en/qa16_basic-induction_train.txt
  test_file  = babi-tasks_1-20_v1-2/en/qa16_basic-induction_test.txt

S[01](text=['lily', 'is', 'a', 'frog'])
S[02](text=['bernhard', 'is', 'a', 'frog'])
S[03](text=['bernhard', 'is', 'green'])
S[04](text=['brian', 'is', 'a', 'lion'])
S[05](text=['brian', 'is', 'white'])
S[06](text=['julius', 'is', 'a', 'swan'])
S[07](text=['julius', 'is', 'green'])
S[08](text=['lily', 'is', 'green'])
S[09](text=['greg', 'is', 'a', 'swan'])
Q[10](text=['what', 'color', 'is', 'greg'], answer=['green'], refs=[9, 6, 7])


In [21]:
word2id = {}
id2word = {}
for episode in data.train + data.test:
    for item in episode:
        for word in item.text + (item.answer if isinstance(item, Query) else []):
            try:
                i = word2id[word]
            except KeyError:
                word2id[word] = len(word2id)
                id2word[word2id[word]] = word
print('Vocab')
for i, word in sorted(id2word.items()):
    print(i, '=>', word)

Vocab
0 => lily
1 => is
2 => a
3 => frog
4 => bernhard
5 => green
6 => brian
7 => lion
8 => white
9 => julius
10 => swan
11 => greg
12 => what
13 => color
14 => rhino
15 => gray
16 => yellow


# Register Network

### Update
$\begin{align*}
    v &= Ax_t \\
    z &= \sigma(W_a v + b_a) \\
    s_t &= \text{lstm}(v, s_{t-1}) \\
    g &= \text{softmax}(W_g s_t^h + b_g) \\
    m_{i,t} &= g_i \cdot z + (1 - g_i) * m_{i,t-1}
\end{align*}$

### Answer
$\begin{align*}
    v &= Aq_t \\
    k &= \sigma(W_k v + b_k) \\
    g_i &= \frac{k^T m_{i,t}}{||k||\cdot||m_{i,t-1}||} \\
    h_t &= \sum_i g_i \cdot m_{i,t-1}
\end{align*}$

In [22]:
class LSTM(torch.nn.Module):
    def __init__(self, n_in, n_out):
        super(LSTM, self).__init__()
        self.w = Parameter(FloatTensor(n_in + n_out, 4 * n_out))
        self.b = Parameter(FloatTensor(1, 4 * n_out))
        self.h0 = Parameter(FloatTensor(1, n_out))
        self.c0 = Parameter(FloatTensor(1, n_out))
        self.reset_parameters()
    
    def reset_parameters(self):
        self.w.data.normal_(0, 0.01)
        self.b.data.zero_()
        self.h0.data.zero_()
        self.c0.data.zero_()
    
    @property
    def initial_state(self):
        return self.h0, self.c0
    
    def observable(self, s):
        return s[0]
    
    def step(self, x, s=None):
        h_prev, c_prev = s or self.initial_state
        z = torch.cat([x, h_prev], 1)
        a = z.mm(self.w) + self.b.expand(x.size(0), self.b.size(1))
        i, f, o, g = torch.chunk(a, 4, dim=1)
        i = F.sigmoid(i)
        f = F.sigmoid(f)
        o = F.sigmoid(o)
        g = F.tanh(g)
        c = f * c_prev + i * g
        h = o * c
        return h, c


class RegNet(torch.nn.Module):
    def __init__(self, n_words, embed_dim, controller_hid_dim, register_dim, n_registers):
        super(RegNet, self).__init__()
        self.n_words = n_words
        self.embed_dim = embed_dim
        self.controller_hid_dim = controller_hid_dim
        self.register_dim = register_dim
        self.n_registers = n_registers
        
        self.embedding = Embedding(n_words, embed_dim)
        self.candidate = Linear(embed_dim, register_dim)
        self.controller = LSTM(embed_dim, controller_hid_dim)
        self.attend = Linear(controller_hid_dim, n_registers)
        self.read_key = Linear(embed_dim, register_dim)
        self.output = Linear(register_dim, n_words)
    
    @property
    def initial_state(self):
        memory = [Variable(torch.zeros(1, self.register_dim)) for i in range(self.n_registers)]
        return (memory, self.controller.initial_state)
    
    def embed(self, x):
        return self.embedding(x).sum(0)
    
    def update(self, x, state=None):
        m, s = state or self.initial_state
        v = self.embed(x)
        z = F.sigmoid(self.candidate(v))
        s = self.controller.step(v, s)
        g = F.softmax(self.attend(self.controller.observable(s)))
        g = [gi.expand(z.size()) for gi in g.chunk(self.n_registers, dim=1)]
        m = [g[i] * z + (1 - g[i]) * m[i] for i in range(self.n_registers)]
        return m, s
    
    def scores(self, x, s=None):
        m, _ = s or self.initial_state
        v = self.embed(x)
        k = self.read_key(v)
        k_norm = k.norm()
        g = [m[i].dot(k) / (m[i].norm() * k_norm) for i in range(self.n_registers)]
        g = [g[i].unsqueeze(0).expand(m[i].size()) for i in range(self.n_registers)]
#         g = [(m[i] * k).sum(1) / (m[i].norm() * k_norm) for i in range(self.n_registers)]
#         g = F.softmax(torch.stack(g).squeeze().t()).chunk(self.n_registers, 1)
        h = sum(g[i].expand(m[i].size()) * m[i] for i in range(self.n_registers))
        return self.output(h)
        
    def unfold(self, xs):
        s = self.initial_state
        ss = []
        ps = []
        for x in xs:
            if isinstance(x, Statement):
                s = self.update(x.text, s)
                ss.append(s)
                ps.append(None)
            elif isinstance(x, Query):
                p = self.scores(x.text, s)
                ss.append(s)
                ps.append(p)
            else:
                raise TypeError('expected Statement or Query, got {}'.format(type(x)))
        return ps, ss

    def is_output(self, p, x):
        if p is None:
            assert isinstance(x, Statement)
            return False
        assert isinstance(x, Query)
        return True
    
    def answer_cost(self, p, x):
        return F.cross_entropy(p, x.answer)
    
    def cost(self, xs):
        ps, _ = self.unfold(xs)
        return sum(self.answer_cost(p, x) 
                   for p, x in zip(ps, xs) 
                       if self.is_output(p, x))
    
    def predict(self, xs):
        ps, _ = self.unfold(xs)
        return [p.data.numpy().argmax(1) 
                if self.is_output(p, x) 
                else None 
                for p, x in zip(ps, xs)]


In [23]:
regnet = RegNet(len(word2id), 20, 15, 20, 5)
opt = torch.optim.SGD(regnet.parameters(), lr=0.005)

# Eval

In [24]:
data_train = [[x.torchify(word2id) for x in xs] for xs in data.train]
print('Number of samples:', len(data_train))
while True:
    random.shuffle(data_train)
    sum_nll = 0.0
    for xs in data_train:
        regnet.zero_grad()
        nll = regnet.cost(xs)
        nll.backward()
#         opt.step()
        for param in regnet.parameters():
            param.data.add_(-0.005, param.grad.data)
        sum_nll += nll.data.numpy()[0]
    print(sum_nll)

Number of samples: 1000
1528.01190928
1324.85786232
1256.07790552
1200.66384088
1182.84560367
1170.09475642
1153.78504046
1154.36563838
1145.14568537
1139.93336801
1134.16952732
1137.06689192
1130.83015932
1123.51893623
1119.56539039


KeyboardInterrupt: 

In [25]:
for i, xs in enumerate(data.test):
    if i >= 2:
        break
    xs_vars = [x.torchify(word2id) for x in xs]
    ps = regnet.predict(xs_vars)
    assert(len(ps) == len(xs))
    for x, p in zip(xs, ps):
        if p is None:
            print(x, None)
        else:
            print(x)
            print((x.answer, [word2id[w] for w in x.answer]),
                  ([id2word[j] for j in p.tolist()], p))

S[01](text=['lily', 'is', 'a', 'swan']) None
S[02](text=['bernhard', 'is', 'a', 'lion']) None
S[03](text=['greg', 'is', 'a', 'swan']) None
S[04](text=['bernhard', 'is', 'white']) None
S[05](text=['brian', 'is', 'a', 'lion']) None
S[06](text=['lily', 'is', 'gray']) None
S[07](text=['julius', 'is', 'a', 'rhino']) None
S[08](text=['julius', 'is', 'gray']) None
S[09](text=['greg', 'is', 'gray']) None
Q[10](text=['what', 'color', 'is', 'brian'], answer=['white'], refs=[5, 2, 4])
(['white'], [8]) (['gray'], array([15]))
S[01](text=['lily', 'is', 'a', 'rhino']) None
S[02](text=['brian', 'is', 'a', 'swan']) None
S[03](text=['bernhard', 'is', 'a', 'swan']) None
S[04](text=['lily', 'is', 'gray']) None
S[05](text=['brian', 'is', 'white']) None
S[06](text=['bernhard', 'is', 'white']) None
S[07](text=['julius', 'is', 'a', 'frog']) None
S[08](text=['julius', 'is', 'white']) None
S[09](text=['greg', 'is', 'a', 'frog']) None
Q[10](text=['what', 'color', 'is', 'greg'], answer=['white'], refs=[9, 7, 8])

In [26]:
def compute_error(ds):
    errors = 0
    total = 0
    for xs in ds:
        xs_vars = [x.torchify(word2id) for x in xs]
        ps = regnet.predict(xs_vars)
        assert(len(ps) == len(xs))
        for x, p in zip(xs, ps):
            if p is not None:
                if x.answer != [id2word[i] for i in p.tolist()]:
                    errors += 1
                total += 1
    return errors, total

In [27]:
err, num = compute_error(data.test)
print('error: {:0.03f} ({} of {})'.format(err / num, err, num))

error: 0.529 (529 of 1000)


In [124]:
R = range(5)
k = Variable(FloatTensor(np.random.normal(0, 1, (2, 3))))
m = [Variable(FloatTensor(np.random.normal(0, 1, (2, 3)))) for i in R]
d = [(m[i] * k).sum(1) for i in R]
s = F.softmax(torch.stack(d2).squeeze().t())
g = s.chunk(5, 1)
print('k =', k)
print('s =', s)
for i in R:
    print('m[{}] = {}'.format(i, m[i]))
    print('d1[{}] = {}'.format(i, d[i]))
    print('g[{}] = {}'.format(i, g[i]))
print('g[0] =', g[0].expand(2, 3))

k = Variable containing:
 1.0152 -1.2493 -0.0987
 1.1050  0.3949 -0.6933
[torch.FloatTensor of size 2x3]

s = Variable containing:
 0.0581  0.6610  0.0723  0.1448  0.0638
 0.1228  0.1073  0.1913  0.0773  0.5014
[torch.FloatTensor of size 2x5]

m[0] = Variable containing:
 0.6267  1.3359 -0.2588
-1.1040  0.0640  0.0412
[torch.FloatTensor of size 2x3]

d1[0] = Variable containing:
-1.0072
-1.2232
[torch.FloatTensor of size 2x1]

g[0] = Variable containing:
 0.0581
 0.1228
[torch.FloatTensor of size 2x1]

m[1] = Variable containing:
 0.9208  1.0184  0.5933
-0.8441 -0.1000 -1.1692
[torch.FloatTensor of size 2x3]

d1[1] = Variable containing:
-0.3960
-0.1616
[torch.FloatTensor of size 2x1]

g[1] = Variable containing:
 0.6610
 0.1073
[torch.FloatTensor of size 2x1]

m[2] = Variable containing:
 0.2443 -0.2024  0.4426
-0.9661 -2.4098 -0.4936
[torch.FloatTensor of size 2x3]

d1[2] = Variable containing:
 0.4571
-1.6769
[torch.FloatTensor of size 2x1]

g[2] = Variable containing:
 0.0723
 0.19