Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor][7/n] remove reduction rule #109144

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 1 addition & 48 deletions test/distributed/_tensor/test_common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.op_schema import OpSchema

from torch.distributed._tensor.ops.common_rules import (
einop_rule,
pointwise_rule,
reduction_rule,
)
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
Expand Down Expand Up @@ -416,49 +412,6 @@ def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self):
self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1)
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1)

@with_comms
def test_reduction_rule(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

sum_call = aten.sum.default
# reduction on a 2d mat
mat1 = [0, -1]
mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4]))
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], tensor_meta=mat1_tensor_meta
)
# reduction on dim 0
output_sharding_0 = reduction_rule(
OpSchema(sum_call, (mat1_spec, 0), {}),
dims=[0],
reduction_linear=True,
)
self.assertIsNotNone(output_sharding_0.output_spec)
self.assertEqual(output_sharding_0.output_spec.dim_map, [-1])
# pending sum on dim 0
self.assertEqual(output_sharding_0.output_spec.sums, [0])

# reduction on dim 1
output_sharding_1 = reduction_rule(
OpSchema(sum_call, (mat1_spec, 1), {}),
dims=[1],
reduction_linear=True,
)
self.assertIsNotNone(output_sharding_1.output_spec)
self.assertEqual(output_sharding_1.output_spec.dim_map, [0])
self.assertEqual(output_sharding_1.output_spec.sums, [])

# full reduction if not specify dim
output_sharding_all_dim = reduction_rule(
OpSchema(sum_call, (mat1_spec,), {}),
dims=[0, 1],
reduction_linear=True,
)
self.assertIsNotNone(output_sharding_all_dim.output_spec)
self.assertEqual(output_sharding_all_dim.output_spec.dim_map, [])
# pending sum on mesh
self.assertEqual(output_sharding_all_dim.output_spec.sums, [0])


if __name__ == "__main__":
run_tests()
73 changes: 1 addition & 72 deletions torch/distributed/_tensor/ops/common_rules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, Dict, List, Optional, Sequence, Tuple
from typing import cast, Dict, List, Optional, Tuple

import torch
from torch.distributed._tensor._utils import compute_local_shape
Expand Down Expand Up @@ -294,74 +294,3 @@ def linear_pointwise_rule(op_schema: OpSchema) -> OutputSharding:
pending sum as well without any communication overhead.
"""
return pointwise_rule(op_schema, linearity=True)


def reduction_rule(
op_schema: OpSchema,
*,
dims: Optional[Sequence[int]] = None,
keep_dim: bool = False,
reduction_linear: bool = False,
) -> OutputSharding:
"""
Propagate the sharding for reduction operations. Examples:
ij->i - sum on dim

reduction_linear means that the reduction `f` follows this rule:
f([f(a), f(b)]) = f([a, b])

reduction linear should be super set of linearity.
"""
alphabet = "abcdefghijklmnopqrstuvwxyz"
# reduction op usually begin with a single tensor
input_spec = cast(DTensorSpec, op_schema.args_schema[0])
reduce_dims = range(input_spec.ndim) if dims is None else dims

if not reduction_linear:
# if the reduction is not linear, we need to clear the pending sum
# on the input spec, also replicate the reducing dimension if the
# reducing dimension is sharded, then suggest a resharding
reshard_dim_map = input_spec.dim_map
needs_reshard = False
for dim in reduce_dims:
if input_spec.dim_map[dim] != -1:
needs_reshard = True
reshard_dim_map[dim] = -1
needs_reshard = needs_reshard or len(input_spec.sums) > 0

if needs_reshard:
no_partial_spec = DTensorSpec.from_dim_map(
input_spec.mesh, reshard_dim_map, [], tensor_meta=input_spec.tensor_meta
)
schema_suggestion = OpSchema(op_schema.op, (no_partial_spec,), {})
schema_suggestion._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
output_spec=None, schema_suggestions=[schema_suggestion]
)

input_chars = alphabet[: input_spec.ndim]

if dims is None and not keep_dim:
# reducing to a single scalar tensor, we just mark output as empty
out_dimchars = ""
else:
# if keep the reduction dim, we need to keep the dim char by marking
# it as a singleton "1" in the out_dimchars
reduce_dim_char = ord("1") if keep_dim else None
out_dimchars = input_chars.translate(
{ord(alphabet[dim]): reduce_dim_char for dim in reduce_dims}
)
fmt = f"{input_chars}->{out_dimchars}"

enforce_sharding: Dict[str, int] = {}
if _is_out_variant_op(op_schema.op):
out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"])
for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map):
enforce_sharding[out_dimchar] = mesh_dim

return einop_rule(
fmt,
op_schema,
linearity=reduction_linear,
enforce_sharding=enforce_sharding,
)
Loading