Skip to content

Commit

Permalink
Merge pull request #118 from dnouri/issues/optimizer-initialization
Browse files Browse the repository at this point in the history
make sure model is initialized before optim init
  • Loading branch information
benjamin-work committed Nov 20, 2017
2 parents 25264a5 + 9fc3801 commit 61815df
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,12 @@ def set_params(self, **kwargs):
self.initialize_module()
self.initialize_optimizer()
if any(key.startswith('optimizer') for key in special_params):
# Model selectors such as GridSearchCV will set the
# parameters before .initialize() is called, therefore we
# need to make sure that we have an initialized model here
# as the optimizer depends on it.
if not hasattr(self, 'module_'):
self.initialize_module()
self.initialize_optimizer()

return self
Expand Down
6 changes: 6 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ def test_changing_model_reinitializes_optimizer(self, net, data):
# by the optimizer after 10 epochs.
assert (abs(d2 - d1) > 1e-05).all()

def test_setting_optimizer_needs_model(self, net_cls, module_cls):
net = net_cls(module_cls)
assert not hasattr(net, 'module_')
# should not break
net.set_params(optimizer=torch.optim.SGD)

def test_module_params_in_init(self, net_cls, module_cls, data):
X, y = data

Expand Down

0 comments on commit 61815df

Please sign in to comment.