Skip to content

Commit

Permalink
Fix issues with unwrapping modules with accelerate (#963)
Browse files Browse the repository at this point in the history
There were a few issues at once that are now being fixed:

- When unwrapping, pass keep_fp32_wrapper=True, or else the modules are
  not completely unwrapped and could not be pickled.
- Only net.module_ was unwrapped, but there can be other modules and
  criteria. Those are now also unwrapped.
- In some circumstances, unwrapping is undesired. For instance, when
  users want to continue training or benefit from AMP during inference.
  Therefore, an option was added to prevent automatic unwrapping:
  set unwrap_after_train=False.
  • Loading branch information
BenjaminBossan committed May 16, 2023
1 parent 994cca1 commit a218ebc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add support for compiled PyTorch modules using the `torch.compile` function, introduced in [PyTorch 2.0 release](https://pytorch.org/get-started/pytorch-2.0/), which can greatly improve performance on new GPU architectures; to use it, initialize your net with the `compile=True` argument, further compilation arguments can be specified using the dunder notation, e.g. `compile__dynamic=True`
- Add a class [`DistributedHistory`](https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory) which should be used when training in a multi GPU setting (#955)
- `SkorchDoctor`: A helper class that assists in understanding and debugging the neural net training, see [this notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb) (#912)
- When using `AccelerateMixin`, it is now possible to prevent unwrapping of the modules by setting `unwrap_after_train=True`

### Changed

Expand All @@ -21,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `_get_param_names` returns a list instead of a generator so that subsequent
error messages return useful information instead of a generator `repr`
string (#925)
- Fixed a bug that caused modules to not be sufficiently unwrapped at the end of training when using `AccelerateMixin`, which could prevent them from being pickleable

## [0.12.1] - 2022-11-18

Expand Down
25 changes: 24 additions & 1 deletion skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,19 @@ class was added: Using this mixin in conjunction with the accelerate library
leave device handling to accelerate. Therefore, it is best to leave this
argument to be None, which means that skorch does not set the device.
unwrap_after_train : bool (default=True)
By default, with this option being ``True``, the module(s) and criterion
are automatically "unwrapped" after training. This means that their
initial state -- from before they were prepared by the ``accelerator`` --
is restored. This is necessary to pickle the net.
There are circumstances where you might want to disable this behavior. For
instance, when you want to further train the model with AMP enabled (using
``net.partial_fit`` or ``warm_start=True``). Also, unwrapping the modules
means that the advantage of using mixed precision is lost during
inference. In those cases, if you don't need to pickle the net, you should
set ``unwrap_after_train=False``.
callbacks__print_log__sink : 'auto' or callable
If 'auto', uses the ``print`` function of the accelerator, if it has one.
This avoids printing the same output multiple times when training
Expand All @@ -900,6 +913,7 @@ def __init__(
*args,
accelerator,
device=None,
unwrap_after_train=True,
callbacks__print_log__sink='auto',
**kwargs
):
Expand All @@ -910,6 +924,7 @@ def __init__(
**kwargs
)
self.accelerator = accelerator
self.unwrap_after_train = unwrap_after_train

def _validate_params(self):
super()._validate_params()
Expand Down Expand Up @@ -1009,7 +1024,15 @@ def _step_optimizer(self, step_fn):
# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
super().on_train_end(net, X=X, y=y, **kwargs)
self.module_ = self.accelerator.unwrap_model(self.module_)
if not self.unwrap_after_train:
return self

for name in self._modules + self._criteria:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
orig = self.accelerator.unwrap_model(module, keep_fp32_wrapper=False)
setattr(self, name + '_', orig)
return self

def evaluation_step(self, batch, training=False):
# More context:
Expand Down
50 changes: 46 additions & 4 deletions skorch/tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def test_mixed_precision(self, net_cls, accelerator_cls, data, mixed_precision):
assert np.isfinite(net.history[:, "train_loss"]).all()

@pytest.mark.parametrize('mixed_precision', [
pytest.param('fp16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
pytest.param('bf16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
'no', # no acceleration works because forward is left the same
'fp16',
pytest.param('bf16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
])
def test_mixed_precision_pickling(
self, net_cls, accelerator_cls, data, mixed_precision
Expand Down Expand Up @@ -601,9 +601,51 @@ def test_mixed_precision_pickling(

accelerator = accelerator_cls(mixed_precision=mixed_precision)
net = net_cls(accelerator=accelerator)
net.initialize()
X, y = data
net.fit(X[:100], y[:100])
pickle.loads(pickle.dumps(net))

def test_unwrapping_all_modules(self, module_cls, accelerator_cls, data):
# This test is for a bug we had previously where only 'module_' was
# unwrapped, not all possible modules and criteria.
if not torch.cuda.is_available():
pytest.skip('skipping test because device does not support it')

class MyNet(AcceleratedNet):
"""Net with two different modules"""
def initialize_module(self):
super().initialize_module()
self.module2_ = module_cls()
return self

accelerator = accelerator_cls(mixed_precision='fp16')
net = MyNet(module_cls, accelerator=accelerator, unwrap_after_train=True)
X, y = data
net.fit(X[:100], y[:100])

# there isn't really an elegant way to check if the modules have been
# correctly unwrapped
assert not hasattr(net.criterion_.forward, '__wrapped__')
assert not hasattr(net.module_.forward, '__wrapped__')
assert not hasattr(net.module2_.forward, '__wrapped__')

def test_not_unwrapping_modules(self, net_cls, accelerator_cls, data):
# Make it possible not to unwrap the modules after training. This is
# useful, e.g., to allow further training with warm start or to do
# inference with AMP, but it prevents the model from being pickled.
if not torch.cuda.is_available():
pytest.skip('skipping test because device does not support it')

accelerator = accelerator_cls(mixed_precision='fp16')
net = net_cls(accelerator=accelerator, unwrap_after_train=False)
X, y = data
net.fit(X[:100], y[:100])

# there isn't really an elegant way to check if the modules have been
# correctly unwrapped
assert hasattr(net.criterion_.forward, '__wrapped__')
assert hasattr(net.module_.forward, '__wrapped__')

@pytest.mark.parametrize('mixed_precision', ['fp16', 'bf16', 'no'])
def test_mixed_precision_save_load_params(
self, net_cls, accelerator_cls, data, mixed_precision, tmp_path
Expand Down Expand Up @@ -693,7 +735,7 @@ def backward(self, loss, **kwargs):
loss.backward(**kwargs)
loss.backward_was_called = True

def unwrap_model(self, model):
def unwrap_model(self, model, keep_fp32_wrapper=True):
return model

def gather_for_metrics(self, output):
Expand Down

0 comments on commit a218ebc

Please sign in to comment.