# "Neural Turing Machines" paper implementation - https://arxiv.org/pdf/1410.5401.pdf

In [1]:
import random
import torch
from torch import nn
from torch import optim
import numpy as np
from model.model import EncapsulatedNTM

class CopyTaskParams:
    def __init__(self, name="copy-task", controller_size=100, controller_layers=1, num_heads=1,
                 sequence_width=8, sequence_min_len=1, sequence_max_len=20, memory_n=128, memory_m=20,
                 num_batches=5000, batch_size=1, rmsprop_lr=1e-4, rmsprop_momentum=0.9, rmsprop_alpha=0.95):
        self.name = name
        self.controller_size = int(controller_size)
        self.controller_layers = int(controller_layers)
        self.num_heads = int(num_heads)
        self.sequence_width = int(sequence_width)
        self.sequence_min_len = int(sequence_min_len)
        self.sequence_max_len = int(sequence_max_len)
        self.memory_n = int(memory_n)
        self.memory_m = int(memory_m)
        self.num_batches = int(num_batches)
        self.batch_size = int(batch_size)
        self.rmsprop_lr = float(rmsprop_lr)
        self.rmsprop_momentum = float(rmsprop_momentum)
        self.rmsprop_alpha = float(rmsprop_alpha)

class CopyTaskModelTraining:
    def __init__(self, params=None):
        if params is None:
            params = CopyTaskParams()
        self.params = params
        self.net = self.default_net()
        self.dataloader = self.default_dataloader()
        self.criterion = self.default_criterion()
        self.optimizer = self.default_optimizer()

    def default_net(self):
        net = EncapsulatedNTM(self.params.sequence_width + 1, self.params.sequence_width,
                              self.params.controller_size, self.params.controller_layers,
                              self.params.num_heads,
                              self.params.memory_n, self.params.memory_m)
        return net

    def default_dataloader(self):
        return dataloader(self.params.num_batches, self.params.batch_size,
                          self.params.sequence_width,
                          self.params.sequence_min_len, self.params.sequence_max_len)

    def default_criterion(self):
        return nn.BCELoss()

    def default_optimizer(self):
        return optim.RMSprop(self.net.parameters(),
                             momentum=self.params.rmsprop_momentum,
                             alpha=self.params.rmsprop_alpha,
                             lr=self.params.rmsprop_lr)

def dataloader(num_batches, batch_size, seq_width, min_len, max_len):
    for batch_num in range(num_batches):
        seq_len = random.randint(min_len, max_len)
        seq = np.random.binomial(1, 0.5, (seq_len, batch_size, seq_width))
        seq = torch.from_numpy(seq)

        inp = torch.zeros(seq_len + 1, batch_size, seq_width + 1)
        inp[:seq_len, :, :seq_width] = seq
        inp[seq_len, :, seq_width] = 1.0
        outp = seq.clone()

        yield batch_num + 1, inp.float(), outp.float()

In [2]:
import argparse
import json
import logging
import time
import random
import re
import sys

import torch
import numpy as np


LOGGER = logging.getLogger(__name__)

TASKS = {
    'copy': (CopyTaskModelTraining, CopyTaskParams),
}

RANDOM_SEED = 1000
REPORT_INTERVAL = 200
CHECKPOINT_INTERVAL = 1000


def get_ms():
    return time.time() * 1000


