Open
Description
Describe the issue
Summary
Conv combined with QuantizeLinear/DequantizeLinear gives bad output in CPUExecutionProvider.
This problem only occurs when the execution provider is CPU, and graph optimization level is ENABLE_EXTENDED or higher.
My hypothesis is that ORT is trying to optimize my model from A to B, but something went wrong.
A:
(x) -> QuantizeLinear -> DequantizeLinear --+
(W) -> QuantizeLinaer -> DequantizeLinear --+--> Conv --> QuantizeLinear -> DequantizeLinear
(b_int32) -------------> DequantizeLinear --+
B:
(x) -> QuantizeLinear --+
(W) -> QuantizeLinear --+----------------------> QLinearConv --> DequantizeLinear
(b_int32) --------------+
To reproduce
Minimum Reproducible Example
Environment
- onnx==1.16.2
- onnxruntime-gpu==1.19.2
- CUDA 12.2
from collections import defaultdict
import numpy as np
import onnx
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
import onnxruntime as ort
import torch
input_scale = np.array(0.001, dtype=np.float32)
input_zero_point = np.array(128, dtype=np.uint8)
weight_scale = np.array([0.001, 0.001, 0.001], dtype=np.float32)
bias_scale = input_scale * weight_scale
output_scale = np.array(0.0001, dtype=np.float32)
output_zero_point = np.array(128, dtype=np.uint8)
input = (np.random.randint(-128, 128, size=(1, 3, 32, 32)) * input_scale).astype(np.float32)
weight = np.random.randn(3, 3, 3, 3).astype('float32')
bias_int32 = np.array([-5123, 123, 5151]).astype('int32')
def qdq_conv():
op = OperatorSetIdProto()
op.version = 21
model = helper.make_model(
graph=helper.make_graph(
name="QDQ_Conv",
inputs=[helper.make_tensor_value_info('input', TensorProto.FLOAT, shape=[1, 3, 32, 32])],
outputs=[helper.make_tensor_value_info('output_qdq', TensorProto.FLOAT, shape=[1, 3, 30, 30])],
initializer=[
numpy_helper.from_array(weight, name='weight'),
numpy_helper.from_array(bias_int32, name='bias_q'),
numpy_helper.from_array(input_scale, name='input_scale'),
numpy_helper.from_array(input_zero_point, name='input_zero_point'),
numpy_helper.from_array(weight_scale, name='weight_scale'),
numpy_helper.from_array(bias_scale, name='bias_scale'),
numpy_helper.from_array(output_scale, name='output_scale'),
numpy_helper.from_array(output_zero_point, name='output_zero_point'),
],
nodes=[
helper.make_node(
'QuantizeLinear',
inputs=['input', 'input_scale', "input_zero_point"],
outputs=['input_q'],
output_dtype=TensorProto.UINT8
),
helper.make_node(
'DequantizeLinear',
inputs=['input_q', 'input_scale', "input_zero_point"],
outputs=['input_qdq'],
),
helper.make_node(
'QuantizeLinear',
inputs=['weight', 'weight_scale'],
outputs=['weight_q'],
axis=0,
output_dtype=TensorProto.INT8
),
helper.make_node(
'DequantizeLinear',
inputs=['weight_q', 'weight_scale'],
outputs=['weight_qdq'],
axis=0
),
helper.make_node(
'DequantizeLinear',
inputs=['bias_q', 'bias_scale'],
outputs=['bias_qdq'],
axis=0
),
helper.make_node(
"Conv",
inputs=["input_qdq", "weight_qdq", "bias_qdq"],
outputs=["output"],
name="conv"
),
helper.make_node(
'QuantizeLinear',
inputs=['output', 'output_scale', "output_zero_point"],
outputs=['output_q'],
output_dtype=TensorProto.UINT8
),
helper.make_node(
'DequantizeLinear',
inputs=['output_q', 'output_scale', "output_zero_point"],
outputs=['output_qdq'],
),
]
),
opset_imports=[op]
)
onnx.checker.check_model(model, True)
return model
model = qdq_conv()
outputs = defaultdict(dict)
for provider in ["CPUExecutionProvider", "CUDAExecutionProvider"]:
for optim_level in [ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
ort.GraphOptimizationLevel.ORT_ENABLE_ALL]:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = optim_level
sess = ort.InferenceSession(model.SerializeToString(),
sess_options=sess_options,
providers=[provider])
output, = sess.run(None, {"input": input})
outputs[provider][optim_level] = output
# Allow off-by-one error
atol = output_scale * 1
cpu_vs_cuda_no_opt = "==" if np.allclose(
outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL],
outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL],
atol=atol,
) else "!="
cpu_vs_cuda_no_opt_similarity = torch.nn.functional.cosine_similarity(
torch.from_numpy(outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL]),
torch.from_numpy(outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL]),
).mean()
cpu_vs_cuda_opt = "==" if np.allclose(
outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL],
outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL],
atol=atol,
) else "!="
cpu_vs_cuda_opt_similarity = torch.nn.functional.cosine_similarity(
torch.from_numpy(outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL]),
torch.from_numpy(outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL]),
).mean()
opt_vs_noopt_cpu = "==" if np.allclose(
outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL],
outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL],
atol=atol,
) else "!="
opt_vs_noopt_cpu_similarity = torch.nn.functional.cosine_similarity(
torch.from_numpy(outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL]),
torch.from_numpy(outputs["CPUExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL]),
).mean()
opt_vs_noopt_cuda = "==" if np.allclose(
outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL],
outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL],
atol=atol,
) else "!="
opt_vs_noopt_cuda_similarity = torch.nn.functional.cosine_similarity(
torch.from_numpy(outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_ENABLE_ALL]),
torch.from_numpy(outputs["CUDAExecutionProvider"][ort.GraphOptimizationLevel.ORT_DISABLE_ALL]),
).mean()
print(
f"""
(similiarity: {cpu_vs_cuda_opt_similarity:.2f})
CPUExecutionProvider <------- {cpu_vs_cuda_opt} ------> CUDAExecutionProvider
(ORT_ENABLE_ALL) (ORT_ENABLE_ALL)
^ ^
| |
(similiarity: {opt_vs_noopt_cpu_similarity:.2f}) {opt_vs_noopt_cpu} {opt_vs_noopt_cpu} (similiarity: {opt_vs_noopt_cuda_similarity:.2f})
| |
v v
CPUExecutionProvider <------- {cpu_vs_cuda_no_opt} -------> CUDAExecutionProvider
(ORT_DISABLE_ALL) (similiarity: {cpu_vs_cuda_no_opt_similarity:.2f}) (ORT_DISABLE_ALL)
"""
)
Output
(similiarity: 0.80)
CPUExecutionProvider <------- != ------> CUDAExecutionProvider
(ORT_ENABLE_ALL) (ORT_ENABLE_ALL)
^ ^
| |
(similiarity: 0.80) != == (similiarity: 1.00)
| |
v v
CPUExecutionProvider <------- == -------> CUDAExecutionProvider
(ORT_DISABLE_ALL) (similiarity: 1.00) (ORT_DISABLE_ALL)
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22:04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.19.2
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
CUDA 12.2