Skip to content

Error in _adjust_weight_scale_for_int32_bias when quantizing mnist-12.onnx #24815

Open
@kuroihacker

Description

@kuroihacker

Describe the issue
An error is encountered in the ONNX Runtime quantization code, specifically in _adjust_weight_scale_for_int32_bias:

File "/home/.pyenv/versions/dev/lib/python3.10/site-packages/onnxruntime/quantization/qdq_quantizer.py", line 478, in _adjust_weight_scale_for_int32_bias
    if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):

This occurs when quantizing the MNIST model from ONNX model zoo:
https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mnist/model/mnist-12.onnx

Notably, the model must be preprocessed with onnxsim.simplify before quantization, as this introduces the GEMM op, which triggered the error. If the quantization config option QDQDisableWeightAdjustForInt32Bias is enabled, the error disappears. It seems likely there is an issue with handling singleton dimension bias during weight scale adjustment.

To reproduce

Model:
https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mnist/model/mnist-12.onnx

Source code for reproduction:

import numpy as np
import onnxruntime
import typing
from onnxruntime.quantization import quantize, StaticQuantConfig, QuantType, CalibrationMethod
import onnx
import onnxsim

class RandomDataReader(onnxruntime.quantization.calibrate.CalibrationDataReader):
    """Provides random data for calibration."""

    def __init__(
        self,
        input_value_name: str,
        input_shape: typing.Tuple[int, ...],
        num_samples: int = 100,
        input_mean: float = 0.0,
        input_std: float = 1.0,
    ):
        super().__init__()
        self._input_value_name = input_value_name
        self._input_shape = input_shape
        self._num_samples = num_samples
        self._input_mean = input_mean
        self._input_std = input_std
        self._current_index = 0

    def __len__(self):
        return self._num_samples

    def get_next(self):
        if self._current_index >= self._num_samples:
            return None

        # Generate random data between 0 and 255 (like MNIST pixel values)
        random_data = np.random.randint(0, 256, self._input_shape, dtype=np.uint8)
        random_data = np.expand_dims(random_data, (0, 1))

        # Normalized float image
        float_data = random_data.astype(np.float32) / 255.0
        float_data = (float_data - self._input_mean) / self._input_std

        result = {self._input_value_name: float_data}
        self._current_index += 1
        return result

    def rewind(self):
        self._current_index = 0

random_data_reader = RandomDataReader(
    input_value_name="Input3",
    input_shape=(28, 28),
    num_samples=1,
    input_mean=0.0,
    input_std=1.0,
)

quant_config = StaticQuantConfig(
        random_data_reader,
        quant_format=onnxruntime.quantization.QuantFormat.QDQ,
        activation_type=QuantType.QUInt8,
        per_channel=True,
        calibrate_method=CalibrationMethod.MinMax,
        op_types_to_quantize=["Conv", "Gemm", "Relu", "MaxPool"],
        #extra_options={"QDQDisableWeightAdjustForInt32Bias": True},
        extra_options=None,
        )

orig_model = onnx.load("mnist-12.onnx")
model, simplify_result = onnxsim.simplify(orig_model)
if not simplify_result:
    raise AssertionError("Failed to simplify with onnxsim")

onnx.save(model, "/tmp/mnist-12-preprocessed.onnx")

quantize("/tmp/mnist-12-preprocessed.onnx", "mnist-12q.onnx", quant_config)

System information

  • Platform: Linux
  • OS Version: Ubuntu 22.04
  • ONNX Runtime Installation: Released Package
  • ONNX Runtime Version or Commit ID: 1.20.1
  • ONNX Runtime API: Python
  • Architecture: X64
  • Execution Provider: Default CPU
  • Execution Provider Library Version: N/A

Metadata

Metadata

Assignees

No one assigned

    Labels

    quantizationissues related to quantization

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions