Skip to content

Commit

Permalink
Merge pull request #43 from tfjgeorge/transposed
Browse files Browse the repository at this point in the history
Transposed
  • Loading branch information
tfjgeorge committed Dec 7, 2021
2 parents f3ad0dc + 3b8f397 commit a91cb2d
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 40 deletions.
32 changes: 22 additions & 10 deletions nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from nngeometry.layercollection import (Affine1dLayer, Cosine1dLayer, LinearLayer, Conv2dLayer, BatchNorm1dLayer,
BatchNorm2dLayer, GroupNormLayer, WeightNorm1dLayer,
WeightNorm2dLayer)
from .grads_conv import conv2d_backward as per_example_grad_conv
WeightNorm2dLayer, ConvTranspose2dLayer)
from .grads_conv import conv2d_backward, convtranspose2d_backward, unfold_transpose_conv2d

import torch.nn.functional as F

Expand Down Expand Up @@ -117,7 +117,7 @@ class Conv2dJacobianFactory(JacobianFactory):
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
w_numel = layer.weight.numel()
indiv_gw = per_example_grad_conv(mod, x, gy)
indiv_gw = conv2d_backward(mod, x, gy)
buffer[:, :w_numel].add_(indiv_gw.view(bs, -1))
if layer.bias is not None:
buffer[:, w_numel:].add_(gy.sum(dim=(2, 3)))
Expand All @@ -142,7 +142,7 @@ def kfac_xx(cls, buffer, mod, layer, x, gy):
.view(-1, A_tilda.size(1))
if layer.bias is not None:
A_tilda = torch.cat([A_tilda,
torch.ones_like(A_tilda[:, :1])],
torch.ones_like(A_tilda[:, :1])],
dim=1)
# Omega_hat in KFC
buffer.add_(torch.mm(A_tilda.t(), A_tilda))
Expand Down Expand Up @@ -183,21 +183,32 @@ def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g):
@classmethod
def quasidiag(cls, buffer_diag, buffer_cross, mod, layer, x, gy):
w_numel = layer.weight.numel()
indiv_gw = per_example_grad_conv(mod, x, gy)
indiv_gw = conv2d_backward(mod, x, gy)
buffer_diag[:w_numel].add_((indiv_gw**2).sum(dim=0).view(-1))
if layer.bias is not None:
gb_per_example = gy.sum(dim=(2, 3))
buffer_diag[w_numel:].add_((gb_per_example**2).sum(dim=0))
y = (gy * gb_per_example.unsqueeze(2).unsqueeze(3))
cross_this = F.conv2d(x.transpose(0, 1),
y.transpose(0, 1),
stride=mod.dilation,
padding=mod.padding,
dilation=mod.stride).transpose(0, 1)
y.transpose(0, 1),
stride=mod.dilation,
padding=mod.padding,
dilation=mod.stride).transpose(0, 1)
cross_this = cross_this[:, :, :mod.kernel_size[0], :mod.kernel_size[1]]
buffer_cross.add_(cross_this)


class ConvTranspose2dJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
w_numel = layer.weight.numel()
indiv_gw = convtranspose2d_backward(mod, x, gy)
buffer[:, :w_numel].add_(indiv_gw.view(bs, -1))
if layer.bias is not None:
buffer[:, w_numel:].add_(gy.sum(dim=(2, 3)))


