@@ -39,7 +39,6 @@ def _scaled_grouped_mm(
39
39
and in column-major memory layout.
40
40
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
41
41
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42
- use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
43
42
"""
44
43
logger .info ("Using differentiable _scaled_grouped_mm" )
45
44
return _Float8GroupedMM .apply (
@@ -61,8 +60,8 @@ def forward(
61
60
offs : Optional [torch .Tensor ] = None ,
62
61
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
63
62
) -> torch .Tensor :
64
- # torchao _scaled_grouped_mm only supports A=2D, B=3D.
65
- assert A .ndim == 2 , "A must be 2D"
63
+ # torchao _scaled_grouped_mm only supports A=2D|3D + B=3D.
64
+ assert A .ndim == 2 or A . ndim == 3 , "A must be 2D or 3D "
66
65
assert B_t .ndim == 3 , "B must be 3D"
67
66
68
67
assert A .size (- 1 ) % 16 == 0 , (
@@ -151,12 +150,25 @@ def forward(
151
150
assert _is_column_major (B_t_fp8_col_major ), (
152
151
"B must be column-major for output = A @ B"
153
152
)
153
+
154
+ # TODO: remove excessive logging once prototype is more mature.
155
+ logger .debug (
156
+ (
157
+ f"forward scaled_grouped_mm: A_fp8_row_major.shape={ A_fp8_row_major .shape } , "
158
+ f"A_scale.shape={ A_scales .squeeze (- 1 ).shape } , "
159
+ f"B_t_fp8_col_major.shape={ B_t_fp8_col_major .shape } , "
160
+ f"B_t_scale.shape={ B_t_scales .squeeze (1 ).shape } , "
161
+ f"offs={ offs if offs is not None else None } "
162
+ )
163
+ )
154
164
return torch ._scaled_grouped_mm (
155
165
A_fp8_row_major ,
156
166
B_t_fp8_col_major ,
157
- A_scales .squeeze ().reciprocal (),
158
- B_t_scales .squeeze ().reciprocal (),
159
- offs ,
167
+ # Squeeze A scales to: (B, S, 1) => (B, M), or (B*S, 1) => (B*S)
168
+ A_scales .squeeze (- 1 ).reciprocal (),
169
+ # Squeeze B scales to: (B, 1, N) => (B, N)
170
+ B_t_scales .squeeze (1 ).reciprocal (),
171
+ offs = offs ,
160
172
out_dtype = out_dtype ,
161
173
use_fast_accum = True ,
162
174
)
@@ -193,12 +205,20 @@ def backward(ctx, grad_output: torch.Tensor):
193
205
assert _is_column_major (B_fp8_col_major ), (
194
206
"B must be column-major for grad_A = grad_output @ B"
195
207
)
208
+ logger .debug (
209
+ (
210
+ f"backward grad_A: grad_output_fp8_row_major.shape={ grad_output_fp8_row_major .shape } , "
211
+ f"grad_output_scale.shape={ grad_output_scales .shape } , "
212
+ f"B_fp8_col_major.shape={ B_fp8_col_major .shape } , "
213
+ f"B_scale.shape={ B_scales .shape } , "
214
+ )
215
+ )
196
216
grad_A = torch ._scaled_grouped_mm (
197
217
grad_output_fp8_row_major ,
198
218
B_fp8_col_major ,
199
- grad_output_scales .squeeze ().reciprocal (),
200
- B_scales .squeeze ().reciprocal (),
201
- offs ,
219
+ grad_output_scales .squeeze (- 1 ).reciprocal (),
220
+ B_scales .squeeze (1 ).reciprocal (),
221
+ offs = offs ,
202
222
out_dtype = out_dtype ,
203
223
use_fast_accum = True ,
204
224
)
@@ -238,12 +258,21 @@ def backward(ctx, grad_output: torch.Tensor):
238
258
assert _is_column_major (A_fp8_col_major ), (
239
259
"A must be column-major for grad_B = grad_output_t @ A"
240
260
)
261
+
262
+ logger .debug (
263
+ (
264
+ f"backward grad_B: grad_output_t_fp8_row_major.shape={ grad_output_t_fp8_row_major .shape } , "
265
+ f"grad_output_t_scale.shape={ grad_output_t_scales .shape } , "
266
+ f"A_fp8_col_major.shape={ A_fp8_col_major .shape } , "
267
+ f"A_scale.shape={ A_scales .shape } , "
268
+ )
269
+ )
241
270
grad_B = torch ._scaled_grouped_mm (
242
271
grad_output_t_fp8_row_major ,
243
272
A_fp8_col_major ,
244
273
grad_output_t_scales .reciprocal (),
245
274
A_scales .reciprocal (),
246
- offs ,
275
+ offs = offs ,
247
276
out_dtype = out_dtype ,
248
277
use_fast_accum = True ,
249
278
)
0 commit comments