Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/dnouri/inferno
Browse files Browse the repository at this point in the history
  • Loading branch information
ottonemo committed Oct 30, 2017
2 parents 2f8aa27 + 2149a5f commit 708dd8d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
12 changes: 12 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,20 @@
:width: 30%
======

|docs|

A scikit-learn compatible neural network library that wraps pytorch.

.. |docs| image:: https://readthedocs.org/projects/skorch/badge/?version=latest
:alt: Documentation Status
:scale: 100%
:target: https://skorch.readthedocs.io/en/latest/?badge=latest

Resources:

- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
- `Source Code <https://github.com/dnouri/skorch/>`_

Example
-------

Expand Down
2 changes: 1 addition & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def initialize_module(self):
"""
kwargs = self._get_params_for('module')
module = self.module
is_initialized = not isinstance(module, type)
is_initialized = isinstance(module, torch.nn.Module)

if kwargs or not is_initialized:
if is_initialized:
Expand Down
5 changes: 5 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ def test_module_params_in_init(self, net_cls, module_cls, data):
assert net.module_.dense1.in_features == 20
assert net.module_.nonlin is F.tanh

def test_module_initialized_with_partial_module(self, net_cls, module_cls):
net = net_cls(partial(module_cls, num_units=123))
net.initialize()
assert net.module_.dense0.out_features == 123

def test_criterion_init_with_params(self, net_cls, module_cls):
mock = Mock()
net = net_cls(module_cls, criterion=mock, criterion__spam='eggs')
Expand Down

0 comments on commit 708dd8d

Please sign in to comment.