def check_bn_training(mod):
# check that BN layers are in eval mode
if mod.training:
Expand Down Expand Up @@ -261,7 +272,7 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
out_dim = mod.weight.size(0)
norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps
gw = per_example_grad_conv(mod, x, gy / torch.sqrt(norm2).view(1, out_dim, 1, 1))
gw = conv2d_backward(mod, x, gy / torch.sqrt(norm2).view(1, out_dim, 1, 1))
gw = gw.view(bs, out_dim, -1)
wn2_out = F.conv2d(x, mod.weight / norm2.view(out_dim, 1, 1, 1)**1.5, None,
stride=mod.stride, padding=mod.padding, dilation=mod.dilation)
Expand Down Expand Up @@ -296,6 +307,7 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
FactoryMap = {
LinearLayer: LinearJacobianFactory,
Conv2dLayer: Conv2dJacobianFactory,
ConvTranspose2dLayer: ConvTranspose2dJacobianFactory,
BatchNorm1dLayer: BatchNorm1dJacobianFactory,
BatchNorm2dLayer: BatchNorm2dJacobianFactory,
GroupNormLayer: GroupNormJacobianFactory,
Expand Down
45 changes: 44 additions & 1 deletion nngeometry/generator/jacobian/grads_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
from torch._C import unify_type_list
import torch.nn.functional as F

def conv_backward(input, grad_output, in_channels, out_channels, kernel_size,
Expand Down Expand Up @@ -134,4 +135,46 @@ def __enter__(self):
_conv_grad_impl.use_unfold(False)

def __exit__(self, exc_type, exc_value, traceback):
_conv_grad_impl._use_unfold = self.prev
_conv_grad_impl._use_unfold = self.prev


def convtranspose2d_backward(mod, x, gy):
'''Computes per-example gradients for nn.ConvTranspose2d layers.'''
bs = gy.size(0)
s_i, s_o, k_h, k_w = mod.weight.size()
x_unfold = unfold_transpose_conv2d(mod, x)

x_perm = x_unfold.view(bs, s_i*k_w*k_h, -1).permute(0, 2, 1)
o = torch.bmm(gy.view(bs, s_o, -1), x_perm)
o = o.view(bs, s_o, s_i, k_h, k_w).permute(0, 2, 1, 3, 4)
o = o.contiguous()
return o


def unfold_transpose_conv2d(mod, x):
unfold_filter = _filter_bank.get(mod)
return F.conv_transpose2d(x, unfold_filter, stride=mod.stride, padding=mod.padding,
output_padding=mod.output_padding, groups=mod.in_channels,
dilation=mod.dilation)

class TransposeConv_Unfold_Filter_Bank:

def __init__(self):
self.filters = dict()

def get(self, mod):
if mod not in self.filters:
self.filters[mod] = self._create_unfold_filter(mod)
return self.filters[mod]

def _create_unfold_filter(self, mod):
kw, kh = mod.kernel_size
unfold_filter = mod.weight.data.new(mod.in_channels, kw * kh, kw, kh)
unfold_filter.fill_(0)
for i in range(mod.in_channels):
for j in range(kw):
for k in range(kh):
unfold_filter[i, k + kh*j, j, k] = 1
return unfold_filter

_filter_bank = TransposeConv_Unfold_Filter_Bank()
33 changes: 32 additions & 1 deletion nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class LayerCollection:

_known_modules = ['Linear', 'Conv2d', 'BatchNorm1d',
'BatchNorm2d', 'GroupNorm', 'WeightNorm1d',
'WeightNorm2d', 'Cosine1d', 'Affine1d']
'WeightNorm2d', 'Cosine1d', 'Affine1d',
'ConvTranspose2d']

def __init__(self, layers=None):
if layers is None:
Expand Down Expand Up @@ -91,6 +92,11 @@ def _module_to_layer(mod):
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
bias=(mod.bias is not None))
elif mod_class == 'ConvTranspose2d':
return ConvTranspose2dLayer(in_channels=mod.in_channels,
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
bias=(mod.bias is not None))
elif mod_class == 'BatchNorm1d':
return BatchNorm1dLayer(num_features=mod.num_features)
elif mod_class == 'BatchNorm2d':
Expand Down Expand Up @@ -172,6 +178,31 @@ def __eq__(self, other):
self.kernel_size == other.kernel_size)


class ConvTranspose2dLayer(AbstractLayer):

def __init__(self, in_channels, out_channels, kernel_size, bias=True):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.weight = Parameter(out_channels, in_channels, kernel_size[0],
kernel_size[1])
if bias:
self.bias = Parameter(out_channels)
else:
self.bias = None

def numel(self):
if self.bias is not None:
return self.weight.numel() + self.bias.numel()
else:
return self.weight.numel()

def __eq__(self, other):
return (self.in_channels == other.in_channels and
self.out_channels == other.out_channels and
self.kernel_size == other.kernel_size)


class LinearLayer(AbstractLayer):

