Move criterion to compute device automatically (#455)

Previously the user had to make sure that criterion parameters
such as class weights are on the correct computing device.
Now the criterion is moved to the compute device using `.to`
which moves the parameters as well.

Non nn.Module classes still work as criterion.
ottonemo authored and BenjaminBossan committed Apr 15, 2019
1 parent ea28546 commit 81ae225e976ec5343c5d50d9dc4369d3621e13e8
Showing with 31 additions and 0 deletions.
  1. +1 −0
  2. +2 −0 skorch/
  3. BIN skorch/tests/net_cuda.pkl
  4. +28 −0 skorch/tests/
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](
previously `"criterion_"` would not match `net.criterion__weight` as set by
- skorch pickle format changed in order to improve CUDA compatibility, if you have pickled models, please re-pickle them to be able to load them in the future
- `net.criterion_` and its parameters are now moved to target device when using criteria that inherit from `torch.nn.Module`. Previously the user had to make sure that parameters such as class weight are on the compute device

### Fixed

@@ -434,6 +434,8 @@ def initialize_criterion(self):
"""Initializes the criterion."""
criterion_params = self._get_params_for('criterion')
self.criterion_ = self.criterion(**criterion_params)
if isinstance(self.criterion_, torch.nn.Module):
self.criterion_ =
return self

def _format_reinit_msg(self, name, kwargs=None, triggered_directly=True):
@@ -1031,6 +1031,34 @@ def test_criterion_set_params(self, net_cls, module_cls):
assert mock.call_count == 2
assert mock.call_args_list[1][1]['spam'] == 'eggs'

def test_criterion_non_module(self, net_cls, module_cls, data):
# test non-nn.Module classes passed as criterion
class SimpleCriterion:
def __call__(self, y_pred, y_true):
return y_pred.mean()

net = net_cls(module_cls, criterion=SimpleCriterion)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_criterion_params_on_device(self, net_cls, module_cls, device):
# attributes like criterion.weight should be automatically moved
# to the Net's device.
criterion = torch.nn.NLLLoss
weight = torch.ones(2)
net = net_cls(

assert weight.device.type == 'cpu'
assert net.criterion_.weight.device.type == device

def test_callback_with_name_init_with_params(self, net_cls, module_cls):
mock = Mock()
net = net_cls(

