Skip to content

Commit

Permalink
Huggingface tokenizers (#841)
Browse files Browse the repository at this point in the history
Huggingface tokenizers (especiall pre-trained ones) are widely used in
NLP tasks. However, it's not straightforward to integrate them with
skorch, there are a lot of small things to consider.

With this PR, we add HuggingfaceTokenizer and
HuggingfacePretrainedTokenizer classes, which take care of all these
small things. They provide a (mostly) sklearn compatible interface that
allows them to be used in a Pipeline with a skorch net.

In this implementation, I did not opt for adding all existing options
available in Huggingface, instead focusing on a (still comprehensive)
list of (what seemed to me) most useful options.

Implementation

How to handle the Huggingface dependencies (tokenizers and transformers)
could be up for debate. I think we should not strictly depend on them,
which is why they're not imported at root level, even though that
results in some ugly code.

Furthermore, I put the code inside a hf.py module. Maybe there are
better suggestions? If not, do we want to move the AccelerateMixin to
this module as well?

There is a minor issue at the moment with the tokenizer Trainer not
being pickleable. I mitigated this mostly by popping it from the state
and replacing it by None. This should be mostly acceptable because users
will tend to pickle trained transformers, and after training, the
trainer is no longer required.

UPDATE: It was worked on but no (non-yanked) new release yet at
the time of writing this.

However, this mitigation is still not enough when using sklearn's clone,
which recursively traverses all attributes to copy them. Possible
mitigations may be possible by overriding __deepcopy__ or __reduce_ex__
methods, but those would probably still not allow cloning and then
fitting (e.g. for grid search).

I opened an issue, so that this problem will hopefully be addressed
soon:
huggingface/tokenizers#941

Incidental change

skorch's check_is_fitted function now makes use of sklearn's
check_is_fitted function. It intercepts the sklearn error and raises a
skorch error instead. This is already in line with the comment on this
function which claimed that the sklearn function is being used under the
hood, even though that was not true.
  • Loading branch information
BenjaminBossan committed Jul 11, 2022
1 parent d17e3bc commit b889de9
Show file tree
Hide file tree
Showing 7 changed files with 1,329 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12))
- It is now possible to call `np.asarray` with `SliceDataset`s (#858)
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer

### Fixed
- Fix a bug in `SliceDataset` that prevented it to be used with `to_numpy` (#858)
Expand Down
5 changes: 5 additions & 0 deletions docs/hf.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
skorch.hf
=========

.. automodule:: skorch.hf
:members:
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ sacred
sphinx
sphinx_rtd_theme
tensorboard>=1.14.0
tokenizers
transformers
wandb>=0.12.17

0 comments on commit b889de9

Please sign in to comment.