Skip to content

Commit 44778d0

Browse files
tp on routed experts working
1 parent efd993f commit 44778d0

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import logging
17
from typing import Callable, Optional
28

39
from torch import nn
@@ -8,6 +14,8 @@
814
register_quantize_module_handler,
915
)
1016

17+
logger: logging.Logger = logging.getLogger(__name__)
18+
1119

1220
class MoETrainingConfig(AOBaseConfig):
1321
"""
@@ -105,7 +113,9 @@ def post_order_traversal(
105113
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
106114
)
107115
setattr(module, param_name, new_param)
108-
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
116+
logger.info(
117+
f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor"
118+
)
109119

110120
post_order_traversal(root_module)
111121
return root_module

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def _scaled_grouped_mm(
3939
and in column-major memory layout.
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42-
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
4342
"""
4443
logger.info("Using differentiable _scaled_grouped_mm")
4544
return _Float8GroupedMM.apply(
@@ -61,8 +60,8 @@ def forward(
6160
offs: Optional[torch.Tensor] = None,
6261
out_dtype: Optional[torch.dtype] = torch.bfloat16,
6362
) -> torch.Tensor:
64-
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
65-
assert A.ndim == 2, "A must be 2D"
63+
# torchao _scaled_grouped_mm only supports A=2D|3D + B=3D.
64+
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
6665
assert B_t.ndim == 3, "B must be 3D"
6766

6867
assert A.size(-1) % 16 == 0, (
@@ -151,12 +150,25 @@ def forward(
151150
assert _is_column_major(B_t_fp8_col_major), (
152151
"B must be column-major for output = A @ B"
153152
)
153+
154+
# TODO: remove excessive logging once prototype is more mature.
155+
logger.debug(
156+
(
157+
f"forward scaled_grouped_mm: A_fp8_row_major.shape={A_fp8_row_major.shape}, "
158+
f"A_scale.shape={A_scales.squeeze(-1).shape}, "
159+
f"B_t_fp8_col_major.shape={B_t_fp8_col_major.shape}, "
160+
f"B_t_scale.shape={B_t_scales.squeeze(1).shape}, "
161+
f"offs={offs if offs is not None else None}"
162+
)
163+
)
154164
return torch._scaled_grouped_mm(
155165
A_fp8_row_major,
156166
B_t_fp8_col_major,
157-
A_scales.squeeze().reciprocal(),
158-
B_t_scales.squeeze().reciprocal(),
159-
offs,
167+
# Squeeze A scales to: (B, S, 1) => (B, M), or (B*S, 1) => (B*S)
168+
A_scales.squeeze(-1).reciprocal(),
169+
# Squeeze B scales to: (B, 1, N) => (B, N)
170+
B_t_scales.squeeze(1).reciprocal(),
171+
offs=offs,
160172
out_dtype=out_dtype,
161173
use_fast_accum=True,
162174
)
@@ -193,12 +205,20 @@ def backward(ctx, grad_output: torch.Tensor):
193205
assert _is_column_major(B_fp8_col_major), (
194206
"B must be column-major for grad_A = grad_output @ B"
195207
)
208+
logger.debug(
209+
(
210+
f"backward grad_A: grad_output_fp8_row_major.shape={grad_output_fp8_row_major.shape}, "
211+
f"grad_output_scale.shape={grad_output_scales.shape}, "
212+
f"B_fp8_col_major.shape={B_fp8_col_major.shape}, "
213+
f"B_scale.shape={B_scales.shape}, "
214+
)
215+
)
196216
grad_A = torch._scaled_grouped_mm(
197217
grad_output_fp8_row_major,
198218
B_fp8_col_major,
199-
grad_output_scales.squeeze().reciprocal(),
200-
B_scales.squeeze().reciprocal(),
201-
offs,
219+
grad_output_scales.squeeze(-1).reciprocal(),
220+
B_scales.squeeze(1).reciprocal(),
221+
offs=offs,
202222
out_dtype=out_dtype,
203223
use_fast_accum=True,
204224
)
@@ -238,12 +258,21 @@ def backward(ctx, grad_output: torch.Tensor):
238258
assert _is_column_major(A_fp8_col_major), (
239259
"A must be column-major for grad_B = grad_output_t @ A"
240260
)
261+
262+
logger.debug(
263+
(
264+
f"backward grad_B: grad_output_t_fp8_row_major.shape={grad_output_t_fp8_row_major.shape}, "
265+
f"grad_output_t_scale.shape={grad_output_t_scales.shape}, "
266+
f"A_fp8_col_major.shape={A_fp8_col_major.shape}, "
267+
f"A_scale.shape={A_scales.shape}, "
268+
)
269+
)
241270
grad_B = torch._scaled_grouped_mm(
242271
grad_output_t_fp8_row_major,
243272
A_fp8_col_major,
244273
grad_output_t_scales.reciprocal(),
245274
A_scales.reciprocal(),
246-
offs,
275+
offs=offs,
247276
out_dtype=out_dtype,
248277
use_fast_accum=True,
249278
)

torchao/prototype/moe_training/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
7575
# used for shared experts. This is basically the grouped_mm
7676
# kernel handling a bmm.
7777
A, B = args[0], args[1]
78-
A_is_2d = A.dim() == 2
78+
A_is_2d_or_3d = A.dim() in (2, 3)
7979
B_is_3d = B.dim() == 3
8080
has_offs = kwargs.get(cls.offs_arg_name) is not None
81-
logger.info(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
82-
83-
if A_is_2d and B_is_3d:
81+
logger.debug(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
82+
83+
if A_is_2d_or_3d and B_is_3d:
8484
return _scaled_grouped_mm(
8585
*args,
8686
**kwargs,

0 commit comments

Comments
 (0)