diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index de8b70a944b2..a8e8b459d7e9 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self): # Verify that the model losses are not zero, and that the runs match. assert all(loss != 0 for loss in baseline_grad_acc_losses) assert all( - torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8) - for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses, + torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8) + for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses, loop_grad_acc_losses)) + if not SKIP_GRADIENT_CHECKPOINTING: + print('Training loop with XLA\'s `While` gradient accumulation and ' + 'gradient checkpointing.') + with extended_argv( + COMMON_GRAD_ACC_ARGS + + ["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]): + loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc() + assert all( + torch.allclose( + baseline_loss, + loop_grad_acc_grad_chkpt_loss, + rtol=1e-4, + atol=1e-8) + for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip( + baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses)) + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index eacda86c12ca..a1f557d96483 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -181,6 +181,15 @@ def _prepare_fake_tensors( grads = [param.grad for param in params] body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors, *fake_carried_tensors, *params, *grads) + # TODO - Fake the gradients once we are able to create placeholder tensors. + # Since the body is expected to do an in-place mutation of the gradients, we + # clone the gradients and use that as an input to the body. This will ensure + # that we retain a device data IR node in the graph. The cloned gradient will + # be updated to denote an IR operation (e.g. %add), and that can not be + # captured as a device data input for the other required computations, namely + # the condition and init for the XLA while loop. + for param in params: + param.grad = param.grad.clone() body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors), tuple(fake_carried_tensors), tuple(params), tuple(grads)) @@ -375,10 +384,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor, else: loss, *carried_tensors = result loss /= context.num_gradient_steps - gradients = torch.autograd.grad(loss, model_parameters) - acc_grads = [prev_grad + grad for prev_grad, grad in zip(grads, gradients)] - return (iteri, loss, *iterable_tensors, *carried_tensors, *params, - *acc_grads) + loss.backward() + grads = [param.grad for param in params] + return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads) if not torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config(): warnings.warn(