In [1]:
import numpy as np
import torch

import matplotlib.pyplot as plt
%matplotlib inline

import EKFAC

## Testing the implementation of Kronecker-factoring for linear layers with bias

In [2]:
Nbatch = 10
D_in = 2
H = 4
D_out = 1

linear_model = torch.nn.Sequential(
    torch.nn.Linear(D_in, D_out, bias=True)
)

EKFAC_lin = EKFAC.EKFAC(linear_model)

W = list(linear_model.parameters())[0]
b = list(linear_model.parameters())[1]

In [3]:
x = torch.randn(Nbatch, D_in)
y = torch.randn(Nbatch, D_out)

In [4]:
y_mod = linear_model(x)
loss = torch.nn.MSELoss()
l = loss(y,y_mod)
l.backward()

In [5]:
EKFAC_lin.stored_items

{Linear(in_features=2, out_features=1, bias=True): {'grad_wrt_output': tensor([[ 2.1316],
          [ 2.1818],
          [ 0.2731],
          [-2.7090],
          [-3.4923],
          [ 0.0565],
          [ 1.2099],
          [ 0.9729],
          [-0.9527],
          [-0.0302]]), 'input': tensor([[-0.0072,  0.5359,  1.0000],
          [-0.9406,  0.8435,  1.0000],
          [-0.5916, -0.2597,  1.0000],
          [-0.4772, -0.7506,  1.0000],
          [ 1.1458,  1.3147,  1.0000],
          [ 1.7045, -0.4041,  1.0000],
          [ 0.3381,  1.1061,  1.0000],
          [ 0.3160,  0.9787,  1.0000],
          [-1.5471, -0.1625,  1.0000],
          [-0.9177,  0.8845,  1.0000]])}}

In [6]:
h = EKFAC_lin.stored_items[list(EKFAC_lin.stored_items.keys())[0]]['input']

In [7]:
delta = EKFAC_lin.stored_items[list(EKFAC_lin.stored_items.keys())[0]]['grad_wrt_output']

In [8]:
h.t() @ delta / Nbatch

tensor([[-0.2623],
        [ 0.2749],
        [-0.0358]])

In [9]:
W.grad

tensor([[-0.2623,  0.2749]])

In [10]:
b.grad

tensor([-0.0358])