### Comparison of RNN Approaches
(see https://distill.pub/2016/augmented-rnns/)

This creates a sample task, based on number sequences, and
different goals to compare the performance of vanilla RNN's,
LSTM's, NTM's, and attention based networks.

In [1]:
import torch

In [2]:
from torch import nn, einsum
from random import randint
from torch.nn import functional as f

In [3]:
from tasks.Numbers import Numbers
from models import BasicRNN
from models import BasicLSTM

from models.MemoryNN import MemoryNN

from training import train_model

In [4]:
from importlib import reload

## Model Definitions

In [5]:
class NTM_LSTM(nn.Module):
    '''  '''
    
    def __init__(self, input_size, hidden_size, output_size,
                 memory_banks, memory_dim, output_length = 1):
        ''' Init the NTM-LSTM '''
        super(NTM_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.memory_banks = memory_banks
        self.memory_dim = memory_dim
        self.output_length = output_length
        
        # Core LSTM and Memory
        self.lstm = nn.LSTM(input_size + memory_dim, hidden_size)
        self.add_module('mem_nn', MemoryNN(hidden_size, memory_banks, memory_dim))
        #self.mem_nn = MemoryNN(hidden_size, memory_banks, memory_dim)
        
        # Output Layer
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=2)
        
    
    def forward(self, input, state):
        ''' Run a forward path for a single number of sequence of numbers
        
        Args:
            state: contains [hidden, cell, memory, last_weight]
        
        Returns:
            output and hidden layer after last sequence input
        '''
        output = []
        hidden, cell, memory, weight = state
        
        for i in range(input.shape[0]):
            ntm_input = torch.cat((input[i].unsqueeze(0), memory.view(1,1,-1)), dim=2)
        
            _, (hidden, cell) = self.lstm.forward(ntm_input, (hidden, cell))
            memory, weight = self.mem_nn.forward(hidden, weight)
        
        output.append(self.softmax(self.hidden_to_output(hidden)))
        
        for j in range(1, self.output_length):
            ntm_input = torch.cat((output[-1], memory.view(1,1,-1)), dim=2)
        
            _, (hidden, cell) = self.lstm.forward(ntm_input, (hidden, cell))
            memory, weight = self.mem_nn.forward(hidden, weight)
            output.append(self.softmax(self.hidden_to_output(hidden)))
            
        output = torch.cat(output, dim=0)
        return output, (hidden, cell, memory, weight)
    
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        memory, weight = self.mem_nn.reset_memory()
        
        model_device = next(self.parameters()).device

        return (
            torch.randn(1, 1, self.hidden_size).to(model_device),
            torch.randn(1, 1, self.hidden_size).to(model_device),
            memory.to(model_device),
            weight.to(model_device)
        )

## Problem Setup

In [6]:
# Number Generator + Goal Setup
max_number = 9
goal_func = lambda stream: stream[0]

'''number_tool = Numbers(
    max_number,
    reset_value_func = lambda x: randint(0,max_number),
    goal_func = goal_func
)'''

number_tool = Numbers(
    max_number,
    increment_func = lambda incr: randint(0,9),
    goal_func = lambda stream: stream[0:3]
)

In [7]:
# Training Size
train_size = 100e3
criterion = nn.NLLLoss()

In [8]:
setup_kwargs = {
    'train_size': train_size,
    'number_tool': number_tool,
    'criterion': criterion,
    'stream_size': 200,
    'model': None,
    'optim': None
}

In [9]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')

## Basic Vanilla RNN Model

In [12]:
rnn = BasicRNN(number_tool.get_dim(), 5, number_tool.get_dim(), output_length = 3)
rnn.to(device)

rnn_optim = torch.optim.SGD(rnn.parameters(), lr = 0.001)

rnn_kwargs = setup_kwargs.copy()
rnn_kwargs['model'] = rnn
rnn_kwargs['optim'] = rnn_optim

In [13]:
train_model(**rnn_kwargs)

[0] Error: tensor(0.0031, device='cuda:0', grad_fn=<DivBackward0>)


KeyboardInterrupt: 

## Basic LSTM Model

In [16]:
lstm = BasicLSTM(number_tool.get_dim(), 5, number_tool.get_dim(), output_length = 3)
lstm.to(device)

lstm_optim = torch.optim.SGD(lstm.parameters(), lr = 0.001, momentum = 0.9)

lstm_kwargs = setup_kwargs.copy()
lstm_kwargs['model'] = lstm
lstm_kwargs['optim'] = lstm_optim

In [17]:
train_model(**lstm_kwargs)

[0] Error: tensor(0.0026, device='cuda:0', grad_fn=<DivBackward0>)


KeyboardInterrupt: 

## Neural Turing Machine Model

In [10]:
memory_banks = 10
memory_dim = 20

ntm = NTM_LSTM(number_tool.get_dim(), 5, number_tool.get_dim(),
               memory_banks, memory_dim, output_length = 3)
ntm.to(device)

ntm_optim = torch.optim.SGD(ntm.parameters(), lr = 0.001, momentum = 0.9)

ntm_kwargs = setup_kwargs.copy()
ntm_kwargs['model'] = ntm
ntm_kwargs['optim'] = ntm_optim

In [12]:
%debug
train_model(**ntm_kwargs)

> [0;32m/opt/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py[0m(2217)[0;36mcosine_similarity[0;34m()[0m
[0;32m   2215 [0;31m        [0;34m>>[0m[0;34m>[0m [0mprint[0m[0;34m([0m[0moutput[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m   2216 [0;31m    """
[0m[0;32m-> 2217 [0;31m    [0mw12[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mx1[0m [0;34m*[0m [0mx2[0m[0;34m,[0m [0mdim[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m   2218 [0;31m    [0mw1[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx1[0m[0;34m,[0m [0;36m2[0m[0;34m,[0m [0mdim[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m   2219 [0;31m    [0mw2[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx2[0m[0;34m,[0m [0;36m2[0m[0;34m,[0m [0mdim[0m[0;34m)[0m[0;34m[0m[0m
[0m


ipdb>  up


> [0;32m/home/jupyter/rnn_exploration/models/Memory.py[0m(60)[0;36mget_content_similarity[0;34m()[0m
[0;32m     58 [0;31m            [0mkey_strength[0m[0;34m:[0m [0mscalar[0m [0mto[0m [0mamplify[0m[0;34m/[0m[0mattenuate[0m [0mattention[0m[0;34m[0m[0m
[0m[0;32m     59 [0;31m        '''
[0m[0;32m---> 60 [0;31m        [0msimilarity[0m [0;34m=[0m [0mf[0m[0;34m.[0m[0mcosine_similarity[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmemory[0m[0;34m,[0m [0mquery[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;36m1[0m[0;34m,[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     61 [0;31m        [0mcontent_attention[0m [0;34m=[0m [0mf[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mkey_strength[0m [0;34m*[0m [0msimilarity[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m        [0;32mreturn[0m [0mcontent_attention[0

ipdb>  self._buffer


*** AttributeError: 'Memory' object has no attribute '_buffer'


ipdb>  self._buffers


OrderedDict([('memory', tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    

ipdb>  self


Memory()


ipdb>  self.buffers()


*** AttributeError: 'Memory' object has no attribute 'buffers'


ipdb>  up


> [0;32m/home/jupyter/rnn_exploration/models/MemoryNN.py[0m(63)[0;36mcalculate_memory_weight[0;34m()[0m
[0;32m     61 [0;31m            [0mγ[0m[0;34m:[0m [0msharpening[0m [0mexponent[0m [0;32mfor[0m [0mfinal[0m [0mweights[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m        '''
[0m[0;32m---> 63 [0;31m        [0mw_c[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmemory[0m[0;34m.[0m[0mget_content_similarity[0m[0;34m([0m[0mk[0m[0;34m,[0m [0mβ[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m        [0mw_g[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_interpolate[0m[0;34m([0m[0mw_c[0m[0;34m,[0m [0mw_old[0m[0;34m,[0m [0mg[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        [0mw_tilde[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_convolve[0m[0;34m([0m[0mw_g[0m[0;34m,[0m [0ms[0m[0;34m)[0m[0;34m[0m[0m
[0m


ipdb>  self


MemoryNN(
  (memory): Memory()
  (sigmoid): Sigmoid()
  (softmax): Softmax()
  (hidden_to_dim): Linear(in_features=5, out_features=65, bias=True)
  (hidden_to_gate): Linear(in_features=5, out_features=10, bias=True)
)


ipdb>  q


RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #2 'other'

### ----------

-----
## Functional Testing
-----

### Testing _convolve

In [None]:
# Testing for _convolve
s = torch.tensor([1.,0.,0.])
x = torch.tensor([0.40, .05, .50, 0, 0, .05])

print(x)
_convolve(x,s)

### Testing Sequence Loss

In [185]:
# Testing for criterion = nn.NLLLoss()

n = Numbers(increment_func = lambda incr: randint(0,9), goal_func = lambda stream: stream[0:3])
stream = n.create_stream(30)
enc_stream = n.encode_stream(stream)

goal = n.get_stream_goal(stream)
target = torch.LongTensor(goal)

l = BasicRNN(n.get_dim(), 5, n.get_dim(), 3)
hid = l.init_hidden()

the_out, the_hid = l.forward(enc_stream, hid)

the_out[0,0,0] = -2#5
the_out[1,0,5] = -2#5
the_out[2,0,6] = -2#5

loss = criterion(the_out.squeeze(1), target)

[0, 9, 0, 2, 1, 4, 2, 5, 8, 0, 5, 8, 3, 8, 7, 2, 4, 9, 0, 6, 3, 4, 0, 7, 7, 4, 6, 0, 7, 4, 'END']
[0, 9, 0]


### Memory Testing

In [168]:
# READ_MEMORY TESTS
mem = Memory(10,4)

flat = torch.ones(10)
just_4 = torch.zeros(10)
just_4[4] = 1

act = mem.read(just_4)
exp = mem.memory[4]
print(f'[RM TEST1] act:{act} exp:{exp}')

act = mem.read(flat)
exp = mem.memory.mean(dim=0)
print(f'[RM TEST2] act:{act} exp:{exp}')

# WRITE_MEMORY TESTS
mem_dim = 3

att_flat = torch.ones(10)
att_just_4 = torch.zeros(10)
att_just_4[4] = 1

forget_ones = torch.ones(mem_dim)
forget_half = torch.ones(mem_dim) / 2
forget_zeros = torch.zeros(mem_dim)

add_zero = torch.zeros(mem_dim)
add_ones = torch.ones(mem_dim)

mem = Memory(10, mem_dim)
exp = mem.memory[4]
act = mem.write(att_just_4, forget_zeros, add_zero)[4]
print(f'[WM TEST1] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory[4] + 1
act = mem.write(att_just_4, forget_zeros, add_ones)[4]
print(f'[WM TEST2] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = torch.zeros_like(mem.memory[4])
act = mem.write(att_just_4, forget_ones, add_zero)[4]
print(f'[WM TEST3] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10)
act = mem.write(att_flat, forget_half, add_zero)
print(f'[WM TEST4] act:{act.view(-1)} exp:{exp.view(-1)}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10) + 0.1
act = mem.write(att_flat, forget_half, add_ones)
print(f'[WM TEST5] act:{act.view(-1)} exp:{exp.view(-1)}')

[RM TEST1] act:tensor([ 0.5261, -1.5163, -0.6044,  0.5245]) exp:tensor([ 0.5261, -1.5163, -0.6044,  0.5245])
[RM TEST2] act:tensor([ 0.1845, -0.4467,  0.1367, -0.5949]) exp:tensor([ 0.1845, -0.4467,  0.1367, -0.5949])
[WM TEST1] act:tensor([-0.9003,  0.9227, -0.7564]) exp:tensor([-0.9003,  0.9227, -0.7564])
[WM TEST2] act:tensor([1.1107, 3.1953, 1.5835]) exp:tensor([1.1107, 3.1953, 1.5835])
[WM TEST3] act:tensor([0., 0., 0.]) exp:tensor([0., 0., 0.])
[WM TEST4] act:tensor([ 1.8990, -0.5792,  0.9731,  0.5017, -0.8982,  0.1428,  1.1492, -0.0978,
         0.2197,  1.9246,  1.2118, -1.0656,  1.6490, -0.5932, -0.5040,  1.5088,
         0.0564,  2.4903, -0.4258, -0.0668, -1.1677, -0.6731, -1.2875,  1.8346,
        -1.0690,  0.2515,  1.2229, -0.3647,  1.2093, -1.9404]) exp:tensor([ 1.8990, -0.5792,  0.9731,  0.5017, -0.8982,  0.1428,  1.1492, -0.0978,
         0.2197,  1.9246,  1.2118, -1.0656,  1.6490, -0.5932, -0.5040,  1.5088,
         0.0564,  2.4903, -0.4258, -0.0668, -1.1677, -0.6731, -