# `clemkoa/ntm`

## `utils.py`

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def roll(t, n):
    temp = t.flip(1)
    return torch.cat((temp[:, -(n+1):], temp[:, :-(n+1)]), dim=1)


def circular_convolution(w, s):
    temp_cat = torch.t(torch.cat([roll(s, i) for i in range(w.shape[1])]))
    return torch.mm(w, temp_cat)


def _convolve(w, s):
    """Circular convolution implementation."""
    assert s.size(0) == 3
    t = torch.cat([w[-1:], w, w[:1]], dim=0)
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
    return c


def plot_copy_results(target, y, vector_length):
    plt.set_cmap('jet')
    fig = plt.figure()
    ax1 = fig.add_subplot(211)
    ax1.set_ylabel("target", rotation=0, labelpad=20)
    ax1.imshow(torch.t(target.view(-1, vector_length)))
    ax1.tick_params(axis="both", which="both", length=0)
    ax2 = fig.add_subplot(212)
    ax2.set_ylabel("output", rotation=0, labelpad=20)
    ax2.imshow(torch.t(y.clone().data.view(-1, vector_length)))
    ax2.tick_params(axis="both", which="both", length=0)
    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax1.get_yticklabels(), visible=False)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.setp(ax2.get_yticklabels(), visible=False)
    plt.show()

## `memory.py`

In [2]:
import torch
from torch import nn
from torch.nn import Parameter


class Memory(nn.Module):
    def __init__(self, memory_size):
        super(Memory, self).__init__()
        self._memory_size = memory_size

        # Initialize memory bias
        initial_state = torch.ones(memory_size) * 1e-6
        self.register_buffer('initial_state', initial_state.data)

        # Initial read vector is a learnt parameter
        self.initial_read = Parameter(torch.randn(1, self._memory_size[1]) * 0.01)

    def get_size(self):
        return self._memory_size

    def reset(self, batch_size):
        self.memory = self.initial_state.clone().repeat(batch_size, 1, 1)

    def get_initial_read(self, batch_size):
        return self.initial_read.clone().repeat(batch_size, 1)

    def read(self):
        return self.memory

    def write(self, w, e, a):
        self.memory = self.memory * (1 - torch.matmul(w.unsqueeze(-1), e.unsqueeze(1)))
        self.memory = self.memory + torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))
        return self.memory

    def size(self):
        return self._memory_size

## `head.py`

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
# from ntm.utils import _convolve


class Head(nn.Module):
    def __init__(self, memory, hidden_size):
        super(Head, self).__init__()
        self.memory = memory
        memory_length, memory_vector_length = memory.get_size()
        # (k : vector, beta: scalar, g: scalar, s: vector, gamma: scalar)
        self.k_layer = nn.Linear(hidden_size, memory_vector_length)
        self.beta_layer = nn.Linear(hidden_size, 1)
        self.g_layer = nn.Linear(hidden_size, 1)
        self.s_layer = nn.Linear(hidden_size, 3)
        self.gamma_layer = nn.Linear(hidden_size, 1)
        for layer in [self.k_layer, self.beta_layer, self.g_layer, self.s_layer, self.gamma_layer]:
            nn.init.xavier_uniform_(layer.weight, gain=1.4)
            nn.init.normal_(layer.bias, std=0.01)

        self._initial_state = Parameter(torch.randn(1, self.memory.get_size()[0]) * 1e-5)

    def get_initial_state(self, batch_size):
        # Softmax to ensure weights are normalized
        return F.softmax(self._initial_state, dim=1).repeat(batch_size, 1)

    def get_head_weight(self, x, previous_state, memory_read):
        k = self.k_layer(x)
        beta = F.softplus(self.beta_layer(x))
        g = F.sigmoid(self.g_layer(x))
        s = F.softmax(self.s_layer(x), dim=1)
        gamma = 1 + F.softplus(self.gamma_layer(x))
        # Focusing by content
        w_c = F.softmax(beta * F.cosine_similarity(memory_read + 1e-16, k.unsqueeze(1) + 1e-16, dim=-1), dim=1)
        # Focusing by location
        w_g = g * w_c + (1 - g) * previous_state
        w_t = self.shift(w_g, s)
        w = w_t ** gamma
        w = torch.div(w, torch.sum(w, dim=1).unsqueeze(1) + 1e-16)
        return w

    def shift(self, w_g, s):
        result = w_g.clone()
        for b in range(len(w_g)):
            result[b] = _convolve(w_g[b], s[b])
        return result


