From 5ff53df2b66cbfa539d6027ad92c0301790647d6 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 29 Feb 2020 10:10:44 -0800 Subject: [PATCH] Split gradient reduction into a separate utility API. --- torch_xla/core/xla_model.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 65dc14b455b..cb7bab67dc5 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -426,6 +426,19 @@ def wait_device_ops(devices=[]): torch_xla._XLAC._xla_wait_device_ops(devices=devices) +def reduce_gradients(optimizer): + """Reduces all the gradients handled by an optimizer. + + Args: + optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance + containing the gradients to be reduced. + """ + count = torch_xla._XLAC._xla_get_replication_devices_count() + if count > 1: + gradients = _fetch_gradients(optimizer) + all_reduce('sum', gradients, scale=1.0 / count) + + def optimizer_step(optimizer, barrier=False, optimizer_args={}): """Run the provided optimizer step and issue the XLA device step computation. @@ -444,11 +457,7 @@ def optimizer_step(optimizer, barrier=False, optimizer_args={}): Returns: The same value returned by the `optimizer.step()` call. """ - - count = torch_xla._XLAC._xla_get_replication_devices_count() - if count > 1: - gradients = _fetch_gradients(optimizer) - all_reduce('sum', gradients, scale=1.0 / count) + reduce_gradients(optimizer) loss = optimizer.step(**optimizer_args) if barrier: mark_step()