Skip to content

Commit

Permalink
Add tests about saving/loading with accelerate
Browse files Browse the repository at this point in the history
When using AccelerateMixin, the net cannot be pickled. This is caused by
accelerate replacing forward methods by wrapped methods:

https://github.com/huggingface/accelerate/blob/6d038e19a168f5a270bb4c535023255a0160fc10/src/accelerate/accelerator.py#L747-L753

Pickle does not know how to restore those methods and thus fails.

In this PR, we're adding unit tests to check this and documentation
indicating this missing feature (and pointing users towards using
net.save_params instead).

It is not impossible to make pickling work. We could save a copy of the
original methods before they're replaced (or dig out the original
function from __closure__.cell_contents) and after training restore the
methods.

It's however not that trivial, since preparing happens during
initialize, which can be called outside of fit. Therefore, a simple
context manager is not enough. Also, we do want to have acceleration
during inference, so methods such as predict need to prepare and restore
again.

We could also try to mess with __reduce__ but that seems to be equally
brittle. So maybe we just need to live with pickle not working with
accelerate.
  • Loading branch information
BenjaminBossan committed Sep 27, 2022
1 parent 862f205 commit d576840
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
4 changes: 4 additions & 0 deletions docs/user/huggingface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ accelerate_ recommends to leave the device handling to the Accelerator_, which
is why ``device`` defautls to ``None`` (thus telling skorch not to change the
device).

Models using :class:`.AccelerateMixin` cannot be pickled. If you need to save
and load the net, either use :py:meth:`skorch.net.NeuralNet.save_params`
and :py:meth:`skorch.net.NeuralNet.load_params`.

To install accelerate_, run the following command inside your Python environment:

.. code:: bash
Expand Down
5 changes: 5 additions & 0 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,11 @@ class was added: Using this mixin in conjunction with the accelerate library
experimental feature. When accelerate's API stabilizes, we will consider
adding it to skorch proper.
Also, models accelerated this way cannot be pickled. If you need to save
and load the net, either use :py:meth:`skorch.net.NeuralNet.save_params`
and :py:meth:`skorch.net.NeuralNet.load_params` or don't use
``accelerate``.
Examples
--------
>>> from skorch import NeuralNetClassifier
Expand Down
62 changes: 56 additions & 6 deletions skorch/tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from sklearn.base import clone
from sklearn.exceptions import NotFittedError

from skorch import NeuralNetClassifier
from skorch.hf import AccelerateMixin


SPECIAL_TOKENS = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]

Expand Down Expand Up @@ -496,6 +499,11 @@ def test_fixed_vocabulary(self, tokenizer):
assert tokenizer.fixed_vocabulary_ is False


# The class is defined on top level so that it can be pickled
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
pass


class TestAccelerate:
@pytest.fixture(scope='module')
def data(self, classifier_data):
Expand All @@ -522,12 +530,6 @@ def accelerator_cls(self):

@pytest.fixture
def net_cls(self, module_cls):
from skorch import NeuralNetClassifier
from skorch.hf import AccelerateMixin

class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
pass

return partial(
AcceleratedNet,
module=module_cls,
Expand All @@ -553,6 +555,54 @@ def test_mixed_precision(self, net_cls, accelerator_cls, data, mixed_precision):
net.fit(X, y) # does not raise
assert np.isfinite(net.history[:, "train_loss"]).all()

@pytest.mark.parametrize('mixed_precision', [
pytest.param('fp16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
pytest.param('bf16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
'no', # no acceleration works because forward is left the same
])
def test_mixed_precision_pickling(
self, net_cls, accelerator_cls, data, mixed_precision
):
# Pickling currently doesn't work because the forward method on modules
# is overwritten with a modified version of the method using autocast.
# Pickle doesn't know how to restore those methods.

# Note: For accelerate <= v.0.10, there was a bug that would make it
# seem that this works when the 'mixed_precision' <= 'no' condition was
# tested first. This is because of some state that is preserved between
# object initialization in the same process (now fixed through
# AcceleratorState._reset_state()). This bug should be fixed now but to
# be sure, still start out with 'fp16' before 'no'.
from accelerate.utils import is_bf16_available

if (mixed_precision != 'no') and not torch.cuda.is_available():
pytest.skip('skipping AMP test because device does not support it')
if (mixed_precision == 'bf16') and not is_bf16_available():
pytest.skip('skipping bf16 test because device does not support it')

accelerator = accelerator_cls(mixed_precision=mixed_precision)
net = net_cls(accelerator=accelerator)
net.initialize()
pickle.loads(pickle.dumps(net))

@pytest.mark.parametrize('mixed_precision', ['fp16', 'bf16', 'no'])
def test_mixed_precision_save_load_params(
self, net_cls, accelerator_cls, data, mixed_precision, tmp_path
):
from accelerate.utils import is_bf16_available

if (mixed_precision != 'no') and not torch.cuda.is_available():
pytest.skip('skipping AMP test because device does not support it')
if (mixed_precision == 'bf16') and not is_bf16_available():
pytest.skip('skipping bf16 test because device does not support it')

accelerator = accelerator_cls(mixed_precision=mixed_precision)
net = net_cls(accelerator=accelerator)
net.initialize()
filename = tmp_path / 'accel-net-params.pth'
net.save_params(f_params=filename)
net.load_params(f_params=filename)

def test_force_cpu(self, net_cls, accelerator_cls, data):
accelerator = accelerator_cls(device_placement=False, cpu=True)
net = net_cls(accelerator=accelerator)
Expand Down

0 comments on commit d576840

Please sign in to comment.