It recently came to my attention that everybody is implementing NALUs (from [this paper](https://arxiv.org/pdf/1808.00508.pdf)) in PyTorch seemingly incorrectly. When you Google "PyTorch NALU" all the implementations that show up on the first page (as of the time of writing) contain the same mistake that turns their NACs into simple Linear modules. In this short notebook, I'll go over the incorrect approach and then provide a correct implementation of NAC.

In [1]:
import torch
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn import init
from torch.nn.modules import Module
import torch.optim as optim
import torch.utils.data
import numpy as np

## fauxNAC

In [2]:
class fauxNAC(Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.W_hat = Parameter(torch.Tensor(n_out, n_in))
        self.M_hat = Parameter(torch.Tensor(n_out, n_in))
        self.weights = Parameter(torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.W_hat)
        init.kaiming_uniform_(self.M_hat)
    
    def forward(self, input):
        return F.linear(input, self.weights)

The above is more or less similar to other PyTorch implementations I've seen and is **wrong**. The mistake is in the 8th line of the above cell: `weights` should not be a member variable and a Parameter. Instead it should be calculated in `forward()` using the values of `W_hat` and `M_hat`. The way it's implemented in `fauxNAC` means that `W_hat` and `M_hat` will never get updated and this entire module is just a Linear layer, but with worse memory usage.

Since seeing is believing, I'll show first hand how `W_hat` and `M_hat` don't get updated at all, by training `fauxNAC` on a simple summing learning task. If you've already been convinced and want to see the correct implementation, skip to the end of the notebook or take a look at the rest of this repo.

## Sum task

We start by writing a simple training loop

In [3]:
def fit(m, dataloader, opt, crit):
    for epoch in range(100):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(dataloader):
            # get the inputs
            inputs, labels = data
            inputs = inputs.cuda().float()
            labels = labels.cuda().float()

            # zero the parameter gradients
            opt.zero_grad()

            # forward + backward + optimize
            outputs = m(inputs)
            loss = crit(outputs, labels)
            loss.backward()
            opt.step()

            # print statistics
            running_loss += loss.item()
            if i % 8 == 7 and epoch % 20 == 19: # Print every eight minibatch of every 20th epoch
                print('[%d] loss: %.3f' % (epoch + 1, running_loss / 8))
                running_loss = 0.0

Then we setup the training data, model, optimiser, etc.

In [4]:
data = np.array([np.array([a, b]) for a, b in zip(np.random.uniform(-5, 5, 4096), np.random.uniform(-5, 5, 4096))])
y = np.array([np.array([a + b]) for a, b in data])

dataset = torch.utils.data.TensorDataset(torch.Tensor(data), torch.Tensor(y))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)

m = fauxNAC(2, 1).cuda()
opt = optim.Adam(m.parameters(), 1e-2)
crit = F.mse_loss

Let's check the values of our parameters `M_hat` and `W_hat`.

In [5]:
m.M_hat

Parameter containing:
tensor([[-0.3000, -1.3239]], device='cuda:0', requires_grad=True)

In [6]:
m.W_hat

Parameter containing:
tensor([[ 0.7358, -1.4481]], device='cuda:0', requires_grad=True)

Let's train and see what happens to those parameters

In [7]:
fit(m, dataloader, opt, crit)

[20] loss: 0.062
[40] loss: 0.000
[60] loss: 0.000
[80] loss: 0.000
[100] loss: 0.000


In [8]:
m.M_hat

Parameter containing:
tensor([[-0.3000, -1.3239]], device='cuda:0', requires_grad=True)

In [9]:
m.W_hat

Parameter containing:
tensor([[ 0.7358, -1.4481]], device='cuda:0', requires_grad=True)

As should be expected, they were not changed. The Parameter that was changed instead was `weights`, making the `fauxNAC` just a linear module:

In [10]:
m.weights

Parameter containing:
tensor([[1.0000, 1.0000]], device='cuda:0', requires_grad=True)

## Correct implementation

Next, I'll give the correct implementation of NAC by using only `M_hat` and `W_hat` as Parameters and performing the calculation in the `forward()` method.

In [11]:
class NAC(Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.W_hat = Parameter(torch.Tensor(n_out, n_in))
        self.M_hat = Parameter(torch.Tensor(n_out, n_in))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.W_hat)
        init.kaiming_uniform_(self.M_hat)
    
    def forward(self, input):
        weights = torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
        return F.linear(input, weights)