Skip to content

Commit

Permalink
transformers+tokenizers not imported on module lvl (#948)
Browse files Browse the repository at this point in the history
With this PR, the Hugging Face packages transformers and tokenizers are
no longer imported on a module level. This allows users to use the parts
of skorch.hf that are independent of those packages without installing
them, e.g. the accelerate mixin.

Comment

Previously, I had imported them on a class level, wrongly thinking that
would be enough.

I wanted to add a test for this change, patching the import function to
raise an error when either package is imported, then importing something
else from skorch.hf and checking that no ImportError is raised. This
test did, however, not work, because AccelerateMixin is imported on
module level in test_hf.py (this is necessary in order to define
AcceleratedNet on the module level, as it would otherwise not be
pickleable). This leads to the following situation:

- the hf.py module is already loaded when the tests are collected by
  pytest
- therefore, the imports are triggered before the patch is applied
- therefore, the imports are cached and the patch is useless

Well, if someone knows a way out of this conundrum, let me know. But I
don't think this test is particularly important, so I would be fine
merging without it.
  • Loading branch information
BenjaminBossan committed Apr 14, 2023
1 parent ff99a26 commit 51bba10
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,6 @@ class HuggingfaceTokenizer(_HuggingfaceTokenizerBase):
.. _tokenizers: https://huggingface.co/docs/tokenizers/python/latest/index.html
"""
import transformers as _transformers
import tokenizers as _tokenizers

prefixes_ = [
'model', 'normalizer', 'post_processor', 'pre_tokenizer', 'tokenizer', 'trainer'
]
Expand Down Expand Up @@ -550,6 +547,8 @@ def fit(self, X, y=None, **fit_params):
The fitted instance of the tokenizer.
"""
from transformers import PreTrainedTokenizerFast

# from sklearn, triggers a parameter validation
if isinstance(X, str):
raise ValueError(
Expand All @@ -562,7 +561,7 @@ def fit(self, X, y=None, **fit_params):
trainer = self.initialize_trainer()
self.tokenizer_.train_from_iterator(X, trainer)
self.tokenizer_.add_special_tokens([self.pad_token])
self.fast_tokenizer_ = self._transformers.PreTrainedTokenizerFast(
self.fast_tokenizer_ = PreTrainedTokenizerFast(
tokenizer_object=self.tokenizer_,
pad_token=self.pad_token,
)
Expand Down Expand Up @@ -745,8 +744,6 @@ class HuggingfacePretrainedTokenizer(_HuggingfaceTokenizerBase):
"""

import transformers as _transformers

def __init__(
self,
tokenizer,
Expand Down Expand Up @@ -789,6 +786,8 @@ def fit(self, X, y=None, **fit_params):
The fitted instance of the tokenizer.
"""
from transformers import AutoTokenizer

# from sklearn, triggers a parameter validation
# even though X is not used, we leave this check in for consistency
if isinstance(X, str):
Expand All @@ -800,7 +799,7 @@ def fit(self, X, y=None, **fit_params):
raise ValueError("Setting vocab_size has no effect if train=False")

if isinstance(self.tokenizer, (str, os.PathLike)):
self.fast_tokenizer_ = self._transformers.AutoTokenizer.from_pretrained(
self.fast_tokenizer_ = AutoTokenizer.from_pretrained(
self.tokenizer
)
else:
Expand Down

0 comments on commit 51bba10

Please sign in to comment.