Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile error with dynamic=True: Found <class 'sympy.core.relational.Unequality'>, which is not a supported top level IR node #103587

Closed
sunhs opened this issue Jun 14, 2023 · 13 comments
Assignees
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@sunhs
Copy link

sunhs commented Jun 14, 2023

馃悰 Describe the bug

Trying to perform torch.compile with dynamic=True on Unet from hugging face's StableDiffusionPipeline, and error occurs.

With dynamic=False it works, but every time the prompt is changed (and thus the prompt embedding could have a different shape) the compilation is triggered, which takes a long time.

Reproduction code

from diffusers import StableDiffusionPipeline
import torch


pipe = StableDiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V2.0", torch_dtype=torch.float16
)
pipe.to("cuda:0")
# pipe.unet = torch.compile(pipe.unet, dynamic=False)  # This is OK.
pipe.unet = torch.compile(pipe.unet, dynamic=True)

pipe(prompt="prompt")

Log

vae/diffusion_pytorch_model.safetensors not found
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.

  0%|          | 0/50 [00:00<?, ?it/s][2023-06-14 19:08:23,752] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 226, in wrapped
    validate_ir(out)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/ir.py", line 105, in validate_ir
    _check_tensorbox(node_or_nodes)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/ir.py", line 90, in _check_tensorbox
    assert isinstance(
AssertionError: Found <class 'sympy.core.relational.Unequality'>, which is not a supported top level IR node. See [Note: Inductor IR]

  0%|          | 0/50 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 226, in wrapped
    validate_ir(out)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/ir.py", line 105, in validate_ir
    _check_tensorbox(node_or_nodes)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/ir.py", line 90, in _check_tensorbox
    assert isinstance(
AssertionError: Found <class 'sympy.core.relational.Unequality'>, which is not a supported top level IR node. See [Note: Inductor IR]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
    return compile_fx(
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2822, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2515, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1715, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1328, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/opt/conda/envs/pt2/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
    graph.run(*example_inputs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 194, in run
    return super().run(*args)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 407, in run_node
    result = super().run_node(n)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 337, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: AssertionError: Found <class 'sympy.core.relational.Unequality'>, which is not a supported top level IR node. See [Note: Inductor IR]
  target: <built-in function ne>
  args[0]: Mod(s2, 8)
  args[1]: 0

While executing %ne : [#users=1] = call_function[target=operator.ne](args = (%mod, 0), kwargs = {})
Original traceback:
None

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "compile.py", line 12, in <module>
    pipe(prompt="prompt")
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 728, in __call__
    noise_pred = self.unet(
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 372, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 541, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: AssertionError: Found <class 'sympy.core.relational.Unequality'>, which is not a supported top level IR node. See [Note: Inductor IR]
  target: <built-in function ne>
  args[0]: Mod(s2, 8)
  args[1]: 0

While executing %ne : [#users=1] = call_function[target=operator.ne](args = (%mod, 0), kwargs = {})
Original traceback:
None

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.7 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~16.04) 9.4.0
Clang version: 11.1.0-++20210314110124+1fdec59bffc1-1~exp1~20210314220751.162
CMake version: version 3.19.3
Libc version: glibc-2.23

Python version: 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.9.70-040970-generic-x86_64-with-glibc2.23
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti

Nvidia driver version: 455.45.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                32
On-line CPU(s) list:   0-31
Thread(s) per core:    2
Core(s) per socket:    8
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 85
Model name:            Intel(R) Xeon(R) Silver 4110 CPU @ 2.10GHz
Stepping:              4
CPU MHz:               2100.000
BogoMIPS:              4201.49
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              1024K
L3 cache:              11264K
NUMA node0 CPU(s):     0-7,16-23
NUMA node1 CPU(s):     8-15,24-31
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb intel_pt tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] numpy==1.23.5
[pip3] open-clip-torch==2.7.0
[pip3] pytorch-lightning==1.9.4
[pip3] torch==2.0.1
[pip3] torch-tensorrt==1.4.0
[pip3] torchaudio==2.0.2
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.11.4
[pip3] torchsde==0.2.5
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2023.1.0         h6d00ec8_46342  
[conda] mkl-service               2.4.0           py310h5eee18b_1  
[conda] mkl_fft                   1.3.6           py310h1128e8f_1  
[conda] mkl_random                1.2.2           py310h1128e8f_1  
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] open-clip-torch           2.7.0                    pypi_0    pypi
[conda] pytorch                   2.0.1           py3.10_cuda11.7_cudnn8.5.0_0    pytorch
[conda] pytorch-cuda              11.7                 h778d358_5    pytorch
[conda] pytorch-lightning         1.9.4                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch-tensorrt            1.4.0                    pypi_0    pypi
[conda] torchaudio                2.0.2               py310_cu117    pytorch
[conda] torchdiffeq               0.2.3                    pypi_0    pypi
[conda] torchmetrics              0.11.4                   pypi_0    pypi
[conda] torchsde                  0.2.5                    pypi_0    pypi
[conda] torchtriton               2.0.0                     py310    pytorch
[conda] torchvision               0.15.2              py310_cu117    pytorch

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

@ezyang
Copy link
Contributor

ezyang commented Jun 16, 2023

I'll take a closer look in a bit, but try this patch:

diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 52931d14a81..78318212d50 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -113,6 +113,7 @@ def validate_ir(node_or_nodes):
                     TensorBox,
                     sympy.Symbol,
                     sympy.core.relational.Relational,
+                    sympy.core.relational.Unequality,
                     Expr,
                     torch._inductor.ir.ExpandView,
                 ),

Also, as an alternative, try instead setting this config:

torch._dynamo.config.dynamic_shapes = True
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True

which will enable dynamic shapes in a less aggressive way (we will be turning this on as default soon).

@sunhs
Copy link
Author

sunhs commented Jun 19, 2023

Hi, I saw that validate_ir's function signature in pytorch 2.0.1 is different from your diff's, so I suppose pytorch nightly is expected.

Now I'm using pytorch 2.1.0.dev20230618 and run with both of your suggestions, and the logs are the same as below:

  0%|          | 0/50 [00:00<?, ?it/s]/conda/envs/ptdev/lib/python3.10/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/conda/envs/ptdev/lib/python3.10/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/conda/envs/ptdev/lib/python3.10/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/conda/envs/ptdev/lib/python3.10/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,

  0%|          | 0/50 [00:04<?, ?it/s]
Traceback (most recent call last):
  File "torch_compile.py", line 18, in <module>
    pipe(prompt="prompt")
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 728, in __call__
    noise_pred = self.unet(
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 636, in forward
    def forward(
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3761, in forward
    return compiled_fn(full_args)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1449, in g
    return f(*args)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2407, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1473, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1561, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/conda/envs/ptdev/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 457, in run
    return model(new_inputs)
  File "/tmp/torchinductor/oi/coiilopdfkj6pdlbmvmpfvlesgzeoi5v3asbtndawlkleexuodaz.py", line 31, in call
    return (Ne(Mod(s2, 8), 0), Ne(Mod(s2, 8), 0), )
NameError: name 'Ne' is not defined

And this is the content from /tmp/torchinductor/oi/coiilopdfkj6pdlbmvmpfvlesgzeoi5v3asbtndawlkleexuodaz.py

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile

from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1 = args
    args.clear()
    s0 = arg0_1
    s1 = arg1_1
    s2 = arg2_1
    assert_size_stride(arg3_1, (s0, s1, s2, s2), (s1*(s2*s2), s2*s2, s2, 1))
    return (Ne(Mod(s2, 8), 0), Ne(Mod(s2, 8), 0), )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = 2
    arg1_1 = 4
    arg2_1 = 64
    arg3_1 = rand_strided((2, 4, 64, 64), (16384, 4096, 64, 1), device='cuda:0', dtype=torch.float16)
    return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]), times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.utils import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

And I missing something?

@ezyang
Copy link
Contributor

ezyang commented Jun 19, 2023

No, you just hit another but. I'll send a patch tomorrow

@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 20, 2023
@ezyang
Copy link
Contributor

ezyang commented Jun 23, 2023

Confirmed with #104104 the test script runs all the way to completion.

@sunhs
Copy link
Author

sunhs commented Jun 26, 2023

Confirmed it's working on pytorch-nightly 2.1.0.dev20230625. Thanks a lot!

@sunhs
Copy link
Author

sunhs commented Jun 28, 2023

Confirmed it's working on pytorch-nightly 2.1.0.dev20230625. Thanks a lot!

Sorry, it's working for the first run, but if the input shapes are changed for the second run, it seems to hang for a long time before I have to ctrl-C.

Could you try this and confirm?

from diffusers import StableDiffusionPipeline
import torch


pipe = StableDiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V2.0", torch_dtype=torch.float16
)
pipe.to("cuda:0")
pipe.unet = torch.compile(pipe.unet, dynamic=True)

pipe(prompt="prompt", height=512, width=512)
pipe(prompt="prompt", height=768, width=768)

@ezyang
Copy link
Contributor

ezyang commented Jul 5, 2023

@sunhs Instead of dynamic=True could you please try torch._dynamo.config.automatic_dynamic_shapes = True instead?

@sunhs
Copy link
Author

sunhs commented Jul 7, 2023

@ezyang It completed to run, but still took a long time if the input shape changed. Already upgraded pytorch to 2.1.0.dev20230706.

from diffusers import StableDiffusionPipeline
import torch
import torch._dynamo
import datetime


torch._dynamo.config.dynamic_shapes = True
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True


pipe = StableDiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V2.0", torch_dtype=torch.float16
)
pipe.to("cuda:0")
pipe.unet = torch.compile(pipe.unet)

start = datetime.datetime.now()
pipe(prompt="prompt", height=512, width=512)
first_done = datetime.datetime.now()
pipe(prompt="prompt", height=768, width=768)
second_done = datetime.datetime.now()

print("first elapsed:", first_done - start)
print("second elapsed:", second_done - first_done)

Log:

  0%|          | 0/50 [00:00<?, ?it/s]Using FallbackKernel: aten._scaled_dot_product_efficient_attention
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 50/50 [01:14<00:00,  1.49s/it]
  0%|          | 0/50 [00:00<?, ?it/s]Using FallbackKernel: aten._scaled_dot_product_efficient_attention
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
reduction over non-contiguous dims
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 50/50 [00:51<00:00,  1.04s/it]
first elapsed: 0:01:15.543955
second elapsed: 0:00:52.173735

Not sure if this reduction over non-contiguous dims warning indicates anything ?

@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2023

So, I'm still checking up on things, but at least in the example you posted, I do expect you need to compile once: the first time we assume that you only wanted one image size, and then the second time we recompile trying to keep the kernel dynamic for size. I'm not too sure if the resulting kernel is dynamic or not, but when I asked it to do a third run at 512x 768 at least it didn't recompile. Do you have a list of sizes you want to work, by any chance?

@ezyang ezyang reopened this Jul 9, 2023
@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2023

I'm wrong, it actually does recompile the third time.

@ezyang
Copy link
Contributor

ezyang commented Jul 10, 2023

@sunhs Please feel free to file a new issue with more details about how few kernels you need for your use case. My impression with SD is there are not too many sizes people typically want to run generation on and it is not a big burden to compile each of them, but I could be wrong. The inductor generated kernels are specialized in somewhat hard to parse ways, so it will take some more serious debugging to diagnose.

@ezyang ezyang closed this as completed Jul 10, 2023
@AlphaAtlas
Copy link

@ezyang It depends on the use case.

For instance, an AI horde worker might get any arbitrary shape and size coming in from clients.

Or even a single user running an img2img pipeline might swap out resolution very frequently.

And thats just the tip of the iceberg... In the wild outside of pure diffusers user constantly adjust augmentations like TomeSD or controlnet. They also run "high res fix" which runs 2 different resolutions back to back.

512x512 to 1024x1024, in 64 bit non square increments, is probably a sane range?

I noticed that AITemplate allows for manually specifying a dynamic range and changing weights, which is very nice. @sunhs If you are running linux or WSL, you should probably give it a shot, as its very quick after the initial compile: https://github.com/facebookincubator/AITemplate/tree/main/examples/05_stable_diffusion#alternative-pipeline

@ezyang
Copy link
Contributor

ezyang commented Jul 10, 2023

Ok. That suggests there aren't fundamental problems with dynamic compilation here, just need to knock out the problems. I will take a closer look at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants