Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchInductor compile with trace.graph_diagram errors out #125263

Open
AmoghDabholkar opened this issue Apr 30, 2024 · 0 comments
Open

TorchInductor compile with trace.graph_diagram errors out #125263

AmoghDabholkar opened this issue Apr 30, 2024 · 0 comments
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@AmoghDabholkar
Copy link

AmoghDabholkar commented Apr 30, 2024

馃悰 Describe the bug

Repro

import torch
import torch._inductor.config as iconfig
iconfig.trace.enabled = True
iconfig.trace.graph_diagram = True

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(100, 100)
    
    def forward(self, x):
        return self.l(x)

m = ToyModel().to(device="cuda:0")
m = torch.compile(m)
input_tensor = torch.randn(100).to(device="cuda:0")
out = m(input_tensor)

Logs

[2024-04-30 17:13:04,165] torch._dynamo.eval_frame: [DEBUG] Saving dynamo config and hash for new compiled object(s). Hash: 2efecedd99d1d731eac3660a5ea99e20
[2024-04-30 17:13:04,169] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* normcase             <frozen posixpath> 52
[2024-04-30 17:13:04,169] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* __instancecheck__             <frozen abc> 117
[2024-04-30 17:13:04,170] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* basename             <frozen posixpath> 140
[2024-04-30 17:13:04,170] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* _get_sep             <frozen posixpath> 41
[2024-04-30 17:13:04,170] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* splitext             <frozen posixpath> 117
[2024-04-30 17:13:04,170] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* _splitext             <frozen genericpath> 121
[2024-04-30 17:13:04,169] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 2efecedd99d1d731eac3660a5ea99e20
[2024-04-30 17:13:04,173] [0/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:12
[2024-04-30 17:13:04,174] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-04-30 17:13:04,176] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:12 in forward (ToyModel.forward)
[2024-04-30 17:13:04,176] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, x):
[2024-04-30 17:13:04,178] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-04-30 17:13:04,178] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:13 in forward (ToyModel.forward)
[2024-04-30 17:13:04,178] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             return self.relu(self.l(x))
[2024-04-30 17:13:04,178] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self []
[2024-04-30 17:13:04,178] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_METHOD relu [LazyVariableTracker()]
[2024-04-30 17:13:04,179] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self [NullVariable, NNModuleVariable()]
[2024-04-30 17:13:04,179] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_METHOD l [NullVariable, NNModuleVariable(), NNModuleVariable()]
[2024-04-30 17:13:04,179] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [NullVariable, NNModuleVariable(), NullVariable, NNModuleVariable()]
[2024-04-30 17:13:04,179] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE PRECALL 1 [NullVariable, NNModuleVariable(), NullVariable, NNModuleVariable(), LazyVariableTracker()]
[2024-04-30 17:13:04,179] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL 1 [NullVariable, NNModuleVariable(), NullVariable, NNModuleVariable(), LazyVariableTracker()]
[2024-04-30 17:13:04,180] [0/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_x_ L['x']
[2024-04-30 17:13:04,180] [0/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['x'] (100,) [<DimDynamic.STATIC: 2>] [None]
[2024-04-30 17:13:04,181] [0/0] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call l__self___l from /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:13 in forward (ToyModel.forward)
[2024-04-30 17:13:04,181] [0/0] torch._dynamo.output_graph.__trace_call: [DEBUG]         return self.relu(self.l(x))
[2024-04-30 17:13:04,181] [0/0] torch._dynamo.output_graph.__trace_call: [DEBUG]                          ~~~~~~^^^
[2024-04-30 17:13:04,188] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE PRECALL 1 [NullVariable, NNModuleVariable(), TensorVariable()]
[2024-04-30 17:13:04,188] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL 1 [NullVariable, NNModuleVariable(), TensorVariable()]
[2024-04-30 17:13:04,188] [0/0] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call l__self___relu from /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:13 in forward (ToyModel.forward)
[2024-04-30 17:13:04,188] [0/0] torch._dynamo.output_graph.__trace_call: [DEBUG]         return self.relu(self.l(x))
[2024-04-30 17:13:04,188] [0/0] torch._dynamo.output_graph.__trace_call: [DEBUG]                ~~~~~~~~~^^^^^^^^^^^
[2024-04-30 17:13:04,190] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [TensorVariable()]
[2024-04-30 17:13:04,190] [0/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2024-04-30 17:13:04,190] [0/0] torch._dynamo.symbolic_convert: [DEBUG] RETURN_VALUE triggered compile
[2024-04-30 17:13:04,190] [0/0] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py, line 13 in forward>], graph_break=False)
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_x_ : torch.Tensor):
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_x_ = L_x_
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:13, code: return self.relu(self.l(x))
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l__self___l = self.L__self___l(l_x_);  l_x_ = None
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l__self___relu = self.L__self___relu(l__self___l);  l__self___l = None
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (l__self___relu,)
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] 
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]  ===== __compiled_fn_0 =====
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]     def forward(self, L_x_ : torch.Tensor):
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         l_x_ = L_x_
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         # File: /home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py:13, code: return self.relu(self.l(x))
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         l__self___l = self.L__self___l(l_x_);  l_x_ = None
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         l__self___relu = self.L__self___relu(l__self___l);  l__self___l = None
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         return (l__self___relu,)
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] 
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] TRACED GRAPH TENSOR SIZES
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] ===== __compiled_fn_0 =====
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_x_: (100,)
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l__self___l: (100,)
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l__self___relu: (100,)
[2024-04-30 17:13:04,191] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] 
[2024-04-30 17:13:04,192] [0/0] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2024-04-30 17:13:05,170] [0/0] torch.fx.experimental.symbolic_shapes: [DEBUG] eval True == True [statically known]
Writing FX graph to file: /home/amodab01/mlrsh-filer2/amodab01/ml-training/torch_compile_debug/run_2024_04_30_17_13_05_077049-pid_3281505/torchinductor/model__0_forward_10.0/graph_diagram.svg
[2024-04-30 17:13:05,196] [0/0] torch._inductor.debug: [WARNING] model__0_forward_10 debug trace: /home/amodab01/mlrsh-filer2/amodab01/ml-training/torch_compile_debug/run_2024_04_30_17_13_05_077049-pid_3281505/torchinductor/model__0_forward_10.0
[2024-04-30 17:13:05,198] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 2efecedd99d1d731eac3660a5ea99e20
Traceback (most recent call last):
  File "/home/amodab01/mlrsh-filer2/amodab01/ml-training/foo.py", line 19, in <module>
    out = m(input_tensor)
          ^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
                       ^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
        ^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2243, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 919, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1087, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1159, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1140, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/__init__.py", line 1668, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1168, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 887, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 600, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 425, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 295, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1100, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/debug.py", line 305, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 320, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 550, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1116, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1066, in compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
                                                             ^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1041, in codegen
    self.scheduler = Scheduler(self.buffers)
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/scheduler.py", line 1250, in __init__
    V.debug.graph_diagram(self.nodes)
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/debug.py", line 469, in graph_diagram
    draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/debug.py", line 92, in draw_buffers
    draw_graph(
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_functorch/partitioners.py", line 935, in draw_graph
    g = graph_drawer.FxGraphDrawer(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: FxGraphDrawer.__init__() got an unexpected keyword argument 'dot_graph_shape'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] Function, Runtimes (s)
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner, 0.0000
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] OutputGraph.call_user_compiler, 0.0000
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] create_aot_dispatcher_function, 0.0636
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] compile_fx.<locals>.fw_compiler_base, 0.0000
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] GraphLowering.run, 0.0255
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] GraphLowering.compile_to_module, 0.0000
[2024-04-30 17:13:05,240] torch._dynamo.utils: [INFO] Scheduler.__init__, 0.0000

This is the output of TORCH_LOGS="+dynamo" python3 foo.py

There's no modification to anything at all and yet it errors out.

Versions

Collecting environment information...
PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.23.2
Libc version: glibc-2.35

Python version: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Quadro RTX 6000
GPU 1: Quadro RTX 6000
GPU 2: Quadro RTX 6000
GPU 3: Quadro RTX 6000

Nvidia driver version: 525.147.05
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i9-9960X CPU @ 3.10GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
Stepping: 4
CPU max MHz: 4500.0000
CPU min MHz: 1200.0000
BogoMIPS: 6199.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 16 MiB (16 instances)
L3 cache: 22 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-31
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] rotary-embedding-torch==0.5.3
[pip3] torch==2.2.2+cu121
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.17.2
[pip3] triton==2.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] rotary-embedding-torch 0.5.3 pypi_0 pypi
[conda] torch 2.2.2+cu121 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.17.2 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@AmoghDabholkar AmoghDabholkar changed the title TorchInductor graph erros out TorchInductor graph errors out Apr 30, 2024
@AmoghDabholkar AmoghDabholkar changed the title TorchInductor graph errors out TorchInductor compile with trace.graph_diagram errors out Apr 30, 2024
@yanboliang yanboliang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: inductor labels Apr 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants