-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Description
🐛 Describe the bug
I'm trying to export the qwen2_7B model to onnx using a custom scripts with interface torch.onnx.export
,and set onnx_shape_inference=False
, just as:
torch.onnx.export(
model,
inputs,
onnx_path,
input_names=input_names,
output_names=output_names,
onnx_shape_inference=False
)
Because I saw there is the parameter onnx_shape_inference
in interface _export
in torch/onnx/utils.py
file, and the parameter onnx_shape_inference
controls whether to run the shape type inference:
But the parameter onnx_shape_inference
setted with interface torch.onnx.export
can't be successfully passed to the "_export" interface in "torch/onnx/utils.py". I followed the code and saw that the parameter onnx_shape_inference
setted with interface torch.onnx.export
is discarded in function export
defined in torch.onnx.__init__.py
.
Here is no parameter onnx_shape_inference
to emit:
And the place where interface _export
is called does not pass parameter onnx_shape_inference
with the actual value.
So, interface _export
uses the default value True
of parameter onnx_shape_inference
and the shape type inference is always running while export model to onnx.
However, I see there are some issues for catching the exception of shape type inference to keep the export process: 147259, 140962, controlling whether to run the shape type inference is expected. If so, we can skip shape type inference for some llms always fail in the process.
Versions
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-2ubuntu1~20.04) 11.4.0
Clang version: 14.0.0
CMake version: version 3.28.0-rc3
Libc version: glibc-2.35
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnx==1.14.0
[pip3] onnxruntime==1.13.1
[pip3] onnxsim==0.4.35
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] torchvision 0.20.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi
cc @justinchuby