In [1]:
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

In [2]:
DEVICE = "cpu"
# if torch.cuda.is_available():
#     DEVICE = "cuda"

In [3]:
DEVICE

'cpu'

In [4]:
class Memory(nn.Module):
    def __init__(self, N, M):
        super().__init__()
        
        self.N    = N
        self.M    = M
        self.size = [self.N, self.M]
        
        self.register_buffer("memory_bias", torch.Tensor(N, M))
        stdev = 1 / (np.sqrt(N + M))
        nn.init.uniform_(self.memory_bias, -stdev, stdev)
        
    def reset(self, batch_size=1):
        self.batch_size = batch_size
        self.memory     = self.memory_bias.clone().repeat(batch_size, 1, 1)
        
    def read(self, w):
        data = torch.matmul(w.unsqueeze(1), self.memory).squeeze(1)
        return data
    
    def write(self, w, e_gate, a_gate):
        self.flashback = self.memory
        self.memory    = torch.Tensor(self.batch_size, self.N, self.M)
        erase          = torch.matmul(w.unsqueeze(-1), e_gate.unsqueeze(1))
        add            = torch.matmul(w.unsqueeze(-1), a_gate.unsqueeze(1))
        self.memory    = self.flashback * (1 - erase) + add
        
    def address(self, k, b, g, s, y, w_prev):
        wc = self._similarity(k, b)
        wg = self._interpolate(w_prev, wc, g)
        w_ = self._shift(wg, s)
        w  = self._sharpen(w_, y)
        return w
    
    def _similarity(self, k, b):
        k              = k.view(self.batch_size, 1, -1)
        similarity     = F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1)
        content_weight = F.softmax(b * similarity, dim=1)
        return content_weight
    
    def _interpolate(self, w_prev, wc, g):
        focus = g * wc + (1 - g) * w_prev
        return focus
        
    def _shift(self, wg, s):
        shift = torch.zeros(wg.size())
        for batch in range(self.batch_size):
            shift[batch] = _convolve(wg[batch], s[batch])
        return shift
    
    def _sharpen(self, w_, y):
        w = w_ ** y
        w = torch.div(w, torch.sum(w, dim=1).view(-1, 1) + 1e-16)
        return w
    
    

def _convolve(w, s):
    t = torch.cat([w[-1:], w, w[:1]])
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
    return c

In [5]:
class ReadHead(nn.Module):
    def __init__(self, memory, controller_size):
        super().__init__()
        
        self.memory             = memory
        self.N, self.M          = self.memory.size
        self.controller_size    = controller_size
        
        self.key                = nn.Linear(self.controller_size, self.M)
        self.key_strength       = nn.Linear(self.controller_size, 1)
        self.interpolation_gate = nn.Linear(self.controller_size, 1)
        self.shift_weighting    = nn.Linear(self.controller_size, 3)
        self.sharpen_factor     = nn.Linear(self.controller_size, 1)
        
        self.is_read_head       = True
        
        self.reset()
        
    def _address(self, k, b, g, s, y, w_prev):
        k = k.clone()
        b = F.softplus(b)
        g = torch.sigmoid(g)
        s = torch.softmax(s, dim=1)
        y = 1 + F.softplus(y)
        
        w = self.memory.address(k, b, g, s, y, w_prev)
        
        return w
    
    def forward(self, controller_state, w_prev):
        
        k = self.key(controller_state)
        b = self.key_strength(controller_state)
        g = self.interpolation_gate(controller_state)
        s = self.shift_weighting(controller_state)
        y = self.sharpen_factor(controller_state)
        
        w    = self._address(k, b, g, s, y, w_prev)
        
        data = self.memory.read(w)
        
        return data, w
    
    def create_new_state(self, batch_size):
        return torch.zeros(batch_size, self.N)

    def reset(self):
        nn.init.xavier_uniform_(self.key.weight, gain=1.4)
        nn.init.xavier_uniform_(self.key_strength.weight, gain=1.4)
        nn.init.xavier_uniform_(self.interpolation_gate.weight, gain=1.4)
        nn.init.xavier_uniform_(self.shift_weighting.weight, gain=1.4)
        nn.init.xavier_uniform_(self.sharpen_factor.weight, gain=1.4)
        
        nn.init.normal_(self.key.bias, std=0.01)
        nn.init.normal_(self.key_strength.bias, std=0.01)
        nn.init.normal_(self.interpolation_gate.bias, std=0.01)
        nn.init.normal_(self.shift_weighting.bias, std=0.01)
        nn.init.normal_(self.sharpen_factor.bias, std=0.01)

