Skip to content

Commit

Permalink
Merge pull request #38 from tfjgeorge/conv_switch
Browse files Browse the repository at this point in the history
falls back to unfold implementation of convolution gradients
  • Loading branch information
tfjgeorge committed Dec 1, 2021
2 parents bcbaa78 + 13b89b9 commit c5a0381
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 4 deletions.
130 changes: 130 additions & 0 deletions benchmark/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
'''ResNet in PyTorch.
from: https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
# self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
# self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.conv1(x))
out = self.conv2(out)
out = F.relu(out + self.shortcut(x))
return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
# self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
# self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion *
planes, kernel_size=1, bias=False)
# self.bn3 = nn.BatchNorm2d(self.expansion*planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = self.conv3(out)
out = F.relu(out + self.shortcut(x))
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)
# self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.conv1(x))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out


def ResNet18():
return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
net = ResNet18()
y = net(torch.randn(1, 3, 32, 32))
print(y.size())

# test()
107 changes: 107 additions & 0 deletions benchmark/timings_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# %%
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import time
import pprint


from nngeometry.layercollection import LayerCollection
from nngeometry.metrics import FIM_MonteCarlo
from nngeometry.object.vector import random_pvector
from nngeometry.generator import jacobian as nnj

from nngeometry.object import PMatDiag, PMatKFAC, PMatEKFAC, PMatQuasiDiag, PMatImplicit


# # ResNet50 on CIFAR10

# %%
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = datasets.CIFAR10(root='/tmp/data', train=True,
download=True, transform=transform)
trainset = Subset(trainset, range(100))
trainloader = DataLoader(trainset, batch_size=50,
shuffle=False, num_workers=1)

# %%
from resnet import ResNet50
resnet = ResNet50().cuda()

layer_collection = LayerCollection.from_model(resnet)
v = random_pvector(LayerCollection.from_model(resnet), device='cuda')

print(f'{layer_collection.numel()} parameters')

# %%
# compute timings and display FIMs

def perform_timing():
timings = dict()

for repr in [PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]:

print('Timing representation:')
pprint.pprint(repr)

timings[repr] = dict()

time_start = time.time()
F = FIM_MonteCarlo(model=resnet,
loader=trainloader,
representation=repr,
device='cuda')
time_end = time.time()
timings[repr]['init'] = time_end - time_start

if repr == PMatEKFAC:
time_start = time.time()
F.update_diag(examples=trainloader)
time_end = time.time()
timings[repr]['update_diag'] = time_end - time_start

time_start = time.time()
F.mv(v)
time_end = time.time()
timings[repr]['Mv'] = time_end - time_start

time_start = time.time()
F.vTMv(v)
time_end = time.time()
timings[repr]['vTMv'] = time_end - time_start

time_start = time.time()
F.trace()
time_end = time.time()
timings[repr]['tr'] = time_end - time_start

try:
time_start = time.time()
F.frobenius_norm()
time_end = time.time()
timings[repr]['frob'] = time_end - time_start
except NotImplementedError:
pass

try:
time_start = time.time()
F.solve(v)
time_end = time.time()
timings[repr]['solve'] = time_end - time_start
except:
pass

del F

pprint.pprint(timings)

# %%

with nnj.use_unfold_impl_for_convs():
perform_timing()

with nnj.use_conv_impl_for_convs():
perform_timing()
1 change: 1 addition & 0 deletions nngeometry/generator/jacobian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from nngeometry.object.vector import PVector, FVector
from nngeometry.layercollection import LayerCollection
from .grads import FactoryMap
from .grads_conv import use_conv_impl_for_convs, use_unfold_impl_for_convs

class Jacobian:
"""
Expand Down
59 changes: 55 additions & 4 deletions nngeometry/generator/jacobian/grads_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
# These functions are borrowed from https://github.com/owkin/grad-cnns

import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F


def conv_backward(input, grad_output, in_channels, out_channels, kernel_size,
stride=1, dilation=1, padding=0, groups=1, nd=1):
'''Computes per-example gradients for nn.Conv1d and nn.Conv2d layers.
Expand Down Expand Up @@ -74,7 +73,7 @@ def conv1d_backward(*args, **kwargs):
return conv_backward(*args, nd=1, **kwargs)


def conv2d_backward(mod, x, gy):
def conv2d_backward_using_conv(mod, x, gy):
'''Computes per-example gradients for nn.Conv2d layers.'''
return conv_backward(x, gy, nd=2,
in_channels=mod.in_channels,
Expand All @@ -83,4 +82,56 @@ def conv2d_backward(mod, x, gy):
stride=mod.stride,
dilation=mod.dilation,
padding=mod.padding,
groups=mod.groups)
groups=mod.groups)


def conv2d_backward_using_unfold(mod, x, gy):
'''Computes per-example gradients for nn.Conv2d layers.'''
ks = (mod.weight.size(2), mod.weight.size(3))
gy_s = gy.size()
bs = gy_s[0]
x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride,
padding=mod.padding, dilation=mod.dilation)
x_unfold_s = x_unfold.size()
return torch.bmm(gy.view(bs, gy_s[1], -1),
x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1))


def conv2d_backward(*args, **kwargs):
return _conv_grad_impl.get_impl()(*args, **kwargs)


class ConvGradImplManager:

def __init__(self):
self._use_unfold = True

def use_unfold(self, choice=True):
self._use_unfold = choice

def get_impl(self):
if self._use_unfold:
return conv2d_backward_using_unfold
else:
return conv2d_backward_using_conv


_conv_grad_impl = ConvGradImplManager()

class use_unfold_impl_for_convs:

def __enter__(self):
self.prev = _conv_grad_impl._use_unfold
_conv_grad_impl.use_unfold(True)

def __exit__(self, exc_type, exc_value, traceback):
_conv_grad_impl._use_unfold = self.prev

class use_conv_impl_for_convs:

def __enter__(self):
self.prev = _conv_grad_impl._use_unfold
_conv_grad_impl.use_unfold(False)

def __exit__(self, exc_type, exc_value, traceback):
_conv_grad_impl._use_unfold = self.prev
22 changes: 22 additions & 0 deletions tests/test_conv_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from nngeometry.generator import Jacobian, jacobian
from nngeometry.object.pspace import PMatDense
from tasks import get_conv_task
from utils import check_tensors

def test_conv_impl_switch():
loader, lc, parameters, model, function, n_output = get_conv_task()
generator = Jacobian(layer_collection=lc,
model=model,
function=function,
n_output=n_output)

with jacobian.use_unfold_impl_for_convs():
PMat_dense_unfold = PMatDense(generator=generator,
examples=loader)

with jacobian.use_conv_impl_for_convs():
PMat_dense_conv = PMatDense(generator=generator,
examples=loader)

check_tensors(PMat_dense_unfold.get_dense_tensor(),
PMat_dense_conv.get_dense_tensor())

0 comments on commit c5a0381

Please sign in to comment.