Skip to content

Commit

Permalink
HfHubWriter: Save checkpoints on Hugging Face Hub (#881)
Browse files Browse the repository at this point in the history
This PR adds the possibility to create model checkpoints on Hugging Face Hub.
For this to work, we introduce a new HfHubWriter class, which can be passed
instead of a file name to Checkpoint or TrainEndCheckpoint (and hopefully other
callbacks out there in the wild).

The design goal of this PR was to be able to re-use existing checkpoint
callbacks, instead of writing a new one. This is much more scalable, since we
could use a similar design to enable storing on S3, GCS, etc.

One of the difficulties here was to make this work with both pickle.dump and
torch.save. The latter takes a few different turns under the hood. Therefore, I
had to adjust open_file_like to be more similar to what torch does. It should,
however, still work with existing code (at least the tests still pass).

There are some limitations to the current design. For instance, it only supports
writing for now. Therefore, a checkpoint using the new feature does not work
with LoadInitState.

Furthermore, the uploads are performed synchronously. This is mostly to save us
some headaches. However, the obvious disadvantage is that if the upload is slow
(compared to training time), we can see a considerable slowdown. Therefore, I
recommend to use this feature with TrainEndCheckpoint rather than Checkpoint.
Notebook

There is a notebook included to test this feature on the real HF Hub.

This PR also fixes a wrong URL used in Hugging_Face_Finetuning.ipynb.
  • Loading branch information
BenjaminBossan committed Oct 6, 2022
1 parent 84dda2a commit fa9e813
Show file tree
Hide file tree
Showing 8 changed files with 1,293 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a method, `trim_for_prediction`, on the net classes, which trims the net from everything not required for using it for prediction; call this after fitting to reduce the size of the net
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer
- Added support for creating model checkpoints on Hugging Face Hub using [`HfHubStorage`](https://skorch.readthedocs.io/en/latest/hf.html#skorch.hf.HfHubStorage)

### Changed
- The minimum required scikit-learn version has been bumped to 0.22.0
Expand Down
8 changes: 8 additions & 0 deletions docs/user/save_load.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,11 @@ load this checkpoint to predict with it:
In this case, it is important to initialize the neural net before
running :meth:`.NeuralNet.load_params`.

Saving on Hugging Face Hub
--------------------------

:class:`.Checkpoint` and :class:`.TrainEndCheckpoint` can also be used to store
models on the `Hugging Face Hub <https://huggingface.co/docs/hub/index>`__. For
this to work, instead of indicating a file name for the component to be stored,
use :class:`.skorch.hf.HfHubStorage`.
6 changes: 3 additions & 3 deletions notebooks/Hugging_Face_Finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "dea00c9e-ab6a-4a56-8e91-782b06002f6c",
"metadata": {},
"source": [
"# Fine-tuning a BERT model with skorch and huggingface"
"# Fine-tuning a BERT model with skorch and Hugging Face"
]
},
{
Expand All @@ -27,10 +27,10 @@
"metadata": {},
"source": [
"<table align=\"left\"><td>\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Huggingface_Finetuning.ipynb\">\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb\">\n",
" <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
"</td><td>\n",
"<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/Huggingface_Finetuning.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
"<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
]
},
{
Expand Down

0 comments on commit fa9e813

Please sign in to comment.