Skip to content

Commit

Permalink
[dtensor] fix pointwise op linearity with strategy (pytorch#112107)
Browse files Browse the repository at this point in the history
This PR fixes the pointwise op strategy linearity, and switch the
linear pointwise ops to use strategy. Also add tests show that using
the new way we can enable full shard (S(0), S(0)) like operations

Why this is useful? for 2-D Parallel like patterns where the named
parameters are possibly fully sharded on all devices, [S(0), S(0)] or
[S(1), S(0)], etc. need to work, since we don't use the sharding rules
anymore, this is possible at this point.

@awgu
Pull Request resolved: pytorch#112107
Approved by: https://github.com/wz337
  • Loading branch information
wanchaol authored and xuhancn committed Nov 8, 2023
1 parent 2b9119f commit 3a8c211
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 22 deletions.
26 changes: 25 additions & 1 deletion test/distributed/_tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from torch.distributed._tensor import distribute_tensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
Expand Down Expand Up @@ -114,6 +114,30 @@ def test_softmax_with_bwd(self):
dist_x_grad = dist_x.grad.redistribute(device_mesh, [Replicate()])
self.assertEqual(dist_x_grad.to_local(), x.grad)

@with_comms
def test_full_shard_math_ops(self):
mesh_shape = (2, self.world_size // 2)
mesh = DeviceMesh(
self.device_type,
torch.arange(self.world_size).reshape(*mesh_shape),
)
global_tensor = torch.ones(4, 4)
double_shard_tensor = distribute_tensor(
global_tensor, mesh, [Shard(0), Shard(0)]
)
fully_shard_tensor = distribute_tensor(
global_tensor, mesh, [Shard(0), Shard(1)]
)

# for op in [torch.add, torch.sub, torch.mul, torch.div]:
for op in [torch.add, torch.sub, torch.mul, torch.div]:
expect_rs = op(global_tensor, 2)
actual_rs = op(double_shard_tensor, 2).redistribute(
mesh, [Replicate(), Replicate()]
)
actual_local_res = actual_rs.to_local()
self.assertEqual(actual_local_res, expect_rs)


if __name__ == "__main__":
run_tests()
10 changes: 9 additions & 1 deletion test/distributed/_tensor/test_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,15 @@ def test_partial_add(self):
d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()])
d_2 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()])
d_3 = d_1 + d_2
self.assertEqual(d_3._spec.placements[0].is_partial(), True)
self.assertTrue(d_3._spec.placements[0].is_partial())

def test_partial_mul_failure(self):
device_mesh = self.build_device_mesh()
d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()])
d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()])
d_3 = d_1 * d_2
self.assertTrue(d_3._spec.placements[0].is_replicate())
self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size**2))

def test_activations(self):
device_mesh = self.build_device_mesh()
Expand Down
9 changes: 0 additions & 9 deletions torch/distributed/_tensor/ops/common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,3 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
linearity=linearity,
enforce_sharding=enforce_sharding,
)


def linear_pointwise_rule(op_schema: OpSchema) -> OutputSharding:
"""
Linear pointwise operators can propagate pending reductions.
For example, c = add(a, b); if a is pending sum, then c will be
pending sum as well without any communication overhead.
"""
return pointwise_rule(op_schema, linearity=True)
38 changes: 27 additions & 11 deletions torch/distributed/_tensor/ops/pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
StrategyType,
)

from torch.distributed._tensor.ops.common_rules import linear_pointwise_rule
from torch.distributed._tensor.ops.utils import (
generate_redistribute_costs,
infer_broadcast_dims_map,
map_placements_after_broadcast,
normalize_dim,
register_op_strategy,
register_prop_rule,
)
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Shard
from torch.distributed._tensor.placement_types import (
_Partial,
DTensorSpec,
Placement,
Replicate,
Shard,
)


aten = torch.ops.aten
Expand Down Expand Up @@ -394,7 +398,9 @@
]


def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
def pointwise_strategy(
mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
) -> StrategyType:
max_shards_strategy_index = -1
max_shards = -1
# handle broadcasting
Expand Down Expand Up @@ -445,19 +451,18 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
common_ndim = len(common_shape)
new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
out_placements.append(Shard(new_shard_dim))
elif isinstance(placement, _Partial) and not linearity:
# clear the partial placemnet if op does not support linearity
# by default we just replicate the partial, need to see if this
# is optimal for all cases
out_placements.append(Replicate())
else:
out_placements.append(placement)

input_specs = []
redistribute_costs: List[List[float]] = []
for idx, input_arg in enumerate(op_schema.args_schema):
if isinstance(input_arg, OpStrategy):
if idx == max_shards_strategy_index:
# the current input arg is the one we want to follow
input_specs.append(spec_to_follow)
redistribute_costs.append([0] * len(input_arg.strategies))
continue

# every arg follow the out_placements, but need to handle broadcasting
input_arg_spec = input_arg.strategies[0].output_spec
input_arg_dims_map = infer_broadcast_dims_map(
Expand Down Expand Up @@ -491,8 +496,19 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
return pointwise_strategy


def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""
Linear pointwise operators can propagate pending reductions.
For example, c = add(a, b); if a is pending sum, then c will be
pending sum as well without any communication overhead.
"""
return pointwise_strategy(mesh, op_schema, linearity=True)


for op in linear_pointwise_ops:
register_prop_rule(op)(linear_pointwise_rule)
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
linear_pointwise_strategy
)


for op in pointwise_ops:
Expand Down

0 comments on commit 3a8c211

Please sign in to comment.