In [6]:
class WriteHead(nn.Module):
    def __init__(self, memory, controller_size):
        super().__init__()
        
        self.memory             = memory
        self.N, self.M          = self.memory.size
        self.controller_size    = controller_size
        
        self.key                = nn.Linear(self.controller_size, self.M)
        self.key_strength       = nn.Linear(self.controller_size, 1)
        self.interpolation_gate = nn.Linear(self.controller_size, 1)
        self.shift_weighting    = nn.Linear(self.controller_size, 3)
        self.sharpen_factor     = nn.Linear(self.controller_size, 1)
        self.erase              = nn.Linear(self.controller_size, self.M)
        self.add                = nn.Linear(self.controller_size, self.M)
        
        self.is_read_head       = False
        
        self.reset()
        
    def _address(self, k, b, g, s, y, w_prev):
        k = k.clone()
        b = F.softplus(b)
        g = torch.sigmoid(g)
        s = torch.softmax(s, dim=1)
        y = 1 + F.softplus(y)
        
        w = self.memory.address(k, b, g, s, y, w_prev)
        
        return w
    
    def forward(self, controller_state, w_prev):
        
        k = self.key(controller_state)
        b = self.key_strength(controller_state)
        g = self.interpolation_gate(controller_state)
        s = self.shift_weighting(controller_state)
        y = self.sharpen_factor(controller_state)
        e = self.erase(controller_state)
        a = self.add(controller_state)
        
        e = torch.sigmoid(e)
        
        w    = self._address(k, b, g, s, y, w_prev)
        
        self.memory.write(w, e, a)
        
        return w
    
    def create_new_state(self, batch_size):
        return torch.zeros(batch_size, self.N)

    def reset(self):
        nn.init.xavier_uniform_(self.key.weight, gain=1.4)
        nn.init.xavier_uniform_(self.key_strength.weight, gain=1.4)
        nn.init.xavier_uniform_(self.interpolation_gate.weight, gain=1.4)
        nn.init.xavier_uniform_(self.shift_weighting.weight, gain=1.4)
        nn.init.xavier_uniform_(self.sharpen_factor.weight, gain=1.4)
        nn.init.xavier_uniform_(self.erase.weight, gain=1.4)
        nn.init.xavier_uniform_(self.add.weight, gain=1.4)
        
        nn.init.normal_(self.key.bias, std=0.01)
        nn.init.normal_(self.key_strength.bias, std=0.01)
        nn.init.normal_(self.interpolation_gate.bias, std=0.01)
        nn.init.normal_(self.shift_weighting.bias, std=0.01)
        nn.init.normal_(self.sharpen_factor.bias, std=0.01)
        nn.init.normal_(self.erase.bias, std=0.01)
        nn.init.normal_(self.add.bias, std=0.01)
    


In [7]:
class Controller(nn.Module):
    def __init__(self, no_input, no_output, no_layer):
        super().__init__()
        
        self.no_input  = no_input
        self.no_output = no_output
        self.no_layer  = no_layer
        self.size      = [self.no_input, self.no_output]
        
        self.lstm      = nn.LSTM(input_size =self.no_input,
                                 hidden_size=self.no_output,
                                 num_layers = self.no_layer)
        
        self.h_bias    = nn.Parameter(torch.randn(self.no_layer, 1, self.no_output) * 0.05)
        self.c_bias    = nn.Parameter(torch.randn(self.no_layer, 1, self.no_output) * 0.05)
        
        self.reset()
        
    def forward(self, data, prev_state):
        data          = data.unsqueeze(0)
        output, state = self.lstm(data, prev_state)
        return output.squeeze(0), state
        
    def create_new_state(self, batch_size):
        h = self.h_bias.clone().repeat(1, batch_size, 1)
        c = self.c_bias.clone().repeat(1, batch_size, 1)
        
        return h, c
    
    def reset(self):
        for param in self.lstm.parameters():
            if param.dim()==1:
                nn.init.constant_(param, 0)
            else:
                stdev = 1 / (np.sqrt(self.no_input + self.no_output))
                nn.init.uniform_(param, -stdev, stdev)

