Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 74 additions & 7 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@impl(m, "quantize_per_tensor")
def quantize_per_tensor(
input: torch.Tensor,
input_tensor: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
Expand All @@ -35,10 +35,10 @@ def quantize_per_tensor(
Quantizes a floating-point tensor to an integral tensor.

Args:
- input (Tensor): input tensor
- scale (float): Quantization scale. Derived from the ratio
- input_tensor (Tensor): input tensor
- scale (float): Inverse of quantization scale. Derived from the ratio
between the min/max of the floating-point tensor and the
min/max of the quantized range.
min/max of the quantized range, and then inverted.
- zero_point (int): The point which represents 0 in the quantized
range. For example, consider the floating point range [-1., 2.] and
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
Expand All @@ -61,7 +61,12 @@ def quantize_per_tensor(
raise ValueError(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
)
return torch.round(input / scale + zero_point).to(dtype)

dequantized = torch.round(input_tensor * scale + zero_point).to(dtype)
return torch.max(
torch.min(dequantized, torch.tensor(quant_max)),
torch.tensor(quant_min),
)


@impl(m, "dequantize_per_tensor")
Expand Down Expand Up @@ -173,9 +178,16 @@ def quantized_add(
dequant_X = X_scale * (X - X_zero_point)
dequant_Y = Y_scale * (Y - Y_zero_point)

out_scale_inv = 1 / out_scale

# q_min/q_max are unused args
return quantize_per_tensor(
dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype
dequant_X + dequant_Y,
out_scale_inv,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
)


Expand Down Expand Up @@ -206,6 +218,7 @@ def quantized_linear(
- offset (Tensor): Unused
"""
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
out_scale_inv = 1 / out_scale

N, K = weight.shape

Expand All @@ -223,10 +236,64 @@ def quantized_linear(
src - in_zero_point, weight - weight_zero_point, bias
)
return quantize_per_tensor(
out, out_scale, out_zero_point, -128, 127, dtype
out,
out_scale_inv,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
).reshape(*leading_dims, N)


@impl(m, "quantized_layer_norm_per_tensor")
def quantized_layer_norm_per_tensor(
input_tensor: torch.Tensor,
X_scale: float,
X_zero_point: int,
normalized_shape: int,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
output_scale: float,
output_zero_point: int,
) -> torch.Tensor:
"""
Quantized layer norm operation.

Args:
- input_tensor (Tensor): The activations tensor
- X_scale (float): The scale of the input
- X_zero_point (int): The zero point of the input
- normalized_shape (int): The shape of the input
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- eps (float): The epsilon value
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
"""
supported_dtypes = [torch.int8, torch.uint8]
if input_tensor.dtype not in supported_dtypes:
raise ValueError(
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
)

float_input_tensor = dequantize_per_tensor(
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
)
out = torch.nn.functional.layer_norm(
float_input_tensor, (normalized_shape,), weight, bias, eps=eps
)

return quantize_per_tensor(
out,
1 / output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
100 changes: 98 additions & 2 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
dequantize_per_tensor,
quantize_per_tensor,
quantized_add,
quantized_layer_norm_per_tensor,
quantized_linear,
)
from executorch.backends.cadence.aot.typing_stubs import expand
Expand All @@ -40,11 +41,12 @@ def test_quantize_per_tensor(
) -> None:
input_tensor = torch.tensor([input_value])
scale = (f_max - f_min) / (q_max - q_min)
zero_point = round(-f_min / scale) + q_min
inv_scale = 1.0 / scale
zero_point = round(-f_min * inv_scale) + q_min
expected_output = torch.tensor([expected_value], dtype=target_dtype)

output = quantize_per_tensor(
input_tensor, scale, zero_point, q_min, q_max, target_dtype
input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype
)

self.assertEqual(
Expand Down Expand Up @@ -239,3 +241,97 @@ def test_quantized_linear(
torch.equal(output, expected_output),
f"Values don't match: got {output}, expected {expected_output}",
)

@expand(
[
# Test case 1: Simple case with int8, zero mean input
(
torch.tensor(
[[-1, 1]], dtype=torch.int8
), # input: dequantized to [-0.1, 0.1]
0.1, # X_scale
0, # X_zero_point
2, # normalized_shape (last dimension)
torch.tensor([1.0, 1.0]), # weight
torch.tensor([0.0, 0.0]), # bias
1e-5, # eps
0.1, # output_scale
0, # output_zero_point
torch.int8, # dtype
torch.tensor([[-10, 10]], dtype=torch.int8), # expected_output
),
# Test case 2: uint8 with zero_point offset
(
torch.tensor(
[[127, 129]], dtype=torch.uint8
), # input: dequantized to [-0.05, 0.05]
0.05, # X_scale
128, # X_zero_point
2, # normalized_shape (last dimension)
torch.tensor([1.0, 1.0]), # weight
torch.tensor([0.0, 0.0]), # bias
1e-5, # eps
0.05, # output_scale
128, # output_zero_point
torch.uint8, # dtype
torch.tensor([[108, 148]], dtype=torch.uint8), # expected_output
),
# Test case 3: Test with weight and bias scaling
(
torch.tensor(
[[-2, 2]], dtype=torch.int8
), # input: dequantized to [-0.2, 0.2]
0.1, # X_scale
0, # X_zero_point
2, # normalized_shape (last dimension)
torch.tensor(
[2.0, 0.5]
), # weight: scale first element by 2, second by 0.5
torch.tensor(
[0.1, -0.1]
), # bias: add 0.1 to first, subtract 0.1 from second
1e-5, # eps
0.1, # output_scale
0, # output_zero_point
torch.int8, # dtype
torch.tensor([[-19, 4]], dtype=torch.int8), # expected_output
),
]
)
def test_quantized_layer_norm_per_tensor(
self,
input_tensor: torch.Tensor,
X_scale: float,
X_zero_point: int,
normalized_shape: int,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
output_scale: float,
output_zero_point: int,
dtype: torch.dtype,
expected_output: torch.Tensor,
) -> None:
output = quantized_layer_norm_per_tensor(
input_tensor,
X_scale,
X_zero_point,
normalized_shape,
weight,
bias,
eps,
output_scale,
output_zero_point,
)

# Verify output properties
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
self.assertEqual(
output.shape, input_tensor.shape, "Output shape should match input shape"
)

# Verify output matches expected values
self.assertTrue(
torch.equal(output, expected_output),
f"Output values don't match expected. Got {output}, expected {expected_output}",
)
Loading