Skip to content

Commit

Permalink
Fix a bug that led to double-registration (#781)
Browse files Browse the repository at this point in the history
After cloning a net, _module, _criteria, and _optimizers are already
populated. Then, when loading params, there is yet another registration,
i.e. a double registration. As a consequence, there would be two
'modules' etc. This is a fix for that.
  • Loading branch information
BenjaminBossan committed Jun 20, 2021
1 parent 852383e commit 1165a78
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
13 changes: 9 additions & 4 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,10 @@ def get_params(self, deep=True, **kwargs):
# special treatment.
params_cb = self._get_params_callbacks(deep=deep)
params.update(params_cb)
return params

# don't include the following attributes
to_exclude = {'_modules', '_criteria', '_optimizers'}
return {key: val for key, val in params.items() if key not in to_exclude}

def _check_kwargs(self, kwargs):
"""Check argument names passed at initialization.
Expand Down Expand Up @@ -2095,11 +2098,13 @@ def _register_attribute(
self.cuda_dependent_attributes_ = (
self.cuda_dependent_attributes_[:] + [name + '_'])

if self.init_context_ == 'module':
# make sure to not double register -- this should never happen, but
# still better to check
if (self.init_context_ == 'module') and (name not in self._modules):
self._modules = self._modules[:] + [name]
elif self.init_context_ == 'criterion':
elif (self.init_context_ == 'criterion') and (name not in self._criteria):
self._criteria = self._criteria[:] + [name]
elif self.init_context_ == 'optimizer':
elif (self.init_context_ == 'optimizer') and (name not in self._optimizers):
self._optimizers = self._optimizers[:] + [name]

def _unregister_attribute(
Expand Down
47 changes: 47 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,44 @@ def test_save_load_state_dict_str(
score_after = accuracy_score(y, net.predict(X))
assert np.isclose(score_after, score_before)

def test_save_load_state_dict_no_duplicate_registration_after_initialize(
self, net_cls, module_cls, net_fit, tmpdir):
# #781
net = net_cls(module_cls).initialize()

p = tmpdir.mkdir('skorch').join('testmodel.pkl')
with open(str(p), 'wb') as f:
net_fit.save_params(f_params=f)
del net_fit

with open(str(p), 'rb') as f:
net.load_params(f_params=f)

# check that there are no duplicates in _modules, _criteria, _optimizers
# pylint: disable=protected-access
assert net._modules == ['module']
assert net._criteria == ['criterion']
assert net._optimizers == ['optimizer']

def test_save_load_state_dict_no_duplicate_registration_after_clone(
self, net_fit, tmpdir):
# #781
net = clone(net_fit).initialize()

p = tmpdir.mkdir('skorch').join('testmodel.pkl')
with open(str(p), 'wb') as f:
net_fit.save_params(f_params=f)
del net_fit

with open(str(p), 'rb') as f:
net.load_params(f_params=f)

# check that there are no duplicates in _modules, _criteria, _optimizers
# pylint: disable=protected-access
assert net._modules == ['module']
assert net._criteria == ['criterion']
assert net._optimizers == ['optimizer']

@pytest.fixture(scope='module')
def net_fit_adam(self, net_cls, module_cls, data):
net = net_cls(
Expand Down Expand Up @@ -1426,6 +1464,15 @@ def test_get_params_works(self, net_cls, module_cls):
# now initialized
assert 'callbacks__myscore__scoring' in params

def test_get_params_no_unwanted_params(self, net, net_fit):
# #781
# make sure certain keys are not returned
keys_unwanted = {'_modules', '_criteria', '_optimizers'}
for net_ in (net, net_fit):
keys_found = set(net_.get_params())
overlap = keys_found & keys_unwanted
assert not overlap

def test_get_params_with_uninit_callbacks(self, net_cls, module_cls):
from skorch.callbacks import EpochTimer

Expand Down

0 comments on commit 1165a78

Please sign in to comment.