-
Notifications
You must be signed in to change notification settings - Fork 72
[torchlib] Implement quantize_per_channel and dequantize_per_channel #2390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,9 @@ | |
from onnxscript.function_libs.torch_lib.ops import common | ||
from onnxscript.function_libs.torch_lib.registration import torch_op | ||
from onnxscript.onnx_opset import opset18 as op | ||
from onnxscript.onnx_opset import opset23 as op23 | ||
from onnxscript.onnx_types import TensorType | ||
from typing import Optional | ||
|
||
|
||
@torch_op( | ||
|
@@ -61,3 +63,61 @@ | |
return dequantized | ||
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" | ||
return op.Cast(dequantized, to=out_dtype) | ||
|
||
|
||
@torch_op( | ||
( | ||
"quantized_decomposed::quantize_per_channel", | ||
"quantized_decomposed::quantize_per_channel.tensor", | ||
"quantized_decomposed::quantize_per_channel.tensor2", | ||
), | ||
trace_only=True, | ||
) | ||
def quantized_decomposed_quantize_per_channel( | ||
input: TensorType, | ||
scales: TensorType, | ||
zero_points: TensorType, | ||
axis: int, | ||
quant_min: int, | ||
quant_max: int, | ||
Comment on lines
+82
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these unused? Why is that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These parameters are unused because the ONNX QuantizeLinear/DequantizeLinear operators don't require explicit quant_min/quant_max parameters - they determine the quantization range from the data type and quantization parameters. The parameters are kept in the function signature for API compatibility with PyTorch's reference implementation, following the same pattern as the existing per-tensor functions above. |
||
dtype: int, | ||
) -> TensorType: | ||
"""Affine per channel quantization for the Tensor using the same quantization | ||
parameters for each channel/axis to map from floating point to quantized values. | ||
|
||
|
||
Uses ONNX QuantizeLinear with per-axis quantization support. | ||
""" | ||
# Use opset23 for per-axis quantization support | ||
return op23.QuantizeLinear(input, scales, zero_points, axis=axis, output_dtype=dtype) | ||
|
||
|
||
@torch_op( | ||
( | ||
"quantized_decomposed::dequantize_per_channel", | ||
"quantized_decomposed::dequantize_per_channel.tensor", | ||
"quantized_decomposed::dequantize_per_channel.tensor2", | ||
), | ||
trace_only=True, | ||
) | ||
def quantized_decomposed_dequantize_per_channel( | ||
input: TensorType, | ||
scales: TensorType, | ||
zero_points: Optional[TensorType], | ||
axis: int, | ||
quant_min: int, | ||
quant_max: int, | ||
dtype: int, | ||
out_dtype: int = -1, | ||
) -> TensorType: | ||
"""Affine per channel dequantization for the Tensor using the same quantization | ||
parameters for each channel/axis to map from quantized values to floating point values. | ||
|
||
|
||
Uses ONNX DequantizeLinear with per-axis quantization support. | ||
""" | ||
# Use opset23 for per-axis quantization support with optional output_dtype | ||
if out_dtype in (-1, None): | ||
# Use default output type (same as scales type) | ||
return op23.DequantizeLinear(input, scales, zero_points, axis=axis) | ||
else: | ||
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" | ||
return op23.DequantizeLinear(input, scales, zero_points, axis=axis, output_dtype=out_dtype) | ||
Uh oh!
There was an error while loading. Please reload this page.