### Comparison of RNN Approaches
(see https://distill.pub/2016/augmented-rnns/)

This creates a sample task, based on number sequences, and
different goals to compare the performance of vanilla RNN's,
LSTM's, NTM's, and attention based networks.

In [2]:
import torch

In [3]:
from torch import nn
from random import randint

In [6]:
class Numbers:
    ''' This is a controller 
    
    '''
    def __init__(self, max_number = 9,
                 increment_func = lambda stream: stream[-1]+1,
                 reset_func = None,
                 reset_value_func = lambda stream: 0,
                 goal_func = lambda stream: stream[0]):
        ''' Initialize the controller
        Args:
            increment_func: called on each number in the stream to decide
                the next number in the stream
            reset_func: evalutes the stream to determine if a reset
                condition has been met causing a restart on a new number
            reset_value_func: if a reset is triggered, this is called to
                determine the new value to restart the stream with
            goal_func: a function that returns the target output value
                for the stream (i.e. the goal output)
        '''
        self.max_number = max_number
        self.all_numbers = list(range(max_number+1))
        self.all_classes = self.all_numbers + [-1]
        
        self.increment_func = increment_func
        self.reset_func = reset_func
        self.reset_value_func = reset_value_func
        self.goal_func = goal_func
        
        
    def create_stream(self, length, seed_stream = [0]):
        ''' Create a stream (sequence) of numbers
        
        Args:
            length: length of the stream to generate
            seed_stream: seed stream to use in generating the first number
                
        Returns:
            A sequence (list) of numbers in the stream
        '''
        if self.reset_func == None:
            self.reset_func = lambda stream: stream[-1] == self.max_number
    
        start_num = self.reset_value_func(seed_stream)
        stream = [start_num]
        
        for i in range(max(0,length-1)):
            if self.reset_func(stream):
                new_num = self.reset_value_func(stream)
            else:
                new_num = self.increment_func(stream)
            stream.append(new_num)
            
        return stream
    
    
    def encode_stream(self, num_stream):
        ''' Converts number or number_stream into one-hot encoded tensor
        An additional category is added for OTHER (unrecognized numbers)
        
        Args:
            num_stream: A single number of sequence (list) of numbers
                that will be encoded
            
        Returns:
            Encoded tensor of shape (number_index, _, one_hot_encoded_number)
        '''
        num_stream = num_stream if type(num_stream) is list else [num_stream]
        output = torch.zeros(len(num_stream), 1, len(self.all_numbers)+1)
        
        for i, num in enumerate(num_stream):
            value = num if num in self.all_numbers else -1
            ind = self.all_classes.index(value)
            output[i,0,ind] = 1
            
        return output
    
    def get_stream_goal(self, stream):
        ''' Return the goal output for the providede stream 
        Args:
            stream: a list of numbers in the sequence
        
        Returns:
            Simple evaluation of the goal function (type depends
            on the goal function provided)
        '''
        return self.goal_func(stream)
    
    
    def decode_row(self, encoded_number):
        ''' Decode a one-hot encoded vector back to the number '''
        ind = encoded_number.argmax()
        value = self.all_numbers[ind] if ind in self.all_numbers else -1
        return value
    
        
    def get_dim(self):
        ''' Return dimension of each encoded tensor '''
        return len(self.all_classes)

In [7]:
class BasicRNN(nn.Module):
    ''' Basic RNN with a softmax output layer '''
    
    def __init__(self, input_size, hidden_size, output_size):
        ''' Init the RNN '''
        super(BasicRNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.input_to_output = nn.Linear(input_size + hidden_size, output_size)
        self.input_to_hidden = nn.Linear(input_size + hidden_size, hidden_size)
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        ''' Run a forward path for a single number of sequence of numbers
        
        Returns:
            The output and hidden layer at the termination of the sequence
        '''
        for i in range(input.shape[0]):
            combined_input = torch.cat((input[i].unsqueeze(0), hidden), dim=2)
            hidden = self.input_to_hidden(combined_input)
            
        output = self.input_to_output(combined_input)
        output = self.softmax(output)
        
        return output, hidden
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        model_device = next(self.parameters()).device
        return torch.zeros(1, 1, self.hidden_size).to(model_device)

In [10]:
def train_model(model, number_tool, criterion, optim, train_size = 100e3, stream_size = 200, print_interval = 1e3):
    ''' Runs a full training pass for a given model '''
    train_size = int(train_size)
    print_interval = int(print_interval)
    error_sum = 0
    model_device = next(model.parameters()).device

    for i in range(train_size):
        # Generate a new random sequence for training
        stream = number_tool.create_stream(stream_size)
        obs = number_tool.encode_stream(stream).to(model_device)

        # Reset the model gradients and hidden layer
        model.zero_grad()
        hidden = model.init_hidden()
        
        # Score the model
        output, hidden = model.forward(obs, hidden)

        # Determine the target output and calculate the loss
        goal = number_tool.get_stream_goal(stream)
        target = torch.LongTensor([goal]).to(model_device)
        loss = criterion(output.squeeze(0), target)
        
        # Run backprop with the errors and update model/trackers
        loss.backward()
        optim.step()
        error_sum += loss

        if i % print_interval == 0:
            print(f'[{i}] Error: {error_sum / print_interval}')
            error_sum = 0

In [11]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')

In [12]:
# Number Generator + Goal Setup
max_number = 9
goal_func = lambda stream: stream[0]
goal_dim = max_number+2

number_tool = Numbers(max_number,
                      reset_value_func = lambda x: randint(0,max_number),
                      goal_func = goal_func)

# Model Setup
rnn = BasicRNN(number_tool.get_dim(), 5, goal_dim)
criterion = nn.NLLLoss()
rnn_optim = torch.optim.SGD(rnn.parameters(), lr = 0.001)

# Training Size
train_size = 100e3

In [13]:
rnn_setup_kwargs = {
    'train_size': train_size,
    'model': rnn,
    'number_tool': number_tool,
    'criterion': criterion,
    'optim': rnn_optim,
    'stream_size': 200
}

In [None]:
train_model(**rnn_setup_kwargs)

In [14]:
class BasicLSTM(nn.Module):
    ''' Basic single-LSTM with output calculated using a single
        linear layer with a softmax activation '''
    
    def __init__(self, input_size, hidden_size, output_size):
        ''' Init the LSTM '''
        super(BasicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden_to_output = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        ''' Run a forward path for a single number of sequence of numbers
        
        Returns:
            The output and hidden layer at the termination of the sequence
        '''
        # Hidden is made up of (hidden, cell_state)
        history, hidden = self.lstm.forward(input, hidden)
        output = self.hidden_to_output(hidden[0])
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        ''' Returns new hidden layers for the start of a new sequence '''
        model_device = next(self.parameters()).device
        return (torch.randn(1, 1, self.hidden_size).to(model_device),
                torch.randn(1, 1, self.hidden_size).to(model_device))

In [15]:
lstm = BasicLSTM(number_tool.get_dim(), 5, goal_dim)
lstm.to(device)

lstm_optim = torch.optim.SGD(lstm.parameters(), lr = 0.001, momentum = 0.9)

In [16]:
lstm_setup_kwargs = rnn_setup_kwargs.copy()
lstm_setup_kwargs['model'] = lstm
lstm_setup_kwargs['optim'] = lstm_optim

In [None]:
train_model(**lstm_setup_kwargs)

### ----------

In [None]:
all_notes = ['A','A#','B','C','C#','D','D#','E','F','F#','G','G#','PAUSE']

class Music:
    def encode_note(note):
        ''' One-hot encode the note if present otherwise return all zeros '''
        one_hot_vector = torch.zeros(1, len(all_notes))
        if note in all_notes:
            ind = notes.index(note)
            one_hot_vector[0,ind] = 1
        return one_hot_vector
    
    def decode_vector(encoded_note):
        ''' Decode a one-hot encoded vector '''
        note_index = encoded_note.argmax()
        