def __init__(self, in_features, out_features, bias=True):
Expand Down
56 changes: 33 additions & 23 deletions tests/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as tF
from torch.nn.modules.conv import ConvTranspose2d
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from nngeometry.layercollection import LayerCollection
Expand Down Expand Up @@ -109,11 +110,11 @@ def forward(self, x):
x = tF.relu(self.wn1(x))
else:
x = tF.relu(self.conv1(x))
x = tF.max_pool2d(x, 2, 2)
x = tF.avg_pool2d(x, 2, 2)
x = tF.relu(self.conv2(x))
x = tF.max_pool2d(x, 2, 2)
x = tF.avg_pool2d(x, 2, 2)
x = tF.relu(self.conv3(x), inplace=True)
x = x.view(-1, 1*1*7)
x = x.view(-1, 7)
if self.normalization == 'batch_norm':
x = self.bn2(self.fc1(x))
elif self.normalization == 'weight_norm':
Expand All @@ -132,10 +133,14 @@ def __init__(self, normalization='none'):
if normalization == 'weight_norm':
self.l1 = WeightNorm2d(1, 6, 3, 2)
self.l2 = WeightNorm2d(6, 3, 2, 3)
elif normalization == 'transpose':
self.l1 = ConvTranspose2d(1, 6, (3, 2), 2)
self.l2 = ConvTranspose2d(6, 3, (2, 3), 3, bias=False)
else:
raise NotImplementedError

def forward(self, x):
x = x[:, :, 5:-5, 5:-5]
x = tF.relu(self.l1(x))
x = tF.relu(self.l2(x))
return x.sum(dim=(2, 3))
Expand All @@ -157,10 +162,10 @@ def forward(self, x):

def get_linear_fc_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = LinearFCNet()
to_device_model(net)
Expand Down Expand Up @@ -190,10 +195,10 @@ def forward(self, x):

def get_linear_conv_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = LinearConvNet()
to_device_model(net)
Expand Down Expand Up @@ -225,10 +230,10 @@ def forward(self, x):

def get_batchnorm_fc_linear_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = BatchNormFCLinearNet()
to_device_model(net)
Expand Down Expand Up @@ -267,10 +272,10 @@ def forward(self, x):

def get_batchnorm_conv_linear_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = BatchNormConvLinearNet()
to_device_model(net)
Expand Down Expand Up @@ -320,10 +325,10 @@ def forward(self, x):

def get_batchnorm_nonlinear_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=1000,
batch_size=30,
shuffle=False)
net = BatchNormNonLinearNet()
to_device_model(net)
Expand All @@ -346,10 +351,10 @@ def get_mnist():

def get_fullyconnect_task(normalization='none'):
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = FCNet(out_size=3, normalization=normalization)
to_device_model(net)
Expand Down Expand Up @@ -381,10 +386,10 @@ def get_fullyconnect_affine_task():

def get_conv_task(normalization='none'):
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = ConvNet(normalization=normalization)
to_device_model(net)
Expand Down Expand Up @@ -413,12 +418,13 @@ def get_conv_wn_task():
def get_conv_cosine_task():
return get_conv_task(normalization='cosine')


def get_conv_task(normalization='none', small=False):
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
if small:
net = SmallConvNet(normalization=normalization)
Expand All @@ -438,6 +444,10 @@ def output_fn(input, target):
def get_small_conv_wn_task():
return get_conv_task(normalization='weight_norm', small=True)


def get_small_conv_transpose_task():
return get_conv_task(normalization='transpose', small=True)

def get_fullyconnect_onlylast_task():
train_loader, lc_full, _, net, output_fn, n_output = \
get_fullyconnect_task()
Expand All @@ -450,10 +460,10 @@ def get_fullyconnect_onlylast_task():

def get_fullyconnect_segm_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = FCNetSegmentation(out_size=3)
to_device_model(net)
Expand Down Expand Up @@ -485,10 +495,10 @@ def forward(self, x):

def get_conv_skip_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_set = Subset(train_set, range(70))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
batch_size=30,
shuffle=False)
net = ConvNetWithSkipConnection()
to_device_model(net)
Expand Down

0 comments on commit a91cb2d

Please sign in to comment.