Skip to content

Commit 8525185

Browse files
authored
[Float8] add non-decomposed version of quantize/dequantize ops for fp8 (#2961)
* register fp8 quant/dequant only on CPU * add non-decomposed quantize_affine_float8 and dequantize_affine_float8
1 parent 22819f4 commit 8525185

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,48 @@ def test_preprocess_scale_3d_reshape(self):
733733
expected_shape = (8, 1) # Flattened (2*2*2, 1)
734734
self.assertEqual(result.shape, expected_shape)
735735

736+
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
737+
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
738+
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
739+
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8_non_decomposed
740+
dequantize_affine_float8 = (
741+
torch.ops.torchao.dequantize_affine_float8_non_decomposed
742+
)
743+
input = torch.randn(10, 10)
744+
with torch.no_grad():
745+
torch._dynamo.reset()
746+
expected_scale = torch.tensor(2.0)
747+
expected_quantized = quantize_affine_float8(
748+
input,
749+
expected_scale,
750+
float8_dtype=float8_dtype,
751+
)
752+
expected_dequantized = dequantize_affine_float8(
753+
expected_quantized,
754+
expected_scale,
755+
output_dtype=hp_dtype,
756+
)
757+
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
758+
torch.compile(quantize_affine_float8),
759+
input,
760+
expected_scale,
761+
float8_dtype=float8_dtype,
762+
)
763+
torch.testing.FileCheck().check(f"{quantize_affine_float8}.default").run(
764+
code_q
765+
)
766+
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
767+
torch.compile(dequantize_affine_float8),
768+
test_q,
769+
expected_scale,
770+
hp_dtype,
771+
)
772+
torch.testing.FileCheck().check(f"{dequantize_affine_float8}.default").run(
773+
code_dq
774+
)
775+
torch.testing.assert_close(expected_quantized, test_q)
776+
torch.testing.assert_close(expected_dequantized, test_dq)
777+
736778
@torch.no_grad()
737779
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
738780
@unittest.skipIf(

torchao/quantization/quant_primitives.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,8 +2310,6 @@ def _quantize_affine_float8(
23102310
return _RoundToFloat8.apply(tensor_clamped, float8_dtype)
23112311

23122312

2313-
# TODO: don't register as custom op?
2314-
@_register_custom_op(quant_lib, False)
23152313
def _dequantize_affine_float8(
23162314
tensor: torch.Tensor,
23172315
scale: torch.Tensor,
@@ -2329,7 +2327,48 @@ def _dequantize_affine_float8(
23292327
return hp_tensor.to(output_dtype)
23302328

23312329

2332-
@_register_meta_op(quant_lib, "dequantize_affine_float8")
2330+
@_register_custom_op(quant_lib, False)
2331+
def _quantize_affine_float8_non_decomposed(
2332+
tensor: torch.Tensor,
2333+
scale: torch.Tensor,
2334+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
2335+
) -> torch.Tensor:
2336+
"""
2337+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
2338+
"""
2339+
return _quantize_affine_float8(
2340+
tensor=tensor,
2341+
scale=scale,
2342+
float8_dtype=float8_dtype,
2343+
)
2344+
2345+
2346+
@_register_meta_op(quant_lib, "quantize_affine_float8_non_decomposed")
2347+
def _quantize_affine_float8_meta(
2348+
tensor: torch.Tensor,
2349+
scale: torch.Tensor,
2350+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
2351+
) -> torch.Tensor:
2352+
return torch.empty_like(tensor, dtype=float8_dtype)
2353+
2354+
2355+
@_register_custom_op(quant_lib, False)
2356+
def _dequantize_affine_float8_non_decomposed(
2357+
tensor: torch.Tensor,
2358+
scale: torch.Tensor,
2359+
output_dtype: torch.dtype = torch.float32,
2360+
) -> torch.Tensor:
2361+
"""
2362+
Dequantizes the float8 tensor to high precision tensor.
2363+
"""
2364+
return _dequantize_affine_float8(
2365+
tensor=tensor,
2366+
scale=scale,
2367+
output_dtype=output_dtype,
2368+
)
2369+
2370+
2371+
@_register_meta_op(quant_lib, "dequantize_affine_float8_non_decomposed")
23332372
def _dequantize_affine_float8_meta(
23342373
tensor: torch.Tensor,
23352374
scale: torch.Tensor,

0 commit comments

Comments
 (0)