Skip to content

Commit

Permalink
Unwrap accelerated model on train end (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Mar 10, 2022
1 parent 5e22da6 commit 472a664
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
5 changes: 5 additions & 0 deletions skorch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,8 @@ def _step_optimizer(self, step_fn):
for name in self._optimizers:
optimizer = getattr(self, name + '_')
optimizer.step()

# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
super().on_train_end(net, X=X, y=y, **kwargs)
self.module_ = self.accelerator.unwrap_model(self.module_)
3 changes: 3 additions & 0 deletions skorch/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,9 @@ def backward(self, loss, **kwargs):
loss.backward(**kwargs)
loss.backward_was_called = True

def unwrap_model(self, model):
return model

# pylint: disable=missing-class-docstring
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
def get_iterator(self, *args, **kwargs):
Expand Down

0 comments on commit 472a664

Please sign in to comment.