### Comparison of RNN Approaches
(see https://distill.pub/2016/augmented-rnns/)

This creates a sample task, based on number sequences, and
different goals to compare the performance of vanilla RNN's,
LSTM's, NTM's, and attention based networks.

In [1]:
import torch

In [2]:
from torch import nn, einsum
from random import randint
from torch.nn import functional as f

In [3]:
from tasks import Numbers

from training import train_model

from models import BasicRNN
from models import BasicLSTM
from models import NTM_LSTM

from models.MemoryNN import MemoryNN

In [4]:
from importlib import reload

## Problem Setup

In [5]:
# Number Generator + Goal Setup
max_number = 9

number_tool = Numbers(
    max_number,
    increment_func = lambda incr: randint(0,9), # lambda x: randint(0,max_number),
    goal_func = lambda stream: stream[0:3] # lambda stream: stream[0]
)

In [6]:
# Training Size
train_size = 100e3
criterion = nn.NLLLoss()

In [7]:
setup_kwargs = {
    'train_size': train_size,
    'number_tool': number_tool,
    'criterion': criterion,
    'stream_size': 200,
    'model': None,
    'optim': None
}

In [8]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')

## Basic Vanilla RNN Model

In [None]:
rnn = BasicRNN(number_tool.get_dim(), 5, number_tool.get_dim(), output_length = 3)
rnn.to(device)

rnn_optim = torch.optim.SGD(rnn.parameters(), lr = 0.001, momentum = 0.9)

rnn_kwargs = setup_kwargs.copy()
rnn_kwargs['model'] = rnn
rnn_kwargs['optim'] = rnn_optim

In [None]:
train_model(**rnn_kwargs)

## Basic LSTM Model

In [None]:
lstm = BasicLSTM(number_tool.get_dim(), 5, number_tool.get_dim(), output_length = 3)
lstm.to(device)

lstm_optim = torch.optim.SGD(lstm.parameters(), lr = 0.001, momentum = 0.9)

lstm_kwargs = setup_kwargs.copy()
lstm_kwargs['model'] = lstm
lstm_kwargs['optim'] = lstm_optim

In [None]:
train_model(**lstm_kwargs)

## Neural Turing Machine Model

In [10]:
ntm = NTM_LSTM(number_tool.get_dim(), 5, number_tool.get_dim(),
               10, 10, output_length = 3)

In [None]:
memory_banks = 10
memory_dim = 20

ntm = NTM_LSTM(number_tool.get_dim(), 5, number_tool.get_dim(),
               memory_banks, memory_dim, output_length = 3)
ntm.to(device)

ntm_optim = torch.optim.SGD(ntm.parameters(), lr = 0.001, momentum = 0.9)

ntm_kwargs = setup_kwargs.copy()
ntm_kwargs['model'] = ntm
ntm_kwargs['optim'] = ntm_optim

In [None]:
train_model(**ntm_kwargs)

### ----------

### ----------

### ----------

-----
## Functional Testing
-----

### Testing _convolve

In [None]:
# Testing for _convolve
s = torch.tensor([1.,0.,0.])
x = torch.tensor([0.40, .05, .50, 0, 0, .05])

print(x)
_convolve(x,s)

### Testing Sequence Loss

In [None]:
# Testing for criterion = nn.NLLLoss()

n = Numbers(increment_func = lambda incr: randint(0,9), goal_func = lambda stream: stream[0:3])
stream = n.create_stream(30)
enc_stream = n.encode_stream(stream)

goal = n.get_stream_goal(stream)
target = torch.LongTensor(goal)

l = BasicRNN(n.get_dim(), 5, n.get_dim(), 3)
hid = l.init_hidden()

the_out, the_hid = l.forward(enc_stream, hid)

the_out[0,0,0] = -2#5
the_out[1,0,5] = -2#5
the_out[2,0,6] = -2#5

loss = criterion(the_out.squeeze(1), target)

### Memory Testing

In [None]:
# READ_MEMORY TESTS
mem = Memory(10,4)

flat = torch.ones(10)
just_4 = torch.zeros(10)
just_4[4] = 1

act = mem.read(just_4)
exp = mem.memory[4]
print(f'[RM TEST1] act:{act} exp:{exp}')

act = mem.read(flat)
exp = mem.memory.mean(dim=0)
print(f'[RM TEST2] act:{act} exp:{exp}')

# WRITE_MEMORY TESTS
mem_dim = 3

att_flat = torch.ones(10)
att_just_4 = torch.zeros(10)
att_just_4[4] = 1

forget_ones = torch.ones(mem_dim)
forget_half = torch.ones(mem_dim) / 2
forget_zeros = torch.zeros(mem_dim)

add_zero = torch.zeros(mem_dim)
add_ones = torch.ones(mem_dim)

mem = Memory(10, mem_dim)
exp = mem.memory[4]
act = mem.write(att_just_4, forget_zeros, add_zero)[4]
print(f'[WM TEST1] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory[4] + 1
act = mem.write(att_just_4, forget_zeros, add_ones)[4]
print(f'[WM TEST2] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = torch.zeros_like(mem.memory[4])
act = mem.write(att_just_4, forget_ones, add_zero)[4]
print(f'[WM TEST3] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10)
act = mem.write(att_flat, forget_half, add_zero)
print(f'[WM TEST4] act:{act.view(-1)} exp:{exp.view(-1)}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10) + 0.1
act = mem.write(att_flat, forget_half, add_ones)
print(f'[WM TEST5] act:{act.view(-1)} exp:{exp.view(-1)}')

### Buffer Registration Test

In [None]:
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.y1 = torch.ones(5,5)
        self.register_buffer('y2', torch.ones(5,5))
        
class Base(nn.Module):
    def __init__(self):
        super(Base, self).__init__()
        self.x1 = torch.ones(5,5)
        self.register_buffer('x2', torch.ones(5,5))
        self.add_module('test', Test())

In [None]:
b = Base()
bc = b.to(device)

print(bc.x1.device)
print(bc.x2.device)
print(bc.test.y1.device)
print(bc.test.y2.device)