Skip to content

[Inductor] Joint graph ConstantFolder error:  #103924

@yanboliang

Description

@yanboliang

🐛 Describe the bug

Repro:

import torch
import torch._dynamo
import logging
import torch.nn.functional as F
import numpy as np

torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, inductor=logging.DEBUG)

class MyModule(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.temperature = 1
        self.layer = torch.nn.Softmax(dim=1)

    def forward(self, x):
        n_samples, _ = x.shape
        y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device)
        inp = x / y[..., None]
        return self.layer(inp)


x = torch.rand([4, 4])

m = MyModule()
print(m(x))

opt_m = torch.compile(backend="inductor")(m)
print(opt_m(x))

Error:

Traceback (most recent call last):
  File "/scratch/ybliang/work/repos/debug/debug7.py", line 29, in <module>
    print(opt_m(x))
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 295, in _fn
    return fn(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 448, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 430, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 415, in transform
    tracer.run()
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2026, in run
    super().run()
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 708, in run
    and self.step()
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 668, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2114, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 763, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/scratch/ybliang/work/env/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 859, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 915, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 911, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/scratch/ybliang/work/repos/pytorch/torch/__init__.py", line 1530, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/compile_fx.py", line 912, in compile_fx
    return aot_autograd(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 3713, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 3252, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 2068, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 2248, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 1522, in aot_dispatch_base
    compiled_fw = compiler(fw_module, flat_args)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/compile_fx.py", line 799, in fw_compiler_base
    joint_graph_passes(model)
  File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 195, in joint_graph_passes
    constant_fold_uniform_value(graph)
  File "/scratch/ybliang/work/env/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 184, in constant_fold_uniform_value
    remove_no_ops(gm, zeros, ones)
  File "/scratch/ybliang/work/env/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 92, in remove_no_ops
    replace_no_op(node, replace_input_index)
  File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 48, in replace_no_op
    if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'float' object has no attribute 'meta'

Versions

N/A

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions