Skip to content

Parameter "onnx_shape_inference" can't be successfully passed to the "_export" interface in "torch/onnx/utils.py" via the "torch.onnx.export" interface #156480

@Taot-chen

Description

@Taot-chen

🐛 Describe the bug

I'm trying to export the qwen2_7B model to onnx using a custom scripts with interface torch.onnx.export,and set onnx_shape_inference=False, just as:

torch.onnx.export(
    model,
    inputs,
    onnx_path,
    input_names=input_names,
    output_names=output_names,
    onnx_shape_inference=False
)

Because I saw there is the parameter onnx_shape_inference in interface _export in torch/onnx/utils.py file, and the parameter onnx_shape_inference controls whether to run the shape type inference:

Image

Image

But the parameter onnx_shape_inference setted with interface torch.onnx.export can't be successfully passed to the "_export" interface in "torch/onnx/utils.py". I followed the code and saw that the parameter onnx_shape_inference setted with interface torch.onnx.export is discarded in function export defined in torch.onnx.__init__.py.
Here is no parameter onnx_shape_inference to emit:

Image

And the place where interface _export is called does not pass parameter onnx_shape_inference with the actual value.

Image

So, interface _export uses the default value True of parameter onnx_shape_inference and the shape type inference is always running while export model to onnx.

However, I see there are some issues for catching the exception of shape type inference to keep the export process: 147259, 140962, controlling whether to run the shape type inference is expected. If so, we can skip shape type inference for some llms always fail in the process.

Versions

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-2ubuntu1~20.04) 11.4.0
Clang version: 14.0.0
CMake version: version 3.28.0-rc3
Libc version: glibc-2.35

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnx==1.14.0
[pip3] onnxruntime==1.13.1
[pip3] onnxsim==0.4.35
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] torchvision 0.20.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi

cc @justinchuby

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions