diff --git a/test/test_nn.py b/test/test_nn.py index db595ff382da3..cf2c1680fc6f8 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1577,6 +1577,31 @@ def test_module_to_argparse(self): with self.assertRaises(TypeError): net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True) + def test_module_apply_inplace_op(self): + def add_one_inplace(t): + return t.add_(1.0) + + # Test that applying an in-place operation to a module would bump + # the module's parameters' version counter. + m = nn.Linear(20, 10) + pvm = m.weight.mul(m.weight) + m_weight_version_saved = m.weight._version + m = m._apply(add_one_inplace) + self.assertGreater(m.weight._version, m_weight_version_saved) + with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): + pvm.backward(torch.randn(10, 20)) + + # Test that applying an in-place operation to a module would bump + # the module's parameters' gradients' version counter. + m = nn.Linear(20, 10) + m.weight.grad = torch.randn(10, 20).requires_grad_() + pgm = m.weight.grad.mul(m.weight.grad) + m_weight_grad_version_saved = m.weight.grad._version + m = m._apply(add_one_inplace) + self.assertGreater(m.weight.grad._version, m_weight_grad_version_saved) + with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): + pgm.backward(torch.randn(10, 20)) + def test_type(self): l = nn.Linear(10, 20) net = nn.Module() diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 70b2180ce7960..d94b938a090ce 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -195,11 +195,13 @@ def _apply(self, fn): for param in self._parameters.values(): if param is not None: - # Tensors stored in modules are graph leaves, and we don't - # want to create copy nodes, so we have to unpack the data. - param.data = fn(param.data) + with torch.no_grad(): + param_applied = fn(param) + param.data = param_applied if param._grad is not None: - param._grad.data = fn(param._grad.data) + with torch.no_grad(): + grad_applied = fn(param._grad) + param._grad.data = grad_applied for key, buf in self._buffers.items(): if buf is not None: