Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Huggingface tokenizers (especiall pre-trained ones) are widely used in NLP tasks. However, it's not straightforward to integrate them with skorch, there are a lot of small things to consider. With this PR, we add HuggingfaceTokenizer and HuggingfacePretrainedTokenizer classes, which take care of all these small things. They provide a (mostly) sklearn compatible interface that allows them to be used in a Pipeline with a skorch net. In this implementation, I did not opt for adding all existing options available in Huggingface, instead focusing on a (still comprehensive) list of (what seemed to me) most useful options. Implementation How to handle the Huggingface dependencies (tokenizers and transformers) could be up for debate. I think we should not strictly depend on them, which is why they're not imported at root level, even though that results in some ugly code. Furthermore, I put the code inside a hf.py module. Maybe there are better suggestions? If not, do we want to move the AccelerateMixin to this module as well? There is a minor issue at the moment with the tokenizer Trainer not being pickleable. I mitigated this mostly by popping it from the state and replacing it by None. This should be mostly acceptable because users will tend to pickle trained transformers, and after training, the trainer is no longer required. UPDATE: It was worked on but no (non-yanked) new release yet at the time of writing this. However, this mitigation is still not enough when using sklearn's clone, which recursively traverses all attributes to copy them. Possible mitigations may be possible by overriding __deepcopy__ or __reduce_ex__ methods, but those would probably still not allow cloning and then fitting (e.g. for grid search). I opened an issue, so that this problem will hopefully be addressed soon: huggingface/tokenizers#941 Incidental change skorch's check_is_fitted function now makes use of sklearn's check_is_fitted function. It intercepts the sklearn error and raises a skorch error instead. This is already in line with the comment on this function which claimed that the sklearn function is being used under the hood, even though that was not true.
- Loading branch information