class ReadHead(Head):
    def forward(self, x, previous_state):
        memory_read = self.memory.read()
        w = self.get_head_weight(x, previous_state, memory_read)
        return torch.matmul(w.unsqueeze(1), memory_read).squeeze(1), w


class WriteHead(Head):
    def __init__(self, memory, hidden_size):
        super(WriteHead, self).__init__(memory, hidden_size)
        memory_length, memory_vector_length = memory.get_size()
        self.e_layer = nn.Linear(hidden_size, memory_vector_length)
        self.a_layer = nn.Linear(hidden_size, memory_vector_length)
        for layer in [self.e_layer, self.a_layer]:
            nn.init.xavier_uniform_(layer.weight, gain=1.4)
            nn.init.normal_(layer.bias, std=0.01)

    def forward(self, x, previous_state):
        memory_read = self.memory.read()
        w = self.get_head_weight(x, previous_state, memory_read)
        e = F.sigmoid(self.e_layer(x))
        a = self.a_layer(x)

        # write to memory (w, memory, e , a)
        self.memory.write(w, e, a)
        return w

## `controller.py`

In [4]:
import torch
from torch import nn
from torch.nn import Parameter
import numpy as np
import torch.nn.functional as F


class Controller(nn.Module):
    def __init__(self, lstm_controller, vector_length, hidden_size):
        super(Controller, self).__init__()
        # We allow either a feed-forward network or a LSTM for the controller
        self._lstm_controller = lstm_controller
        if self._lstm_controller:
            self._controller = LSTMController(vector_length, hidden_size)
        else:
            self._controller = FeedForwardController(vector_length, hidden_size)

    def forward(self, x, state):
        return self._controller(x, state)

    def get_initial_state(self, batch_size):
        return self._controller.get_initial_state(batch_size)


class LSTMController(nn.Module):
    def __init__(self, vector_length, hidden_size):
        super(LSTMController, self).__init__()
        self.layer = nn.LSTM(input_size=vector_length, hidden_size=hidden_size)
        # The hidden state is a learned parameter
        self.lstm_h_state = Parameter(torch.randn(1, 1, hidden_size) * 0.05)
        self.lstm_c_state = Parameter(torch.randn(1, 1, hidden_size) * 0.05)
        for p in self.layer.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(vector_length + hidden_size))
                nn.init.uniform_(p, -stdev, stdev)

    def forward(self, x, state):
        output, state = self.layer(x.unsqueeze(0), state)
        return output.squeeze(0), state

    def get_initial_state(self, batch_size):
        lstm_h = self.lstm_h_state.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_state.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c


class FeedForwardController(nn.Module):
    def __init__(self, vector_length, hidden_size):
        super(FeedForwardController, self).__init__()
        self.layer_1 = nn.Linear(vector_length, hidden_size)
        self.layer_2 = nn.Linear(hidden_size, hidden_size)
        stdev = 5 / (np.sqrt(vector_length + hidden_size))
        nn.init.uniform_(self.layer_1.weight, -stdev, stdev)
        nn.init.uniform_(self.layer_2.weight, -stdev, stdev)

    def forward(self, x, state):
        x1 = F.relu(self.layer_1(x))
        output = F.relu(self.layer_2(x1))
        return output, state

    def get_initial_state(self):
        return 0, 0

## `ntm.py`

In [5]:
import torch
from torch import nn
import torch.nn.functional as F
# from ntm.controller import Controller
# from ntm.memory import Memory
# from ntm.head import ReadHead, WriteHead


