Skip to content
Browse files

Move criterion to compute device automatically (#455)

Move criterion to compute device automatically

Previously the user had to make sure that criterion parameters
such as class weights are on the correct computing device.
Now the criterion is moved to the compute device using `.to`
which moves the parameters as well.

Non nn.Module classes still work as criterion.
  • Loading branch information...
ottonemo authored and BenjaminBossan committed Apr 15, 2019
1 parent ea28546 commit 81ae225e976ec5343c5d50d9dc4369d3621e13e8
Showing with 31 additions and 0 deletions.
  1. +1 −0
  2. +2 −0 skorch/
  3. BIN skorch/tests/net_cuda.pkl
  4. +28 −0 skorch/tests/
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](
previously `"criterion_"` would not match `net.criterion__weight` as set by
- skorch pickle format changed in order to improve CUDA compatibility, if you have pickled models, please re-pickle them to be able to load them in the future
- `net.criterion_` and its parameters are now moved to target device when using criteria that inherit from `torch.nn.Module`. Previously the user had to make sure that parameters such as class weight are on the compute device

### Fixed

@@ -434,6 +434,8 @@ def initialize_criterion(self):
"""Initializes the criterion."""
criterion_params = self._get_params_for('criterion')
self.criterion_ = self.criterion(**criterion_params)
if isinstance(self.criterion_, torch.nn.Module):
self.criterion_ =
return self

def _format_reinit_msg(self, name, kwargs=None, triggered_directly=True):
BIN +1.28 KB (100%) skorch/tests/net_cuda.pkl
Binary file not shown.
@@ -1031,6 +1031,34 @@ def test_criterion_set_params(self, net_cls, module_cls):
assert mock.call_count == 2
assert mock.call_args_list[1][1]['spam'] == 'eggs'

def test_criterion_non_module(self, net_cls, module_cls, data):
# test non-nn.Module classes passed as criterion
class SimpleCriterion:
def __call__(self, y_pred, y_true):
return y_pred.mean()

net = net_cls(module_cls, criterion=SimpleCriterion)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_criterion_params_on_device(self, net_cls, module_cls, device):
# attributes like criterion.weight should be automatically moved
# to the Net's device.
criterion = torch.nn.NLLLoss
weight = torch.ones(2)
net = net_cls(

assert weight.device.type == 'cpu'
assert net.criterion_.weight.device.type == device

def test_callback_with_name_init_with_params(self, net_cls, module_cls):
mock = Mock()
net = net_cls(

0 comments on commit 81ae225

Please sign in to comment.
You can’t perform that action at this time.