From 2f9ed684b565e0def168d2ea17a6ab4331fc4d12 Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Sun, 8 May 2022 04:08:52 +0000 Subject: [PATCH] allow disabling layout pinning in optimizer_step and reduce_gradients --- torch_xla/core/xla_model.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 167f878eaee7..89b89137cf7b 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -836,7 +836,7 @@ def wait_device_ops(devices=[]): torch_xla._XLAC._xla_wait_device_ops(devices=devices) -def reduce_gradients(optimizer, groups=None): +def reduce_gradients(optimizer, groups=None, pin_layout=True): """Reduces all the gradients handled by an optimizer. Args: @@ -847,16 +847,27 @@ def reduce_gradients(optimizer, groups=None): defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. + pin_layout (bool, optional): whether to pin the layout when reducing gradients. + See `xm.all_reduce` for details. """ cctx = CollectiveContext() count = max(cctx.replica_devcount, cctx.world_size) if count > 1: gradients = _fetch_gradients(optimizer) all_reduce( - REDUCE_SUM, gradients, scale=1.0 / count, groups=groups, cctx=cctx) + REDUCE_SUM, + gradients, + scale=1.0 / count, + groups=groups, + cctx=cctx, + pin_layout=pin_layout) -def optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None): +def optimizer_step(optimizer, + barrier=False, + optimizer_args={}, + groups=None, + pin_layout=True): """Run the provided optimizer step and issue the XLA device step computation. Args: @@ -875,11 +886,13 @@ def optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None): defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. + pin_layout (bool, optional): whether to pin the layout when reducing gradients. + See `xm.all_reduce` for details. Returns: The same value returned by the `optimizer.step()` call. """ - reduce_gradients(optimizer, groups=groups) + reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout) loss = optimizer.step(**optimizer_args) if barrier: mark_step()