@@ -534,8 +534,8 @@ def all_gather(value: torch.Tensor,
534534 groups : Optional [List [List [int ]]] = None ,
535535 output : Optional [torch .Tensor ] = None ,
536536 pin_layout : bool = True ,
537- channel_id = None ,
538- use_global_device_ids = None ) -> torch .Tensor :
537+ channel_id = None ,
538+ use_global_device_ids = None ) -> torch .Tensor :
539539 """Performs an all-gather operation along a given dimension.
540540
541541 Args:
@@ -588,7 +588,8 @@ def all_gather(value: torch.Tensor,
588588 return output
589589
590590 result = torch_xla ._XLAC ._xla_all_gather (value , dim , shard_count , groups or
591- [], pin_layout , channel_id , use_global_device_ids )
591+ [], pin_layout , channel_id ,
592+ use_global_device_ids )
592593 return result
593594
594595 # Now the input should be a list of Tensors.
@@ -875,8 +876,8 @@ def reduce_scatter(reduce_type: str,
875876 output : Optional [Union [torch .Tensor ,
876877 List [torch .Tensor ]]] = None ,
877878 pin_layout : bool = True ,
878- channel_id = None ,
879- use_global_device_ids = None ) -> torch .Tensor :
879+ channel_id = None ,
880+ use_global_device_ids = None ) -> torch .Tensor :
880881 """Performs a XLA `ReduceScatter()` operation on the input tensor.
881882
882883 See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
0 commit comments