Skip to content

Commit

Permalink
Fix broken tests with new accelerate version (#873)
Browse files Browse the repository at this point in the history
* Fix broken tests with new accelerate version

Latest changes in accelerate led to breaking tests. Unfortunately, the
solution is to use a privat method on AcceleratorState but the devs said
that this would be the best solution.

This PR also sets the requirements for accelerate to be v0.11 or higher.
This not only makes our code simpler, but also opens the door to
supporting gradient accumulation with accelerate (which will come in a
separate PR).

* Skip AMP tests if no CUDA device is found
  • Loading branch information
BenjaminBossan committed Jul 21, 2022
1 parent 764a1a0 commit 4b9e765
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
accelerate
accelerate>=0.11.0
fire
flaky
future>=0.17.1
Expand Down
20 changes: 12 additions & 8 deletions skorch/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,13 @@ def accelerator_cls(self):
pytest.importorskip('accelerate')

from accelerate import Accelerator
from accelerate.state import AcceleratorState

# We have to use this private method because otherwise, the
# AcceleratorState is not completely reset, which results in an error
# initializing the Accelerator more than once in the same process.
# pylint: disable=protected-access
AcceleratorState._reset_state()
return Accelerator

@pytest.fixture
Expand All @@ -821,16 +827,14 @@ def test_mixed_precision(self, net_cls, accelerator_cls, data, mixed_precision):
# Only test if training works at all, no specific test of whether the
# indicated precision is actually used, since that depends on the
# underlying hardware.
import accelerate
from accelerate.utils import is_bf16_available

if LooseVersion(accelerate.__version__) > '0.5.1':
accelerator = accelerator_cls(mixed_precision=mixed_precision)
elif mixed_precision == 'bf16':
pytest.skip('bf16 only supported in accelerate version > 0.5.1')
else:
fp16 = mixed_precision == 'fp16'
accelerator = accelerator_cls(fp16=fp16)
if (mixed_precision != 'no') and not torch.cuda.is_available():
pytest.skip('skipping AMP test because device does not support it')
if (mixed_precision == 'bf16') and not is_bf16_available():
pytest.skip('skipping bf16 test because device does not support it')

accelerator = accelerator_cls(mixed_precision=mixed_precision)
net = net_cls(accelerator=accelerator)
X, y = data
net.fit(X, y) # does not raise
Expand Down

0 comments on commit 4b9e765

Please sign in to comment.