class NTM(nn.Module):
    def __init__(self, vector_length, hidden_size, memory_size, lstm_controller=True):
        super(NTM, self).__init__()
        self.controller = Controller(lstm_controller, vector_length + 1 + memory_size[1], hidden_size)
        self.memory = Memory(memory_size)
        self.read_head = ReadHead(self.memory, hidden_size)
        self.write_head = WriteHead(self.memory, hidden_size)
        self.fc = nn.Linear(hidden_size + memory_size[1], 5)
        nn.init.xavier_uniform_(self.fc.weight, gain=1)
        nn.init.normal_(self.fc.bias, std=0.01)

    def get_initial_state(self, batch_size=1):
        self.memory.reset(batch_size)
        controller_state = self.controller.get_initial_state(batch_size)
        read = self.memory.get_initial_read(batch_size)
        read_head_state = self.read_head.get_initial_state(batch_size)
        write_head_state = self.write_head.get_initial_state(batch_size)
        return (read, read_head_state, write_head_state, controller_state)

    def forward(self, x, previous_state):
        previous_read, previous_read_head_state, previous_write_head_state, previous_controller_state = previous_state
        controller_input = torch.cat([x, previous_read], dim=1)
        controller_output, controller_state = self.controller(controller_input, previous_controller_state)
        # Read
        read_head_output, read_head_state = self.read_head(controller_output, previous_read_head_state)
        # Write
        write_head_state = self.write_head(controller_output, previous_write_head_state)
        fc_input = torch.cat((controller_output, read_head_output), dim=1)
        state = (read_head_output, read_head_state, write_head_state, controller_state)
        return F.softmax(self.fc(fc_input)), state

## `repeat_task.py`

In [6]:
import random
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
# from ntm.ntm import NTM
# from ntm.utils import plot_copy_results
import argparse
import numpy as np
# from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [7]:
# parser = argparse.ArgumentParser(description="Process some integers.")
# parser.add_argument("--train", help="Trains the model", action="store_true")
# parser.add_argument("--ff", help="Feed forward controller", action="store_true")
# parser.add_argument("--eval", help="Evaluates the model. Default path is models/repeat.pt", action="store_true")
# parser.add_argument("--modelpath", help="Specify the model path to load, for training or evaluation", type=str)
# parser.add_argument("--epochs", help="Specify the number of epochs for training", type=int, default=50_000)
# args = parser.parse_args()

In [8]:
from types import SimpleNamespace

args = SimpleNamespace()
args.train = True
args.ff = False
args.eval = False
args.modelpath = "models/repeat.pt"
args.epochs = 50000

In [9]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f5e98086f50>

In [10]:
def get_training_sequence(sequence_min_length, sequence_max_length, repeat_min, repeat_max, vector_length, batch_size=1):
    sequence_length = random.randint(sequence_min_length, sequence_max_length)
    repeat = random.randint(repeat_min, repeat_max)

    target = torch.bernoulli(torch.Tensor(sequence_length, batch_size, vector_length).uniform_(0, 1))

    input = torch.zeros(sequence_length + 2, batch_size, vector_length + 2)
    input[:sequence_length, :, :vector_length] = target
    # delimiter vector
    input[sequence_length, :, vector_length] = 1.0
    # repeat channel
    input[sequence_length + 1, :, vector_length + 1] = repeat / sequence_max_length

    output = torch.zeros(sequence_length * repeat + 1, batch_size, vector_length + 1)
    output[:sequence_length * repeat, :, :vector_length] = target.clone().repeat(repeat, 1, 1)
    # delimiter vector
    output[-1, :, -1] = 1.0
    return input, output

In [11]:
def train(epochs=50_000):
#     tensorboard_log_folder = f"runs/repeat-copy-task-{datetime.now().strftime('%Y-%m-%dT%H%M%S')}"
#     writer = SummaryWriter(tensorboard_log_folder)
#     print(f"Training for {epochs} epochs, logging in {tensorboard_log_folder}")
    sequence_min_length = 1
    sequence_max_length = 10
    repeat_min = 1
    repeat_max = 10
    vector_length = 8
    memory_size = (128, 20)
    hidden_layer_size = 100
    batch_size = 4
    lstm_controller = not args.ff

