Skip to content

Commit

Permalink
[dtensor] switch softmax forward ops to OpStrategy (#117723)
Browse files Browse the repository at this point in the history
**Summary**
This PR switches the softmax and log_softmax ops to use OpStrategy instead of rules. This PR also adds support when the softmax dimension is sharded -- a replication is performed before computation.

**Test**
`python test/distributed/_tensor/test_math_ops.py -k test_softmax_fwd`
`python test/distributed/_tensor/test_math_ops.py -k test_softmax_with_bwd`

Pull Request resolved: #117723
Approved by: https://github.com/XilunWu
  • Loading branch information
tianyu-l authored and pytorchmergebot committed Jan 22, 2024
1 parent fdac55c commit 86e8551
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 31 deletions.
43 changes: 20 additions & 23 deletions test/distributed/_tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torch.distributed._tensor import DeviceMesh, distribute_module, distribute_tensor
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.ops.utils import is_tensor_partial
from torch.distributed._tensor.ops.utils import is_tensor_partial, normalize_dim
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
Expand Down Expand Up @@ -73,18 +73,14 @@ def test_softmax_fwd(self):
x, dim=softmax_dim, dtype=torch.float32
)
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
)
shard_dim = normalize_dim(shard_dim, dist_x.ndim)
if dims[shard_dim] == dims[softmax_dim]:
with self.assertRaisesRegex(
Exception, "Cannot run .* on sharding dimension!$"
):
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
)
self.assertTrue(dist_y.placements[0].is_replicate())
self.assertEqual(dist_y.to_local(), local_y)
else:
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
)
shard_dim = shard_dim + dist_y.ndim if shard_dim < 0 else shard_dim
self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim))
self.assertEqual(dist_y.full_tensor(), local_y)

Expand Down Expand Up @@ -112,22 +108,23 @@ def test_softmax_with_bwd(self):

dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
self.assertTrue(dist_x.requires_grad)
dist_softmax = dist_x.softmax(dim=softmax_dim)
shard_dim = normalize_dim(shard_dim, dist_x.ndim)
if dims[softmax_dim] == dims[shard_dim]:
with self.assertRaisesRegex(
Exception, "Cannot run .* on sharding dimension!$"
):
dist_softmax = dist_x.softmax(dim=softmax_dim)
self.assertTrue(dist_softmax.placements[0].is_replicate())
else:
dist_softmax = dist_x.softmax(dim=softmax_dim)
shard_dim = shard_dim + dist_x.ndim if shard_dim < 0 else shard_dim
self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim))
dist_y = dist_softmax.sum()
dist_y = dist_softmax.sum()
if dims[softmax_dim] == dims[shard_dim]:
self.assertTrue(dist_y.placements[0].is_replicate())
else:
self.assertTrue(dist_y.placements[0].is_partial())
dist_y = dist_y.redistribute(device_mesh, [Replicate()])
self.assertEqual(dist_y.to_local(), local_y)
self.assertIsNone(dist_x.grad)
dist_y.backward()
self.assertIsNotNone(dist_x.grad)
self.assertEqual(dist_x.grad.full_tensor(), x.grad)
self.assertEqual(dist_y.to_local(), local_y)
self.assertIsNone(dist_x.grad)
dist_y.backward()
self.assertIsNotNone(dist_x.grad)
self.assertEqual(dist_x.grad.full_tensor(), x.grad)

@with_comms
def test_full_shard_math_ops(self):
Expand Down
41 changes: 33 additions & 8 deletions torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed._tensor.ops.utils import (
as_list,
generate_redistribute_costs,
normalize_dim,
normalize_dims,
normalize_to_torch_size,
register_op_strategy,
Expand Down Expand Up @@ -221,17 +222,41 @@ def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
)


@register_prop_rule(
@register_op_strategy(
[aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1)
)
def softmax_rule(op_schema: OpSchema) -> OutputSharding:
input_spec, softmax_dim, _ = op_schema.args_schema
input_spec = cast(DTensorSpec, input_spec)
def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
input_strategy, softmax_dim, _ = op_schema.args_schema
input_strategy = cast(OpStrategy, input_strategy)
softmax_dim = cast(int, softmax_dim)
dim_map = input_spec.dim_map
if softmax_dim < len(dim_map) and dim_map[softmax_dim] >= 0:
raise RuntimeError("Cannot run softmax on sharding dimension!")
return OutputSharding(input_spec)
softmax_dim = normalize_dim(softmax_dim, input_strategy.output_ndim)

output_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
redistribute_costs = []
input_src_spec = input_placement_strategy.out_spec

# make sure input is replicated along the softmax dim
input_target_spec = DTensorSpec(
mesh=mesh,
placements=replicate_reduction_dims(
input_src_spec.placements, [softmax_dim]
),
tensor_meta=input_src_spec.tensor_meta,
)
redistribute_costs.append(
generate_redistribute_costs(input_strategy, input_target_spec)
)
output_target_spec = input_target_spec
output_strategy.strategies.append(
PlacementStrategy(
output_spec=output_target_spec,
input_specs=[input_target_spec],
redistribute_cost=redistribute_costs,
)
)

return output_strategy


@register_prop_rule(
Expand Down

0 comments on commit 86e8551

Please sign in to comment.