Skip to content

[torchlib] Implement quantize_per_channel and dequantize_per_channel #2389

@justinchuby

Description

@justinchuby

Implement both ops in onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py. Here is the reference

@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
def quantize_per_channel(
    input: torch.Tensor,
    scales: torch.Tensor,
    zero_points: torch.Tensor,
    axis: int,
    quant_min: int,
    quant_max: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Affine per channel quantization for the Tensor using the same quantization
    parameters for each channel/axis to map from floating point to quantized values

    Args:
       input (torch.Tensor): original float32 or bfloat16 Tensor
       scales (torch.Tensor): a list of scale quantization parameter for
       affine quantization, one per channel
       zero_point (torch.Tensor): a list of zero_point quantization parameter for
       affine quantization, one per channel
       quant_min (int): minimum quantized value for output Tensor
       quant_max (int): maximum quantized value for output Tensor
       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor

    Returns:
       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
       are not stored in the Tensor, we are storing them in function arguments instead
    """
    if input.dtype in [torch.float16, torch.bfloat16]:
        input = input.to(torch.float32)
    assert (
        input.dtype == torch.float32
    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
    input, permute_axis_list = _permute_to_axis_zero(input, axis)

    new_shape = [1] * input.dim()
    new_shape[0] = scales.shape[0]
    scales = scales.view(new_shape)
    zero_points = zero_points.view(new_shape)

    res = torch.clamp(
        torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max
    )
    out = res.permute(tuple(permute_axis_list))
    return out.to(dtype)


@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
def quantize_per_channel_meta(
    input: torch.Tensor,
    scales: torch.Tensor,
    zero_points: torch.Tensor,
    axis: int,
    quant_min: int,
    quant_max: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    if input.dtype in [torch.float16, torch.bfloat16]:
        input = input.to(torch.float32)
    assert (
        input.dtype == torch.float32
    ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
    return torch.empty_like(input, dtype=dtype)


# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
# the signature as metadata for the input Tensor, this might be useful for pattern
# matching in the future
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
    "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, "
    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
)

quantized_decomposed_lib.define(
    "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, "
    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
)


@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
def dequantize_per_channel(
    input: torch.Tensor,
    scales: torch.Tensor,
    zero_points: Optional[torch.Tensor],
    axis: int,
    quant_min: int,
    quant_max: int,
    dtype: torch.dtype,
    *,
    out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    """Affine per channel dequantization for the Tensor using the same quantization
    parameters for each channel/axis to map from quantized values to floating point values

    Args:
       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
       e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
       quantization parameter in the argument of this function (scales/zero_points/axis)

       scales (torch.Tensor): a list of scale quantization parameter for
       affine quantization, one per channel

       zero_points (torch.Tensor): a list of zero_point quantization parameter for
       affine quantization, one per channel

       quant_min (int): minimum quantized value for output Tensor (not used in computation,
       reserved for pattern matching)

       quant_max (int): maximum quantized value for output Tensor (not used in computation,
       reserved for pattern matching)

       dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
       reserved for pattern matching)

       out_dtype (torch.dtype?): optional dtype for output Tensor

    Returns:
       dequantized float32 Tensor
    """
    assert (
        input.dtype == dtype
    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
    if out_dtype is None:
        out_dtype = torch.float32
    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
    input, permute_axis_list = _permute_to_axis_zero(input, axis)

    new_shape = [1] * input.dim()
    new_shape[0] = scales.shape[0]
    scales = scales.view(new_shape)
    if zero_points is not None:
        res = (input - zero_points.view(new_shape)) * scales
    else:
        res = input * scales

    res = res.to(out_dtype)

    out = res.permute(tuple(permute_axis_list))
    return out


@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
def dequantize_per_channel_meta(
    input: torch.Tensor,
    scales: torch.Tensor,
    zero_points: Optional[torch.Tensor],
    axis: int,
    quant_min: int,
    quant_max: int,
    dtype: torch.dtype,
    *,
    out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    assert (
        input.dtype == dtype
    ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
    if out_dtype is None:
        out_dtype = torch.float32
    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
    return torch.empty_like(input, dtype=out_dtype)

from https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py.

Here is the reference from onnx op documentation

DequantizeLinear - 23
Version
name: [DequantizeLinear (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DequantizeLinear)

domain: main

since_version: 23

function: False

support_level: SupportType.COMMON

shape inference: True

This version of the operator has been available since version 23.

Summary
The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full-precision tensor. The dequantization formula is y = (x - x_zero_point) * x_scale. x_scale and x_zero_point must have the same shape, determining the quantizations granularity: a scalar for per-tensor/per-layer quantization, a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization. See QuantizeLinear for details on quantization granularity.

x_zero_point and x must have the same type. x and y must have the same shape. In the case of dequantizing int32, theres no zero point (zero point is supposed to be 0). zero-point is usually not used in the case of float8 and 4-bit types quantization, but the dequantization formula remains the same for consistency. The output type is determined by the attribute output_dtype. If output_dtype is not supplied then the output type is the same as x_scale. The output type also determines the precision of the multiplication operation.

Attributes
axis - INT (default is '1'):

(Optional) The axis of the dequantizing dimension of the input tensor. Used for per-axis and blocked quantization. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input).

block_size - INT (default is '0'):

(Optional) The size of the quantization block (number of times every scale is replicated). Used only for blocked quantization. The block size is a positive integer. Given x shape (D0, ..., Di, ..., Dn), y_scale shape (S0, ... Si, ...Sn) and axis=i, the accepted range is [ceil(Di/Si), ceil(Di/(Si-1))-1]

output_dtype - INT (default is '0'):

(Optional) The output data type. If not supplied, the output data type is inferred from x_scale data type (T2)

Inputs
Between 2 and 3 inputs.

x (heterogeneous) - T1:

N-D quantized input tensor to be de-quantized.

x_scale (heterogeneous) - T2:

Scale for input x. For per-tensor/layer dequantization the scale is a scalar, for per per-axis dequantization it is a 1-D Tensor and for blocked dequantization it has the same shape as the input, except for one dimension in which blocking is performed.

x_zero_point (optional, heterogeneous) - T1:

Zero point for input x. Shape must match x_scale. Its optional. Zero point is 0 when its not specified.

Outputs
y (heterogeneous) - T3:

N-D full precision output tensor. It has the same shape as input x. The data type is specified by the output_dtype attribute or, in its absence, the type of x_scale.

Type Constraints
T1 in ( tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8) ):

The type of the inputsx_zero_pointandx’.

T2 in ( tensor(bfloat16), tensor(float), tensor(float16) ):

The type of the inputx_scale’.

T3 in ( tensor(bfloat16), tensor(float), tensor(float16) ):

The type of the outputy’.
QuantizeLinear - 23
Version
name: [QuantizeLinear (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QuantizeLinear)

domain: main

since_version: 23

function: False

support_level: SupportType.COMMON

shape inference: True

This version of the operator has been available since version 23.

Summary
The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization granularity. The quantization formula is y = saturate((x / y_scale) + y_zero_point).

Saturation is done according to:

uint16: [0, 65535]

int16: [-32768, 32767]

uint8: [0, 255]

int8: [-128, 127]

uint4: [0, 15]

int4: [-8, 7]

For (x / y_scale), it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.

y_zero_point and y must have the same type. y_zero_point is usually not used for quantization to float8 and 4bit types, but the quantization formula remains the same for consistency, and the type of the attribute y_zero_point still determines the quantization type. x and y_scale are allowed to have different types. The type of y_scale determines the precision of the division operation between x and y_scale, unless the precision attribute is specified.

There are three supported quantization granularities, determined by the shape of y_scale. In all cases, y_zero_point must have the same shape as y_scale.

Per-tensor (per-layer) quantization: y_scale is a scalar.

Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape (D0, ..., Di, ..., Dn) and axis=i, y_scale is a 1-D tensor of length Di.

Blocked quantization: The scale’s shape is identical to the input’s shape, except for one dimension, in which blocking is performed. Given x shape (D0, ..., Di, ..., Dn), axis=i, and block size B: y_scale shape is (D0, ..., ceil(Di/B), ..., Dn).

Attributes
axis - INT (default is '1'):

(Optional) The axis of the dequantizing dimension of the input tensor. Used only for per-axis and blocked quantization. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input). When the rank of the input is 1, per-tensor quantization is applied, rendering the axis unnecessary in this scenario.

block_size - INT (default is '0'):

(Optional) The size of the quantization block (number of times every scale is replicated). Used only for blocked quantization. The block size is a positive integer. Given x shape (D0, ..., Di, ..., Dn), y_scale shape (S0, ... Si, ...Sn) and axis=i, the accepted range is [ceil(Di/Si), ceil(Di/(Si-1))-1]

output_dtype - INT (default is '0'):

(Optional) The output data type. If not supplied, the output data type is inferred from y_zero_point data type (T3). If neither output_dtype nor y_zero_point are supplied, output data type is uint8. If both output_dtype and y_zero_point are specified, output_dtype must be T3.

precision - INT (default is '0'):

(Optional) The precision of the division operation between x and y_scale. If not provided, it will be the same as the type of y_scale.

saturate - INT (default is '1'):

The parameter defines how the conversion behaves if an input value is out of range of the destination type. It only applies for float 8 quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). It is true by default. All cases are fully described in two tables inserted in the operator description.

Inputs
Between 2 and 3 inputs.

x (heterogeneous) - T1:

N-D full precision Input tensor to be quantized.

y_scale (heterogeneous) - T2:

Scale for doing quantization to get y. For per-tensor/layer quantization the scale is a scalar, for per-axis quantization it is a 1-D Tensor and for blocked quantization it has the same shape as the input, except for one dimension in which blocking is performed.

y_zero_point (optional, heterogeneous) - T3:

Zero point for doing quantization to get y. Shape must match y_scale.Default is uint8 with zero point of 0 if it’s not specified.

Outputs
y (heterogeneous) - T3:

N-D quantized output tensor. It has same shape as input x.

Type Constraints
T1 in ( tensor(bfloat16), tensor(float), tensor(float16), tensor(int32) ):

The type of the input ‘x’.

T2 in ( tensor(bfloat16), tensor(float), tensor(float16), tensor(int32) ):

The type of the input ‘y_scale’.

T3 in ( tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8) ):

The type of the input y_zero_point and the output y.

Metadata

Metadata

Assignees

Labels

module: torchlibRelated to the torch/aten function lib in development

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions