Skip to content

Commit

Permalink
Fix bug that occurred when partialed module was passed.
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-work authored and ottonemo committed Oct 27, 2017
1 parent 59e6f23 commit 2149a5f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def initialize_module(self):
"""
kwargs = self._get_params_for('module')
module = self.module
is_initialized = not isinstance(module, type)
is_initialized = isinstance(module, torch.nn.Module)

if kwargs or not is_initialized:
if is_initialized:
Expand Down
5 changes: 5 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ def test_module_params_in_init(self, net_cls, module_cls, data):
assert net.module_.dense1.in_features == 20
assert net.module_.nonlin is F.tanh

def test_module_initialized_with_partial_module(self, net_cls, module_cls):
net = net_cls(partial(module_cls, num_units=123))
net.initialize()
assert net.module_.dense0.out_features == 123

def test_criterion_init_with_params(self, net_cls, module_cls):
mock = Mock()
net = net_cls(module_cls, criterion=mock, criterion__spam='eggs')
Expand Down

0 comments on commit 2149a5f

Please sign in to comment.