def init_seed(seed=None):
    if seed is None:
        seed = int(get_ms() // 1000)

    LOGGER.info("Using seed=%d", seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)


def progress_clean():
    print("\r{}".format(" " * 80), end='\r')


def progress_bar(batch_num, report_interval, last_loss):
    progress = (((batch_num-1) % report_interval) + 1) / report_interval
    fill = int(progress * 40)
    print("\r[{}{}]: {} (Loss: {:.4f})".format(
        "=" * fill, " " * (40 - fill), batch_num, last_loss), end='')


def save_checkpoint(net, name, args, batch_num, losses, costs, seq_lengths):
    progress_clean()

    basename = "{}/{}-{}-batch-{}".format(args['checkpoint_path'], name, args['seed'], batch_num)
    model_fname = basename + ".model"
    LOGGER.info("Saving model checkpoint to: '%s'", model_fname)
    torch.save(net.state_dict(), model_fname)

    train_fname = basename + ".json"
    LOGGER.info("Saving model training history to '%s'", train_fname)
    content = {
        "loss": losses,
        "cost": costs,
        "seq_lengths": seq_lengths
    }
    open(train_fname, 'wt').write(json.dumps(content))


def clip_grads(net):
    parameters = list(filter(lambda p: p.grad is not None, net.parameters()))
    for p in parameters:
        p.grad.data.clamp_(-10, 10)

In [3]:
def train_batch(net, criterion, optimizer, X, Y):
    optimizer.zero_grad()
    inp_seq_len = X.size(0)
    outp_seq_len, batch_size, _ = Y.size()

    net.init_sequence(batch_size)

    for i in range(inp_seq_len):
        net(X[i])

    y_out = torch.zeros(Y.size())
    for i in range(outp_seq_len):
        y_out[i], _ = net()

    loss = criterion(y_out, Y)
    loss.backward()
    clip_grads(net)
    optimizer.step()

    y_out_binarized = y_out.clone().data
    y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

    cost = torch.sum(torch.abs(y_out_binarized - Y.data))

    return loss.item(), cost.item() / batch_size

In [4]:
def evaluate(net, criterion, X, Y):
    inp_seq_len = X.size(0)
    outp_seq_len, batch_size, _ = Y.size()

    net.init_sequence(batch_size)

    states = []
    for i in range(inp_seq_len):
        o, state = net(X[i])
        states += [state]

    y_out = torch.zeros(Y.size())
    for i in range(outp_seq_len):
        y_out[i], state = net()
        states += [state]

    loss = criterion(y_out, Y)

    y_out_binarized = y_out.clone().data
    y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

    cost = torch.sum(torch.abs(y_out_binarized - Y.data))

    result = {
        'loss': loss.data[0],
        'cost': cost / batch_size,
        'y_out': y_out,
        'y_out_binarized': y_out_binarized,
        'states': states
    }

    return result

In [5]:
def update_params(params, update_list):
    for param_str in update_list:
        param_split = param_str.split('=')
        if len(param_split) != 2:
            LOGGER.error("Invalid parameter format: '%s'", param_str)
            sys.exit(1)
        key, value = param_split
        if hasattr(params, key):
            setattr(params, key, value)
        else:
            LOGGER.error("Invalid parameter '%s'", key)
            sys.exit(1)
    return params

def train_model(model, args):
    num_batches = model.params.num_batches
    batch_size = model.params.batch_size

    LOGGER.info("Training model for %d batches (batch_size=%d)...",
                num_batches, batch_size)

    losses = []
    costs = []
    seq_lengths = []
    start_ms = get_ms()

    for batch_num, x, y in model.dataloader:
        loss, cost = train_batch(
            model.net, model.criterion, model.optimizer, x, y)
        losses += [loss]
        costs += [cost]
        seq_lengths += [y.size(0)]

        progress_bar(batch_num, args['report_interval'], loss)

        if batch_num % args['report_interval'] == 0:
            mean_loss = np.array(losses[-args['report_interval']:]).mean()
            mean_cost = np.array(costs[-args['report_interval']:]).mean()
            mean_time = int(
                ((get_ms() - start_ms) / args['report_interval']) / batch_size)
            progress_clean()
            LOGGER.info("Batch %d Loss: %.6f Cost: %.2f Time: %d ms/sequence",
                        batch_num, mean_loss, mean_cost, mean_time)
            start_ms = get_ms()

        if (args['checkpoint_interval'] != 0) and (batch_num % args['checkpoint_interval'] == 0):
            save_checkpoint(model.net, model.params.name, args,
                            batch_num, losses, costs, seq_lengths)

    LOGGER.info("Done training.")

def init_arguments():
    args = {
        'seed': RANDOM_SEED,
        'task': 'copy',
        'param': [],
        'checkpoint_interval': CHECKPOINT_INTERVAL,
        'checkpoint_path': './',
        'report_interval': REPORT_INTERVAL
    }

    args['checkpoint_path'] = args['checkpoint_path'].rstrip('/')

    return args

def init_model(args):
    LOGGER.info("Training for the **%s** task", args['task'])

    model_cls, params_cls = TASKS.get(args['task'], (None, None))

    if model_cls is None or params_cls is None:
        LOGGER.error("Invalid task '%s'", args['task'])
        sys.exit(1)

    params = params_cls()
    params = update_params(params, args['param'])

    LOGGER.info(params)

    model = model_cls(params=params)
    return model

def init_logging():
    logging.basicConfig(format='[%(asctime)s] [%(levelname)s] [%(name)s]  %(message)s',
                        level=logging.DEBUG)


In [6]:
def main():
    init_logging()

    args = init_arguments()

    init_seed(args['seed'])

    model = init_model(args)

    LOGGER.info("Total number of parameters: %d",
                model.net.calculate_num_params())
    train_model(model, args)


main()

[2023-09-13 19:40:39,907] [INFO] [__main__]  Using seed=1000
[2023-09-13 19:40:39,911] [INFO] [__main__]  Training for the **copy** task
[2023-09-13 19:40:39,911] [INFO] [__main__]  <__main__.CopyTaskParams object at 0x7f33ab6647d0>
[2023-09-13 19:40:39,916] [INFO] [__main__]  Total number of parameters: 62860
[2023-09-13 19:40:39,917] [INFO] [__main__]  Training model for 5000 batches (batch_size=1)...




[2023-09-13 19:41:12,323] [INFO] [__main__]  Batch 200 Loss: 0.691536 Cost: 38.49 Time: 162 ms/sequence




[2023-09-13 19:41:36,559] [INFO] [__main__]  Batch 400 Loss: 0.682120 Cost: 39.23 Time: 121 ms/sequence




[2023-09-13 19:42:01,197] [INFO] [__main__]  Batch 600 Loss: 0.669924 Cost: 35.20 Time: 123 ms/sequence




[2023-09-13 19:42:28,596] [INFO] [__main__]  Batch 800 Loss: 0.665752 Cost: 38.42 Time: 136 ms/sequence




[2023-09-13 19:42:57,612] [INFO] [__main__]  Batch 1000 Loss: 0.656061 Cost: 36.27 Time: 145 ms/sequence
[2023-09-13 19:42:57,613] [INFO] [__main__]  Saving model checkpoint to: './copy-task-1000-batch-1000.model'
[2023-09-13 19:42:57,617] [INFO] [__main__]  Saving model training history to './copy-task-1000-batch-1000.json'




[2023-09-13 19:43:25,714] [INFO] [__main__]  Batch 1200 Loss: 0.654393 Cost: 36.38 Time: 140 ms/sequence




[2023-09-13 19:43:54,329] [INFO] [__main__]  Batch 1400 Loss: 0.647553 Cost: 36.85 Time: 143 ms/sequence




[2023-09-13 19:44:21,010] [INFO] [__main__]  Batch 1600 Loss: 0.641334 Cost: 34.49 Time: 133 ms/sequence




[2023-09-13 19:44:47,598] [INFO] [__main__]  Batch 1800 Loss: 0.633838 Cost: 35.28 Time: 132 ms/sequence




[2023-09-13 19:45:14,010] [INFO] [__main__]  Batch 2000 Loss: 0.628209 Cost: 33.15 Time: 132 ms/sequence
[2023-09-13 19:45:14,012] [INFO] [__main__]  Saving model checkpoint to: './copy-task-1000-batch-2000.model'
[2023-09-13 19:45:14,015] [INFO] [__main__]  Saving model training history to './copy-task-1000-batch-2000.json'




[2023-09-13 19:45:38,107] [INFO] [__main__]  Batch 2200 Loss: 0.619515 Cost: 33.32 Time: 120 ms/sequence




[2023-09-13 19:46:02,813] [INFO] [__main__]  Batch 2400 Loss: 0.621100 Cost: 32.38 Time: 123 ms/sequence




[2023-09-13 19:46:29,437] [INFO] [__main__]  Batch 2600 Loss: 0.607359 Cost: 31.95 Time: 133 ms/sequence




[2023-09-13 19:46:56,223] [INFO] [__main__]  Batch 2800 Loss: 0.611724 Cost: 35.51 Time: 133 ms/sequence




[2023-09-13 19:47:22,822] [INFO] [__main__]  Batch 3000 Loss: 0.601232 Cost: 33.44 Time: 132 ms/sequence
[2023-09-13 19:47:22,824] [INFO] [__main__]  Saving model checkpoint to: './copy-task-1000-batch-3000.model'
[2023-09-13 19:47:22,828] [INFO] [__main__]  Saving model training history to './copy-task-1000-batch-3000.json'




[2023-09-13 19:47:49,429] [INFO] [__main__]  Batch 3200 Loss: 0.599935 Cost: 33.64 Time: 133 ms/sequence




[2023-09-13 19:48:16,318] [INFO] [__main__]  Batch 3400 Loss: 0.602553 Cost: 34.11 Time: 134 ms/sequence




[2023-09-13 19:48:42,088] [INFO] [__main__]  Batch 3600 Loss: 0.596330 Cost: 33.89 Time: 128 ms/sequence




[2023-09-13 19:49:07,895] [INFO] [__main__]  Batch 3800 Loss: 0.581937 Cost: 34.45 Time: 129 ms/sequence




[2023-09-13 19:49:36,128] [INFO] [__main__]  Batch 4000 Loss: 0.581719 Cost: 32.81 Time: 141 ms/sequence
[2023-09-13 19:49:36,129] [INFO] [__main__]  Saving model checkpoint to: './copy-task-1000-batch-4000.model'
[2023-09-13 19:49:36,133] [INFO] [__main__]  Saving model training history to './copy-task-1000-batch-4000.json'




[2023-09-13 19:50:02,966] [INFO] [__main__]  Batch 4200 Loss: 0.593039 Cost: 34.19 Time: 134 ms/sequence




[2023-09-13 19:50:28,730] [INFO] [__main__]  Batch 4400 Loss: 0.575995 Cost: 32.62 Time: 128 ms/sequence




[2023-09-13 19:50:55,571] [INFO] [__main__]  Batch 4600 Loss: 0.584684 Cost: 33.38 Time: 134 ms/sequence




[2023-09-13 19:51:23,134] [INFO] [__main__]  Batch 4800 Loss: 0.586175 Cost: 33.01 Time: 137 ms/sequence




[2023-09-13 19:51:51,000] [INFO] [__main__]  Batch 5000 Loss: 0.572626 Cost: 33.08 Time: 139 ms/sequence
[2023-09-13 19:51:51,002] [INFO] [__main__]  Saving model checkpoint to: './copy-task-1000-batch-5000.model'
[2023-09-13 19:51:51,006] [INFO] [__main__]  Saving model training history to './copy-task-1000-batch-5000.json'
[2023-09-13 19:51:51,016] [INFO] [__main__]  Done training.


                                                                                