Skip to content

[Inductor] [CPU] scaled_dot_product_attention() unexpected a value type caused crash in xcit_large_24_p8_224  #99124

@yudongsi

Description

@yudongsi

🐛 Describe the bug

Found in #93531 (comment) (Single-core Single-thread)

Error message:

cpu  eval  xcit_large_24_p8_224                [2023-04-13 18:22:45,243] torch._inductor.utils: [WARNING] make_fallback(aten.cumprod): a decomposition exists, we should switch to it
ERROR:common:Backend dynamo failed in warmup()
Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 1365, in warmup
    fn(model, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 254, in _fn
    return fn(*args, **kwargs)
  File "benchmarks/dynamo/timm_models.py", line 327, in forward_pass
    return mod(*inputs)
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 401, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 474, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 317, in _convert_frame_assert
    return _compile(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 382, in _compile
    out_code = transform_code_object(code, transform)
  File "/workspace/pytorch/torch/_dynamo/bytecode_transformation.py", line 683, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 369, in transform
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1892, in run
    super().run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 611, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 571, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1979, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 622, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 692, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 774, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/workspace/pytorch/torch/_dynamo/output_graph.py", line 770, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/workspace/pytorch/torch/_dynamo/debug_utils.py", line 1098, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/backends/inductor.py", line 9, in inductor
    return compile_fx(*args, **kwargs)
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 736, in compile_fx
    return aot_autograd(
  File "/workspace/pytorch/torch/_dynamo/backends/common.py", line 62, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 3093, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 2728, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 1827, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 1993, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 1292, in aot_dispatch_base
    compiled_fw = compiler(fw_module, flat_args)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 685, in fw_compiler_base
    joint_graph_passes(model)
  File "/workspace/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 28, in joint_graph_passes
    if patterns.apply(graph.graph):
  File "/workspace/pytorch/torch/_inductor/pattern_matcher.py", line 591, in apply
    if m and entry.extra_check(m):
  File "/workspace/pytorch/torch/_inductor/pattern_matcher.py", line 509, in check_fn
    match.replacement_graph = trace_fn(replace_fn, args)
  File "/workspace/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/pytorch/torch/_inductor/pattern_matcher.py", line 673, in inference_graph
    gm = make_fx(fn, select_decomp_table())(*args)
  File "/workspace/pytorch/torch/fx/experimental/proxy_tensor.py", line 761, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 254, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/fx/experimental/proxy_tensor.py", line 467, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 254, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/workspace/pytorch/torch/fx/experimental/proxy_tensor.py", line 484, in wrapped
    out = f(*tensors)
  File "/workspace/pytorch/torch/_inductor/fx_passes/fuse_attention.py", line 47, in _sfdp_replacement_2
    return aten.scaled_dot_product_attention(
  File "/workspace/pytorch/torch/_ops.py", line 646, in __call__
    return self._op(*args, **kwargs or {})
  File "/workspace/pytorch/torch/_inductor/overrides.py", line 33, in __torch_function__
    return replace_fn(func)(*args, **kwargs)
  File "/workspace/pytorch/torch/_ops.py", line 646, in __call__
    return self._op(*args, **kwargs or {})
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: aten::scaled_dot_product_attention() Expected a value of type 'Optional[float]' for argument 'scale' but instead found type 'FakeTensor'.
Position: 6
Value: FakeTensor(FakeTensor(..., device='meta', size=(16, 1, 1)), cpu)
Declaration: aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> Tensor
Cast error details: Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

ERROR

SW information:

SW Nightly commit Master/Main commit
Pytorch 84ec5d9 6ff32b5
Torchbench / 83a316df
torchaudio 375e751 a8f4e97
torchtext 9749082 46e7eef
torchvision 8d15ca7 98c5815
torchdata b3048d5 f1283eb
dynamo_benchmarks 1238ae3 /

log with TORCH_COMPILE_DEBUG=1 :
single_thread_xcit_large_24_p8_224_rerun__bench_20230414_023546.log

Repro:

python -m torch.backends.xeon.run_cpu --core_list 0 --ncores_per_instance 1 benchmarks/dynamo/timm_models.py--performance --float32 -dcpu -n50 --inductor  --no-skip --dashboard --only xcit_large_24_p8_224   --cold_start_latency --batch_size 1 --threads 1

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

Metadata

Metadata

Assignees

Labels

module: inductortriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions