Skip to content

Commit

Permalink
Merge branch 'master' into test-accelerate-pickling-and-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Sep 28, 2022
2 parents d576840 + 57ea797 commit bf7c3bb
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 50 deletions.
5 changes: 3 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training
- Added a method, `trim_for_prediction`, on the net classes, which trims the net from everything not required for using it for prediction; call this after fitting to reduce the size of the net
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer

### Changed
- The minimum required scikit-learn version has been bumped to 0.22.0
- Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12))
- It is now possible to call `np.asarray` with `SliceDataset`s (#858)
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer

### Fixed
- Fix a bug in `SliceDataset` that prevented it to be used with `to_numpy` (#858)
- Fix a bug that occurred when loading a net that has device set to None
- Fix a bug that occurred when loading a net that has device set to None (#876)
- Fix a bug that in some cases could prevent loading a net that was trained with CUDA without CUDA

## [0.11.0] - 2021-10-11

Expand Down
38 changes: 18 additions & 20 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,30 +350,25 @@ def __init__(
self.pad_token = pad_token
self.verbose = verbose

self._kwargs = kwargs
self._params_to_validate = set(kwargs.keys())
vars(self).update(kwargs)

def _check_kwargs(self, kwargs):
def _validate_params(self):
"""Check argument names passed at initialization.
Raises
------
TypeError
Raises a TypeError if one or more arguments don't seem to
ValueError
Raises a ValueError if one or more arguments don't seem to
match or are malformed.
Returns
-------
kwargs: dict
Return the passed keyword arguments.
"""
# This whole method is taken from NeuralNet

# check for wrong arguments
unexpected_kwargs = []
missing_dunder_kwargs = []
for key in kwargs:
for key in sorted(self._params_to_validate):
if key.endswith('_'):
continue

Expand Down Expand Up @@ -406,9 +401,7 @@ def _check_kwargs(self, kwargs):

if msgs:
full_msg = '\n'.join(msgs)
raise TypeError(full_msg)

return kwargs
raise ValueError(full_msg)

def initialized_instance(self, instance_or_cls, kwargs):
"""Return an instance initialized with the given parameters
Expand Down Expand Up @@ -517,7 +510,7 @@ def initialize_trainer(self):

def initialize(self):
"""Initialize the individual tokenizer components"""
self._check_kwargs(self._kwargs)
self._validate_params()

model = self.initialize_model()
tokenizer = self.initialize_tokenizer(model)
Expand Down Expand Up @@ -586,7 +579,12 @@ def __getstate__(self):

def get_params(self, deep=False):
params = super().get_params(deep=deep)
params.update(self._kwargs)
if deep:
for key in self._params_to_validate:
# We cannot assume that the attribute is already set because
# sklearn's set_params calls get_params first.
if hasattr(self, key):
params[key] = getattr(self, key)
return params

def set_params(self, **kwargs):
Expand All @@ -605,10 +603,10 @@ def set_params(self, **kwargs):
for key, val in kwargs.items():
if any(key.startswith(prefix) for prefix in self.prefixes_):
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
elif '__' in key:
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
else:
normal_params[key] = val

Expand All @@ -630,7 +628,7 @@ def set_params(self, **kwargs):
return self

# if transformer is initialized, checking kwargs is possible
self._check_kwargs(self._kwargs)
self._validate_params()

# Re-initializing of tokenizer necessary
self.initialize()
Expand Down Expand Up @@ -912,8 +910,8 @@ def __init__(
)
self.accelerator = accelerator

def _check_kwargs(self, kwargs):
super()._check_kwargs(kwargs)
def _validate_params(self):
super()._validate_params()

if self.accelerator.device_placement and (self.device is not None):
raise ValueError(
Expand Down
55 changes: 35 additions & 20 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(
initialized = kwargs.pop('initialized_', False)
virtual_params = kwargs.pop('virtual_params_', dict())

self._kwargs = kwargs
self._params_to_validate = set(kwargs.keys())
vars(self).update(kwargs)

self.history_ = history
Expand Down Expand Up @@ -822,7 +822,7 @@ def initialize(self):
self._initialize_optimizer()
self._initialize_history()

self._check_kwargs(self._kwargs)
self._validate_params()

self.initialized_ = True
return self
Expand Down Expand Up @@ -1896,30 +1896,29 @@ def get_params(self, deep=True, **kwargs):
to_exclude = {'_modules', '_criteria', '_optimizers'}
return {key: val for key, val in params.items() if key not in to_exclude}

def _check_kwargs(self, kwargs):
def _validate_params(self):
"""Check argument names passed at initialization.
Note: This method is similar to
:meth:`sklearn.base.BaseEstimator._validate_params` but doesn't use its
machinery.
Raises
------
TypeError
Raises a TypeError if one or more arguments don't seem to
ValueError
Raises a ValueError if one or more arguments don't seem to
match or are malformed.
Returns
-------
kwargs: dict
Return the passed keyword arguments.
Example
-------
>>> net = NeuralNetClassifier(MyModule, iterator_train_shuffle=True)
TypeError: Got an unexpected argument iterator_train_shuffle,
ValueError: Got an unexpected argument iterator_train_shuffle,
did you mean iterator_train__shuffle?
"""
# warn about usage of iterator_valid__shuffle=True, since this
# is almost certainly not what the user wants
if kwargs.get('iterator_valid__shuffle'):
if 'iterator_valid__shuffle' in self._params_to_validate:
warnings.warn(
"You set iterator_valid__shuffle=True; this is most likely not "
"what you want because the values returned by predict and "
Expand All @@ -1929,7 +1928,7 @@ def _check_kwargs(self, kwargs):
# check for wrong arguments
unexpected_kwargs = []
missing_dunder_kwargs = []
for key in kwargs:
for key in sorted(self._params_to_validate):
if key.endswith('_'):
continue

Expand Down Expand Up @@ -1963,9 +1962,7 @@ def _check_kwargs(self, kwargs):

if msgs:
full_msg = '\n'.join(msgs)
raise TypeError(full_msg)

return kwargs
raise ValueError(full_msg)

def _check_deprecated_params(self, **kwargs):
pass
Expand All @@ -1989,13 +1986,13 @@ def set_params(self, **kwargs):
virtual_params[key] = val
elif key.startswith('callbacks'):
cb_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
elif any(key.startswith(prefix) for prefix in self.prefixes_):
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
elif '__' in key:
special_params[key] = val
self._kwargs[key] = val
self._params_to_validate.add(key)
else:
normal_params[key] = val

Expand Down Expand Up @@ -2024,7 +2021,7 @@ def set_params(self, **kwargs):
return self

# if net is initialized, checking kwargs is possible
self._check_kwargs(self._kwargs)
self._validate_params()

######################################################
# Below: Re-initialize parts of the net if necessary #
Expand Down Expand Up @@ -2138,6 +2135,24 @@ def __getstate__(self):
return state

def __setstate__(self, state):
# TODO remove after 2023-09
# in skorch 0.11 -> 0.12, we made a change to parameter validation. We
# don't store key/vals in self._kwargs anymore, as the values were
# redundant and were not considered as possibly CUDA dependent. Instead,
# we now use the attribute '_params_to_validate', which only stores
# keys. The code below is to make the net backwards compatible.
if '_kwargs' in state:
if '_params_to_validate' in state:
# there should not be _kwargs AND _params_to_validate
raise ValueError(
"Something went wrong here. Please open an issue on "
"https://github.com/skorch-dev/skorch/issues detailing what "
"caused this error and the used skorch version."
)
kwargs = state.pop('_kwargs')
params_to_validate = set(kwargs.keys())
state['_params_to_validate'] = params_to_validate

# get_map_location will automatically choose the
# right device in cases where CUDA is not available.
map_location = get_map_location(state['device'])
Expand Down
Binary file modified skorch/tests/net_cuda.pkl
Binary file not shown.
11 changes: 8 additions & 3 deletions skorch/tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,20 @@ def test_set_params(self, data):
tokenizer.set_params(
model__dropout=0.123,
trainer__vocab_size=123,
pre_tokenizer__delimiter='*',
max_length=456,
# With v0.13 of tokenizers, it seems like delimiter always needs to
# be " ", otherwise this error is raised: Error while attempting to
# unpickle Tokenizer: data did not match any variant of untagged
# enum ModelWrapper at line 1 column 2586. So we cannot change its
# value in this test but we should still ensure that set_params
# doesn't fail, so we keep it.
pre_tokenizer__delimiter=' ',
)
tokenizer.fit(data)

assert tokenizer.tokenizer_.model.dropout == pytest.approx(0.123)
assert len(tokenizer.vocabulary_) == pytest.approx(123, abs=5)
assert tokenizer.tokenizer_.pre_tokenizer.delimiter == '*'
assert tokenizer.tokenizer_.pre_tokenizer.delimiter == ' '
assert tokenizer.max_length == 456


Expand Down Expand Up @@ -382,7 +388,6 @@ def tokenizer(self, tokenizer_not_fitted, data):
def test_fixed_vocabulary(self, tokenizer):
assert tokenizer.fixed_vocabulary_ is False

@pytest.mark.xfail
def test_clone(self, tokenizer):
# This might get fixed in a future release of tokenizers
# https://github.com/huggingface/tokenizers/issues/941
Expand Down
44 changes: 39 additions & 5 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_train_net_after_copy(self, net_cls, module_cls, data,
assert param is opt_param

def test_net_init_one_unknown_argument(self, net_cls, module_cls):
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(module_cls, unknown_arg=123).initialize()

expected = ("__init__() got unexpected argument(s) unknown_arg. "
Expand All @@ -181,7 +181,7 @@ def test_net_init_one_unknown_argument(self, net_cls, module_cls):
assert e.value.args[0] == expected

def test_net_init_two_unknown_arguments(self, net_cls, module_cls):
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(module_cls, lr=0.1, mxa_epochs=5,
warm_start=False, bathc_size=20).initialize()

Expand All @@ -202,7 +202,7 @@ def test_net_init_two_unknown_arguments(self, net_cls, module_cls):
def test_net_init_missing_dunder_in_prefix_argument(
self, net_cls, module_cls, name, suggestion):
# forgot to use double-underscore notation
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(module_cls, **{name: 123}).initialize()

tmpl = "Got an unexpected argument {}, did you mean {}?"
Expand All @@ -212,7 +212,7 @@ def test_net_init_missing_dunder_in_prefix_argument(
def test_net_init_missing_dunder_in_2_prefix_arguments(
self, net_cls, module_cls):
# forgot to use double-underscore notation in 2 arguments
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(
module_cls,
max_epochs=7, # correct
Expand All @@ -228,7 +228,7 @@ def test_net_init_missing_dunder_in_2_prefix_arguments(
def test_net_init_missing_dunder_and_unknown(
self, net_cls, module_cls):
# unknown argument and forgot to use double-underscore notation
with pytest.raises(TypeError) as e:
with pytest.raises(ValueError) as e:
net_cls(
module_cls,
foobar=123,
Expand Down Expand Up @@ -435,6 +435,40 @@ def test_pickle_load(self, cuda_available, pickled_cuda_net_path):
with open(pickled_cuda_net_path, 'rb') as f:
pickle.load(f)

def test_load_net_with_kwargs_attribute_to_net_without(self, net_pickleable):
# TODO remove after 2023-09
# in skorch 0.11 -> 0.12, we made a change to parameter validation. We
# don't store key/vals in self._kwargs anymore, as the values were
# redundant and were not considered as possibly CUDA dependent, which
# can cause errors when loading to CPU. Since we remove one attribute
# and add a new one ('_params_to_validate'), we have to take extra steps
# to ensure that old models can still be loaded correctly.

# emulate old net:
del net_pickleable._params_to_validate
net_pickleable._kwargs = {'foo': 123, 'bar__baz': 456}

# after loading, behaves like new net
net_loaded = pickle.loads(pickle.dumps(net_pickleable))
assert net_loaded._params_to_validate == {'foo', 'bar__baz'}
assert not hasattr(net_loaded, '_kwargs')

def test_load_net_with_both_kwargs_and_params_to_validate_attributes_raises(
self, net_pickleable
):
# TODO remove after 2023-09
# Check test_load_net_with_kwargs_attribute_to_net_without for more
# details
net_pickleable._kwargs = {'foo': 123}
net_pickleable._params_to_validate = {'foo'}
msg = (
"Something went wrong here. Please open an issue on "
"https://github.com/skorch-dev/skorch/issues detailing what "
"caused this error and the used skorch version."
)
with pytest.raises(ValueError, match=msg):
pickle.loads(pickle.dumps(net_pickleable))

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_device_torch_device(self, net_cls, module_cls, device):
# Check if native torch.device works as well.
Expand Down

0 comments on commit bf7c3bb

Please sign in to comment.