#     writer.add_scalar("sequence_min_length", sequence_min_length)
#     writer.add_scalar("sequence_max_length", sequence_max_length)
#     writer.add_scalar("vector_length", vector_length)
#     writer.add_scalar("memory_size0", memory_size[0])
#     writer.add_scalar("memory_size1", memory_size[1])
#     writer.add_scalar("hidden_layer_size", hidden_layer_size)
#     writer.add_scalar("lstm_controller", lstm_controller)
#     writer.add_scalar("seed", seed)
#     writer.add_scalar("batch_size", batch_size)

    model = NTM(vector_length + 1, hidden_layer_size, memory_size, lstm_controller)

    optimizer = optim.RMSprop(model.parameters(), momentum=0.9, alpha=0.95, lr=1e-4)
    feedback_frequency = 100
    total_loss = []
    total_cost = []

    os.makedirs("models", exist_ok=True)
    if os.path.isfile(model_path):
        print(f"Loading model from {model_path}")
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint)

    for epoch in range(epochs + 1):
        optimizer.zero_grad()
        input, target = get_training_sequence(sequence_min_length, sequence_max_length, repeat_min, repeat_max, vector_length, batch_size)
        state = model.get_initial_state(batch_size)
        for vector in input:
            _, state = model(vector, state)
        y_out = torch.zeros(target.size())
        for j in range(len(target)):
            y_out[j], state = model(torch.zeros(batch_size, vector_length + 2), state)
        loss = F.binary_cross_entropy(y_out, target)
        loss.backward()
        optimizer.step()
        total_loss.append(loss.item())
        y_out_binarized = y_out.clone().data
        y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)
        cost = torch.sum(torch.abs(y_out_binarized - target)) / len(target)
        total_cost.append(cost.item())
        if epoch % feedback_frequency == 0:
            running_loss = sum(total_loss) / len(total_loss)
            running_cost = sum(total_cost) / len(total_cost)
            print(f"Loss at step {epoch}: {running_loss}")
#             writer.add_scalar('training loss', running_loss, epoch)
#             writer.add_scalar('training cost', running_cost, epoch)
            total_loss = []
            total_cost = []

    torch.save(model.state_dict(), model_path)

In [12]:
def eval(model_path):
    vector_length = 8
    memory_size = (128, 20)
    hidden_layer_size = 100
    lstm_controller = not args.ff

    model = NTM(vector_length + 1, hidden_layer_size, memory_size, lstm_controller)

    print(f"Loading model from {model_path}")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint)

    model.eval()

    input, target = get_training_sequence(10, 10, 10, 10, vector_length)
    y_out = infer_sequence(model, input, target, vector_length)
    plot_copy_results(target, y_out, vector_length + 1)

    input, target = get_training_sequence(10, 10, 20, 20, vector_length)
    y_out = infer_sequence(model, input, target, vector_length)
    plot_copy_results(target, y_out, vector_length + 1)

    input, target = get_training_sequence(20, 20, 10, 10, vector_length)
    y_out = infer_sequence(model, input, target, vector_length)
    plot_copy_results(target, y_out, vector_length + 1)

In [13]:
def infer_sequence(model, input, target, vector_length):
    state = model.get_initial_state()
    for vector in input:
        _, state = model(vector, state)
    y_out = torch.zeros(target.size())
    for j in range(len(target)):
        y_out[j], state = model(torch.zeros(1, vector_length + 2), state)
    return y_out

In [14]:
# if __name__ == "__main__":
#     model_path = "models/repeat.pt"
#     if args.modelpath:
#         model_path = args.modelpath
#     if args.train:
#         train(args.epochs)
#     if args.eval:
#         eval(model_path)

# My Code

## Omniglot

In [15]:
from torchmeta.datasets import Omniglot
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader

In [16]:
dataset = Omniglot(
    "data",
    # Number of ways
    num_classes_per_task=5,
    # Resize the images to 20x20 and converts them to PyTorch tensors (from Torchvision)
    transform=Compose([Resize(20), ToTensor()]),
    # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
    target_transform=Categorical(num_classes=5),
    # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
    class_augmentations=[Rotation([90, 180, 270])],
    meta_train=True,
    download=True,
)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=5)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

In [17]:
def inputs_targets_to_seq(inputs, targets):
    X = inputs.flatten(2, 4)
    y = F.one_hot(targets, num_classes=5)
    y = torch.cat((torch.zeros(y.shape[0], 1, y.shape[2]), y), dim=1)[:, :-1, :]
    seq = torch.cat((X, y), dim=2)
    seq = torch.swapaxes(seq, 0, 1)

    # Shape: (seq_len, batch, n_features)
    return seq

