In [1]:
import numpy as np
import torch

import matplotlib.pyplot as plt
%matplotlib inline

import EKFAC

## Testing gradient passed to the backward hook function

In the EKFAC code, we register a backwards hook for every **linear** layer that gets called every time ``loss.backward()`` is called. One of the arguments passed to the hook (which our hook saves) is the gradient of the loss function with respect to the **output** of the linear layer.  In this notebook, we verify that that argument matches analytic calculations, for some simple networks.  

### Single-layer linear network with a single output, and no bias

The first network we test is simply a linear network with an arbitrary input dimension but a one-dimensional output, i.e. taking the dot product of the input with a weights vector.  

That is, 
\begin{equation}
y = \vec{W}\cdot\vec{x}
\end{equation}

We will use two different loss functions, either the mean-squared loss or the total squared loss, where the mean or sum respectively is taken over all of the mini-batch inputs.  So if the data in a mini-batch is denoted $\{ ( \vec{x}_i, y_i ) \}$, then the MSE loss is given by 
\begin{equation}
L_{MSE} = \frac{1}{N_b}\sum_i (y_i - \vec{W} \cdot \vec{x})^2 = \frac{1}{N_b}\sum_i (y_i - y^{mod}_i)^2
\end{equation}
and the total squared loss is given by 
\begin{equation}
L_{TSE} = \sum_i (y_i - \vec{W} \cdot \vec{x})^2 = \sum_i (y_i - y^{mod}_i)^2,
\end{equation}
where $y^{mod}_i$ is defined via $y^{mod}_i = y(\vec{x}_i)$.  

Then, when we take gradients of the loss function with respect to the outputs $y^{mod}_i$ and get, for each loss function, that
\begin{equation}
\frac{\partial L_{MSE}}{\partial y^{mod}_i} = -\frac{2}{N_b} (y_i - y^{mod}_i) 
\end{equation}
and
\begin{equation}
\frac{\partial L_{TSE}}{\partial y^{mod}_i} = -2 (y_i - y^{mod}_i) 
\end{equation}

The only difference between the two is the factor of $N_b$, which comes from the definitions of the $TSE$ and $MSE$ functions themselves.

In [29]:
Nbatch = 10
D_in = 2
D_out = 1

linear_model = torch.nn.Sequential(
    torch.nn.Linear(D_in, D_out, bias=False),
)

W = list(linear_model.parameters())[0]

In [30]:
EKFAC_lin = EKFAC.EKFAC(linear_model)

In [31]:
# Create random Tensors to hold inputs and outputs
x = torch.randn(Nbatch, D_in)
y = torch.randn(Nbatch, D_out)

In [32]:
y_mod = linear_model(x)

loss_functions = {'MSE': torch.nn.MSELoss(reduction='mean'),
          'TSE': torch.nn.MSELoss(reduction='sum')}

loss_gradients  = {'MSE': lambda y, y_mod, N_batch: -2/N_batch*(y-y_mod),
                   'TSE': lambda y, y_mod, N_batch: -2*(y-y_mod)
                  }

In [33]:
loss_type = 'TSE'

loss_fun = loss_functions[loss_type]
loss_grad = loss_gradients[loss_type]

l = loss_fun(y, y_mod)
l.backward()

Now we can compare the stored items in the ``EKFAC`` object to the gradients we calculated analytically.

In [34]:
gradient_passed_to_hook = EKFAC_lin.stored_items[list(EKFAC_lin.stored_items.keys())[0]]['grad_wrt_output']
gradient_analytic = loss_grad(y, y_mod, Nbatch)

In [35]:
gradient_passed_to_hook

tensor([[  0.7901],
        [-14.3245],
        [-15.9768],
        [ 17.5171],
        [-16.2203],
        [ 10.7378],
        [ -6.6827],
        [  9.1003],
        [ 10.4937],
        [ -0.4674]])

In [36]:
gradient_analytic

tensor([[ 0.0790],
        [-1.4325],
        [-1.5977],
        [ 1.7517],
        [-1.6220],
        [ 1.0738],
        [-0.6683],
        [ 0.9100],
        [ 1.0494],
        [-0.0467]], grad_fn=<MulBackward0>)