<a href="https://colab.research.google.com/github/tfjgeorge/nngeometry-examples/blob/main/Gram_matrix_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/tfjgeorge/nngeometry.git

Collecting git+https://github.com/tfjgeorge/nngeometry.git
  Cloning https://github.com/tfjgeorge/nngeometry.git to /tmp/pip-req-build-axq6z45y
  Running command git clone -q https://github.com/tfjgeorge/nngeometry.git /tmp/pip-req-build-axq6z45y
Building wheels for collected packages: nngeometry
  Building wheel for nngeometry (setup.py) ... [?25l[?25hdone
  Created wheel for nngeometry: filename=nngeometry-0.2.1-cp37-none-any.whl size=23027 sha256=ccc0b40b342878b495a0435bcc5c37746911579127b88bffa44bcfcf56453ced
  Stored in directory: /tmp/pip-ephem-wheel-cache-44j17pjy/wheels/0e/82/b3/42a1a59c9ab5dcb2a16c557430ef6bbdce07fe33ac46af6beb
Successfully built nngeometry
Installing collected packages: nngeometry
Successfully installed nngeometry-0.2.1


# PyTorch dataloader and model definition

In the next cells, this is just your regular model and dataloader definition using standard PyTorch classes. Nothing here is specific to NNGeometry.

We now start by defining our model. We here use a ResNet18.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

model = ResNet18().cuda()

Next, we define the dataloader on which we compute the Gram matrix. Notice the specifics:

- in the `Dataloader` instantiation, we pass `shuffle=False` so that examples in the Gram matrix are arranged in a deterministic way, i.e. the first example in the Gram matrix is the first example in the Dataloader and so on.
- We used a subset of 100 examples of the original test set, since the Gram matrix grows as $n^2$ with $n=$#examples.
- In order to improve performance, we copied the dataset into GPU memory using the `to_tensordataset` function.

In [3]:
from torch.utils.data import DataLoader, TensorDataset, Subset

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = Subset(CIFAR10(root='/tmp', train=False, download=True,
                         transform=transform), range(100))

def to_tensordataset(dataset):
    d = next(iter(DataLoader(dataset,
                  batch_size=len(dataset))))
    return TensorDataset(d[0].to('cuda'), d[1].to('cuda'))

testloader = DataLoader(to_tensordataset(testset), batch_size=100, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting /tmp/cifar-10-python.tar.gz to /tmp


Now that we are done with everything on the PyTorch side, let's get to NNGeometry !

# Computing a Gram matrix

In [4]:
from nngeometry.generator import Jacobian
from nngeometry.object import FMatDense

generator = Jacobian(model=model, n_output=10)
K = FMatDense(generator, examples=testloader)



`K` is a FMatDense object, we can convert to a PyTorch tensor with the following:

In [5]:
K_torch = K.get_dense_tensor()

`K_torch` is arranged as a 10 x 100 x 10 x 100 tensor since we are here using a 10 classes task with 100 examples

In [6]:
K_torch.size()

torch.Size([10, 100, 10, 100])