Skip to content

Commit abbed3a

Browse files
y-sqfacebook-github-bot
authored andcommitted
Add the workaround to support rowwise scaled_gemm for fp32 outputs (#2431)
Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
1 parent 2025b75 commit abbed3a

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

test/float8/test_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
e5m2_dtype,
3535
)
3636
from torchao.float8.float8_linear import Float8Linear
37-
from torchao.float8.float8_linear_utils import (
38-
convert_to_float8_training,
39-
)
37+
from torchao.float8.float8_linear_utils import convert_to_float8_training
4038
from torchao.float8.float8_ops import addmm_float8_unwrapped
4139
from torchao.float8.float8_scaling_utils import (
4240
get_maybe_axiswise_dim,
@@ -379,12 +377,16 @@ def test_linear_from_config_params(
379377
)
380378
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
381379
@pytest.mark.parametrize("linear_bias", [True, False])
380+
@pytest.mark.parametrize(
381+
"linear_dtype", [torch.bfloat16, torch.float16, torch.float32]
382+
)
382383
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
383384
@skip_if_rocm("ROCm enablement in progress")
384385
def test_linear_from_recipe(
385386
self,
386387
recipe_name,
387388
x_shape,
389+
linear_dtype: torch.dtype,
388390
linear_bias: bool,
389391
):
390392
if torch.cuda.get_device_capability() < (9, 0):
@@ -393,7 +395,6 @@ def test_linear_from_recipe(
393395
)
394396
pytest.skip()
395397

396-
linear_dtype = torch.bfloat16
397398
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
398399
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
399400
config = Float8LinearConfig.from_recipe_name(recipe_name)

torchao/float8/float8_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def addmm_float8_unwrapped(
5454
a_inverse_scale = a_inverse_scale.new_ones(())
5555
b_inverse_scale = a_inverse_scale.new_ones(())
5656

57+
# work around torch._scaled_mm not having float32 output type
58+
# TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output
59+
orig_dtype = output_dtype
60+
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
61+
output_dtype = torch.bfloat16
62+
5763
post_bias = None
5864
if output_dtype == torch.float32:
5965
# Bias is not supported by _scaled_mm when output is fp32
@@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
7682
if post_bias is not None:
7783
output += post_bias
7884

85+
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
86+
output = output.to(orig_dtype)
87+
7988
return output
8089

8190

0 commit comments

Comments
 (0)