Skip to content

[torchlib] Implement quantize_per_channel and dequantize_per_channel #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from onnxscript.function_libs.torch_lib.ops import common
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_opset import opset23 as op23
from onnxscript.onnx_types import TensorType
from typing import Optional


@torch_op(
Expand Down Expand Up @@ -61,3 +63,61 @@
return dequantized
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
return op.Cast(dequantized, to=out_dtype)


@torch_op(
(
"quantized_decomposed::quantize_per_channel",
"quantized_decomposed::quantize_per_channel.tensor",
"quantized_decomposed::quantize_per_channel.tensor2",
),
trace_only=True,
)
def quantized_decomposed_quantize_per_channel(
input: TensorType,
scales: TensorType,
zero_points: TensorType,
axis: int,
quant_min: int,
quant_max: int,
Comment on lines +82 to +83
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these unused? Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters are unused because the ONNX QuantizeLinear/DequantizeLinear operators don't require explicit quant_min/quant_max parameters - they determine the quantization range from the data type and quantization parameters. The parameters are kept in the function signature for API compatibility with PyTorch's reference implementation, following the same pattern as the existing per-tensor functions above.

dtype: int,
) -> TensorType:
"""Affine per channel quantization for the Tensor using the same quantization
parameters for each channel/axis to map from floating point to quantized values.

Uses ONNX QuantizeLinear with per-axis quantization support.
"""
# Use opset23 for per-axis quantization support
return op23.QuantizeLinear(input, scales, zero_points, axis=axis, output_dtype=dtype)

Check warning on line 91 in onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py#L91

Added line #L91 was not covered by tests


@torch_op(
(
"quantized_decomposed::dequantize_per_channel",
"quantized_decomposed::dequantize_per_channel.tensor",
"quantized_decomposed::dequantize_per_channel.tensor2",
),
trace_only=True,
)
def quantized_decomposed_dequantize_per_channel(
input: TensorType,
scales: TensorType,
zero_points: Optional[TensorType],
axis: int,
quant_min: int,
quant_max: int,
dtype: int,
out_dtype: int = -1,
) -> TensorType:
"""Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values.

Uses ONNX DequantizeLinear with per-axis quantization support.
"""
# Use opset23 for per-axis quantization support with optional output_dtype
if out_dtype in (-1, None):
# Use default output type (same as scales type)
return op23.DequantizeLinear(input, scales, zero_points, axis=axis)

Check warning on line 120 in onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py#L120

Added line #L120 was not covered by tests
else:
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
return op23.DequantizeLinear(input, scales, zero_points, axis=axis, output_dtype=out_dtype)

Check warning on line 123 in onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py#L122-L123

Added lines #L122 - L123 were not covered by tests
Loading