In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset
import torch.nn as nn

import matplotlib.pyplot as plt

# 1. define your dataloader

In [2]:
transform_list = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.131], std=[0.289])])
dataset = datasets.MNIST(root='/tmp/', train=True, download=True, transform=transform_list)
dataset = Subset(dataset, range(2000))

loader = torch.utils.data.DataLoader(
      dataset=dataset,
      batch_size=500,
      shuffle=False)

# 2. define your model

In [3]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.flatten(1)

n_hidden = 2
hidden_size = 10
device = 'cuda'
layers = [Flatten(), nn.Linear(28 * 28, hidden_size), nn.ReLU()] + \
         [nn.Linear(hidden_size, hidden_size), nn.ReLU()] * (n_hidden - 1) + \
         [nn.Linear(hidden_size, 10), nn.LogSoftmax(dim=1)]
model = nn.Sequential(*layers).to(device)

# 3. define your loss function
For this example here we will compute the MC-sampled Fisher Information Matrix using only 1 sample for each example.

In [4]:
def loss_fim_mc_estimate(input, target):
    log_sm = model(input)
    probs = torch.exp(log_sm)
    random_target = torch.multinomial(probs, 1)
    random_log_sm = torch.gather(log_sm, 1, random_target)
    return random_log_sm

# 4. create your generator

In [5]:
from nngeometry.pspace import M2Gradients

m2_generator = M2Gradients(model=model, dataloader=loader, loss_function=loss_fim_mc_estimate)
n_parameters = m2_generator.get_n_parameters()
print(str(n_parameters) + ' parameters')

8070 parameters


# 5.3 KFACMatrix representation

we now compute the KFAC matrix coefficients using the generator above

In [6]:
from nngeometry.representations import KFACMatrix
from nngeometry.vector import Vector, from_model

F_kfac = KFACMatrix(m2_generator)

v0 = from_model(model)

display the end of the FIM (zoom of the last layer)

In [7]:
v = torch.rand(n_parameters, device=device)
v.requires_grad = True

v_vec = Vector(model=model, vector_repr=v)

In [8]:
regularizer = F_kfac.vTMv(v_vec - v0)

you can now backward through your regularizer object

In [9]:
g = torch.autograd.grad(regularizer, v)[0]
g.size(), g

(torch.Size([8070]),
 tensor([ 0.0271,  0.0271,  0.0271,  ..., -0.1660,  0.1082, -0.1702],
        device='cuda:0'))