-
Couldn't load subscription status.
- Fork 25.7k
Closed
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
@torch.compile(backend="inductor")
def inference_func(promt):
image = pipe(prompt, num_inference_steps=1).images[0]
return image
prompt = "a photo of an astronaut riding a horse on mars"
image = inference_func(prompt)
with output
Fetching 13 files: 100%|███████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 9024.49it/s]
/home/xmo/miniconda3/lib/python3.10/site-packages/safetensors/torch.py:98: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
with safe_open(filename, framework="pt", device=device) as f:
/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_utils.py:771: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
/home/xmo/miniconda3/lib/python3.10/site-packages/torch/storage.py:899: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
storage = cls(wrap_storage=untyped_storage)
/home/xmo/miniconda3/lib/python3.10/site-packages/transformers/modeling_utils.py:386: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
with safe_open(checkpoint_file, framework="pt") as f:
[2023-01-19 11:56:46,579] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:56:46,625] torch._inductor.graph: [WARNING] Creating implicit fallback for:
target: aten.triu_.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf2', layout=FlexibleLayout('cpu', torch.float16, size=[1, 77, 77], stride=[5929, 77, 1]), data=Pointwise(
'cpu',
torch.float16,
tmp0 = constant(-65504.0, torch.float32)
tmp1 = to_dtype(tmp0, torch.float16)
return tmp1
,
ranges=[1, 77, 77],
origins={lift_fresh_copy, fill_, empty, _tensor_constant0}
))
))
args[1]: 1
[2023-01-19 11:56:46,626] torch._inductor.lowering: [WARNING] make_fallback(aten.triu_.default): a decomposition exists, we should switch to it
[2023-01-19 11:56:46,633] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten.triu_.default
[2023-01-19 11:56:46,634] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:57:09,516] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:57:09,526] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten.triu_.default
[2023-01-19 11:57:09,526] torch._inductor.ir: [WARNING] DeviceCopy
[2023-01-19 11:57:15,681] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 301, in call_function
out = lowerings[target](*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 223, in wrapped
return decomp_fn(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 1111, in randn
return fast_randn(*args, **kwargs)
TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
Traceback (most recent call last):
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 301, in call_function
out = lowerings[target](*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 223, in wrapped
return decomp_fn(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 1111, in randn
return fast_randn(*args, **kwargs)
TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 674, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/__init__.py", line 1264, in __call__
return self.compile_fn(model_, inputs_)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
return fn(gm, example_inputs, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
return compile_fx(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 411, in compile_fx
return aot_autograd(
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2453, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2150, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1412, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1062, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 386, in fw_compiler
return inner_compile(
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 586, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/debug.py", line 224, in inner
return fn(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 152, in compile_fx_inner
graph.run(*example_inputs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 178, in run
return super().run(*args)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 375, in run_node
result = super().run_node(n)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_inductor/graph.py", line 305, in call_function
raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
target: aten.randn.generator
args[0]: [1, 4, 96, 96]
kwargs: {'generator': None, 'dtype': torch.float16, 'device': device(type='cuda', index=0), 'pin_memory': False}
While executing %randn : [#users=1] = call_function[target=torch.ops.aten.randn.generator](args = ([1, 4, 96, 96],), kwargs = {generator: None, dtype: torch.float16, device: cuda:0, pin_memory: False})
Original traceback:
File "/home/xmo/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 398, in prepare_latents
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/xmo/workspace/fijit-sys/torch-examples/repro.py", line 19, in <module>
image = inference_func(prompt)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/home/xmo/workspace/fijit-sys/torch-examples/repro.py", line 15, in inference_func
image = pipe(prompt, num_inference_steps=1).images[0]
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 506, in __call__
latents = self.prepare_latents(
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
return fn(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
out_code = transform_code_object(code, transform)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 387, in transform
tracer.run()
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
super().run()
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
and self.step()
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
getattr(self, inst.opname)(inst)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 551, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 598, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/xmo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 679, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: TypeError: make_rand.<locals>.rand_or_randn() got an unexpected keyword argument 'generator'
target: aten.randn.generator
args[0]: [1, 4, 96, 96]
kwargs: {'generator': None, 'dtype': torch.float16, 'device': device(type='cuda', index=0), 'pin_memory': False}
While executing %randn : [#users=1] = call_function[target=torch.ops.aten.randn.generator](args = ([1, 4, 96, 96],), kwargs = {generator: None, dtype: torch.float16, device: cuda:0, pin_memory: False})
Original traceback:
File "/home/xmo/miniconda3/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 398, in prepare_latents
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Versions
This is using a nightly build of torch. I believe it worked on the version before the new year but I didn't get to bisect the versions.
Collecting environment information...
PyTorch version: 2.0.0.dev20230119+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31
Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-46-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB-LS
GPU 1: Tesla V100-SXM2-32GB-LS
GPU 2: Tesla V100-SXM2-32GB-LS
GPU 3: Tesla V100-SXM2-32GB-LS
GPU 4: Tesla V100-SXM2-32GB-LS
GPU 5: Tesla V100-SXM2-32GB-LS
GPU 6: Tesla V100-SXM2-32GB-LS
GPU 7: Tesla V100-SXM2-32GB-LS
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230119+cu117
[pip3] torchaudio==2.0.0.dev20230119+cu117
[pip3] torchvision==0.15.0.dev20230119+cu117
[conda] numpy 1.24.1 pypi_0 pypi
[conda] pytorch-triton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torch 2.0.0.dev20230119+cu117 pypi_0 pypi
[conda] torchaudio 2.0.0.dev20230119+cu117 pypi_0 pypi
[conda] torchvision 0.15.0.dev20230119+cu117 pypi_0 pypi
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire
Metadata
Metadata
Assignees
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module