Skip to content

Commit

Permalink
Merge pull request #36 from tfjgeorge/efficient_conv_grads
Browse files Browse the repository at this point in the history
adds efficient per_example grads for convs
  • Loading branch information
tfjgeorge committed Nov 26, 2021
2 parents 2e6dea0 + 7677f8d commit 391526d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 13 deletions.
1 change: 0 additions & 1 deletion nngeometry/generator/jacobian/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from nngeometry.utils import per_example_grad_conv
from nngeometry.object.vector import PVector, FVector
from nngeometry.layercollection import LayerCollection
from .grads import FactoryMap
Expand Down
6 changes: 5 additions & 1 deletion nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from nngeometry.layercollection import (Affine1dLayer, Cosine1dLayer, LinearLayer, Conv2dLayer, BatchNorm1dLayer,
BatchNorm2dLayer, GroupNormLayer, WeightNorm1dLayer,
WeightNorm2dLayer)
from nngeometry.utils import per_example_grad_conv
from .grads_conv import conv2d_backward as per_example_grad_conv

import torch.nn.functional as F


Expand Down Expand Up @@ -261,8 +262,11 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
out_dim = mod.weight.size(0)
norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps
gw = per_example_grad_conv(mod, x, gy / torch.sqrt(norm2).view(1, out_dim, 1, 1))
gw = gw.view(bs, out_dim, -1)
wn2_out = F.conv2d(x, mod.weight / norm2.view(out_dim, 1, 1, 1)**1.5, None,
stride=mod.stride, padding=mod.padding, dilation=mod.dilation)
t = (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1)
print(gw.size(), t.size())
gw -= (gy * wn2_out).sum(dim=(2, 3)).view(bs, out_dim, 1) * mod.weight.view(1, out_dim, -1)
buffer.add_(gw.view(bs, -1))

Expand Down
86 changes: 86 additions & 0 deletions nngeometry/generator/jacobian/grads_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Author(s): Gaspar Rochette <gaspar.rochette@ens.fr>
# License: BSD 3 clause
# These functions are borrowed from https://github.com/owkin/grad-cnns

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


def conv_backward(input, grad_output, in_channels, out_channels, kernel_size,
stride=1, dilation=1, padding=0, groups=1, nd=1):
'''Computes per-example gradients for nn.Conv1d and nn.Conv2d layers.
This function is used in the internal behaviour of bnn.Linear.
'''

# Change format of stride from int to tuple if necessary.
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * nd
if isinstance(stride, int):
stride = (stride,) * nd
if isinstance(dilation, int):
dilation = (dilation,) * nd
if isinstance(padding, int):
padding = (padding,) * nd

# Get some useful sizes
batch_size = input.size(0)
input_shape = input.size()[-nd:]
output_shape = grad_output.size()[-nd:]

# Reshape to extract groups from the convolutional layer
# Channels are seen as an extra spatial dimension with kernel size 1
input_conv = input.view(1, batch_size * groups, in_channels // groups, *input_shape)

# Compute convolution between input and output; the batchsize is seen
# as channels, taking advantage of the `groups` argument
grad_output_conv = grad_output.view(-1, 1, 1, *output_shape)

stride = (1, *stride)
dilation = (1, *dilation)
padding = (0, *padding)

if nd == 1:
convnd = F.conv2d
s_ = np.s_[..., :kernel_size[0]]
elif nd == 2:
convnd = F.conv3d
s_ = np.s_[..., :kernel_size[0], :kernel_size[1]]
elif nd == 3:
raise NotImplementedError('3d convolution is not available with current per-example gradient computation')

conv = convnd(
input_conv, grad_output_conv,
groups=batch_size * groups,
stride=dilation,
dilation=stride,
padding=padding
)

# Because of rounding shapes when using non-default stride or dilation,
# convolution result must be truncated to convolution kernel size
conv = conv[s_]

# Reshape weight gradient to correct shape
new_shape = [batch_size, out_channels, in_channels // groups, *kernel_size]
weight_bgrad = conv.view(*new_shape).contiguous()

return weight_bgrad


def conv1d_backward(*args, **kwargs):
'''Computes per-example gradients for nn.Conv1d layers.'''
return conv_backward(*args, nd=1, **kwargs)


def conv2d_backward(mod, x, gy):
'''Computes per-example gradients for nn.Conv2d layers.'''
return conv_backward(x, gy, nd=2,
in_channels=mod.in_channels,
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
stride=mod.stride,
dilation=mod.dilation,
padding=mod.padding,
groups=mod.groups)
11 changes: 0 additions & 11 deletions nngeometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,6 @@
from nngeometry.object.vector import PVector


def per_example_grad_conv(mod, x, gy):
ks = (mod.weight.size(2), mod.weight.size(3))
gy_s = gy.size()
bs = gy_s[0]
x_unfold = F.unfold(x, kernel_size=ks, stride=mod.stride,
padding=mod.padding, dilation=mod.dilation)
x_unfold_s = x_unfold.size()
return torch.bmm(gy.view(bs, gy_s[1], -1),
x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1))


def display_correl(M, axis):

M = M.get_dense_tensor()
Expand Down

0 comments on commit 391526d

Please sign in to comment.