Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] UnsupportedOperatorError: Exporting the operator 'aten::l1_loss' to ONNX opset version 17 is not supported #100913

Closed
shingjan opened this issue May 8, 2023 · 3 comments
Labels
module: onnx Related to torch.onnx oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shingjan
Copy link
Contributor

shingjan commented May 8, 2023

馃悰 Describe the bug

My repro:

python benchmarks/dynamo/torchbench.py --only Super_SloMo --backend onnxrt -dcuda --performance --inference

with this line opset_version=17, added for line.

And with the following error:

cuda eval  Super_SloMo                         ========= Diagnostic Run torch.onnx.export version 2.0.0a0+gitc263bd4 ==========
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 1 ERROR ========================
ERROR: missing-standard-symbolic-function
=========================================
Exporting the operator 'aten::l1_loss' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None
<Set verbose=True to see more details>


ERROR:common:Backend dynamo failed in warmup()
Traceback (most recent call last):
  File "/home/yj/pytorch/benchmarks/dynamo/common.py", line 1485, in warmup
    fn(model, example_inputs)
  File "/home/yj/pytorch/torch/_dynamo/eval_frame.py", line 280, in _fn
    return fn(*args, **kwargs)
  File "/home/yj/pytorch/torch/_dynamo/eval_frame.py", line 433, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/home/yj/pytorch/torch/_dynamo/convert_frame.py", line 519, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/yj/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
    return fn(*args, **kwargs)
  File "/home/yj/pytorch/torch/_dynamo/convert_frame.py", line 355, in _convert_frame_assert
    return _compile(
  File "/home/yj/pytorch/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yj/pytorch/torch/_dynamo/convert_frame.py", line 425, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/yj/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yj/pytorch/torch/_dynamo/convert_frame.py", line 410, in transform
    tracer.run()
  File "/home/yj/pytorch/torch/_dynamo/symbolic_convert.py", line 2010, in run
    super().run()
  File "/home/yj/pytorch/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/pytorch/torch/_dynamo/symbolic_convert.py", line 2098, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/yj/pytorch/torch/_dynamo/output_graph.py", line 723, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/yj/anaconda3/envs/dynamite/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/yj/pytorch/torch/_dynamo/output_graph.py", line 800, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/yj/pytorch/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yj/pytorch/torch/_dynamo/output_graph.py", line 859, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/yj/pytorch/torch/_dynamo/output_graph.py", line 855, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/yj/pytorch/torch/_dynamo/repro/after_dynamo.py", line 108, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/yj/pytorch/torch/_dynamo/backends/common.py", line 125, in wrapper
    return fn(model, inputs, **kwargs)
  File "/home/yj/pytorch/torch/_dynamo/backends/onnxrt.py", line 55, in onnxrt
    return onnxrt(gm, example_inputs, filename=tmp.name)
  File "/home/yj/pytorch/torch/_dynamo/backends/common.py", line 125, in wrapper
    return fn(model, inputs, **kwargs)
  File "/home/yj/pytorch/torch/_dynamo/backends/onnxrt.py", line 72, in onnxrt
    torch.onnx.export(
  File "/home/yj/pytorch/torch/onnx/utils.py", line 507, in export
    _export(
  File "/home/yj/pytorch/torch/onnx/utils.py", line 1567, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/yj/pytorch/torch/onnx/utils.py", line 1128, in _model_to_graph
    graph = _optimize_graph(
  File "/home/yj/pytorch/torch/onnx/utils.py", line 666, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/yj/pytorch/torch/onnx/utils.py", line 1918, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch._dynamo.exc.BackendCompilerFailed: backend='onnxrt' raised:
UnsupportedOperatorError: Exporting the operator 'aten::l1_loss' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.0.0a0+gitc263bd4
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.8.15 (default, Nov 24 2022, 15:19:38) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 520.61.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 24
On-line CPU(s) list: 0-23
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 25
Model: 33
Model name: AMD Ryzen 9 5900X 12-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 2786.283
CPU max MHz: 3700.0000
CPU min MHz: 2200.0000
BogoMIPS: 7399.70
Virtualization: AMD-V
L1d cache: 384 KiB
L1i cache: 384 KiB
L2 cache: 6 MiB
L3 cache: 64 MiB
NUMA node0 CPU(s): 0-23
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm

Versions of relevant libraries:
[pip3] bert-pytorch==0.0.1a4
[pip3] clip-anytorch==2.5.2
[pip3] CoCa-pytorch==0.0.7
[pip3] dalle2-pytorch==1.12.4
[pip3] ema-pytorch==0.1.4
[pip3] functorch==1.14.0a0+408bcf1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] open-clip-torch==2.16.0
[pip3] pytorch-transformers==1.2.0
[pip3] pytorch-triton==2.1.0+46672772b4
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.2.1
[pip3] torch==2.1.0a0+git2f95380
[pip3] torch-fidelity==0.3.0
[pip3] torch-scatter==2.1.1+pt20cpu
[pip3] torch-sparse==0.6.17+pt20cpu
[pip3] torch-struct==0.5
[pip3] torchaudio==2.1.0a0+d5b2996
[pip3] torchdata==0.7.0a0+f083d52
[pip3] torchmetrics==0.11.0
[pip3] torchrec-nightly==2023.1.25
[pip3] torchtext==0.16.0a0+79100a6
[pip3] torchvision==0.16.0a0+0d75d9e
[pip3] torchx==0.4.0
[pip3] vector-quantize-pytorch==0.10.15
[conda] bert-pytorch 0.0.1a4 dev_0
[conda] clip-anytorch 2.5.2 pypi_0 pypi
[conda] coca-pytorch 0.0.7 pypi_0 pypi
[conda] dalle2-pytorch 1.12.4 pypi_0 pypi
[conda] ema-pytorch 0.1.4 pypi_0 pypi
[conda] functorch 1.14.0a0+408bcf1 pypi_0 pypi
[conda] magma-cuda118 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 h06a4308_46342
[conda] numpy 1.23.5 pypi_0 pypi
[conda] open-clip-torch 2.16.0 pypi_0 pypi
[conda] pytorch-transformers 1.2.0 pypi_0 pypi
[conda] pytorch-triton 2.1.0+46672772b4 pypi_0 pypi
[conda] pytorch-warmup 0.1.1 pypi_0 pypi
[conda] rotary-embedding-torch 0.2.1 pypi_0 pypi
[conda] torch 2.1.0a0+git2f95380 dev_0
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torch-scatter 2.1.1+pt20cpu pypi_0 pypi
[conda] torch-sparse 0.6.17+pt20cpu pypi_0 pypi
[conda] torch-struct 0.5 pypi_0 pypi
[conda] torchaudio 2.1.0a0+d5b2996 dev_0
[conda] torchdata 0.7.0a0+f083d52 pypi_0 pypi
[conda] torchmetrics 0.11.0 pypi_0 pypi
[conda] torchrec-nightly 2023.1.25 pypi_0 pypi
[conda] torchtext 0.16.0a0+79100a6 dev_0
[conda] torchvision 0.15.0a0+85983a5 pypi_0 pypi
[conda] torchx 0.4.0 pypi_0 pypi
[conda] vector-quantize-pytorch 0.10.15 pypi_0 pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305

@soulitzer soulitzer added module: onnx Related to torch.onnx oncall: pt2 labels May 8, 2023
@shingjan shingjan changed the title [dynamo] UnsupportedOperatorError: Exporting the operator 'aten::l1_loss' to ONNX opset version 17 is not supported [onnx] UnsupportedOperatorError: Exporting the operator 'aten::l1_loss' to ONNX opset version 17 is not supported May 9, 2023
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 25, 2023
@tugsbayasgalan
Copy link
Contributor

@BowenBao do you know if this is still an issue?

@BowenBao
Copy link
Collaborator

Need re-evaluation since a lot has changed. The stack trace shows onnxrt backend is still running torchscript based export, which we have already swapped to dynamo based.

@BowenBao
Copy link
Collaborator

Closing since this model works under dynamo onnx exporter. Feel free to reopen if you'd still encounter issues.

@BowenBao BowenBao removed their assignment May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

5 participants