Skip to content

Commit

Permalink
Merge pull request #35 from tfjgeorge/lc_km
Browse files Browse the repository at this point in the history
moves known modules to outside of LC methods
  • Loading branch information
tfjgeorge committed Nov 17, 2021
2 parents 1970601 + ff14483 commit 556bec9
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class LayerCollection:
:param layers:
"""

_known_modules = ['Linear', 'Conv2d', 'BatchNorm1d',
'BatchNorm2d', 'GroupNorm', 'WeightNorm1d',
'WeightNorm2d', 'Cosine1d', 'Affine1d']

def __init__(self, layers=None):
if layers is None:
self.layers = OrderedDict()
Expand All @@ -37,9 +41,7 @@ def from_model(model, ignore_unsupported_layers=False):
lc = LayerCollection()
for layer, mod in model.named_modules():
mod_class = mod.__class__.__name__
if mod_class in ['Linear', 'Conv2d', 'BatchNorm1d',
'BatchNorm2d', 'GroupNorm', 'WeightNorm1d',
'WeightNorm2d', 'Cosine1d', 'Affine1d']:
if mod_class in LayerCollection._known_modules:
lc.add_layer('%s.%s' % (layer, str(mod)),
LayerCollection._module_to_layer(mod))
elif not ignore_unsupported_layers:
Expand Down Expand Up @@ -71,10 +73,7 @@ def add_layer_from_model(self, model, module):
:param model: The model defining the neural network
:param module: The layer to be added
"""
if module.__class__.__name__ not in \
['Linear', 'Conv2d', 'BatchNorm1d',
'BatchNorm2d', 'GroupNorm', 'WeightNorm1d',
'WeightNorm2d', 'Cosine1d', 'Affine1d']:
if module.__class__.__name__ not in LayerCollection._known_modules:
raise NotImplementedError
for layer, mod in model.named_modules():
if mod is module:
Expand Down

0 comments on commit 556bec9

Please sign in to comment.