Skip to content

Commit

Permalink
Fix missing argument when calling _get_quantize_input_nodes (microsof…
Browse files Browse the repository at this point in the history
…t#20245)

### Description
The current code is calling one method with a missing argument.



### Motivation and Context
It breaks Olive's unittests.

---------

Co-authored-by: Xavier Dupré <xavier.dupre@gmail.com>
  • Loading branch information
xadupre and sdpython committed Apr 24, 2024
1 parent a5182a2 commit 218b6b0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 9 deletions.
38 changes: 30 additions & 8 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,19 @@ def is_float_tensor(self, tensor_name):
)
return False

def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType):
def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType, initial_type):
"""
Create nodes for dynamic quantization of input and add them to nodes_list.
parameter input_name: Name of the input.
parameter nodes_list: new nodes are appended to this list.
parameter qType: type to quantize to.
parameter initial_type: type to quantize from
return: scale_name, zero_point_name, scale_shape, zero_point_shape.
"""
if qType == onnx_proto.TensorProto.INT8:
return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list)
return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list, initial_type)
if qType == onnx_proto.TensorProto.UINT8:
return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list)
if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
return self._get_dynamic_input_quantization_params_float8e4m3fn(input_name, nodes_list)
return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list, initial_type)
raise ValueError(f"Unexpected value for qType={qType}.")

def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list, initial_type):
Expand Down Expand Up @@ -559,7 +558,9 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non

return True, scale_name, zero_point_name, scale_shape, zero_point_shape

def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None):
def _get_quantize_input_nodes(
self, node, input_index, qType, given_scale_name=None, given_zp_name=None, initial_type=None
):
"""
Given an input for a node (which is not a initializer), this function
Expand All @@ -571,6 +572,7 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
:param qType: type to quantize to.
:param given_scale_name: if those inputs need to be quanitzed using this scale tensor.
:param given_zp_name: if those inputs to be quantized using this zeropoint tensor.
:param initial_type: type of the weight to quantize
:return: List of newly created nodes in NodeProto format.
"""
input_name = node.input[input_index]
Expand Down Expand Up @@ -606,12 +608,16 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
ql_node_name,
)
else:
assert initial_type is not None, (
f"Cannot quantize input without knowing the initial type, "
f"input_name={input_name!r}, input_index={input_index}, qType={qType}, node={node}"
)
(
scale_name,
zp_name,
scale_shape,
zp_shape,
) = self._get_dynamic_input_quantization_params(input_name, nodes, qType)
) = self._get_dynamic_input_quantization_params(input_name, nodes, qType, initial_type=initial_type)
qlinear_node = onnx.helper.make_node(
"QuantizeLinear",
[input_name, scale_name, zp_name],
Expand Down Expand Up @@ -794,7 +800,23 @@ def __quantize_inputs(
node_input + "_QuantizeLinear", self.new_nodes, self.model.graph()
)
if qlinear_node is None:
quantize_input_nodes = self._get_quantize_input_nodes(node, input_index, self.activation_qType)
input_name = node.input[input_index]
if input_name in self.value_infos:
value_info = self.value_infos[input_name]
assert value_info.HasField("type"), f"value_info={value_info} has no type."
assert value_info.type.HasField("tensor_type"), f"value_info={value_info} is not a tensor."
initial_type = value_info.type.tensor_type.elem_type
else:
# Shape inference failed. Fallback to self.tensor_names.
assert input_name in self.tensor_names, (
f"shape inference failed for {input_name!r} and "
f"attribute 'tensor_names' does not have any value for "
f"this tensor."
)
initial_type = self.tensor_names[input_name]
quantize_input_nodes = self._get_quantize_input_nodes(
node, input_index, self.activation_qType, initial_type=initial_type
)
if quantize_input_nodes is None:
return (None, None, None, None)
if from_subgraph:
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/quantization/operators/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def quantize(self):
self.quantizer.activation_qType,
quantized_input_value.scale_name,
quantized_input_value.zp_name,
initial_type=scale_tensor.data_type,
)
self.quantizer.new_nodes.extend(pad_value_qnodes)
node.input[2] = pad_value_qnodes[0].output[0]
Expand Down
47 changes: 46 additions & 1 deletion onnxruntime/test/python/quantization/test_op_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,52 @@ def test_qgemm_ref_uint8_specific_example(self):
got = ref.run(None, feeds)[0]
assert_allclose(expected, got)

def test_dynamic_quantization(self):
# dummy_model.onnx from Olive
model = helper.make_model(
helper.make_graph(
[
helper.make_node(
"Gemm", ["input", "fc1.weight", "fc1.bias"], ["gemm0"], alpha=1.0, beta=1.0, transB=1
),
helper.make_node("Relu", ["gemm0"], ["output"]),
],
"g",
[helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1])],
[helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 10])],
[
onnx.numpy_helper.from_array(np.random.randn(10, 1).astype(np.float32), name="fc1.weight"),
onnx.numpy_helper.from_array(np.random.randn(10).astype(np.float32), name="fc1.bias"),
],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)
onnx.checker.check_model(model)
run_config = {
"weight_type": QuantType.QInt8,
"op_types_to_quantize": None,
"nodes_to_quantize": None,
"nodes_to_exclude": None,
"per_channel": False,
"reduce_range": False,
"extra_options": {
"extra.Sigmoid.nnapi": False,
"ActivationSymmetric": False,
"WeightSymmetric": True,
"EnableSubgraph": False,
"ForceQuantizeNoInputCheck": False,
"MatMulConstBOnly": True,
},
}
model_path = "test_dynamic_quantization.onnx"
with open(model_path, "wb") as f:
f.write(model.SerializeToString())
qpath = "test_dynamic_quantization.quantized.onnx"
quantize_dynamic(model_input=model_path, model_output=qpath, use_external_data_format=True, **run_config)
onx = onnx.load(qpath)
self.assertIn("DynamicQuantizeLinear", set(n.op_type for n in onx.graph.node))


if __name__ == "__main__":
TestOpGemm().test_quantize_gemm_e4m3fn_p3()
unittest.main(verbosity=2)

0 comments on commit 218b6b0

Please sign in to comment.