Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,49 @@ def all_reduce(reduce_type,
return results[0] if isinstance(inputs, torch.Tensor) else results


def all_gather(value, dim=0, groups=None, output=None, pin_layout=False):
def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True):
"""Performs an all-gather operation using all-reduce along a given dimension.

Args:
value (torch.Tensor): The input tensor.
dim (int): The gather dimension.
Default: 0
groups (list, optional): A list of list, representing the replica groups for
the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output (torch.Tensor): Optional output tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
cause some xla compiation to fail. Unpin the layout when you see error message
like "HloModule has a mix of layout constrained".

Returns:
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if dim < 0:
dim = value.dim() + dim
size = value.size(dim)
padding = [0] * (2 * value.dim())
ordinal = get_ordinal()
if groups is None:
left, right = ordinal, xrt_world_size() - 1 - ordinal
else:
ordinals = dict()
for g in groups:
for i, x in enumerate(g):
ordinals[x] = (i, len(g) - 1 - i)
left, right = ordinals[ordinal]
idx = value.dim() - 1 - dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)


def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
"""Performs an all-gather operation along a given dimension.

Args:
Expand All @@ -616,6 +658,12 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=False):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and xla_device_hw(
value.device) in ('TPU', 'GPU') and output == None:
# There is not an easy way to pin the all_gather layout on TPU and GPU, use
# all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
value, dim=dim, groups=groups, pin_layout=True)
if dim < 0:
dim = value.dim() + dim
token, devctx = _get_all_reduce_token()
Expand Down Expand Up @@ -645,7 +693,7 @@ def all_to_all(value,
concat_dimension,
split_count,
groups=None,
pin_layout=False):
pin_layout=True):
"""Performs an XLA `AllToAll()` operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall
Expand Down Expand Up @@ -741,7 +789,7 @@ def reduce_scatter(reduce_type,
shard_count,
groups=None,
output=None,
pin_layout=False):
pin_layout=True):
"""Performs a XLA `ReduceScatter()` operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def allreduce(self, tensors, all_reduce_options):
reduce_type = self._get_reduce_type(all_reduce_options.reduceOp)

# TODO(hjm-aws): implement all_reduce_options.timeout.
xm.all_reduce(reduce_type, tensors, groups=self._mesh)
xm.all_reduce(reduce_type, tensors, groups=self._mesh, pin_layout=False)
return WorkXla(tensors)

def allgather(self, output_tensors_list, input_tensors):
for input_tensor, output_tensors in zip(input_tensors, output_tensors_list):
result = xm.all_gather(input_tensor, groups=self._mesh)
result = xm.all_gather(input_tensor, groups=self._mesh, pin_layout=False)
for i, slice in enumerate(torch.split(result, input_tensor.shape[0])):
output_tensors[i].copy_(slice)

Expand All @@ -77,7 +77,8 @@ def broadcast(self, tensors, opts):
if root_rank != self.rank():
with torch.no_grad():
root_tensor.zero_()
xm.all_reduce(xm.REDUCE_SUM, [root_tensor], groups=self._mesh)
xm.all_reduce(
xm.REDUCE_SUM, [root_tensor], groups=self._mesh, pin_layout=False)

return WorkXla([root_tensor])

Expand All @@ -102,7 +103,8 @@ def reduce_scatter(self, output_tensors, input_tensors_list, opts):
shard_count=shard_count,
scale=1,
groups=groups,
output=output_tensor)
output=output_tensor,
pin_layout=False)

return WorkXla(output_tensors)

Expand Down