In [1]:
import numpy as np
import torch

import matplotlib.pyplot as plt
%matplotlib inline

import EKFAC

## Testing gradient passed to the backward hook function

In the EKFAC code, we register a backwards hook for every **linear** layer that gets called every time ``loss.backward()`` is called. One of the arguments passed to the hook (which our hook saves) is the gradient of the loss function with respect to the **output** of the linear layer.  In this notebook, we verify that that argument matches analytic calculations, for some simple networks.  

### Single-layer linear network with a single output, and no bias

The first network we test is simply a linear network with an arbitrary input dimension but a one-dimensional output, i.e. taking the dot product of the input with a weights vector.  

That is, 
\begin{equation}
y = \vec{W}\cdot\vec{x}
\end{equation}

We will use two different loss functions, either the mean-squared loss or the total squared loss, where the mean or sum respectively is taken over all of the mini-batch inputs.  So if the data in a mini-batch is denoted $\{ ( \vec{x}_i, y_i ) \}$, then the MSE loss is given by 
\begin{equation}
L_{MSE} = \frac{1}{N_b}\sum_i (y_i - \vec{W} \cdot \vec{x})^2 = \frac{1}{N_b}\sum_i (y_i - y^{mod}_i)^2
\end{equation}
and the total squared loss is given by 
\begin{equation}
L_{TSE} = \sum_i (y_i - \vec{W} \cdot \vec{x})^2 = \sum_i (y_i - y^{mod}_i)^2,
\end{equation}
where $y^{mod}_i$ is defined via $y^{mod}_i = y(\vec{x}_i)$.  

Then, when we take gradients of the loss function with respect to the outputs $y^{mod}_i$ and get, for each loss function, that
\begin{equation}
\frac{\partial L_{MSE}}{\partial y^{mod}_i} = -\frac{2}{N_b} (y_i - y^{mod}_i) 
\end{equation}
and
\begin{equation}
\frac{\partial L_{TSE}}{\partial y^{mod}_i} = -2 (y_i - y^{mod}_i) 
\end{equation}

The only difference between the two is the factor of $N_b$, which comes from the definitions of the $TSE$ and $MSE$ functions themselves.

In [65]:
Nbatch = 10
D_in = 2
D_out = 1

loss_type = 'TSE'

linear_model = torch.nn.Sequential(
    torch.nn.Linear(D_in, D_out, bias=False),
)

W = list(linear_model.parameters())[0]

In [66]:
EKFAC_lin = EKFAC.EKFAC(linear_model)

In [67]:
# Create random Tensors to hold inputs and outputs
x = torch.randn(Nbatch, D_in)
y = torch.randn(Nbatch, D_out)

In [68]:
y_mod = linear_model(x)

loss_functions = {'MSE': torch.nn.MSELoss(reduction='mean'),
          'TSE': torch.nn.MSELoss(reduction='sum')}

loss_gradients  = {'MSE': lambda y, y_mod, N_batch: -2/N_batch*(y-y_mod),
                   'TSE': lambda y, y_mod, N_batch: -2*(y-y_mod)
                  }

In [69]:
loss_fun = loss_functions[loss_type]
loss_grad = loss_gradients[loss_type]

l = loss_fun(y, y_mod)
l.backward()

Now we can compare the stored items in the ``EKFAC`` object to the gradients we calculated analytically.

In [70]:
with torch.no_grad():
    gradient_saved_by_hook = EKFAC_lin.stored_items[list(EKFAC_lin.stored_items.keys())[0]]['grad_wrt_output']
    gradient_analytic = loss_grad(y, y_mod, Nbatch)

In [71]:
print('Gradient saved by hook: {}'.format(gradient_saved_by_hook.t()))
print('Analytically calculated gradient: {}'.format(gradient_analytic.t()))

Gradient saved by hook: tensor([[-3.6643e+01, -1.3093e+01, -4.2351e+01,  1.6123e-02, -5.3818e+00,
          9.6941e+00, -3.9953e+00,  2.3490e+01, -5.5934e+00, -2.1213e+01]])
Analytically calculated gradient: tensor([[-3.6643e+00, -1.3093e+00, -4.2351e+00,  1.6123e-03, -5.3818e-01,
          9.6941e-01, -3.9953e-01,  2.3490e+00, -5.5934e-01, -2.1213e+00]])


Here, we can see that the ``gradient_analytic`` is a factor of 10 smaller the ``gradient_saved_by_hook``, which is the batch_size.  This is because we're multiplying the ``gradient_passed_to_hook`` by ``batch_size`` before saving the gradient.  We can see that this is true regardless of whether we use MSE or TSE.   

Now, from the EKFAC paper, we know that if we define, for a linear layer, $h$ as the input to that layer, and $\delta$ as the derivative of the loss function with respect to that layer, then the gradient of the loss function with respect to the weight parameter of the layer is given by $h\delta^T$.  Let's see if that holds true currently.  

In the case of a single linear layer, $h$ is given by $x$, the input to the network, and $\delta$ is given by the analytically calculated gradient above.  To average over the mini-batch, we divide finally by the input size, ``x.size(0)``.

In [76]:
test_weight_grad = x.t() @ gradient_saved_by_hook / x.size(0)

In [77]:
print('W.grad: {}'.format(W.grad))
print('test W grad: {}'.format(test_weight_grad))

W.grad: tensor([[10.7965, -1.1092]])
test W grad: tensor([[10.7965],
        [-1.1092]])


From this, we can see why we had to multiply by the batch size earlier, that is, why we needed the ``gradient_saved_by_hook`` to be a factor of ``batch_size`` larger than the true gradient of the loss function.  If we didn't do this, we'd get the wrong answer when we averaged $h\delta^T$.  The reason is the following: what's stored in ``gradient_passed_to_hook`` is the gradient of the total loss function, $L$, with respect to the parameters.  But when we pass a vector with multiple inputs, i.e., a mini-batch, then what we actually want to store is the gradient of the function which is being averaged over to get $L$.  So if we write $L = \frac{1}{N_B} \sum_i L_i$, then we actually want to store the gradient of $L_i$, not the gradient of $L$.  This is why we multiply by $N_B$ before saving.