Skip to content

Commit

Permalink
Fix bug that could lead to duplicate params (#783)
Browse files Browse the repository at this point in the history
When a net's module references another of that net's modules, the former
will yield the latter's parameters. Therefore, when all parameters are
collected, the latter's parameters appear twice.

This bugfix consists of remembering all yielded parameters in a set and
not yielding those that were already encountered.

This bug was introduced very recently (751) and should occur very
rarely, since modules don't typically reference each other.
  • Loading branch information
BenjaminBossan committed Jun 23, 2021
1 parent 1165a78 commit 906932c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
17 changes: 15 additions & 2 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,24 @@ def initialize_optimizer(self, *args, **kwargs):
pass the named parameters to :meth:`.get_params_for_optimizer`.
"""
# Note: we have to filter out potential duplicate parameters. This can
# happen when a module references another module (e.g. the criterion
# references the module), thus yielding that module's parameters again.
# The parameter name can be difference, therefore we check only the
# identity of the parameter itself.
seen = set()
for name in self._modules + self._criteria:
module = getattr(self, name + '_')
named_parameters = getattr(module, 'named_parameters', None)
if named_parameters:
yield from named_parameters()
if not named_parameters:
continue

for param_name, param in named_parameters():
if param in seen:
continue

seen.add(param)
yield param_name, param

def _initialize_optimizer(self, reason=None):
with self._current_init_context('optimizer'):
Expand Down
23 changes: 23 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2707,6 +2707,29 @@ def initialize_optimizer(self):
# module is not re-initialized, since virtual parameter
assert len(side_effects) == 1

def test_module_referencing_another_module_no_duplicate_params(
self, net_cls, module_cls
):
# When a module references another module, it will yield that modules'
# parameters. Therefore, if we collect all paramters, we have to make
# sure that there are no duplicate parameters.
class MyCriterion(torch.nn.NLLLoss):
"""Criterion that references net.module_"""
def __init__(self, *args, themodule, **kwargs):
super().__init__(*args, **kwargs)
self.themodule = themodule

class MyNet(net_cls):
def initialize_criterion(self):
kwargs = self.get_params_for('criterion')
kwargs['themodule'] = self.module_
self.criterion_ = self.criterion(**kwargs)
return self

net = MyNet(module_cls, criterion=MyCriterion).initialize()
params = [p for _, p in net.get_all_learnable_params()]
assert len(params) == len(set(params))

def test_custom_optimizer_lr_is_associated_with_optimizer(
self, net_cls, module_cls,
):
Expand Down

0 comments on commit 906932c

Please sign in to comment.