Skip to content

Commit 92c3668

Browse files
y-sqfacebook-github-bot
authored andcommitted
Add the workaround to support rowwise scaled_gemm for fp32 outputs (pytorch#2431)
Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
1 parent 2025b75 commit 92c3668

File tree

2 files changed

+69
-59
lines changed

2 files changed

+69
-59
lines changed

test/float8/test_base.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
e5m2_dtype,
3535
)
3636
from torchao.float8.float8_linear import Float8Linear
37-
from torchao.float8.float8_linear_utils import (
38-
convert_to_float8_training,
39-
)
37+
from torchao.float8.float8_linear_utils import convert_to_float8_training
4038
from torchao.float8.float8_ops import addmm_float8_unwrapped
4139
from torchao.float8.float8_scaling_utils import (
4240
get_maybe_axiswise_dim,
@@ -379,12 +377,16 @@ def test_linear_from_config_params(
379377
)
380378
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
381379
@pytest.mark.parametrize("linear_bias", [True, False])
380+
@pytest.mark.parametrize(
381+
"linear_dtype", [torch.bfloat16, torch.float16, torch.float32]
382+
)
382383
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
383384
@skip_if_rocm("ROCm enablement in progress")
384385
def test_linear_from_recipe(
385386
self,
386387
recipe_name,
387388
x_shape,
389+
linear_dtype: torch.dtype,
388390
linear_bias: bool,
389391
):
390392
if torch.cuda.get_device_capability() < (9, 0):
@@ -393,7 +395,6 @@ def test_linear_from_recipe(
393395
)
394396
pytest.skip()
395397

396-
linear_dtype = torch.bfloat16
397398
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
398399
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
399400
config = Float8LinearConfig.from_recipe_name(recipe_name)
@@ -436,9 +437,9 @@ def test_autocast_outputs(
436437

437438
with torch.autocast("cuda", dtype=torch.bfloat16):
438439
y = m(x)
439-
assert y.dtype == torch.bfloat16, (
440-
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
441-
)
440+
assert (
441+
y.dtype == torch.bfloat16
442+
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
442443

443444
@pytest.mark.parametrize(
444445
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
@@ -467,9 +468,9 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
467468

468469
with torch.autocast("cuda", dtype=torch.bfloat16):
469470
y = m(x)
470-
assert y.dtype == torch.bfloat16, (
471-
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
472-
)
471+
assert (
472+
y.dtype == torch.bfloat16
473+
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
473474

474475
def test_repr(self):
475476
m = nn.Linear(32, 16)
@@ -500,9 +501,9 @@ def test_quantize(self):
500501
from torchao.quantization.quant_api import float8_weight_only, quantize_
501502

502503
quantize_(m, float8_weight_only())
503-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
504-
"Post quantization dtype should be torch.float8_e4m3fn"
505-
)
504+
assert (
505+
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
506+
), "Post quantization dtype should be torch.float8_e4m3fn"
506507
with torch.no_grad():
507508
m(x)
508509

torchao/float8/float8_ops.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def addmm_float8_unwrapped(
5454
a_inverse_scale = a_inverse_scale.new_ones(())
5555
b_inverse_scale = a_inverse_scale.new_ones(())
5656

57+
# work around torch._scaled_mm not having float32 output type
58+
# TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output
59+
orig_dtype = output_dtype
60+
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
61+
output_dtype = torch.bfloat16
62+
5763
post_bias = None
5864
if output_dtype == torch.float32:
5965
# Bias is not supported by _scaled_mm when output is fp32
@@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
7682
if post_bias is not None:
7783
output += post_bias
7884

85+
if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling:
86+
output = output.to(orig_dtype)
87+
7988
return output
8089

8190

@@ -260,24 +269,24 @@ def float8_cat(aten_op, args, kwargs=None):
260269
gemm_input_role = chunked_tensors[0]._gemm_input_role
261270
chunk_data = []
262271
for chunk in chunked_tensors:
263-
assert isinstance(chunk, Float8Tensor), (
264-
"Expecting all chunks to be of type Float8Tensor"
265-
)
266-
assert chunk._orig_dtype == orig_dtype, (
267-
"Expecting all chunks to be of the same dtype"
268-
)
269-
assert chunk._scale is scale, (
270-
"Expecting all chunks to have thee same scale as a result of a split"
271-
)
272-
assert chunk._linear_mm_config is mm_config, (
273-
"Expecting all chunks to have thee same mm config as a result of a split"
274-
)
275-
assert chunk._data.dtype == fp8_dtype, (
276-
"Expecting all chunks to be of the same dtype as a result of a split"
277-
)
278-
assert chunk._gemm_input_role is gemm_input_role, (
279-
"Expecting all chunks to have the same gemm_input_role as a result of a split"
280-
)
272+
assert isinstance(
273+
chunk, Float8Tensor
274+
), "Expecting all chunks to be of type Float8Tensor"
275+
assert (
276+
chunk._orig_dtype == orig_dtype
277+
), "Expecting all chunks to be of the same dtype"
278+
assert (
279+
chunk._scale is scale
280+
), "Expecting all chunks to have thee same scale as a result of a split"
281+
assert (
282+
chunk._linear_mm_config is mm_config
283+
), "Expecting all chunks to have thee same mm config as a result of a split"
284+
assert (
285+
chunk._data.dtype == fp8_dtype
286+
), "Expecting all chunks to be of the same dtype as a result of a split"
287+
assert (
288+
chunk._gemm_input_role is gemm_input_role
289+
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
281290
_assert_tensorwise_scale(aten_op, chunk._scale)
282291
chunk_data.append(chunk._data.view(torch.uint8))
283292

@@ -320,9 +329,9 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
320329
)
321330

322331
if scaled_mm_config.pad_inner_dim:
323-
assert a._data.size(1) == b._data.size(0), (
324-
f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
325-
)
332+
assert a._data.size(1) == b._data.size(
333+
0
334+
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
326335
a_data = pad_tensor_for_matmul(a_data, dims=1)
327336
b_data = pad_tensor_for_matmul(b_data, dims=0)
328337

@@ -353,10 +362,10 @@ def float8_mm(aten_op, args, kwargs=None):
353362
a = args[0]
354363
b = args[1]
355364

356-
assert isinstance(a, Float8Tensor) and isinstance(b, Float8Tensor), (
357-
"Expecting both Float8Tensor for mm inputs but found {} and {}".format(
358-
type(a), type(b)
359-
)
365+
assert isinstance(a, Float8Tensor) and isinstance(
366+
b, Float8Tensor
367+
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
368+
type(a), type(b)
360369
)
361370
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
362371
output_dtype = a._orig_dtype
@@ -434,9 +443,9 @@ def autocast_to_copy(aten_op, args, kwargs=None):
434443
"""
435444
_assert_tensorwise_scale(aten_op, args[0]._scale)
436445
assert isinstance(args[0], Float8Tensor)
437-
assert len(kwargs) == 1 and "dtype" in kwargs, (
438-
"Only support dtype kwarg for autocast"
439-
)
446+
assert (
447+
len(kwargs) == 1 and "dtype" in kwargs
448+
), "Only support dtype kwarg for autocast"
440449
assert kwargs["dtype"] in {
441450
torch.float16,
442451
torch.bfloat16,
@@ -462,9 +471,9 @@ def allgather_fp8(aten_op, args, kwargs=None):
462471
"""
463472
_assert_tensorwise_scale(aten_op, args[0]._scale)
464473
fp8_input = args[0]
465-
assert isinstance(fp8_input, Float8Tensor), (
466-
f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
467-
)
474+
assert isinstance(
475+
fp8_input, Float8Tensor
476+
), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
468477

469478
fp8_data = fp8_input._data
470479
fp8_data = fp8_data.contiguous()
@@ -536,21 +545,21 @@ def copy_fp8(aten_op, args, kwargs=None):
536545
return aten_op(self, src_hp, *args[2:], **kwargs)
537546
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
538547
_assert_tensorwise_scale(aten_op, src._scale)
539-
assert self._orig_dtype == src._orig_dtype, (
540-
"Expecting both Float8Tensors to be of the same dtype"
541-
)
542-
assert self._scale == src._scale, (
543-
"Expecting both Float8Tensors to have thee same scale"
544-
)
545-
assert self._linear_mm_config == src._linear_mm_config, (
546-
"Expecting both Float8Tensors to have thee same mm config"
547-
)
548-
assert self._data.dtype == src._data.dtype, (
549-
"Expecting both Float8Tensors to be of the same dtypet"
550-
)
551-
assert self._gemm_input_role == src._gemm_input_role, (
552-
"Expecting both Float8Tensors to have the same gemm_input_role"
553-
)
548+
assert (
549+
self._orig_dtype == src._orig_dtype
550+
), "Expecting both Float8Tensors to be of the same dtype"
551+
assert (
552+
self._scale == src._scale
553+
), "Expecting both Float8Tensors to have thee same scale"
554+
assert (
555+
self._linear_mm_config == src._linear_mm_config
556+
), "Expecting both Float8Tensors to have thee same mm config"
557+
assert (
558+
self._data.dtype == src._data.dtype
559+
), "Expecting both Float8Tensors to be of the same dtypet"
560+
assert (
561+
self._gemm_input_role == src._gemm_input_role
562+
), "Expecting both Float8Tensors to have the same gemm_input_role"
554563
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
555564
return Float8Tensor(
556565
fp8_out,

0 commit comments

Comments
 (0)