# Memory Augmented Neural Network: Simple Illustration

As we showcased above that a Neural Turing Machine's controller is capable of using content-based addressing, location-based addressing or both. Whereas, here MANN works on using a pure content-based memory writer. 

MANN also use a new addressing schema called least recently used access. The idea behind the scene is that the least recently used memory location is determined by the read operation and the read operation is performed by content-based addressing. So, we basically perform content-based addressing for reading and write to the location that was least recently used.




[picture credits: MANN Paper(https://arxiv.org/pdf/1605.06065.pdf)]


<img src="Images/mann.png" width="1500"/>

In this tutorial, we will do following things step by step:
1. Implement Read Operation
2. Implement Write Operation

##### Step1: Lets first import all libraries needed.

In [5]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import copy

##### Step 2: Implement Memory Module Similar to NTM Method but with some changes 


<img src="Images/MANN_read.png" width="500"/>


In [6]:
class Memory(nn.Module):
    def __init__(self, M, N, controller_out):
        super(Memory, self).__init__()
        self.N = N
        self.M = M
        self.read_lengths = self.N + 1 + 1 + 3 + 1
        self.write_lengths = self.N + 1 + 1 + 3 + 1 + self.N + self.N
        self.w_last = [] # define to keep track of weight_vector at each time step.
        self.reset_memory()

    def get_weights(self):
        return self.w_last

    def reset_memory(self):
        # resets the memory for both read and write operations at start of new sequence (new input)
        self.w_last = []
        self.w_last.append(torch.zeros([1, self.M], dtype=torch.float32))

    def address(self, k, beta, g, s, gamma, memory, w_last):
        # Content focus
        w_r = self._similarity(k, beta, memory)
        return w_r

    # Implementing Similarity
    def _similarity(self, k, beta, memory):
        w = F.cosine_similarity(memory, k, -1, 1e-16) 
        w = F.softmax(w, dim=-1)
        return w # return w_r^t for reading purpose


#### Step 2: Implementing Read Operation
Here, We will define read heads which access memory and updates memory according to read operations we discuss in chapter above.


In [7]:
class ReadHead(Memory):

    def __init__(self, M, N, controller_out):
        super(ReadHead, self).__init__(M, N, controller_out)
        self.fc_read = nn.Linear(controller_out, self.read_lengths)
        global w_read
        self.intialize_parameters();

    def intialize_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)
        nn.init.normal_(self.fc_read.bias, std=0.01)

    def read(self, memory, w):
        # Calculate Memory Update
        return torch.matmul(w, memory)

    def forward(self, x, memory):
        param = self.fc_read(x) # gather parameters
        # initialize necessary parameters k, beta, g, shift, and gamma
        k, g, s, gamma = torch.split(param, [self.N, 1, 1, 3, 1], dim=1)
        k = torch.tanh(k)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=1)
        gamma = 1 + F.softplus(gamma)
        # obtain current weight address vectors from Memory
        w_r = self.address(k, g, s, gamma, memory, self.w_last[-1])
        # append in w_last function to keep track content based locations
        self.w_last.append(w_r)
        # obtain current mem location based on above equations
        mem = self.read(memory, w_r)
        w_read = copy.deepcopy(w_r)
        return mem, w_r

##### Step 5: Implement Write Operation
Similar to Read Operation, here we will implement write operation.

Note: Both read and write heads use fully connected layer to produce paremeters (k, beta, g, s, gamma) for content addressing. 

<img src="Images/MANN_write.png" width="500"/>

In [8]:
class WriteHead(Memory):

    def __init__(self, M, N, controller_out):
        super(WriteHead, self).__init__(M, N, controller_out)
        self.fc_write = nn.Linear(controller_out, self.write_lengths)
        global w_write
        global prev_w_u
        self.intialize_parameters()

    def intialize_parameters(self):
        # Initialize the linear layers
        nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)
        nn.init.normal_(self.fc_write.bias, std=0.01)
        prev_w_u = torch.FloatTensor(np.zeros((batch_size,memory_size)))
        w_write = torch.FloatTensor(np.zeros((1,memory_size)))
    
    def usage_weight_vector(self, prev_w_u, w_read, w_write, gamma):
        # usage weight vector, Equation (F2)
        w_u = gamma * prev_w_u + torch.sum(w_read, dim=1) + torch.sum(w_write, dim=1)
        return w_u 
    
    def least_used(self, w_u, memory_size=3, n_reads=4):
        # calculate the least used entries
        _, indices = torch.topk(-1*w_u,k=n_reads)
        wlu_t = torch.sum(F.one_hot(indices, memory_size).type(torch.FloatTensor),dim=1,keepdim=True)
        return indices, wlu_t
    
    def mann_write(self, memory, w_write, a, gamma, prev_w_u, w_read, k):
        # obtain the current usage weight vector
        w_u = self.usage_weight_vector(prev_w_u, w_read, w_write, gamma)
        # Calculate the least used usage weight vector
        w_least_used_weight_t = self.least_used(w_u)
        # Implement write step as per (F3) Equation
        w_write = torch.sigmoid(a)*w_read + (1-torch.sigmoid(a))*w_least_used_weight_t
        # Memory update as per Equation (F4)
        memory_update = memory + w_write*k
    
        
    def forward(self, x, memory):
        param = self.fc_write(x) # gather parameters
         # initialize necessary parameters k, beta, g, shift, and gamma
        k, beta, g, s, gamma, a, e = torch.split(param, [self.N, 1, 1, 3, 1, self.N, self.N], dim=1)
        k = F.tanh(k)
        beta = F.softplus(beta)
        g = F.sigmoid(g)
        s = F.softmax(s, dim=-1)
        gamma = 1 + F.softplus(gamma)
        a = F.tanh(a)
        # obtain current weight address vectors from Memory
        w_write = self.address(k, beta, g, s, gamma, memory, self.w_last[-1])
        # append in w_last function to keep track content based locations
        self.w_last.append(w_write)
        # obtain current mem location based on F2-F4 equations
        mem = self.write(memory, w_write, a, gamma, prev_w_u, w_read, k)
        w_write = copy.deepcopy(w)
        return mem, w
