# Memory Augmented Neural Network using Omniglot Dataset

As we showcased that a Neural Turing Machine's controller is capable of using content-based addressing, location-based addressing or both. Whereas, here MANN works on using a pure content-based memory writer. 

MANN also use a new addressing schema called least recently used access. The idea behind the scene is that the least recently used memory location is determined by the read operation and the read operation is performed by content-based addressing. So, we basically perform content-based addressing for reading and write to the location that was least recently used.




[picture credits: MANN Paper(https://arxiv.org/pdf/1605.06065.pdf)]


<img src="mann.png" width="1500"/>

In this tutorial, we will do following things step by step:
1. Implement Read Operation
2. Implement Write Operation

##### Step1: Lets first import all libraries needed.

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

#### Step 2: Implementing Read Operation
Here, We will define read heads which access memory and updates memory according to read operations we discuss in chapter above.

<img src="MANN_read.png" width="500"/>

In [3]:
class ReadHead(Memory):

    def __init__(self, M, N, controller_out):
        super(ReadHead, self).__init__(M, N, controller_out)
        self.fc_read = nn.Linear(controller_out, self.read_lengths)
        self.intialize_parameters();

    def intialize_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)
        nn.init.normal_(self.fc_read.bias, std=0.01)

    def read(self, memory, w):
        """Read from memory."""
        return torch.matmul(w, memory)

    def forward(self, x, memory):
        param = self.fc_read(x)
        k, beta, g, s, gamma = torch.split(param, [self.N, 1, 1, 3, 1], dim=1)

        k = torch.tanh(k)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=1)
        gamma = 1 + F.softplus(gamma)

        w = self.address(k, g, s, gamma, memory, self.w_last[-1])
        self.w_last.append(w)
        mem = self.read(memory, w)
        return mem, w

1623 total character classes
1423 characters assigned for training, 200 characters assigned for validation


##### Step 5: Implement Write Operation
Similar to Read Operation, here we will implement write operation.

Note: Both read and write heads use fully connected layer to produce paremeters (k, beta, g, s, gamma) for content addressing. 

<img src="MANN_write.png" width="500"/>

In [None]:
class WriteHead(Memory):

    def __init__(self, M, N, controller_out):
        super(WriteHead, self).__init__(M, N, controller_out)
        self.fc_write = nn.Linear(controller_out, self.write_lengths)
        self.intialize_parameters()

    def intialize_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)
        nn.init.normal_(self.fc_write.bias, std=0.01)
    
    def mann_write(self, memory, w, e, a):
        # usage weight vector
        w_u = w_read + w_write # keep track of last value
        w_u_current = w_u_last+w_read+w_write
        w_write = torch.sigmoid(a)*w_read_last + (1-torch.sigmoid(a))*w_least_used_weight_vector_last
        memory_update = m_last + w_write*k_t
        
    def forward(self, x, memory):
        param = self.fc_write(x)

        k, beta, g, s, gamma, a, e = torch.split(param, [self.N, 1, 1, 3, 1, self.N, self.N], dim=1)

        k = F.tanh(k)
        beta = F.softplus(beta)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=-1)
        gamma = 1 + F.softplus(gamma)
        a = F.tanh(a)
        e = F.sigmoid(e)

        w = self.address(k, beta, g, s, gamma, memory, self.w_last[-1])
        self.w_last.append(w)
        mem = self.write(memory, w, e, a)
        return mem, w