In [18]:
for episode_i, batch in enumerate(dataloader):
    train_inputs, train_targets = batch["train"]
    seq = inputs_targets_to_seq(train_inputs, train_targets)
    print(seq.shape)
    break



torch.Size([25, 16, 405])


## Model

In [19]:
vector_length = 404
hidden_layer_size = 200
memory_size = (128, 40)
lstm_controller = True
batch_size = 16

In [20]:
model = NTM(vector_length, hidden_layer_size, memory_size, lstm_controller)
optimizer = optim.RMSprop(model.parameters(), momentum=0.9, alpha=0.95, lr=1e-4)

## Train

In [21]:
loss_per_ep = []
acc_per_ep = []
for episode_i, batch in enumerate(dataloader):
    train_inputs, train_targets = batch["train"]
    seq = inputs_targets_to_seq(train_inputs, train_targets)
    
    optimizer.zero_grad()
    state = model.get_initial_state(batch_size)
    y_out = torch.zeros((len(seq), 16, 5))
    for j, vector in enumerate(seq):
        y_out[j], state = model(vector, state)
    loss = F.cross_entropy(y_out.permute(1, 2, 0), train_targets)
    loss.backward()
    optimizer.step()
    correct = torch.sum(y_out.permute(1, 2, 0).argmax(dim=1) == train_targets)
    acc = correct.item() / np.prod(train_targets.size())

    loss_per_ep.append(loss.item())
    acc_per_ep.append(acc)

    if episode_i % 10 == 0:
        print(f"{episode_i:5d} {loss.item():.4f} {acc:.4f}")

    if episode_i >= 1000:
        break

  return F.softmax(self.fc(fc_input)), state


    0 1.6131 0.1575
   10 1.5816 0.4250
   20 1.5070 0.5200
   30 1.4840 0.4850
   40 1.4069 0.6025
   50 1.3434 0.6800
   60 1.2976 0.7375
   70 1.2532 0.7575
   80 1.2153 0.7800
   90 1.2049 0.7750
  100 1.1748 0.8000
  110 1.1569 0.7975
  120 1.1408 0.8100
  130 1.1230 0.8225
  140 1.1144 0.8450
  150 1.1042 0.8300
  160 1.0916 0.8300
  170 1.0753 0.8650
  180 1.0733 0.8525
  190 1.0856 0.8275
  200 1.0537 0.8775
  210 1.0538 0.8700
  220 1.0506 0.8700
  230 1.0466 0.8625
  240 1.0550 0.8625
  250 1.0349 0.8950
  260 1.0332 0.8800
  270 1.0310 0.8750
  280 1.0210 0.8975
  290 1.0254 0.8800
  300 1.0223 0.8850
  310 1.0220 0.8900
  320 1.0121 0.8975
  330 1.0181 0.8925
  340 1.0147 0.8925
  350 1.0180 0.9000
  360 1.0124 0.8925
  370 1.0269 0.8650
  380 1.0139 0.8825
  390 1.0155 0.8900
  400 1.0093 0.9100
  410 1.0160 0.8725
  420 1.0131 0.8925
  430 1.0087 0.8950
  440 1.0022 0.9075
  450 1.0039 0.8950
  460 1.0098 0.8900
  470 1.0148 0.8800
  480 1.0075 0.8975
  490 1.0130 0.8800


In [22]:
test_acc_per_ep = []
for episode_i, batch in enumerate(dataloader):
    test_inputs, test_targets = batch["test"]
    seq = inputs_targets_to_seq(test_inputs, test_targets)

    state = model.get_initial_state(batch_size)
    y_out = torch.zeros((len(seq), 16, 5))
    for j, vector in enumerate(seq):
        y_out[j], state = model(vector, state)
    correct = torch.sum(y_out.permute(1, 2, 0).argmax(dim=1) == test_targets)
    acc = correct.item() / np.prod(test_targets.size())

    test_acc_per_ep.append(acc)

    if episode_i >= 10:
        break

  return F.softmax(self.fc(fc_input)), state


In [23]:
np.mean(test_acc_per_ep)

0.8879545454545453