Skip to content

Commit

Permalink
updates FIM for EWC example to add a mixed KFAC/BlockDiag example whe…
Browse files Browse the repository at this point in the history
…n using batch norm
  • Loading branch information
tfjgeorge committed Apr 13, 2020
1 parent 2b50b48 commit e116a32
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 75 deletions.
271 changes: 217 additions & 54 deletions examples/FIM for EWC.ipynb

Large diffs are not rendered by default.

52 changes: 36 additions & 16 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,10 @@ def from_model(model):
lc = LayerCollection()
for layer, mod in model.named_modules():
mod_class = mod.__class__.__name__
if mod_class == 'Linear':
if mod_class in ['Linear', 'Conv2d', 'BatchNorm1d',
'BatchNorm2d']:
lc.add_layer('%s.%s' % (layer, str(mod)),
LinearLayer(in_features=mod.in_features,
out_features=mod.out_features,
bias=(mod.bias is not None)))
elif mod_class == 'Conv2d':
lc.add_layer('%s.%s' % (layer, str(mod)),
Conv2dLayer(in_channels=mod.in_channels,
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
bias=(mod.bias is not None)))
elif mod_class == 'BatchNorm1d':
lc.add_layer('%s.%s' % (layer, str(mod)),
BatchNorm1dLayer(num_features=mod.num_features))
elif mod_class == 'BatchNorm2d':
lc.add_layer('%s.%s' % (layer, str(mod)),
BatchNorm2dLayer(num_features=mod.num_features))
LayerCollection._module_to_layer(mod))

return lc

Expand All @@ -54,6 +41,39 @@ def add_layer(self, name, layer):
self.p_pos[name] = self._numel
self._numel += layer.numel()

def add_layer_from_model(self, model, module):
"""
Add a layer by specifying the module corresponding
to this layer (e.g. torch.nn.Linear or torch.nn.BatchNorm1d)
:param model: The model defining the neural network
:param module: The layer to be added
"""
if module.__class__.__name__ not in \
['Linear', 'Conv2d', 'BatchNorm1d',
'BatchNorm2d']:
raise NotImplementedError
for layer, mod in model.named_modules():
if mod is module:
self.add_layer('%s.%s' % (layer, str(mod)),
LayerCollection._module_to_layer(mod))

def _module_to_layer(mod):
mod_class = mod.__class__.__name__
if mod_class == 'Linear':
return LinearLayer(in_features=mod.in_features,
out_features=mod.out_features,
bias=(mod.bias is not None))
elif mod_class == 'Conv2d':
return Conv2dLayer(in_channels=mod.in_channels,
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
bias=(mod.bias is not None))
elif mod_class == 'BatchNorm1d':
return BatchNorm1dLayer(num_features=mod.num_features)
elif mod_class == 'BatchNorm2d':
return BatchNorm2dLayer(num_features=mod.num_features)

def numel(self):
return self._numel

Expand Down
10 changes: 5 additions & 5 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,11 @@ def vTMv(self, vector):
# TODO test
vector_dict = vector.get_dict_representation()
norm2 = 0
for mod in vector_dict.keys():
v = vector_dict[mod][0].view(-1)
if len(vector_dict[mod]) > 1:
v = torch.cat([v, vector_dict[mod][1].view(-1)])
norm2 += torch.dot(torch.mv(self.data[mod], v), v)
for layer_id, layer in self.generator.layer_collection.layers.items():
v = vector_dict[layer_id][0].view(-1)
if len(vector_dict[layer_id]) > 1:
v = torch.cat([v, vector_dict[layer_id][1].view(-1)])
norm2 += torch.dot(torch.mv(self.data[layer_id], v), v)
return norm2

def __add__(self, other):
Expand Down

0 comments on commit e116a32

Please sign in to comment.