diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index 96d5cea156af3..4e83b044a7239 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -1478,6 +1478,18 @@ def test_decomposed_quantize_per_tensor(self): self.assertEqual(quantized_decomposed_X.dtype, dtype) self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_quantize_per_tensor_bfloat16_input(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randint(1, 10, (5, 5)).to(torch.float32) + scale, zero_point = _calculate_dynamic_qparams(X, torch.quint8) + quantized_X = torch.quantize_per_tensor(X, scale, zero_point, torch.quint8) + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_tensor( + X.to(torch.bfloat16), scale, zero_point, 0, 255, torch.uint8) + self.assertEqual(quantized_decomposed_X.dtype, torch.uint8) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_dequantize_per_tensor(self): import torch.ao.quantization.fx._decomposed X = torch.randn(5, 10) @@ -1541,6 +1553,24 @@ def test_decomposed_quantize_per_channel(self): self.assertEqual(quantized_decomposed_X.dtype, dtype) self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_quantize_per_channel_bfloat16_input(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randint(1, 10, (5, 5)).to(torch.float32) + qdtype = torch.quint8 + dtype = torch.uint8 + scales = torch.randn(5,) + zero_points = torch.randint(0, 100, (5,)) + quant_min, quant_max = 0, 255 + axis = 0 + + quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype) + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_channel( + X.to(torch.bfloat16), scales, zero_points, axis, quant_min, quant_max, dtype) + self.assertEqual(quantized_decomposed_X.dtype, dtype) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_dequantize_per_channel(self): # register the ops import torch.ao.quantization.fx._decomposed diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index eb0d394f204fc..a93f8941de826 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -423,6 +423,8 @@ def quantize_per_tensor_default_decomp_impl( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) inv_scale = 1.0 / scale return torch.clamp( torch.round(input * inv_scale) + zero_point, quant_min, quant_max @@ -452,6 +454,8 @@ def quantize_per_tensor_tensor_decomp_impl( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) inv_scale = 1.0 / scale return torch.clamp( torch.round(input * inv_scale) + zero_point, quant_min, quant_max diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 08f0608fa788a..c64defa89f700 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -46,7 +46,7 @@ def quantize_per_tensor( from floating point to quantized values Args: - input (torch.Tensor): original float32 Tensor + input (torch.Tensor): original float32 or bfloat16 Tensor scale (float): quantization parameter for affine quantization zero_point (int): quantization parameter for affine quantization quant_min (int): minimum quantized value for output Tensor @@ -57,6 +57,9 @@ def quantize_per_tensor( Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters are not stored in the Tensor, we are storing them in function arguments instead """ + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) @@ -355,7 +358,7 @@ def quantize_per_channel( parameters for each channel/axis to map from floating point to quantized values Args: - input (torch.Tensor): original float32 Tensor + input (torch.Tensor): original float32 or bfloat16 Tensor scales (torch.Tensor): a list of scale quantization parameter for affine quantization, one per channel zero_point (torch.Tensor): a list of zero_point quantization parameter for @@ -368,6 +371,9 @@ def quantize_per_channel( Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters are not stored in the Tensor, we are storing them in function arguments instead """ + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" _quant_min_max_bounds_check(quant_min, quant_max, dtype)