### Comparison of RNN Approaches
(see https://distill.pub/2016/augmented-rnns/)

This creates a sample task, based on number sequences, and
different goals to compare the performance of vanilla RNN's,
LSTM's, NTM's, and attention based networks.

In [68]:
import torch

In [114]:
from torch import nn, einsum
from random import randint
from torch.nn import functional as f

In [70]:
class Numbers:
    ''' This is a controller 
    
    '''
    def __init__(self, max_number = 9,
                 increment_func = lambda stream: stream[-1]+1,
                 reset_func = None,
                 reset_value_func = lambda stream: 0,
                 goal_func = lambda stream: stream[0]):
        ''' Initialize the controller
        Args:
            increment_func: called on each number in the stream to decide
                the next number in the stream
            reset_func: evalutes the stream to determine if a reset
                condition has been met causing a restart on a new number
            reset_value_func: if a reset is triggered, this is called to
                determine the new value to restart the stream with
            goal_func: a function that returns the target output value
                for the stream (i.e. the goal output)
        '''
        self.max_number = max_number
        self.all_numbers = list(range(max_number+1))
        self.all_classes = self.all_numbers + [-1]
        
        self.increment_func = increment_func
        self.reset_func = reset_func
        self.reset_value_func = reset_value_func
        self.goal_func = goal_func
        
        
    def create_stream(self, length, seed_stream = [0]):
        ''' Create a stream (sequence) of numbers
        
        Args:
            length: length of the stream to generate
            seed_stream: seed stream to use in generating the first number
                
        Returns:
            A sequence (list) of numbers in the stream
        '''
        if self.reset_func == None:
            self.reset_func = lambda stream: stream[-1] == self.max_number
    
        start_num = self.reset_value_func(seed_stream)
        stream = [start_num]
        
        for i in range(max(0,length-1)):
            if self.reset_func(stream):
                new_num = self.reset_value_func(stream)
            else:
                new_num = self.increment_func(stream)
            stream.append(new_num)
            
        return stream
    
    
    def encode_stream(self, num_stream):
        ''' Converts number or number_stream into one-hot encoded tensor
        An additional category is added for OTHER (unrecognized numbers)
        
        Args:
            num_stream: A single number of sequence (list) of numbers
                that will be encoded
            
        Returns:
            Encoded tensor of shape (number_index, _, one_hot_encoded_number)
        '''
        num_stream = num_stream if type(num_stream) is list else [num_stream]
        output = torch.zeros(len(num_stream), 1, len(self.all_numbers)+1)
        
        for i, num in enumerate(num_stream):
            value = num if num in self.all_numbers else -1
            ind = self.all_classes.index(value)
            output[i,0,ind] = 1
            
        return output
    
    def get_stream_goal(self, stream):
        ''' Return the goal output for the providede stream 
        Args:
            stream: a list of numbers in the sequence
        
        Returns:
            Simple evaluation of the goal function (type depends
            on the goal function provided)
        '''
        return self.goal_func(stream)
    
    
    def decode_row(self, encoded_number):
        ''' Decode a one-hot encoded vector back to the number '''
        ind = encoded_number.argmax()
        value = self.all_numbers[ind] if ind in self.all_numbers else -1
        return value
    
        
    def get_dim(self):
        ''' Return dimension of each encoded tensor '''
        return len(self.all_classes)

In [72]:
def train_model(model, number_tool, criterion, optim, train_size = 100e3, stream_size = 200, print_interval = 1e3):
    ''' Runs a full training pass for a given model '''
    train_size = int(train_size)
    print_interval = int(print_interval)
    error_sum = 0
    model_device = next(model.parameters()).device

    for i in range(train_size):
        # Generate a new random sequence for training
        stream = number_tool.create_stream(stream_size)
        obs = number_tool.encode_stream(stream).to(model_device)

        # Reset the model gradients and hidden layer
        model.zero_grad()
        hidden = model.init_hidden()
        
        # Score the model
        output, hidden = model.forward(obs, hidden)

        # Determine the target output and calculate the loss
        goal = number_tool.get_stream_goal(stream)
        target = torch.LongTensor([goal]).to(model_device)
        loss = criterion(output.squeeze(0), target)
        
        # Run backprop with the errors and update model/trackers
        loss.backward()
        optim.step()
        error_sum += loss

        if i % print_interval == 0:
            print(f'[{i}] Error: {error_sum / print_interval}')
            error_sum = 0

## Model Definitions

In [71]:
class BasicRNN(nn.Module):
    ''' Basic RNN with a softmax output layer '''
    
    def __init__(self, input_size, hidden_size, output_size):
        ''' Init the RNN '''
        super(BasicRNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.input_to_output = nn.Linear(input_size + hidden_size, output_size)
        self.input_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        ''' Run a forward path for a single number of sequence of numbers
        
        Returns:
            The output and hidden layer at the termination of the sequence
        '''
        for i in range(input.shape[0]):
            combined_input = torch.cat((input[i].unsqueeze(0), hidden), dim=2)
            hidden = self.input_to_hidden(combined_input)
            
        output = self.input_to_output(combined_input)
        output = self.softmax(output)
        
        return output, hidden
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        model_device = next(self.parameters()).device
        return torch.zeros(1, 1, self.hidden_size).to(model_device)

In [14]:
class BasicLSTM(nn.Module):
    ''' Basic single-LSTM with output calculated using a single
        linear layer with a softmax activation '''
    
    def __init__(self, input_size, hidden_size, output_size):
        ''' Init the LSTM '''
        super(BasicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        ''' Run a forward path for a single number of sequence of numbers
        
        Returns:
            The output and hidden layer at the termination of the sequence
        '''
        # Hidden is made up of (hidden, cell_state)
        history, hidden = self.lstm.forward(input, hidden)
        output = self.hidden_to_output(hidden[0])
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        model_device = next(self.parameters()).device
        return (torch.randn(1, 1, self.hidden_size).to(model_device),
                torch.randn(1, 1, self.hidden_size).to(model_device))

In [157]:
class Memory:
    ''' Implements an attention-based memory system as outlined in
        the Neural Turing Machine paper '''
    
    def __init__(self, memory_cells, memory_dim):
        ''' Initialize the memory store '''
        self.memory = torch.randn(memory_cells, memory_dim)
        
    def read_memory(self, weights):
        ''' Reads stored memory based on the attention
        
        Args:
            weights: weights of size (memory_cell) indicating how
                much attention to put on each cell
            
        Returns:
            Weighted output from memiory of size (memory_dim)
                based on the provided attention weights
        '''
        attention = f.normalize(weights, p=1, dim=0)
        read = einsum('ij,i->j', (self.memory, attention))
        return read
    
    def write_memory(self, weights, erase, add):
        ''' Updates attented memory with the erase and add
        Args:
            weights: weights of size (memory_cell) indicating how
                much attention to put on each cell
            erase: weights of size (memory_dim) in [0,1] indicating
                how much desired forgeting happens for each position
            add: weights of size (memory_dim) of how much to add to
                each memory position
        '''
        attention = f.normalize(weights, p=1, dim=0)
        forget = einsum('i,j->ij', (attention, erase.clamp(0,1)))
        decay = torch.ones_like(forget) - forget
        
        remember = einsum('i,j->ij', (attention, add))
        
        self.memory = self.memory * decay + remember
        return self.memory

In [None]:
class NTM_LSTM(nn.Module):
    '''  '''
    
    def __init__(self, input_size, hidden_size, output_size, memory_cells, memory_dim):
        ''' Init the NTM-LSTM '''
        super(NTM_LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        ''' Run a forward path for a single number of sequence of numbers
        
        Returns:
            The output and hidden layer at the termination of the sequence
        '''
        # Hidden is made up of (hidden, cell_state)
        history, hidden = self.lstm.forward(input, hidden)
        output = self.hidden_to_output(hidden[0])
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        model_device = next(self.parameters()).device
        return (torch.randn(1, 1, self.hidden_size).to(model_device),
                torch.randn(1, 1, self.hidden_size).to(model_device))

## Problem Setup

In [12]:
# Number Generator + Goal Setup
max_number = 9
goal_func = lambda stream: stream[0]
goal_dim = max_number+2

number_tool = Numbers(max_number,
                      reset_value_func = lambda x: randint(0,max_number),
                      goal_func = goal_func)

In [None]:
# Training Size
train_size = 100e3
criterion = nn.NLLLoss()

In [13]:
setup_kwargs = {
    'train_size': train_size,
    'number_tool': number_tool,
    'criterion': criterion,
    'stream_size': 200,
    'model': None,
    'optim': None
}

In [73]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')

## Basic Vanilla RNN Testing

In [None]:
rnn = BasicRNN(number_tool.get_dim(), 5, goal_dim)
rnn_optim = torch.optim.SGD(rnn.parameters(), lr = 0.001)

rnn_kwargs = setup_kwargs.copy()
rnn_kwargs['model'] = rnn
rnn_kwargs['optim'] = rnn_optim

In [None]:
train_model(**rnn_setup_kwargs)

## Basic LSTM Model

In [15]:
lstm = BasicLSTM(number_tool.get_dim(), 5, goal_dim)
lstm.to(device)

lstm_optim = torch.optim.SGD(lstm.parameters(), lr = 0.001, momentum = 0.9)

lstm_kwargs = setup_kwargs.copy()
lstm_kwargs['model'] = lstm
lstm_kwargs['optim'] = lstm_optim

In [None]:
train_model(**lstm_setup_kwargs)

### ----------

-----
## Functional Testing
-----

In [168]:
# READ_MEMORY TESTS
mem = Memory(10,4)

flat = torch.ones(10)
just_4 = torch.zeros(10)
just_4[4] = 1

act = mem.read_memory(just_4)
exp = mem.memory[4]
print(f'[RM TEST1] act:{act} exp:{exp}')

act = mem.read_memory(flat)
exp = mem.memory.mean(dim=0)
print(f'[RM TEST2] act:{act} exp:{exp}')

# WRITE_MEMORY TESTS
mem_dim = 3

att_flat = torch.ones(10)
att_just_4 = torch.zeros(10)
att_just_4[4] = 1

forget_ones = torch.ones(mem_dim)
forget_half = torch.ones(mem_dim) / 2
forget_zeros = torch.zeros(mem_dim)

add_zero = torch.zeros(mem_dim)
add_ones = torch.ones(mem_dim)

mem = Memory(10, mem_dim)
exp = mem.memory[4]
act = mem.write_memory(att_just_4, forget_zeros, add_zero)[4]
print(f'[WM TEST1] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory[4] + 1
act = mem.write_memory(att_just_4, forget_zeros, add_ones)[4]
print(f'[WM TEST2] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = torch.zeros_like(mem.memory[4])
act = mem.write_memory(att_just_4, forget_ones, add_zero)[4]
print(f'[WM TEST3] act:{act} exp:{exp}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10)
act = mem.write_memory(att_flat, forget_half, add_zero)
print(f'[WM TEST4] act:{act.view(-1)} exp:{exp.view(-1)}')

mem = Memory(10, mem_dim)
exp = mem.memory * (1 - 0.5 / 10) + 0.1
act = mem.write_memory(att_flat, forget_half, add_ones)
print(f'[WM TEST5] act:{act.view(-1)} exp:{exp.view(-1)}')

[RM TEST1] act:tensor([ 0.5261, -1.5163, -0.6044,  0.5245]) exp:tensor([ 0.5261, -1.5163, -0.6044,  0.5245])
[RM TEST2] act:tensor([ 0.1845, -0.4467,  0.1367, -0.5949]) exp:tensor([ 0.1845, -0.4467,  0.1367, -0.5949])
[WM TEST1] act:tensor([-0.9003,  0.9227, -0.7564]) exp:tensor([-0.9003,  0.9227, -0.7564])
[WM TEST2] act:tensor([1.1107, 3.1953, 1.5835]) exp:tensor([1.1107, 3.1953, 1.5835])
[WM TEST3] act:tensor([0., 0., 0.]) exp:tensor([0., 0., 0.])
[WM TEST4] act:tensor([ 1.8990, -0.5792,  0.9731,  0.5017, -0.8982,  0.1428,  1.1492, -0.0978,
         0.2197,  1.9246,  1.2118, -1.0656,  1.6490, -0.5932, -0.5040,  1.5088,
         0.0564,  2.4903, -0.4258, -0.0668, -1.1677, -0.6731, -1.2875,  1.8346,
        -1.0690,  0.2515,  1.2229, -0.3647,  1.2093, -1.9404]) exp:tensor([ 1.8990, -0.5792,  0.9731,  0.5017, -0.8982,  0.1428,  1.1492, -0.0978,
         0.2197,  1.9246,  1.2118, -1.0656,  1.6490, -0.5932, -0.5040,  1.5088,
         0.0564,  2.4903, -0.4258, -0.0668, -1.1677, -0.6731, -