Skip to content

Commit

Permalink
Fix a bug that was caused when dataset was partialed.
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-work authored and ottonemo committed Oct 27, 2017
1 parent 6bc0007 commit 47e7f6f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
5 changes: 2 additions & 3 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def get_dataset(self, X, y=None):
"""
dataset = self.dataset
is_initialized = not isinstance(dataset, type)
is_initialized = not callable(dataset)

kwargs = self._get_params_for('dataset')
if kwargs and is_initialized:
Expand All @@ -819,8 +819,7 @@ def get_dataset(self, X, y=None):
if 'use_cuda' not in kwargs:
kwargs['use_cuda'] = self.use_cuda

dataset = self.dataset(X, y, **kwargs)
return dataset
return dataset(X, y, **kwargs)

def get_iterator(self, dataset, train=False):
"""Get an iterator that allows to loop over the batches of the
Expand Down
10 changes: 10 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for net.py"""

from functools import partial
import pickle
from unittest.mock import Mock
from unittest.mock import patch
Expand Down Expand Up @@ -688,6 +689,15 @@ def test_net_initialized_with_initalized_dataset(
)
net.fit(*data) # does not raise

def test_net_initialized_with_partialed_dataset(
self, net_cls, module_cls, data, dataset_cls):
net = net_cls(
module_cls,
dataset=partial(dataset_cls, use_cuda=0),
max_epochs=1,
)
net.fit(*data) # does not raise

def test_net_initialized_with_initalized_dataset_and_kwargs_raises(
self, net_cls, module_cls, data, dataset_cls):
net = net_cls(
Expand Down

0 comments on commit 47e7f6f

Please sign in to comment.