# "Neural Turing Machines" paper implementation - https://arxiv.org/pdf/1410.5401.pdf

In [5]:
import random
import torch
from torch import nn
from torch import optim
import numpy as np
from model.model import EncapsulatedNTM

class CopyTaskParams:
    def __init__(self, name="copy-task", controller_size=100, controller_layers=1, num_heads=1,
                 sequence_width=8, sequence_min_len=1, sequence_max_len=20, memory_n=128, memory_m=20,
                 num_batches=50000, batch_size=1, rmsprop_lr=1e-4, rmsprop_momentum=0.9, rmsprop_alpha=0.95):
        self.name = name
        self.controller_size = int(controller_size)
        self.controller_layers = int(controller_layers)
        self.num_heads = int(num_heads)
        self.sequence_width = int(sequence_width)
        self.sequence_min_len = int(sequence_min_len)
        self.sequence_max_len = int(sequence_max_len)
        self.memory_n = int(memory_n)
        self.memory_m = int(memory_m)
        self.num_batches = int(num_batches)
        self.batch_size = int(batch_size)
        self.rmsprop_lr = float(rmsprop_lr)
        self.rmsprop_momentum = float(rmsprop_momentum)
        self.rmsprop_alpha = float(rmsprop_alpha)

class CopyTaskModelTraining:
    def __init__(self, params=None):
        if params is None:
            params = CopyTaskParams()
        self.params = params
        self.net = self.default_net()
        self.dataloader = self.default_dataloader()
        self.criterion = self.default_criterion()
        self.optimizer = self.default_optimizer()

    def default_net(self):
        net = EncapsulatedNTM(self.params.sequence_width + 1, self.params.sequence_width,
                              self.params.controller_size, self.params.controller_layers,
                              self.params.num_heads,
                              self.params.memory_n, self.params.memory_m)
        return net

    def default_dataloader(self):
        return dataloader(self.params.num_batches, self.params.batch_size,
                          self.params.sequence_width,
                          self.params.sequence_min_len, self.params.sequence_max_len)

    def default_criterion(self):
        return nn.BCELoss()

    def default_optimizer(self):
        return optim.RMSprop(self.net.parameters(),
                             momentum=self.params.rmsprop_momentum,
                             alpha=self.params.rmsprop_alpha,
                             lr=self.params.rmsprop_lr)

def dataloader(num_batches, batch_size, seq_width, min_len, max_len):
    for batch_num in range(num_batches):
        seq_len = random.randint(min_len, max_len)
        seq = np.random.binomial(1, 0.5, (seq_len, batch_size, seq_width))
        seq = torch.from_numpy(seq)

        inp = torch.zeros(seq_len + 1, batch_size, seq_width + 1)
        inp[:seq_len, :, :seq_width] = seq
        inp[seq_len, :, seq_width] = 1.0
        outp = seq.clone()

        yield batch_num + 1, inp.float(), outp.float()