In [8]:
class NTM(nn.Module):
    def __init__(self, no_input, no_output, controller_size, controller_layer, no_head, N, M):
        super().__init__()
        
        self.no_input         = no_input
        self.no_output        = no_output
        self.controller_size  = controller_size
        self.controller_layer = controller_layer
        self.no_head          = no_head
        
        self.N                = N
        self.M                = M        

        self.memory           = Memory(self.N, self.M)
        self.controller       = Controller(self.no_input + (self.M * self.no_head), self.controller_size, self.controller_layer)
        self.head             = nn.ModuleList([])
        
        _, self.controller_size  = self.controller.size
        
        for head_no in range(self.no_head):
            self.head += [
                            ReadHead(self.memory, self.controller_size),
                            WriteHead(self.memory, self.controller_size)
            ]
        self.no_read_head = 0
        self.read         = []
        for head in self.head:
            if head.is_read_head:
                read_bias = torch.randn(1, self.M) * 0.01
                self.register_buffer("read{}_bias".format(self.no_read_head), read_bias.data)
                self.read += [read_bias]
                self.no_read_head += 1
        
        self.fc = nn.Linear(self.controller_size + self.no_read_head * self.M, self.no_output)
        self.reset()
        
    def create_new_state(self, batch_size):
        read             = [r.clone().repeat(batch_size, 1) for r in self.read]
        controller_state = self.controller.create_new_state(batch_size)
        head_state       = [head.create_new_state(batch_size) for head in self.head]
        return read, controller_state, head_state
        
    def init_sequence(self, batch_size):
        self.batch_size = batch_size
        self.memory.reset(batch_size)
        self.previous_state = self.create_new_state(batch_size)
        
    def forward(self, x=None):
        if x is None:
            x = torch.zeros(self.batch_size, self.no_input)
            
        prev_read, prev_controller_state, prev_head_state = self.previous_state
        
        inp                                 = torch.cat([x] + prev_read, dim=1)
        controller_output, controller_state = self.controller(inp, prev_controller_state)
        
        reads = []
        head_state = []
        for head, prev_head_state in zip(self.head, prev_head_state):
            if head.is_read_head:
                r, h_state = head(controller_output, prev_head_state)
                reads += [r]
            else:
                h_state = head(controller_output, prev_head_state)
            head_state += [h_state]
            
        out = torch.cat([controller_output] + reads, dim=1)
        out = torch.sigmoid(self.fc(out))
        
        self.previous_state = (reads, controller_state, head_state)
        
        return out, self.previous_state
        
    def reset(self):
        nn.init.xavier_uniform_(self.fc.weight, gain=1)
        nn.init.normal_(self.fc.bias, std=0.01)
        
    def no_param(self):
        no_param = 0
        for param in self.parameters():
            no_param += param.data.view(-1).size(0)
        return no_param

In [9]:
def dataloader(no_batch, batch_size, seq_width, min_len, max_len):
    
    for batch_no in range(no_batch):
        seq_len = np.random.randint(min_len, max_len)
        seq     = np.random.binomial(1, 0.5, (seq_len, batch_size, seq_width))
        seq     = torch.from_numpy(seq)
        
        inp     = torch.zeros(seq_len+1, batch_size, seq_width+1)
        inp[:seq_len, :, :seq_width] = seq
        inp[seq_len, :, seq_width]   = 1
        out     = seq.clone()
        
        yield batch_no+1, inp.float(), out.float()

In [10]:
no_input         = 9
no_output        = 8
controller_size  = 100
controller_layer = 1
no_head          = 1
N                = 128
M                = 20

In [11]:
COPIER = NTM(no_input=no_input, no_output=no_output, controller_size=controller_size, controller_layer=controller_layer, no_head=no_head, N=N, M=M).to(DEVICE)

In [12]:
def progress_clean():
    """Clean the progress bar."""
    print("\r{}".format(" " * 80), end='\r')


def progress_bar(batch_num, report_interval, last_loss):
    """Prints the progress until the next report."""
    progress = (((batch_num-1) % report_interval) + 1) / report_interval
    fill = int(progress * 40)
    print("\r[{}{}]: {} (Loss: {:.4f})".format(
        "=" * fill, " " * (40 - fill), batch_num, last_loss), end='')

