Skip to content

Commit

Permalink
Loading extra arguments w/ cuda dependency on CPU (#880)
Browse files Browse the repository at this point in the history
* Loading extra arguments w/ cuda dependency on CPU

Supersedes #877

This bug could occur if a user has set a parameter with a CUDA
dependency and then tries to load the net without CUDA. Now, this
works (again) as expected.

Underlying reason

The problem started occurring after PR #751, which introduced storing
parameters set via set_params in the private attribute _kwargs.
Normally, for attributes, we make sure that they can be loaded without
CUDA, but attributes within _kwargs were not checked. Thus, loading
those without CUDA failed. Unfortunately, this was not caught by CI
because CI is not CUDA-enabled.

The bugfix consists of making sure that we don't store any values in
_kwargs. Since values are not needed, only the keys (parameter names),
this is more efficient anyway. Thus, there are no more possibly
CUDA-dependent values that can "slip through".

After discussion, we decided to also rename the attribute, as _kwargs
was not very specific. The new attribute is called _params_to_validate
and it is a set instead of a dict. Also, the _check_kwargs method was
renamed to _validate_params and it doesn't take a kwargs argument
anymore. And on top of that, I changed the raised error from TypeError
to ValueError.

The reason for making this change is that it now is similar to sklearn's
_validate_params method on BaseEstimator (same signature and same error
type). However, we don't make use of the actual sklearn machinery since
our validation does a few things differently (e.g. proposing possible
fixes when the name is wrong).

As the attribute was renamed, we would normally get an error when
unpickling nets stored with the old attribute. To prevent this, we catch
the old attribute _kwargs and convert it to the new attribute
_params_to_validate.

Coincidental changes

- Moved an entry in CHANGES.md to a different section
- Added a reference to an existing entry in CHANGES.md
- I adapted the code in hf.py to use the same new scheme

* Add TODO comment for removing transition code

Give a 1 year grace period to still enable loading old skorch models
with new version.
  • Loading branch information
BenjaminBossan committed Sep 5, 2022
1 parent 862f205 commit 078f5c5
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 47 deletions.
5 changes: 3 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training
- Added a method, `trim_for_prediction`, on the net classes, which trims the net from everything not required for using it for prediction; call this after fitting to reduce the size of the net
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer

### Changed
- The minimum required scikit-learn version has been bumped to 0.22.0
- Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12))
- It is now possible to call `np.asarray` with `SliceDataset`s (#858)
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer

### Fixed
- Fix a bug in `SliceDataset` that prevented it to be used with `to_numpy` (#858)
- Fix a bug that occurred when loading a net that has device set to None
- Fix a bug that occurred when loading a net that has device set to None (#876)
- Fix a bug that in some cases could prevent loading a net that was trained with CUDA without CUDA

## [0.11.0] - 2021-10-11

Expand Down
38 changes: 18 additions & 20 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,30 +350,25 @@ def __init__(
self.pad_token = pad_token
self.verbose = verbose

self._kwargs = kwargs
self._params_to_validate = set(kwargs.keys())
vars(self).update(kwargs)

def _check_kwargs(self, kwargs):
def _validate_params(self):
"""Check argument names passed at initialization.
Raises
------
TypeError
Raises a TypeError if one or more arguments don't seem to
ValueError
Raises a ValueError if one or more arguments don't seem to
match or are malformed.
Returns
-------
kwargs: dict
Return the passed keyword arguments.
"""
# This whole method is taken from NeuralNet

# check for wrong arguments
unexpected_kwargs = []
missing_dunder_kwargs = []
for key in kwargs:
for key in sorted(self._params_to_validate):
if key.endswith('_'):
continue

Expand Down Expand Up @@ -406,9 +401,7 @@ def _check_kwargs(self, kwargs):

if msgs:
full_msg = '\n'.join(msgs)
raise TypeError(full_msg)

return kwargs
raise ValueError(full_msg)

def initialized_instance(self, instance_or_cls, kwargs):
"""Return an instance initialized with the given parameters
Expand Down Expand Up @@ -517,7 +510,7 @@ def initialize_trainer(self):

def initialize(self):
"""Initialize the individual tokenizer components"""
self._check_kwargs(self._kwargs)
self._validate_params()

model = self.initialize_model()
tokenizer = self.initialize_tokenizer(model)
Expand Down Expand Up @@ -586,7 +579,12 @@ def __getstate__(self):

def get_params(self, deep=False):
params = super().get_params(deep=deep)
params.update(self._kwargs)
if deep:
for key in self._params_to_validate:
# We cannot assume that the attribute is already set because
# sklearn's set_params calls get_params first.
if hasattr(self, key):
params[key] = getattr(self, key)
return params

def set_params(self, **kwargs):
Expand All @@ -605,10 +603,10 @@ def set_params(self, **kwargs):
for key, val in kwargs.items():
if any(key.startswith(prefix) for prefix in self.prefixes_):
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
elif '__' in key:
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
else:
normal_params[key] = val

Expand All @@ -630,7 +628,7 @@ def set_params(self, **kwargs):
return self

# if transformer is initialized, checking kwargs is possible
self._check_kwargs(self._kwargs)
self._validate_params()

# Re-initializing of tokenizer necessary
self.initialize()
Expand Down Expand Up @@ -907,8 +905,8 @@ def __init__(
)
self.accelerator = accelerator

def _check_kwargs(self, kwargs):
super()._check_kwargs(kwargs)
def _validate_params(self):
super()._validate_params()

if self.accelerator.device_placement and (self.device is not None):
raise ValueError(
Expand Down
55 changes: 35 additions & 20 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(
initialized = kwargs.pop('initialized_', False)
virtual_params = kwargs.pop('virtual_params_', dict())

self._kwargs = kwargs
self._params_to_validate = set(kwargs.keys())
vars(self).update(kwargs)

self.history_ = history
Expand Down Expand Up @@ -822,7 +822,7 @@ def initialize(self):
self._initialize_optimizer()
self._initialize_history()

self._check_kwargs(self._kwargs)
self._validate_params()

self.initialized_ = True
return self
Expand Down Expand Up @@ -1896,30 +1896,29 @@ def get_params(self, deep=True, **kwargs):
to_exclude = {'_modules', '_criteria', '_optimizers'}
return {key: val for key, val in params.items() if key not in to_exclude}

def _check_kwargs(self, kwargs):
def _validate_params(self):
"""Check argument names passed at initialization.
Note: This method is similar to
:meth:`sklearn.base.BaseEstimator._validate_params` but doesn't use its
machinery.
Raises
------
TypeError
Raises a TypeError if one or more arguments don't seem to
ValueError
Raises a ValueError if one or more arguments don't seem to
match or are malformed.
Returns
-------
kwargs: dict
Return the passed keyword arguments.
Example
-------
>>> net = NeuralNetClassifier(MyModule, iterator_train_shuffle=True)
TypeError: Got an unexpected argument iterator_train_shuffle,
ValueError: Got an unexpected argument iterator_train_shuffle,
did you mean iterator_train__shuffle?
"""
# warn about usage of iterator_valid__shuffle=True, since this
# is almost certainly not what the user wants
if kwargs.get('iterator_valid__shuffle'):
if 'iterator_valid__shuffle' in self._params_to_validate:
warnings.warn(
"You set iterator_valid__shuffle=True; this is most likely not "
"what you want because the values returned by predict and "
Expand All @@ -1929,7 +1928,7 @@ def _check_kwargs(self, kwargs):
# check for wrong arguments
unexpected_kwargs = []
missing_dunder_kwargs = []
for key in kwargs:
for key in sorted(self._params_to_validate):
if key.endswith('_'):
continue

Expand Down Expand Up @@ -1963,9 +1962,7 @@ def _check_kwargs(self, kwargs):

if msgs:
full_msg = '\n'.join(msgs)
raise TypeError(full_msg)

return kwargs
raise ValueError(full_msg)

def _check_deprecated_params(self, **kwargs):
pass
Expand All @@ -1989,13 +1986,13 @@ def set_params(self, **kwargs):
virtual_params[key] = val
elif key.startswith('callbacks'):
cb_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
elif any(key.startswith(prefix) for prefix in self.prefixes_):
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
elif '__' in key:
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
else:
normal_params[key] = val

Expand Down Expand Up @@ -2024,7 +2021,7 @@ def set_params(self, **kwargs):
return self

# if net is initialized, checking kwargs is possible
self._check_kwargs(self._kwargs)
self._validate_params()

######################################################
# Below: Re-initialize parts of the net if necessary #
Expand Down Expand Up @@ -2138,6 +2135,24 @@ def __getstate__(self):
return state

def __setstate__(self, state):
# TODO remove after 2023-09
# in skorch 0.11 -> 0.12, we made a change to parameter validation. We
# don't store key/vals in self._kwargs anymore, as the values were
# redundant and were not considered as possibly CUDA dependent. Instead,
# we now use the attribute '_params_to_validate', which only stores
# keys. The code below is to make the net backwards compatible.
if '_kwargs' in state:
if '_params_to_validate' in state:
# there should not be _kwargs AND _params_to_validate
raise ValueError(
"Something went wrong here. Please open an issue on "
"https://github.com/skorch-dev/skorch/issues detailing what "
"caused this error and the used skorch version."
)
kwargs = state.pop('_kwargs')
params_to_validate = set(kwargs.keys())
state['_params_to_validate'] = params_to_validate

# get_map_location will automatically choose the
# right device in cases where CUDA is not available.
map_location = get_map_location(state['device'])
Expand Down
Binary file modified skorch/tests/net_cuda.pkl
Binary file not shown.
44 changes: 39 additions & 5 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_train_net_after_copy(self, net_cls, module_cls, data,
assert param is opt_param

def test_net_init_one_unknown_argument(self, net_cls, module_cls):
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(module_cls, unknown_arg=123).initialize()

expected = ("__init__() got unexpected argument(s) unknown_arg. "
Expand All @@ -181,7 +181,7 @@ def test_net_init_one_unknown_argument(self, net_cls, module_cls):
assert e.value.args[0] == expected

def test_net_init_two_unknown_arguments(self, net_cls, module_cls):
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(module_cls, lr=0.1, mxa_epochs=5,
warm_start=False, bathc_size=20).initialize()

Expand All @@ -202,7 +202,7 @@ def test_net_init_two_unknown_arguments(self, net_cls, module_cls):
def test_net_init_missing_dunder_in_prefix_argument(
self, net_cls, module_cls, name, suggestion):
# forgot to use double-underscore notation
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(module_cls, **{name: 123}).initialize()

tmpl = "Got an unexpected argument {}, did you mean {}?"
Expand All @@ -212,7 +212,7 @@ def test_net_init_missing_dunder_in_prefix_argument(
def test_net_init_missing_dunder_in_2_prefix_arguments(
self, net_cls, module_cls):
# forgot to use double-underscore notation in 2 arguments
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(
module_cls,
max_epochs=7, # correct
Expand All @@ -228,7 +228,7 @@ def test_net_init_missing_dunder_in_2_prefix_arguments(
def test_net_init_missing_dunder_and_unknown(
self, net_cls, module_cls):
# unknown argument and forgot to use double-underscore notation
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(
module_cls,
foobar=123,
Expand Down Expand Up @@ -435,6 +435,40 @@ def test_pickle_load(self, cuda_available, pickled_cuda_net_path):
with open(pickled_cuda_net_path, 'rb') as f:
pickle.load(f)

def test_load_net_with_kwargs_attribute_to_net_without(self, net_pickleable):
# TODO remove after 2023-09
# in skorch 0.11 -> 0.12, we made a change to parameter validation. We
# don't store key/vals in self._kwargs anymore, as the values were
# redundant and were not considered as possibly CUDA dependent, which
# can cause errors when loading to CPU. Since we remove one attribute
# and add a new one ('_params_to_validate'), we have to take extra steps
# to ensure that old models can still be loaded correctly.

# emulate old net:
del net_pickleable._params_to_validate
net_pickleable._kwargs = {'foo': 123, 'bar__baz': 456}

# after loading, behaves like new net
net_loaded = pickle.loads(pickle.dumps(net_pickleable))
assert net_loaded._params_to_validate == {'foo', 'bar__baz'}
assert not hasattr(net_loaded, '_kwargs')

def test_load_net_with_both_kwargs_and_params_to_validate_attributes_raises(
self, net_pickleable
):
# TODO remove after 2023-09
# Check test_load_net_with_kwargs_attribute_to_net_without for more
# details
net_pickleable._kwargs = {'foo': 123}
net_pickleable._params_to_validate = {'foo'}
msg = (
"Something went wrong here. Please open an issue on "
"https://github.com/skorch-dev/skorch/issues detailing what "
"caused this error and the used skorch version."
)
with pytest.raises(ValueError, match=msg):
pickle.loads(pickle.dumps(net_pickleable))

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_device_torch_device(self, net_cls, module_cls, device):
# Check if native torch.device works as well.
Expand Down

0 comments on commit 078f5c5

Please sign in to comment.