Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Dynamo Error when Eager-Compiling Diffusers Model #107437

Closed
gs-olive opened this issue Aug 18, 2023 · 2 comments
Closed

Torch Dynamo Error when Eager-Compiling Diffusers Model #107437

gs-olive opened this issue Aug 18, 2023 · 2 comments
Assignees
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@gs-olive
Copy link
Contributor

gs-olive commented Aug 18, 2023

馃悰 Describe the bug

When compiling the entire pipeline for this Stable Diffusion model via diffusers, the following error is encountered with Torch August 11 nightly, but not with the August 10 nightly. The error seems to source from these lines in the code.

The issue is a regression from the August 10th nightly to the August 11th nightly, and persists up through the August 17th nightly.

Error logs

Traceback (most recent call last):
  File "~/test.py", line 106, in <module>
    image = pipe(prompt).images[0]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 649, in __call__
    self.scheduler.set_timesteps(num_inference_steps, device=device)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 624, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 132, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 370, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 554, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 465, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 432, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2071, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 168, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 560, in call_function
    return wrap_fx_proxy_cls(cls, tx, proxy, **options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1237, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1351, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1319, in get_fake_value
    return wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 898, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1320, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1385, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1372, in run_node
    return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., size=(51,), dtype=torch.int64), slice(None, None, -1)), **{}):
step must be greater than zero

from user code:
   File "/usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_pndm.py", line 211, in set_timesteps
    self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[

Minified repro

Use the tutorial featured here: https://huggingface.co/CompVis/stable-diffusion-v1-4#pytorch
Then add a torch.compile call with "eager" backend to the pipeline, before the inference, as so:

##### Setup Precedes
...

# torch.compile call
pipe = torch.compile(pipe, backend="eager")

##### Inference Follows
...

Versions

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] diffusers==0.20.0
[pip3] torch==2.1.0.dev20230811+cu121  # Python 3.10

cc @mruberry @rgommers @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

@gs-olive gs-olive changed the title Torch Dynamo Regression Error when Eager-Compiling Diffusers Stable Diffusion Torch Dynamo Regression Error when Eager-Compiling Diffusers Model Aug 18, 2023
@gs-olive gs-olive changed the title Torch Dynamo Regression Error when Eager-Compiling Diffusers Model Torch Dynamo Error when Eager-Compiling Diffusers Model Aug 18, 2023
@ezyang
Copy link
Contributor

ezyang commented Aug 18, 2023

This is almost certainly due to #106211

cc @lezcano

@ezyang ezyang added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Aug 18, 2023
@lezcano lezcano self-assigned this Aug 21, 2023
@lezcano
Copy link
Collaborator

lezcano commented Aug 21, 2023

I'll have a look

@wconstab wconstab added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 21, 2023
lezcano added a commit that referenced this issue Aug 22, 2023
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: 0ece075938578a045308133cc86f5a6e7f29abb0
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: 5b8b61c82514f6a112bc3ef8fc475de2813176cb
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: 10813b6094f4b8943eb16845294d85abecceaaa7
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: 5eddac8a2dace430e0e5641bac28a95f5136d055
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: 9b834d0e0bebda732e826fe92bc67d6bed1f136d
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: f689c0824eae5ca8dfd044b6e10b75cb5bf6feb8
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

ghstack-source-id: 53ff9d9b52e17db579685924ec92287ca3327b6a
Pull Request resolved: #107689
lezcano added a commit that referenced this issue Aug 22, 2023
Fixes #107437

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants