We compare different representations on two datasets/architectures:

 - A small convnet on MNIST with 1,001 parameters on the full MNIST dataset
 - A Resnet50 on CIFAR10 with 23,467,722 parameters on 100 CIFAR10 examples
 
In the convnet example, since we do not have so many parameters, we are also able to visualize FIMs.

In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as tF
import torch.nn as nn
import time

import matplotlib.pyplot as plt


from nngeometry.layercollection import LayerCollection
from nngeometry.metrics import FIM_MonteCarlo
from nngeometry.object.vector import random_pvector

from nngeometry.object import PMatDiag, PMatKFAC, PMatEKFAC, PMatLowRank, PMatBlockDiag, PMatQuasiDiag, PMatImplicit, PVector, PMatDense

# Small convnet on MNIST

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, 3, 1)
        self.conv2 = nn.Conv2d(5, 6, 4, 1)
        self.conv3 = nn.Conv2d(6, 7, 3, 1)
        self.fc1 = nn.Linear(1*1*7, 10)

    def forward(self, x):
        x = tF.relu(self.conv1(x))
        x = tF.max_pool2d(x, 2, 2)
        x = tF.relu(self.conv2(x))
        x = tF.max_pool2d(x, 2, 2)
        x = tF.relu(self.conv3(x))
        x = tF.max_pool2d(x, 2, 2)
        x = x.view(-1, 1*1*7)
        x = self.fc1(x)
        return tF.log_softmax(x, dim=1)

trainset = datasets.MNIST(root='/tmp/', train=True, download=True,
                           transform=transforms.ToTensor())
trainloader = DataLoader(
    dataset=trainset,
    batch_size=1000,
    shuffle=False)
    
convnet = ConvNet().to('cuda')

layer_collection = LayerCollection.from_model(convnet)
v = random_pvector(LayerCollection.from_model(convnet), device='cuda')

layer_collection.numel()

In [None]:
def compute_correlation(M):
    diag = torch.diag(M)
    dM = (diag + diag.mean() / 100) **.5
    return torch.abs(M) / dM.unsqueeze(0) / dM.unsqueeze(1)

In [None]:
# compute timings and display FIMs

timings = dict()

for repr in [PMatDense, PMatBlockDiag, PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]:
    
    timings[repr] = dict()
    
    time_start = time.time()
    F = FIM_MonteCarlo(model=convnet,
                       loader=trainloader,
                       representation=repr,
                       device='cuda')
    time_end = time.time()
    timings[repr]['init'] = time_end - time_start
    
    if repr == PMatEKFAC:
        time_start = time.time()
        F.update_diag(examples=trainloader)
        time_end = time.time()
        timings[repr]['update_diag'] = time_end - time_start
        
    time_start = time.time()
    F.mv(v)
    time_end = time.time()
    timings[repr]['Mv'] = time_end - time_start
    
    time_start = time.time()
    F.vTMv(v)
    time_end = time.time()
    timings[repr]['vTMv'] = time_end - time_start
    
    time_start = time.time()
    F.trace()
    time_end = time.time()
    timings[repr]['tr'] = time_end - time_start
    
    try:
        time_start = time.time()
        F.frobenius_norm()
        time_end = time.time()
        timings[repr]['frob'] = time_end - time_start
    except NotImplementedError:
        pass
    
    try:
        time_start = time.time()
        F.solve(v)
        time_end = time.time()
        timings[repr]['solve'] = time_end - time_start
    except:
        pass
    
    try:
        time_start = time.time()
        F_dense = F.get_dense_tensor()
        time_end = time.time()
        timings[repr]['get_dense'] = time_end - time_start
    except:
        pass
    
    try:
        repr_name = str(repr).split('.')[-1][:-2]
        plt.figure(figsize=(10, 10))
        plt.imshow(compute_correlation(F_dense).cpu())
        plt.title(repr_name)
        plt.savefig('repr_img/'+ repr_name + '.png')
        plt.show()
        plt.close()
        
        del F_dense
        
    except:
        pass
    
    del F

In [None]:
timings

# ResNet50 on CIFAR10

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='/tmp/data', train=True,
                                        download=True, transform=transform)
trainset = torch.utils.data.Subset(trainset, range(100))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=50,
                                          shuffle=False, num_workers=1)

In [None]:
from resnet import ResNet50
resnet = ResNet50().cuda()

layer_collection = LayerCollection.from_model(resnet)
v = random_pvector(LayerCollection.from_model(resnet), device='cuda')

layer_collection.numel()

In [None]:
# compute timings and display FIMs

timings = dict()

for repr in [PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]:
    
    timings[repr] = dict()
    
    time_start = time.time()
#     F = repr(generator=generator)
    F = FIM_MonteCarlo(model=resnet,
                        loader=trainloader,
                        representation=repr,
                        device='cuda')
    time_end = time.time()
    timings[repr]['init'] = time_end - time_start
    
    if repr == PMatEKFAC:
        time_start = time.time()
        F.update_diag(examples=trainloader)
        time_end = time.time()
        timings[repr]['update_diag'] = time_end - time_start
        
    time_start = time.time()
    F.mv(v)
    time_end = time.time()
    timings[repr]['Mv'] = time_end - time_start
    
    time_start = time.time()
    F.vTMv(v)
    time_end = time.time()
    timings[repr]['vTMv'] = time_end - time_start
    
    time_start = time.time()
    F.trace()
    time_end = time.time()
    timings[repr]['tr'] = time_end - time_start
    
    try:
        time_start = time.time()
        F.frobenius_norm()
        time_end = time.time()
        timings[repr]['frob'] = time_end - time_start
    except NotImplementedError:
        pass
    
    try:
        time_start = time.time()
        F.solve(v)
        time_end = time.time()
        timings[repr]['solve'] = time_end - time_start
    except:
        pass
    
    del F

In [None]:
timings