@@ -54,6 +54,12 @@ def addmm_float8_unwrapped(
54
54
a_inverse_scale = a_inverse_scale .new_ones (())
55
55
b_inverse_scale = a_inverse_scale .new_ones (())
56
56
57
+ # work around torch._scaled_mm not having float32 output type
58
+ # TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output
59
+ orig_dtype = output_dtype
60
+ if orig_dtype in (torch .float16 , torch .float32 ) and is_rowwise_scaling :
61
+ output_dtype = torch .bfloat16
62
+
57
63
post_bias = None
58
64
if output_dtype == torch .float32 :
59
65
# Bias is not supported by _scaled_mm when output is fp32
@@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
76
82
if post_bias is not None :
77
83
output += post_bias
78
84
85
+ if orig_dtype in (torch .float16 , torch .float32 ) and is_rowwise_scaling :
86
+ output = output .to (orig_dtype )
87
+
79
88
return output
80
89
81
90
@@ -260,24 +269,24 @@ def float8_cat(aten_op, args, kwargs=None):
260
269
gemm_input_role = chunked_tensors [0 ]._gemm_input_role
261
270
chunk_data = []
262
271
for chunk in chunked_tensors :
263
- assert isinstance (chunk , Float8Tensor ), (
264
- "Expecting all chunks to be of type Float8Tensor"
265
- )
266
- assert chunk . _orig_dtype == orig_dtype , (
267
- "Expecting all chunks to be of the same dtype"
268
- )
269
- assert chunk . _scale is scale , (
270
- "Expecting all chunks to have thee same scale as a result of a split"
271
- )
272
- assert chunk . _linear_mm_config is mm_config , (
273
- "Expecting all chunks to have thee same mm config as a result of a split"
274
- )
275
- assert chunk . _data . dtype == fp8_dtype , (
276
- "Expecting all chunks to be of the same dtype as a result of a split"
277
- )
278
- assert chunk . _gemm_input_role is gemm_input_role , (
279
- "Expecting all chunks to have the same gemm_input_role as a result of a split"
280
- )
272
+ assert isinstance (
273
+ chunk , Float8Tensor
274
+ ), "Expecting all chunks to be of type Float8Tensor"
275
+ assert (
276
+ chunk . _orig_dtype == orig_dtype
277
+ ), "Expecting all chunks to be of the same dtype"
278
+ assert (
279
+ chunk . _scale is scale
280
+ ), "Expecting all chunks to have thee same scale as a result of a split"
281
+ assert (
282
+ chunk . _linear_mm_config is mm_config
283
+ ), "Expecting all chunks to have thee same mm config as a result of a split"
284
+ assert (
285
+ chunk . _data . dtype == fp8_dtype
286
+ ), "Expecting all chunks to be of the same dtype as a result of a split"
287
+ assert (
288
+ chunk . _gemm_input_role is gemm_input_role
289
+ ), "Expecting all chunks to have the same gemm_input_role as a result of a split"
281
290
_assert_tensorwise_scale (aten_op , chunk ._scale )
282
291
chunk_data .append (chunk ._data .view (torch .uint8 ))
283
292
@@ -320,9 +329,9 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
320
329
)
321
330
322
331
if scaled_mm_config .pad_inner_dim :
323
- assert a ._data .size (1 ) == b ._data .size (0 ), (
324
- f"Inner dims must match for mm, got { a . _data . size ( 1 ) } and { b . _data . size ( 0 ) } "
325
- )
332
+ assert a ._data .size (1 ) == b ._data .size (
333
+ 0
334
+ ), f"Inner dims must match for mm, got { a . _data . size ( 1 ) } and { b . _data . size ( 0 ) } "
326
335
a_data = pad_tensor_for_matmul (a_data , dims = 1 )
327
336
b_data = pad_tensor_for_matmul (b_data , dims = 0 )
328
337
@@ -353,10 +362,10 @@ def float8_mm(aten_op, args, kwargs=None):
353
362
a = args [0 ]
354
363
b = args [1 ]
355
364
356
- assert isinstance (a , Float8Tensor ) and isinstance (b , Float8Tensor ), (
357
- "Expecting both Float8Tensor for mm inputs but found {} and {}" . format (
358
- type ( a ), type ( b )
359
- )
365
+ assert isinstance (a , Float8Tensor ) and isinstance (
366
+ b , Float8Tensor
367
+ ), "Expecting both Float8Tensor for mm inputs but found {} and {}" . format (
368
+ type ( a ), type ( b )
360
369
)
361
370
a_data , a_scale , b_data , b_scale = preprocess_addmm (a , b )
362
371
output_dtype = a ._orig_dtype
@@ -434,9 +443,9 @@ def autocast_to_copy(aten_op, args, kwargs=None):
434
443
"""
435
444
_assert_tensorwise_scale (aten_op , args [0 ]._scale )
436
445
assert isinstance (args [0 ], Float8Tensor )
437
- assert len ( kwargs ) == 1 and "dtype" in kwargs , (
438
- "Only support dtype kwarg for autocast"
439
- )
446
+ assert (
447
+ len ( kwargs ) == 1 and "dtype" in kwargs
448
+ ), "Only support dtype kwarg for autocast"
440
449
assert kwargs ["dtype" ] in {
441
450
torch .float16 ,
442
451
torch .bfloat16 ,
@@ -462,9 +471,9 @@ def allgather_fp8(aten_op, args, kwargs=None):
462
471
"""
463
472
_assert_tensorwise_scale (aten_op , args [0 ]._scale )
464
473
fp8_input = args [0 ]
465
- assert isinstance (fp8_input , Float8Tensor ), (
466
- f"expecting a Float8Tensor for allgather but found { type ( fp8_input ) } "
467
- )
474
+ assert isinstance (
475
+ fp8_input , Float8Tensor
476
+ ), f"expecting a Float8Tensor for allgather but found { type ( fp8_input ) } "
468
477
469
478
fp8_data = fp8_input ._data
470
479
fp8_data = fp8_data .contiguous ()
@@ -536,21 +545,21 @@ def copy_fp8(aten_op, args, kwargs=None):
536
545
return aten_op (self , src_hp , * args [2 :], ** kwargs )
537
546
elif isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
538
547
_assert_tensorwise_scale (aten_op , src ._scale )
539
- assert self . _orig_dtype == src . _orig_dtype , (
540
- "Expecting both Float8Tensors to be of the same dtype"
541
- )
542
- assert self . _scale == src . _scale , (
543
- "Expecting both Float8Tensors to have thee same scale"
544
- )
545
- assert self . _linear_mm_config == src . _linear_mm_config , (
546
- "Expecting both Float8Tensors to have thee same mm config"
547
- )
548
- assert self . _data . dtype == src . _data . dtype , (
549
- "Expecting both Float8Tensors to be of the same dtypet"
550
- )
551
- assert self . _gemm_input_role == src . _gemm_input_role , (
552
- "Expecting both Float8Tensors to have the same gemm_input_role"
553
- )
548
+ assert (
549
+ self . _orig_dtype == src . _orig_dtype
550
+ ), "Expecting both Float8Tensors to be of the same dtype"
551
+ assert (
552
+ self . _scale == src . _scale
553
+ ), "Expecting both Float8Tensors to have thee same scale"
554
+ assert (
555
+ self . _linear_mm_config == src . _linear_mm_config
556
+ ), "Expecting both Float8Tensors to have thee same mm config"
557
+ assert (
558
+ self . _data . dtype == src . _data . dtype
559
+ ), "Expecting both Float8Tensors to be of the same dtypet"
560
+ assert (
561
+ self . _gemm_input_role == src . _gemm_input_role
562
+ ), "Expecting both Float8Tensors to have the same gemm_input_role"
554
563
fp8_out = aten_op (self ._data , src ._data , * args [2 :], ** kwargs )
555
564
return Float8Tensor (
556
565
fp8_out ,
0 commit comments