Skip to content

TorchInductor Doesnt Support Generators #92633

@simon-mo

Description

@simon-mo

🐛 Describe the bug

import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline

model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    scheduler=scheduler,
    torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

@torch.compile(backend="inductor")
def inference_func(promt):
    image = pipe(prompt, num_inference_steps=1).images[0]
    return image

prompt = "a photo of an astronaut riding a horse on mars"
image = inference_func(prompt)

with output

Fetching 13 files: 100%|███████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 9024.49it/s]
/home/xmo/miniconda3/lib/python3.10/site-packages/safetensors/torch.py:98: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  with safe_open(filename, framework="pt", device=device) as f:
/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_utils.py:771: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
/home/xmo/miniconda3/lib/python3.10/site-packages/torch/storage.py:899: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  storage = cls(wrap_storage=untyped_storage)
/home/xmo/miniconda3/lib/python3.10/site-packages/transformers/modeling_utils.py:386: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  with safe_open(checkpoint_file, framework="pt") as f:
[2023-01-19 11:56:46,579] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:56:46,625] torch._inductor.graph: [WARNING] Creating implicit fallback for:
  target: aten.triu_.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf2', layout=FlexibleLayout('cpu', torch.float16, size=[1, 77, 77], stride=[5929, 77, 1]), data=Pointwise(
      'cpu',
      torch.float16,
      tmp0 = constant(-65504.0, torch.float32)
      tmp1 = to_dtype(tmp0, torch.float16)
      return tmp1
      ,
      ranges=[1, 77, 77],
      origins={lift_fresh_copy, fill_, empty, _tensor_constant0}
    ))
  ))
  args[1]: 1
[2023-01-19 11:56:46,626] torch._inductor.lowering: [WARNING] make_fallback(aten.triu_.default): a decomposition exists, we should switch to it
[2023-01-19 11:56:46,633] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten.triu_.default
[2023-01-19 11:56:46,634] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:57:09,516] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:57:09,526] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten.triu_.default
[2023-01-19 11:57:09,526] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:57:15,681] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 301, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 223, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 1111, in randn
    return fast_randn(*args, **kwargs)
TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
Traceback (most recent call last):
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 301, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 223, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 1111, in randn
    return fast_randn(*args, **kwargs)
TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 674, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/__init__.py", line 1264, in __call__
    return self.compile_fn(model_, inputs_)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
    return fn(gm, example_inputs, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
    return compile_fx(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 411, in compile_fx
    return aot_autograd(
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2453, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2150, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1412, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1062, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 386, in fw_compiler
    return inner_compile(
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 586, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/debug.py", line 224, in inner
    return fn(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 152, in compile_fx_inner
    graph.run(*example_inputs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 178, in run
    return super().run(*args)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 375, in run_node
    result = super().run_node(n)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 305, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
  target: aten.randn.generator
  args[0]: [1, 4, 96, 96]
  kwargs: {'generator': None, 'dtype': torch.float16, 'device': device(type='cuda', index=0), 'pin_memory': False}

While executing %randn : [#users=1] = call_function[target=torch.ops.aten.randn.generator](args = ([1, 4, 96, 96],), kwargs = {generator: None, dtype: torch.float16, device: cuda:0, pin_memory: False})
Original traceback:
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 398, in prepare_latents
    latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xmo/workspace/fijit-sys/torch-examples/repro.py", line 19, in <module>
    image = inference_func(prompt)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/home/xmo/workspace/fijit-sys/torch-examples/repro.py", line 15, in inference_func
    image = pipe(prompt, num_inference_steps=1).images[0]
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 506, in __call__
    latents = self.prepare_latents(
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 387, in transform
    tracer.run()
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
    super().run()
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 551, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 598, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 679, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
  target: aten.randn.generator
  args[0]: [1, 4, 96, 96]
  kwargs: {'generator': None, 'dtype': torch.float16, 'device': device(type='cuda', index=0), 'pin_memory': False}

While executing %randn : [#users=1] = call_function[target=torch.ops.aten.randn.generator](args = ([1, 4, 96, 96],), kwargs = {generator: None, dtype: torch.float16, device: cuda:0, pin_memory: False})
Original traceback:
  File "/home/xmo/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 398, in prepare_latents
    latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Versions

This is using a nightly build of torch. I believe it worked on the version before the new year but I didn't get to bisect the versions.

Collecting environment information...
PyTorch version: 2.0.0.dev20230119+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-46-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB-LS
GPU 1: Tesla V100-SXM2-32GB-LS
GPU 2: Tesla V100-SXM2-32GB-LS
GPU 3: Tesla V100-SXM2-32GB-LS
GPU 4: Tesla V100-SXM2-32GB-LS
GPU 5: Tesla V100-SXM2-32GB-LS
GPU 6: Tesla V100-SXM2-32GB-LS
GPU 7: Tesla V100-SXM2-32GB-LS

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230119+cu117
[pip3] torchaudio==2.0.0.dev20230119+cu117
[pip3] torchvision==0.15.0.dev20230119+cu117
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] pytorch-triton            2.0.0+0d7e753227          pypi_0    pypi
[conda] torch                     2.0.0.dev20230119+cu117          pypi_0    pypi
[conda] torchaudio                2.0.0.dev20230119+cu117          pypi_0    pypi
[conda] torchvision               0.15.0.dev20230119+cu117          pypi_0    pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions