## End-To-End Memory Networks

#### Paper
[https://arxiv.org/abs/1503.08895](https://arxiv.org/abs/1503.08895)

#### Citation
```
@inproceedings{sukhbaatar2015end,
  title={End-to-end memory networks},
  author={Sukhbaatar, Sainbayar and Weston, Jason and Fergus, Rob and others},
  booktitle={Advances in neural information processing systems},
  pages={2440--2448},
  year={2015}
}
```

#### Notes
Unlike the paper, this implementation does not use any weight tying.

In [1]:
import torch
from torch import FloatTensor, LongTensor
from torch import optim, nn
from torch.autograd import Variable
from torch.nn import functional as F, init
from torch.nn.utils import clip_grad_norm


def Volatile(x):
    return Variable(x, volatile=True)


def add_grad_noise(parameters, scale=0.01):
    for p in parameters:
        p.grad.data.add_(torch.randn(p.size()).mul_(scale))


def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
    """Function that measures Binary Cross Entropy between target and output logits.
    
    From torch/nn/functional.py. In master, but not in v0.1.12.
    """
    if not target.is_same_size(input):
        raise ValueError('Target size ({}) must be the same as input size ({})'.format(target.size(), input.size()))

    max_val = input.clamp(min=0)
    loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

    if weight is not None:
        loss = loss * weight

    if size_average:
        return loss.mean()
    else:
        return loss.sum()

In [2]:
class MemNet(nn.Module):
    def __init__(self, input_size, query_size, hidden_size, n_layers=2):
        super(MemNet, self).__init__()
        self.input_size = input_size
        self.query_size = query_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers

        self.h = nn.Linear(query_size, hidden_size)
        
        self.fs = nn.ModuleList([
            nn.Linear(input_size, hidden_size) for i in range(n_layers)])
        
        self.gs = nn.ModuleList([
            nn.Linear(input_size, hidden_size) for i in range(n_layers)])
    
    def call1(self, x, q, return_masks=False):
        masks = [] if return_masks else None
        
        z = self.h(q)
        for f, g in zip(self.fs, self.gs):
            u = f(x)
            v = g(x)
            p = F.softmax(z.mm(u.t()))
            w = p.mm(v)
            z = z + w

            if return_masks:
                masks.append(p)
    
        return (z, masks) if return_masks else z

    def __call__(self, x, q, return_masks=False):
        if x.ndimension() == 2:
            return self.call1(x, q, return_masks=return_masks)
        
        masks = [] if return_masks else None

        # Flatten inputs
        batch_size, n_inputs, _ = x.size()
        x_flat = x.view(batch_size*n_inputs, self.input_size)

        # Flatten queries
        batch_size2, n_queries, _ = q.size()
        q_flat = q.view(batch_size2*n_queries, self.query_size)

        # Batch sizes should match
        assert batch_size == batch_size2

        # z = f(q)
        z_flat = self.h(q_flat)
        z = z_flat.view(batch_size, n_queries, self.hidden_size)

        for f, g in zip(self.fs, self.gs):
            # u = f(x)
            u_flat = f(x_flat)
            u = u_flat.view(batch_size, n_inputs, self.hidden_size)
            uT = u.transpose(1, 2)

            # p = softmax(zu^T)
            s = z.bmm(uT)
            s_flat = s.view(batch_size*n_queries, n_inputs)
            p_flat = F.softmax(s_flat)
            p = p_flat.view(batch_size, n_queries, n_inputs)
            
            # v = g(x)
            # w = pv
            # z = z + w
            v_flat = g(x_flat)
            v = v_flat.view(batch_size, n_inputs, self.hidden_size)
            w = p.bmm(v)
            z = z + w
            
            if return_masks:
                masks.append(p)

        return (z, masks) if return_masks else z

In [3]:
n_inputs = 6
n_queries = 5
input_size = 10
query_size = 9

hidden_size = 4
n_layers = 2

batch_size = 3

x = torch.randn(batch_size, n_inputs, input_size)
q = torch.randn(batch_size, n_queries, query_size)
y = (torch.rand(batch_size, n_queries) > 0.5).float()

memnet = MemNet(input_size, query_size, hidden_size, n_layers=2)

In [4]:
# z = memnet(Volatile(x[0]), Volatile(q[0]))
# print(z)
i = 1
z_i = memnet(Volatile(x[i]), Volatile(q[i]))
print(z_i)

Variable containing:
 1.2283  2.6824 -0.0098  0.3675
 0.9428  0.9278 -0.3740 -0.0285
 0.6457  0.2929 -0.4503  0.5719
 0.0582  0.9509  0.5408  1.1107
-0.7183  0.9618  0.1809 -1.0658
[torch.FloatTensor of size 5x4]



In [5]:
z, p = memnet(Volatile(x), Volatile(q), return_masks=True)
print(z[i])
print(p)
print(z)

Variable containing:
 1.2283  2.6824 -0.0098  0.3675
 0.9428  0.9278 -0.3740 -0.0285
 0.6457  0.2929 -0.4503  0.5719
 0.0582  0.9509  0.5408  1.1107
-0.7183  0.9618  0.1809 -1.0658
[torch.FloatTensor of size 5x4]

[Variable containing:
(0 ,.,.) = 
  0.1487  0.0907  0.1641  0.1450  0.0833  0.3682
  0.0762  0.2533  0.0448  0.0899  0.5001  0.0356
  0.2263  0.1909  0.1071  0.1175  0.3352  0.0230
  0.1715  0.1430  0.1556  0.1380  0.1878  0.2041
  0.1369  0.0571  0.1912  0.1224  0.0427  0.4497

(1 ,.,.) = 
  0.1799  0.2987  0.0106  0.0027  0.2075  0.3006
  0.1921  0.2673  0.1263  0.0601  0.1619  0.1922
  0.1812  0.2704  0.1383  0.0818  0.1310  0.1973
  0.1543  0.0732  0.0291  0.0251  0.2994  0.4189
  0.1366  0.0737  0.1424  0.2686  0.2131  0.1656

(2 ,.,.) = 
  0.1641  0.1386  0.2658  0.1293  0.1235  0.1787
  0.1747  0.1939  0.1723  0.1109  0.2143  0.1340
  0.1572  0.2407  0.1473  0.1610  0.1112  0.1826
  0.1872  0.2142  0.1419  0.0993  0.2557  0.1018
  0.1628  0.2501  0.2024  0.1245  0.0804

In [6]:
z = memnet(Volatile(x), Volatile(q))
z_flat = z.view(batch_size*n_queries, hidden_size)

In [8]:
clf = nn.Linear(hidden_size, 1)
y_hat_flat = clf(z_flat)
y_hat = y_hat_flat.view(batch_size, n_queries)
loss = binary_cross_entropy_with_logits(y_hat, Volatile(y))
print(loss)

Variable containing:
 0.7281
[torch.FloatTensor of size 1]



In [9]:
print(y_hat)
print(y)
print(y_hat > 0 == y)
print((y_hat > 0 == y).sum())

Variable containing:
-0.0762 -0.2696  0.5813  0.2067  0.0827
-0.3050  0.3262  0.6307  0.1017 -0.3580
-0.2444  0.1318  0.3082 -0.2508  0.0921
[torch.FloatTensor of size 3x5]


 0  1  1  0  0
 1  1  0  1  1
 0  0  1  1  0
[torch.FloatTensor of size 3x5]


 1  0  0  1  1
 0  0  1  0  0
 1  1  0  0  1
[torch.ByteTensor of size 3x5]

7
