### Comparison of RNN Approaches
(see https://distill.pub/2016/augmented-rnns/)

This creates a sample task, based on number sequences, and
different goals to compare the performance of vanilla RNN's,
LSTM's, NTM's, and attention based networks.

In [1]:
import torch

In [3]:
from torch import nn, einsum
from random import randint
from torch.nn import functional as f

In [72]:
from tasks.Numbers import Numbers
import models
from models.BasicRNN import BasicRNN

In [73]:
from importlib import reload
reload(models)

<module 'models' from '/home/jupyter/rnn_exploration/models/__init__.py'>

In [53]:
def train_model(model, number_tool, criterion, optim, train_size = 100e3, stream_size = 200, print_interval = 1e3):
    ''' Runs a full training pass for a given model '''
    
    train_size = int(train_size)
    print_interval = int(print_interval)
    error_sum = 0
    model_device = next(model.parameters()).device

    for i in range(train_size):
        # Generate a new random sequence for training
        stream = number_tool.create_stream(stream_size)
        obs = number_tool.encode_stream(stream).to(model_device)

        # Reset the model gradients and hidden layer
        model.zero_grad()
        hidden = model.init_hidden()
        
        # Score the model
        output, hidden = model.forward(obs, hidden)

        # Determine the target output and calculate the loss
        goal = number_tool.get_stream_goal(stream)
        target = torch.LongTensor(goal).to(model_device)
        loss = criterion(output.squeeze(1), target)

        # Run backprop with the errors and update model/trackers
        loss.backward()
        optim.step()
        error_sum += loss

        if i % print_interval == 0:
            print(f'[{i}] Error: {error_sum / print_interval}')
            error_sum = 0

## Model Definitions

In [55]:
class BasicLSTM(nn.Module):
    ''' Basic single-LSTM with output calculated using a single
        linear layer with a softmax activation '''
    
    def __init__(self, input_size, hidden_size, output_size, output_length = 1):
        ''' Init the LSTM '''
        super(BasicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size
        self.output_length = output_length
        self.output_size = output_size
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        ''' Run a forward path for a single number of sequence of numbers
        
        Hidden is made up of (hidden, cell_state)
        
        Returns:
            output and hidden layer after last sequence input
        '''
        output = []
        
        history, hidden = self.lstm.forward(input, hidden)
        
        output.append(self.softmax(self.hidden_to_output(hidden[0])))
        
        for j in range(1, self.output_length):
            history, hidden = self.lstm.forward(output[-1], hidden)
            output.append(self.softmax(self.hidden_to_output(hidden[0])))
            
        output = torch.cat(output, dim=0)
        return output, hidden
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        model_device = next(self.parameters()).device
        return (
            torch.randn(1, 1, self.hidden_size).to(model_device),
            torch.randn(1, 1, self.hidden_size).to(model_device)
        )

In [56]:
class NTM_LSTM(nn.Module):
    '''  '''
    
    def __init__(self, input_size, hidden_size, output_size, memory_banks, memory_dim):
        ''' Init the NTM-LSTM '''
        super(NTM_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.memory_banks = memory_banks
        self.memory_dim = memory_dim
        
        # Core LSTM and Memory
        self.lstm = nn.LSTM(input_size + memory_dim, hidden_size)
        self.mem_nn = MemoryNN(hidden_size, memory_banks, memory_dim)
        
        # Output Layer
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=2)
        
    
    def forward(self, input, state):
        ''' Run a forward path for a single number of sequence of numbers
        
        Args:
            state: contains [hidden, cell, memory, last_weight]
        
        Returns:
            output and hidden layer after last sequence input
        '''
        
        hidden, cell, memory, weight = state
        
        for i in range(input.shape[0]):
            ntm_input = torch.cat((input[i].unsqueeze(0), memory.view(1,1,-1)), dim=2)
        
            _, (hidden, cell) = self.lstm.forward(ntm_input, (hidden, cell))
            memory, weight = self.mem_nn.forward(hidden, weight)
        
        output = self.hidden_to_output(hidden)
        output = self.softmax(output)
        
        return output, (hidden, cell, memory, weight)
    
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        memory, weight = self.mem_nn.reset_memory()
        
        model_device = next(self.parameters()).device

        return (
            torch.randn(1, 1, self.hidden_size).to(model_device),
            torch.randn(1, 1, self.hidden_size).to(model_device),
            memory.to(model_device),
            weight.to(model_device)
        )

## Problem Setup

In [57]:
# Number Generator + Goal Setup
max_number = 9
goal_func = lambda stream: stream[0]

'''number_tool = Numbers(
    max_number,
    reset_value_func = lambda x: randint(0,max_number),
    goal_func = goal_func
)'''

number_tool = Numbers(
    max_number,
    increment_func = lambda incr: randint(0,9),
    goal_func = lambda stream: stream[0:3]
)

In [58]:
# Training Size
train_size = 100e3
criterion = nn.NLLLoss()

In [59]:
setup_kwargs = {
    'train_size': train_size,
    'number_tool': number_tool,
    'criterion': criterion,
    'stream_size': 200,
    'model': None,
    'optim': None
}

In [60]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')

## Basic Vanilla RNN Model

In [74]:
rnn = BasicRNN(number_tool.get_dim(), 5, number_tool.get_dim(), output_length = 3)
rnn.to(device)

rnn_optim = torch.optim.SGD(rnn.parameters(), lr = 0.001)

rnn_kwargs = setup_kwargs.copy()
rnn_kwargs['model'] = rnn
rnn_kwargs['optim'] = rnn_optim

In [75]:
train_model(**rnn_kwargs)

RuntimeError: Expected a Tensor of type torch.FloatTensor but found a type torch.cuda.FloatTensor for sequence element 1 in sequence argument at position #1 'tensors'

## Basic LSTM Model

In [63]:
lstm = BasicLSTM(number_tool.get_dim(), 5, number_tool.get_dim(), output_length = 3)
lstm.to(device)

lstm_optim = torch.optim.SGD(lstm.parameters(), lr = 0.001, momentum = 0.9)

lstm_kwargs = setup_kwargs.copy()
lstm_kwargs['model'] = lstm
lstm_kwargs['optim'] = lstm_optim

In [64]:
#%debug
train_model(**lstm_kwargs)

[0] Error: 0.0024926080368459225
[1000] Error: 2.0672757625579834
[2000] Error: 1.67084538936615


KeyboardInterrupt: 

In [38]:
x = [torch.zeros(1,1,2), torch.ones(1,1,2)]

In [49]:
torch.cat(x, dim=0).shape

torch.Size([2, 1, 2])

## Neural Turing Machine Model

In [14]:
memory_banks = 10
memory_dim = 20

ntm = NTM_LSTM(number_tool.get_dim(), 5, goal_dim, memory_banks, memory_dim)
ntm_optim = torch.optim.SGD(ntm.parameters(), lr = 0.001)

ntm_kwargs = setup_kwargs.copy()
ntm_kwargs['model'] = ntm
ntm_kwargs['optim'] = ntm_optim

In [None]:
%debug
train_model(**ntm_kwargs)

ERROR:root:No traceback has been produced, nothing to debug.


[0] Error: 0.0026836078613996506


In [90]:
'''
x_pad = torch.cat((x[-1:], x, x[:1]))

s = s.view(1,1,1,3)
x_pad = x_pad.view(1,1,1,8)

out = f.conv2d(x_pad, s, padding=(0,1))

print(s)
print(x)
#print(x_pad)
#print(out)
print(out.view(-1)[1:-1])
'''

'\nx_pad = torch.cat((x[-1:], x, x[:1]))\n\ns = s.view(1,1,1,3)\nx_pad = x_pad.view(1,1,1,8)\n\nout = f.conv2d(x_pad, s, padding=(0,1))\n\nprint(s)\nprint(x)\n#print(x_pad)\n#print(out)\nprint(out.view(-1)[1:-1])\n'

In [91]:
s = torch.tensor([1.,0.,0.])
x = torch.tensor([0.40, .05, .50, 0, 0, .05])

print(x)
_convolve(x,s)

tensor([0.4000, 0.0500, 0.5000, 0.0000, 0.0000, 0.0500])


tensor([0.0500, 0.4000, 0.0500, 0.5000, 0.0000, 0.0000])

### ----------

In [185]:
criterion = nn.NLLLoss()

n = Numbers(increment_func = lambda incr: randint(0,9), goal_func = lambda stream: stream[0:3])
stream = n.create_stream(30)
enc_stream = n.encode_stream(stream)

goal = n.get_stream_goal(stream)
#enc_goal = n.encode_stream(goal)
target = torch.LongTensor(goal)

print(stream)
print(goal)

[0, 9, 0, 2, 1, 4, 2, 5, 8, 0, 5, 8, 3, 8, 7, 2, 4, 9, 0, 6, 3, 4, 0, 7, 7, 4, 6, 0, 7, 4, 'END']
[0, 9, 0]


In [186]:
l = BasicRNN(n.get_dim(), 5, n.get_dim(), 3)
hid = l.init_hidden()

In [187]:
the_out, the_hid = l.forward(enc_stream, hid)

In [188]:
print(the_out.shape)
print(target.shape)

torch.Size([3, 1, 12])
torch.Size([3])


In [189]:
print(the_out.type())#.dtype)
print(target.type())#.dtype)

torch.FloatTensor
torch.LongTensor


In [193]:
the_out[0,0,0] = -2#5
the_out[1,0,5] = -2#5
the_out[2,0,6] = -2#5
print(the_out)

tensor([[[-2.0000, -2.3569, -2.5708, -2.4759, -2.5534, -2.4921, -2.2407,
          -2.8699, -2.4082, -2.2069, -2.3062, -2.7231]],

        [[-2.6584, -2.5391, -2.2816, -2.3469, -2.5293, -2.0000, -2.3893,
          -2.6638, -2.5010, -2.3386, -2.4294, -2.6367]],

        [[-2.6425, -2.5536, -2.2648, -2.3740, -2.5386, -2.5792, -2.0000,
          -2.6355, -2.5383, -2.3667, -2.4502, -2.6071]]],
       grad_fn=<CopySlices>)


In [194]:
# Determine the target output and calculate the loss
loss = criterion(the_out.squeeze(1), target)

-----
## Functional Testing
-----

In [168]:
# READ_MEMORY TESTS
mem = Memory(10,4)

flat = torch.ones(10)
just_4 = torch.zeros(10)
just_4[4] = 1

act = mem.read(just_4)
exp = mem.memory[4]
print(f'[RM TEST1] act:{act} exp:{exp}')

act = mem.read(flat)
exp = mem.memory.mean(dim=0)
print(f'[RM TEST2] act:{act} exp:{exp}')

# WRITE_MEMORY TESTS
mem_dim = 3

att_flat = torch.ones(10)
att_just_4 = torch.zeros(10)
att_just_4[4] = 1

forget_ones = torch.ones(mem_dim)
forget_half = torch.ones(mem_dim) / 2
forget_zeros = torch.zeros(mem_dim)

add_zero = torch.zeros(mem_dim)
add_ones = torch.ones(mem_dim)

mem = Memory(10, mem_dim)
exp = mem.memory[4]
act = mem.write(att_just_4, forget_zeros, add_zero)[4]
print(f'[WM TEST1] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory[4] + 1
act = mem.write(att_just_4, forget_zeros, add_ones)[4]
print(f'[WM TEST2] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = torch.zeros_like(mem.memory[4])
act = mem.write(att_just_4, forget_ones, add_zero)[4]
print(f'[WM TEST3] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10)
act = mem.write(att_flat, forget_half, add_zero)
print(f'[WM TEST4] act:{act.view(-1)} exp:{exp.view(-1)}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10) + 0.1
act = mem.write(att_flat, forget_half, add_ones)
print(f'[WM TEST5] act:{act.view(-1)} exp:{exp.view(-1)}')

[RM TEST1] act:tensor([ 0.5261, -1.5163, -0.6044,  0.5245]) exp:tensor([ 0.5261, -1.5163, -0.6044,  0.5245])
[RM TEST2] act:tensor([ 0.1845, -0.4467,  0.1367, -0.5949]) exp:tensor([ 0.1845, -0.4467,  0.1367, -0.5949])
[WM TEST1] act:tensor([-0.9003,  0.9227, -0.7564]) exp:tensor([-0.9003,  0.9227, -0.7564])
[WM TEST2] act:tensor([1.1107, 3.1953, 1.5835]) exp:tensor([1.1107, 3.1953, 1.5835])
[WM TEST3] act:tensor([0., 0., 0.]) exp:tensor([0., 0., 0.])
[WM TEST4] act:tensor([ 1.8990, -0.5792,  0.9731,  0.5017, -0.8982,  0.1428,  1.1492, -0.0978,
         0.2197,  1.9246,  1.2118, -1.0656,  1.6490, -0.5932, -0.5040,  1.5088,
         0.0564,  2.4903, -0.4258, -0.0668, -1.1677, -0.6731, -1.2875,  1.8346,
        -1.0690,  0.2515,  1.2229, -0.3647,  1.2093, -1.9404]) exp:tensor([ 1.8990, -0.5792,  0.9731,  0.5017, -0.8982,  0.1428,  1.1492, -0.0978,
         0.2197,  1.9246,  1.2118, -1.0656,  1.6490, -0.5932, -0.5040,  1.5088,
         0.0564,  2.4903, -0.4258, -0.0668, -1.1677, -0.6731, -