Skip to content

Commit

Permalink
[Quant] [PT2] Decomposed quant per tensor/channel accept bfloat16 input
Browse files Browse the repository at this point in the history
ghstack-source-id: 68b017b5a02de389cc8e238c4b71089c8051ac18
Pull Request resolved: #112225
  • Loading branch information
leslie-fang-intel committed Oct 27, 2023
1 parent 0929ae1 commit 0de4a7a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
30 changes: 30 additions & 0 deletions test/quantization/core/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,18 @@ def test_decomposed_quantize_per_tensor(self):
self.assertEqual(quantized_decomposed_X.dtype, dtype)
self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)

def test_decomposed_quantize_per_tensor_bfloat16_input(self):
# register the ops
import torch.ao.quantization.fx._decomposed
X = torch.randint(1, 10, (5, 5)).to(torch.float32)
scale, zero_point = _calculate_dynamic_qparams(X, torch.quint8)
quantized_X = torch.quantize_per_tensor(X, scale, zero_point, torch.quint8)
quantized_decomposed_X = \
torch.ops.quantized_decomposed.quantize_per_tensor(
X.to(torch.bfloat16), scale, zero_point, 0, 255, torch.uint8)
self.assertEqual(quantized_decomposed_X.dtype, torch.uint8)
self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)

def test_decomposed_dequantize_per_tensor(self):
import torch.ao.quantization.fx._decomposed
X = torch.randn(5, 10)
Expand Down Expand Up @@ -1541,6 +1553,24 @@ def test_decomposed_quantize_per_channel(self):
self.assertEqual(quantized_decomposed_X.dtype, dtype)
self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)

def test_decomposed_quantize_per_channel_bfloat16_input(self):
# register the ops
import torch.ao.quantization.fx._decomposed
X = torch.randint(1, 10, (5, 5)).to(torch.float32)
qdtype = torch.quint8
dtype = torch.uint8
scales = torch.randn(5,)
zero_points = torch.randint(0, 100, (5,))
quant_min, quant_max = 0, 255
axis = 0

quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype)
quantized_decomposed_X = \
torch.ops.quantized_decomposed.quantize_per_channel(
X.to(torch.bfloat16), scales, zero_points, axis, quant_min, quant_max, dtype)
self.assertEqual(quantized_decomposed_X.dtype, dtype)
self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)

def test_decomposed_dequantize_per_channel(self):
# register the ops
import torch.ao.quantization.fx._decomposed
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def quantize_per_tensor_default_decomp_impl(
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
Expand Down Expand Up @@ -452,6 +454,8 @@ def quantize_per_tensor_tensor_decomp_impl(
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
Expand Down
10 changes: 8 additions & 2 deletions torch/ao/quantization/fx/_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def quantize_per_tensor(
from floating point to quantized values
Args:
input (torch.Tensor): original float32 Tensor
input (torch.Tensor): original float32 or bfloat16 Tensor
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
quant_min (int): minimum quantized value for output Tensor
Expand All @@ -57,6 +57,9 @@ def quantize_per_tensor(
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)

assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)

Expand Down Expand Up @@ -355,7 +358,7 @@ def quantize_per_channel(
parameters for each channel/axis to map from floating point to quantized values
Args:
input (torch.Tensor): original float32 Tensor
input (torch.Tensor): original float32 or bfloat16 Tensor
scales (torch.Tensor): a list of scale quantization parameter for
affine quantization, one per channel
zero_point (torch.Tensor): a list of zero_point quantization parameter for
Expand All @@ -368,6 +371,9 @@ def quantize_per_channel(
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)

assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
Expand Down

0 comments on commit 0de4a7a

Please sign in to comment.