Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.export of ResNet with dynamic height fails due to constraint violation #124507

Open
lopuhin opened this issue Apr 19, 2024 · 2 comments
Open

Comments

@lopuhin
Copy link
Contributor

lopuhin commented Apr 19, 2024

馃悰 Describe the bug

torch.export of ResNet with dynamic height fails due to constraint violation -- apparently it wants the height to be even.

Here is an example repro

import torch
from torchvision.models import resnet34

device = torch.device('cuda')
model = resnet34().to(device)
model.eval()

k = 1  # works with k = 2
height = torch.export.Dim('height', min=320 // k, max=7680 // k)
exported = torch.export.export(
    model,
    args=(
        torch.randn((1, 3, 640, 320)).to(device),
    ),
    dynamic_shapes=(
        {2: height * k},
    ),
)
example_input = torch.randn((1, 3, 1242, 320)).to(device)
print(exported.module()(example_input).shape)

This fails with

torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (height)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of height = L['x'].size()[2] in the specified range 320 <= height <= 7680 satisfy the generated guard Ne(Mod(((L['x'].size()[2] - 1)//2), 2), 0).

Suggested fixes:
  height = Dim('height', min=320, max=7680)

If we set k to 2 (so that height is always even), it works.

Versions

This is using torch nightly.

Collecting environment information...                                                                                                  
PyTorch version: 2.4.0.dev20240419+cu121                                                                                               
Is debug build: False                                                                                                                  
CUDA used to build PyTorch: 12.1             
ROCM used to build PyTorch: N/A                          

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
                                                                   
Python version: 3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-73-generic-x86_64-with-glibc2.35
Is CUDA available: True              
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:               
GPU 0: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 550.67
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
...

Versions of relevant libraries:
[pip3] mypy==1.7.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.1
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.17.1
[pip3] onnxruntime-gpu==1.17.1
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.4.0.dev20240419+cu121
[pip3] torch-model-archiver==0.5.3
[pip3] torchserve==0.7.1
[pip3] torchvision==0.19.0.dev20240419+cu121
[conda] Could not collect

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@lopuhin
Copy link
Contributor Author

lopuhin commented Apr 19, 2024

One more (unrelated) issue is that performance of exported model, when called from python, is a bit slower than both eager evaluation, and torch.compile (which is the fastest, ignoring compilation time), although it's a small model. I was hoping it would be similar in performance to torch.compile -- just curious if this is expected?

@zhxchen17
Copy link
Contributor

One more (unrelated) issue is that performance of exported model, when called from python, is a bit slower than both eager evaluation, and torch.compile (which is the fastest, ignoring compilation time), although it's a small model. I was hoping it would be similar in performance to torch.compile -- just curious if this is expected?

Note that if you run exported model directly, it won't be performant by default because we are just running it op by op. To get it optimized, it's better to use a backend such as torch._inductor.aot_compile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants