Skip to content

Commit

Permalink
Nets now can be loaded even if device=None (#876)
Browse files Browse the repository at this point in the history
In #600, we introduced the option to set device=None, which means that
skorch should not move any data to any device. However, this setting
introduced a bug when trying to load the net, as that code didn't
explicitly deal with device=None. With this PR, the bug is fixed.

Implementation

If device=None, we really have no way of knowing what device to map the
parameters to. The most reasonable thing to do is to use the fallback,
which is 'cpu'.
  • Loading branch information
BenjaminBossan committed Jul 28, 2022
1 parent eecf94f commit ef23ce4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Fix a bug in `SliceDataset` that prevented it to be used with `to_numpy` (#858)
- Fix a bug that occurred when loading a net that has device set to None

## [0.11.0] - 2021-10-11

Expand Down
9 changes: 9 additions & 0 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2413,6 +2413,15 @@ def _check_device(self, requested_device, map_device):
return the map device if it differs from the requested device
along with a warning.
"""
if requested_device is None:
# user has set net.device=None, we don't know the type, use fallback
msg = (
f"Setting self.device = {map_device} since the requested device "
f"was not specified"
)
warnings.warn(msg, DeviceWarning)
return map_device

type_1 = torch.device(requested_device)
type_2 = torch.device(map_device)
if type_1 != type_2:
Expand Down
16 changes: 16 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,22 @@ def test_pickle_save_load(self, net_pickleable, data, tmpdir):
score_after = accuracy_score(y, net_new.predict(X))
assert np.isclose(score_after, score_before)

def test_pickle_save_load_device_is_none(self, net_pickleable):
# It is legal to set device=None, but in that case we cannot know what
# device was meant, so we should fall back to CPU.
from skorch.exceptions import DeviceWarning

net_pickleable.set_params(device=None)
msg = (
f"Setting self.device = cpu since the requested device "
f"was not specified"
)
with pytest.warns(DeviceWarning, match=msg):
net_loaded = pickle.loads(pickle.dumps(net_pickleable))

params = net_loaded.get_all_learnable_params()
assert all(param.device.type == 'cpu' for _, param in params)

def train_picklable_cuda_net(self, net_pickleable, data):
X, y = data
w = torch.FloatTensor([1.] * int(y.max() + 1)).to('cuda')
Expand Down

0 comments on commit ef23ce4

Please sign in to comment.