-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Repro:
import torch
import torch._dynamo
import logging
import torch.nn.functional as F
import numpy as np
torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, inductor=logging.DEBUG)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.temperature = 1
self.layer = torch.nn.Softmax(dim=1)
def forward(self, x):
n_samples, _ = x.shape
y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device)
inp = x / y[..., None]
return self.layer(inp)
x = torch.rand([4, 4])
m = MyModule()
print(m(x))
opt_m = torch.compile(backend="inductor")(m)
print(opt_m(x))
Error:
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/debug/debug7.py", line 29, in <module>
print(opt_m(x))
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 295, in _fn
return fn(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 448, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 127, in _fn
return fn(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
return _compile(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 430, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 415, in transform
tracer.run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2026, in run
super().run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2114, in RETURN_VALUE
self.output.compile_subgraph(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 763, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/scratch/ybliang/work/env/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 859, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 915, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/output_graph.py", line 911, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/scratch/ybliang/work/repos/pytorch/torch/__init__.py", line 1530, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/compile_fx.py", line 912, in compile_fx
return aot_autograd(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 3713, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 3252, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 2068, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 2248, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/scratch/ybliang/work/repos/pytorch/torch/_functorch/aot_autograd.py", line 1522, in aot_dispatch_base
compiled_fw = compiler(fw_module, flat_args)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/compile_fx.py", line 799, in fw_compiler_base
joint_graph_passes(model)
File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 195, in joint_graph_passes
constant_fold_uniform_value(graph)
File "/scratch/ybliang/work/env/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 184, in constant_fold_uniform_value
remove_no_ops(gm, zeros, ones)
File "/scratch/ybliang/work/env/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 92, in remove_no_ops
replace_no_op(node, replace_input_index)
File "/scratch/ybliang/work/repos/pytorch/torch/_inductor/fx_passes/joint_graph.py", line 48, in replace_no_op
if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'float' object has no attribute 'meta'
Versions
N/A
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78
Metadata
Metadata
Assignees
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module