def save_checkpoint(net, name, args, batch_num, losses, costs, seq_lengths):
    progress_clean()

    basename = "{}/{}-{}-batch-{}".format(args.checkpoint_path, name, args.seed, batch_num)
    model_fname = basename + ".model"
    LOGGER.info("Saving model checkpoint to: '%s'", model_fname)
    torch.save(net.state_dict(), model_fname)

    # Save the training history
    train_fname = basename + ".json"
    LOGGER.info("Saving model training history to '%s'", train_fname)
    content = {
        "loss": losses,
        "cost": costs,
        "seq_lengths": seq_lengths
    }
    open(train_fname, 'wt').write(json.dumps(content))


def clip_grads(net):
    """Gradient clipping to the range [10, 10]."""
    parameters = list(filter(lambda p: p.grad is not None, net.parameters()))
    for p in parameters:
        p.grad.data.clamp_(-10, 10)

In [13]:
import logging
import time

In [14]:
LOGGER = logging.getLogger(__name__)

def get_ms():
    """Returns the current time in miliseconds."""
    return time.time() * 1000

In [15]:
no_batch   = 50000
batch_size = 1
loss       = nn.BCELoss()
optimizer  = torch.optim.RMSprop(COPIER.parameters(), momentum=0.9, alpha=0.95, lr=1e-4)

In [16]:
errors     = []
costs      = []
seq_length = []
start_ms = get_ms()

In [17]:
for batch_no, x, y in dataloader(no_batch=no_batch, batch_size=batch_size, seq_width=8, min_len=1, max_len=20):
    
    optimizer.zero_grad()
    inp_seq_len = x.size(0)
    out_seq_len = y.size(0)
    
    LOGGER.info("Training model for %d batches (batch_size=%d)...",
                no_batch, batch_size)
    
    
    COPIER.init_sequence(batch_size)
    
    for i in range(inp_seq_len):
        COPIER(x[i])
        
    y_ = torch.zeros(y.size())
    
    for i in range(out_seq_len):
        y_[i], _ = COPIER()
        
    error = loss(y_, y)
    error.backward()
    clip_grads(COPIER)
    
    optimizer.step()
    
    y_binarized = y_.clone().data
    y_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

    cost = torch.sum(torch.abs(y_binarized - y.data))
    
    
    errors.append(error.item())
    costs.append(cost.item()/batch_size)
    
    seq_length += [y.size(0)]
    
    
    progress_bar(batch_no, 200, error)

        # Report
    if batch_no % 200 == 0:
        mean_loss = np.array(errors[-200:]).mean()
        mean_cost = np.array(costs[-200:]).mean()
        mean_time = int(((get_ms() - start_ms) / 200) / batch_size)
        progress_clean()
        print("Mean Time: {} ms".format(mean_time))
        print("Mean Cost: {}".format(mean_cost))
        print("Mean Loss: {}".format(mean_loss))
        print("=====================================")
        LOGGER.info("Batch %d Loss: %.6f Cost: %.2f Time: %d ms/sequence",
                    batch_no, mean_loss, mean_cost, mean_time)
        start_ms = get_ms()

#         # Checkpoint
#     if (1000 != 0) and (batch_no % 1000 == 0):
#         save_checkpoint(copier, "copier"+str(batch_no), args,
#                         batch_0, losses, costs, seq_lengths)


Mean Time: 88 ms                                                                
Mean Cost: 39.815
Mean Loss: 0.6929117912054061
Mean Time: 88 ms                                                                
Mean Cost: 37.475
Mean Loss: 0.6854926851391793
Mean Time: 82 ms                                                                
Mean Cost: 32.75
Mean Loss: 0.6770105096697807
Mean Time: 86 ms                                                                
Mean Cost: 34.53
Mean Loss: 0.6714207538962365
Mean Time: 85 ms                                                                
Mean Cost: 33.145
Mean Loss: 0.653018511235714
Mean Time: 82 ms                                                                
Mean Cost: 31.55
Mean Loss: 0.6371622154116631
Mean Time: 88 ms                                                                
Mean Cost: 34.945
Mean Loss: 0.6375371615588665
Mean Time: 87 ms                                                                
Mean Cost: 34.515
Me

# Network has successfully learnt to copy memory elements