Skip to content

Commit fc30bd3

Browse files
Revert "[dtensor] rewrite embedding ops using op strategy (#118079)"
This reverts commit e599a08. Reverted #118079 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](#118079 (comment)))
1 parent bfb5e76 commit fc30bd3

File tree

5 files changed

+99
-206
lines changed

5 files changed

+99
-206
lines changed

test/distributed/_tensor/test_embedding_ops.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import sys
44

55
import torch
6-
from torch.distributed._tensor import (
7-
distribute_module,
8-
distribute_tensor,
9-
DTensor,
10-
Replicate,
11-
Shard,
6+
from torch.distributed._tensor import DTensor
7+
from torch.distributed._tensor.placement_types import Replicate
8+
from torch.distributed.tensor.parallel import (
9+
ColwiseParallel,
10+
parallelize_module,
11+
RowwiseParallel,
1212
)
13-
from torch.distributed._tensor.debug import CommDebugMode
1413
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
1514
from torch.testing._internal.distributed._tensor.common_dtensor import (
1615
DTensorTestBase,
@@ -54,16 +53,12 @@ def _run_embedding_op_test(
5453
sharded_embedding.weight = torch.nn.Parameter(
5554
local_embedding.weight.clone().detach()
5655
)
57-
58-
def shard_embedding_fn(name, module, device_mesh):
59-
for name, param in module.named_parameters():
60-
dist_param = torch.nn.Parameter(
61-
distribute_tensor(param, device_mesh, [Shard(shard_dim)])
62-
)
63-
module.register_parameter(name, dist_param)
64-
65-
sharded_embedding = distribute_module(
66-
sharded_embedding, device_mesh, shard_embedding_fn
56+
parallelize_module(
57+
module=sharded_embedding,
58+
device_mesh=device_mesh,
59+
parallelize_plan=ColwiseParallel(output_layouts=Replicate())
60+
if shard_dim == 1
61+
else RowwiseParallel(),
6762
)
6863

6964
# Run sharded computation
@@ -74,14 +69,8 @@ def shard_embedding_fn(name, module, device_mesh):
7469
target = torch.empty(
7570
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
7671
).random_(0, 1)
77-
dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])
78-
79-
# fwd computation, ensure no comm happened
80-
with CommDebugMode() as fwd_mode:
81-
dist_output = sharded_embedding(dist_inp)
82-
self.assertEqual(fwd_mode.get_total_counts(), 0)
72+
output = sharded_embedding(inp)
8373

84-
output = dist_output.full_tensor()
8574
# Run local computation
8675
local_output = local_embedding(inp)
8776

@@ -90,24 +79,20 @@ def shard_embedding_fn(name, module, device_mesh):
9079

9180
# Use a sample cross entry loss to verify backward and grad computation.
9281
loss = torch.nn.CrossEntropyLoss()
93-
emb_loss = loss(
82+
attn_loss = loss(
9483
output,
9584
target,
9685
)
97-
emb_dup_loss = loss(
86+
attn_dup_loss = loss(
9887
local_output,
9988
target,
10089
)
90+
attn_loss.backward()
91+
attn_dup_loss.backward()
10192

102-
# local embedding backward
103-
emb_dup_loss.backward()
104-
105-
# sharded embedding bwd computation, ensure no comm happened
106-
with CommDebugMode() as bwd_mode:
107-
emb_loss.backward()
108-
self.assertEqual(bwd_mode.get_total_counts(), 0)
109-
110-
gradient = sharded_embedding.weight.grad.full_tensor()
93+
gradient = sharded_embedding.weight.grad.redistribute(
94+
device_mesh, [Replicate()]
95+
).to_local()
11196

11297
local_grad = local_embedding.weight.grad
11398

@@ -138,10 +123,10 @@ def test_sharded_embedding_colwise(self):
138123
self._run_embedding_op_test(1, [8, 6, 5, 4], 23, 13, padding_idx=12)
139124

140125
@with_comms
141-
def test_sharded_embedding_colwise_max_norm_errors(self):
126+
def test_sharded_embedding_colwise_errors(self):
142127
with self.assertRaisesRegex(
143128
NotImplementedError,
144-
"aten.embedding_renorm_.default does not have a sharding strategy registered.",
129+
"DTensor does not support sharded embedding operation with max_norm yet!",
145130
):
146131
self._run_embedding_op_test(
147132
1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
@@ -151,7 +136,7 @@ def test_sharded_embedding_colwise_max_norm_errors(self):
151136
def test_sharded_embedding_rowwise(self):
152137
with self.assertRaisesRegex(
153138
NotImplementedError,
154-
"row-wise sharded embedding operation yet",
139+
"RowwiseParallel currently only support nn.Linear!",
155140
):
156141
self._run_embedding_op_test(0, [5, 12], 16, 22)
157142

test/distributed/tensor/parallel/test_tp_style.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,6 @@ def test_colwise_parallel_style(self):
6767
self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1)
6868
self.assertEqual(comm_mode.get_total_counts(), 2)
6969

70-
@with_comms
71-
def test_colwise_parallel_embedding(self):
72-
mesh = init_device_mesh(self.device_type, (self.world_size,))
73-
74-
comm_mode = CommDebugMode()
75-
tensor = torch.arange(8, device=self.device_type).reshape(4, 2)
76-
model = nn.Embedding(16, 16, device=self.device_type)
77-
78-
default_col_parallel = ColwiseParallel()
79-
with comm_mode:
80-
colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
81-
out = colwise_mod(tensor)
82-
# ensure output shard on the last dim
83-
self.assertEqual(out.shape, (4, 2, 16 // self.world_size))
84-
# ensure no communication happened in fwd
85-
self.assertEqual(comm_mode.get_total_counts(), 0)
86-
87-
out.sum().backward()
88-
# no comm in bwd
89-
self.assertEqual(comm_mode.get_total_counts(), 0)
90-
9170
@with_comms
9271
def test_rowwise_parallel_style(self):
9372
mesh = init_device_mesh(self.device_type, (self.world_size,))

torch/distributed/_tensor/op_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __str__(self) -> str:
246246
args_sharding.append(str(arg))
247247
else:
248248
args_sharding.append(str(arg))
249-
return f"Op(op={self.op}, args_sharding={', '.join(args_sharding)} @ mesh: {mesh_shape})"
249+
return f"Op(op={self.op}, args_sharding={', '.join(args_sharding)}@ mesh: {mesh_shape})"
250250

251251
def __post_init__(self) -> None:
252252
has_symints = False
Lines changed: 74 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,97 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
# implement matrix related ops for distributed tensor
3-
import itertools
4-
from typing import cast, List
53

64
import torch
7-
from torch.distributed._tensor.op_schema import (
8-
OpSchema,
9-
OpStrategy,
10-
PlacementStrategy,
11-
StrategyType,
12-
)
13-
from torch.distributed._tensor.ops.utils import (
14-
generate_redistribute_costs,
15-
is_tensor_shardable,
16-
register_op_strategy,
17-
)
5+
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
6+
from torch.distributed._tensor.ops.utils import register_prop_rule
187

198
from torch.distributed._tensor.placement_types import (
209
_Partial,
2110
DTensorSpec,
22-
Placement,
2311
Replicate,
2412
Shard,
2513
)
2614

27-
from torch.distributed.device_mesh import DeviceMesh
28-
2915
aten = torch.ops.aten
3016

3117

32-
@register_op_strategy(aten.embedding.default)
33-
def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
34-
"""
35-
This strategy handles embedding op. We have two possible embedding shardings:
36-
rowwise and colwise
37-
# TODO: implement rowwise sharding
38-
"""
39-
weight_strategy = cast(OpStrategy, op_schema.args_schema[0])
40-
indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
41-
42-
weight_shape = weight_strategy.output_shape
43-
indices_shape = indices_strategy.output_shape
44-
output_emd_dim = len(indices_shape)
45-
46-
# guard rowwise sharding not implemented for now
47-
weight_spec = weight_strategy.strategies[0].output_spec
18+
# TODO: Enable BWD for embedding op.
19+
@register_prop_rule(aten.embedding.default)
20+
def embedding_rules(op_schema: OpSchema) -> OutputSharding:
21+
weight_spec, inp_spec = op_schema.args_spec
4822
if any(placement.is_shard(0) for placement in weight_spec.placements):
4923
raise NotImplementedError(
5024
"DTensor does not support row-wise sharded embedding operation yet!"
5125
)
5226

53-
all_mesh_dim_strategies = []
54-
55-
for mesh_dim in range(mesh.ndim):
56-
single_mesh_dim_strategies = []
57-
58-
# placement list stores placements of [output, weight, input_indices]
59-
# first we always have replicate all for inputs and output
60-
all_replicate: List[Placement] = [Replicate()] * 3
61-
single_mesh_dim_strategies.append(all_replicate)
62-
63-
# colwise sharding, output shard on last dim, weight shard on dim 1, input replicate
64-
colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()]
65-
single_mesh_dim_strategies.append(colwise_sharding)
66-
67-
# batch dim sharding, weight replicated, input can shard on any dim, output follows input
68-
for input_dim in range(len(indices_shape)):
69-
batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)]
70-
single_mesh_dim_strategies.append(batch_sharding)
71-
72-
all_mesh_dim_strategies.append(single_mesh_dim_strategies)
73-
74-
strategy_combs = itertools.product(*all_mesh_dim_strategies)
75-
76-
all_strategies = []
77-
for strategy_comb in strategy_combs:
78-
spec_list = []
79-
for specs in zip(*strategy_comb):
80-
spec_list.append(DTensorSpec(mesh, tuple(specs)))
81-
82-
if is_tensor_shardable(weight_shape, spec_list[1]) and is_tensor_shardable(
83-
indices_shape, spec_list[2]
84-
):
85-
# only add to the strategy list when both weight and indices are shardable
86-
weight_spec, indices_spec = spec_list[1:]
87-
redistribute_cost = [
88-
generate_redistribute_costs(weight_strategy, weight_spec),
89-
generate_redistribute_costs(indices_strategy, indices_spec),
90-
]
91-
strat = PlacementStrategy(
92-
output_specs=spec_list[0],
93-
input_specs=spec_list[1:],
94-
redistribute_cost=redistribute_cost,
95-
)
96-
all_strategies.append(strat)
97-
98-
return OpStrategy(all_strategies)
99-
100-
101-
@register_op_strategy(aten.embedding_dense_backward.default)
102-
def embedding_dense_backward_strategy(
103-
mesh: DeviceMesh, op_schema: OpSchema
104-
) -> StrategyType:
105-
"""
106-
This strategy handles embedding op. We have two possible embedding shardings:
107-
rowwise and colwise
108-
# TODO: implement rowwise sharding backward
109-
"""
110-
grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0])
111-
indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
112-
113-
grad_out_shape = grad_out_strategy.output_shape
114-
indices_shape = indices_strategy.output_shape
115-
grad_out_ndim = len(grad_out_shape)
116-
117-
all_mesh_dim_strategies = []
118-
119-
for mesh_dim in range(mesh.ndim):
120-
single_mesh_dim_strategies = []
121-
122-
# placement list stores placements of [output, weight, input_indices]
123-
# first we always have replicate all for inputs and output
124-
all_replicate: List[Placement] = [Replicate()] * 3
125-
single_mesh_dim_strategies.append(all_replicate)
126-
127-
# colwise sharding backward, grad_out shard on last dim, input replicate,
128-
# weight grad shard colwise
129-
colwise_sharding = [Shard(1), Shard(grad_out_ndim - 1), Replicate()]
130-
single_mesh_dim_strategies.append(colwise_sharding)
131-
132-
# batch dim sharding, weight replicated, grad_out/input have same sharding
133-
# that can shard on any dim, weight grad partial
134-
for input_dim in range(len(indices_shape)):
135-
batch_sharding = [_Partial(), Shard(input_dim), Shard(input_dim)]
136-
single_mesh_dim_strategies.append(batch_sharding)
137-
138-
# grad_out partial, input replicate, weight grad keep partial
139-
partial_sharding = [_Partial(), _Partial(), Replicate()]
140-
single_mesh_dim_strategies.append(partial_sharding)
141-
142-
all_mesh_dim_strategies.append(single_mesh_dim_strategies)
143-
144-
strategy_combs = itertools.product(*all_mesh_dim_strategies)
27+
if weight_spec.is_replicated() and inp_spec.placements == [Shard(0)]:
28+
# Embedding table is replicated, input ids are sharded along batch
29+
# dimension. Output lookups should match input sharding spec in this case.
30+
return OutputSharding(
31+
output_spec=DTensorSpec(mesh=inp_spec.mesh, placements=inp_spec.placements)
32+
)
14533

146-
all_strategies = []
147-
for strategy_comb in strategy_combs:
148-
spec_list = []
149-
for specs in zip(*strategy_comb):
150-
spec_list.append(DTensorSpec(mesh, tuple(specs)))
34+
if inp_spec.is_replicated():
35+
weight_dim_map = weight_spec.dim_map
36+
output_dim_map = inp_spec.dim_map
37+
output_dim_map.append(weight_dim_map[1])
38+
return OutputSharding(
39+
output_spec=DTensorSpec.from_dim_map(inp_spec.mesh, output_dim_map, [])
40+
)
15141

152-
if is_tensor_shardable(grad_out_shape, spec_list[1]) and is_tensor_shardable(
153-
indices_shape, spec_list[2]
154-
):
155-
# only add to the strategy list when both grad_out and indices are shardable
156-
grad_out_spec, indices_spec = spec_list[1:]
157-
redistribute_cost = [
158-
generate_redistribute_costs(grad_out_strategy, grad_out_spec),
159-
generate_redistribute_costs(indices_strategy, indices_spec),
160-
]
161-
strat = PlacementStrategy(
162-
output_specs=spec_list[0],
163-
input_specs=spec_list[1:],
164-
redistribute_cost=redistribute_cost,
42+
return OutputSharding(
43+
output_spec=None,
44+
schema_suggestions=[
45+
OpSchema(
46+
op=op_schema.op,
47+
args_schema=(
48+
weight_spec,
49+
DTensorSpec(
50+
mesh=inp_spec.mesh,
51+
placements=tuple([Replicate()] * len(inp_spec.placements)),
52+
tensor_meta=inp_spec.tensor_meta,
53+
),
54+
),
55+
kwargs_schema=op_schema.kwargs_schema,
16556
)
166-
all_strategies.append(strat)
167-
168-
return OpStrategy(all_strategies)
57+
],
58+
)
59+
60+
61+
@register_prop_rule(aten.embedding_renorm_.default)
62+
def embedding_renorm_rules(op_schema: OpSchema) -> OutputSharding:
63+
raise NotImplementedError(
64+
"DTensor does not support sharded embedding operation with max_norm yet!"
65+
)
66+
67+
68+
@register_prop_rule(aten.embedding_dense_backward.default)
69+
def embedding_dense_backward_rules(op_schema: OpSchema) -> OutputSharding:
70+
grad_output, indices = op_schema.args_schema[:2]
71+
assert isinstance(grad_output, DTensorSpec)
72+
assert isinstance(indices, DTensorSpec)
73+
if grad_output.placements == indices.placements:
74+
# The embedding table is replicated, and input/oupput activations are
75+
# sharded. In this case, gradients for the embedding table should be
76+
# Partial.
77+
return OutputSharding(
78+
output_spec=DTensorSpec(mesh=indices.mesh, placements=(_Partial(),))
79+
)
80+
elif grad_output.placements == [_Partial()] and indices.placements == [Replicate()]:
81+
# The embedding table is replicated and the indices is also replicated
82+
# (local is a more precise term). This is postional embedding. In this
83+
# case, gradients for the embmedding table should be Partial.
84+
return OutputSharding(
85+
output_spec=DTensorSpec(mesh=indices.mesh, placements=(_Partial(),))
86+
)
87+
elif all(placement.is_replicate() for placement in indices.placements):
88+
# BWD for colwise sharding case
89+
return OutputSharding(
90+
output_spec=DTensorSpec(mesh=indices.mesh, placements=(Shard(1),))
91+
)
92+
else:
93+
raise NotImplementedError(
94+
"Unsupported embedding dense backward schema:\n"
95+
f"grad_output - {grad_output}\n"
96+
f"indices - {indices}"
97+
)

0 commit comments

Comments
 (0)