Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo doesn't log backward graphs in compilation metrics #125313

Closed
yanboliang opened this issue May 1, 2024 · 3 comments
Closed

Dynamo doesn't log backward graphs in compilation metrics #125313

yanboliang opened this issue May 1, 2024 · 3 comments
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yanboliang
Copy link
Contributor

yanboliang commented May 1, 2024

馃悰 Describe the bug

Repro:

import torch
import torch.nn.functional as F

class BasicModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.scale = torch.randn(1, 10)

    def forward(self, x):
        x = self.linear1(x)
        torch._dynamo.graph_break()
        return F.relu(x) * self.scale

x = torch.ones(1, 10)
mod = BasicModule()
mod = torch.compile(mod)
print(mod(x))

mod(x).sum().backward()

with the following changes:

diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 7f3b2ec95f..b98dabb6f0 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -420,6 +420,7 @@ def compile_fx_inner(
     layout_opt: Optional[bool] = None,
     extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
 ) -> Union[CompiledFxGraph, str]:
+    print("running compile_fx_inner")
     """
     Inductor API that compiles a single graph.
 
diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py
index 3da8bc2186..7dcd65e4f5 100644
--- a/torch/_utils_internal.py
+++ b/torch/_utils_internal.py
@@ -92,7 +92,7 @@ def signpost_event(category: str, name: str, parameters: Dict[str, Any]):
 
 
 def log_compilation_event(metrics):
-    log.info("%s", metrics)
+    print("%s", metrics)

Output:

running compile_fx_inner
%s CompilationMetrics(frame_key='1', co_name='forward', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=10, cache_size=0, accumulated_cache_size=0, guard_count=11, shape_env_guard_count=0, graph_op_count=1, graph_node_count=3, graph_input_count=1, start_time=1716076212.7113512, entire_frame_compile_time_s=6.625730991363525, backend_compile_time_s=6.587674379348755, inductor_compile_time_s=4.015706300735474, code_gen_time_s=3.9247825145721436, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons={"'skip function graph_break in file /home/ybliang/local/pytorch/torch/_dynamo/decorators.py'"}, dynamo_time_before_restart_s=0.0244600772857666, has_guarded_code=True)
running compile_fx_inner
%s CompilationMetrics(frame_key='2', co_name='torch_dynamo_resume_in_forward_at_12', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=12, cache_size=0, accumulated_cache_size=0, guard_count=10, shape_env_guard_count=0, graph_op_count=2, graph_node_count=5, graph_input_count=1, start_time=1716076219.3381996, entire_frame_compile_time_s=0.10369133949279785, backend_compile_time_s=0.09059739112854004, inductor_compile_time_s=0.03840017318725586, code_gen_time_s=0.02279829978942871, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons=set(), dynamo_time_before_restart_s=0.0, has_guarded_code=True)
tensor([[-0.5529,  0.0000, -0.1201, -0.0000, -0.0802,  0.1059, -0.1784, -0.0000,
         -0.0401,  0.0000]], grad_fn=<CompiledFunctionBackward>)
running compile_fx_inner
running compile_fx_inner

compile_fx_inner has been called twice, forward and backward respectively. However, we log the compilation metrics only for forward graph. This is because the backward graph Inductor compilation is not captured by convert_frame. This causes confusion when debugging compilation time related issues, since the metrics are not accurate.

Versions

N/A

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@ezyang
Copy link
Contributor

ezyang commented May 1, 2024

Backwards compilation is sometimes lazy. I recently added a structured_trace for AOTAutograd backwards compilation and we can also populate the scuba table with it too, although it must be separate...

@yanboliang
Copy link
Contributor Author

Backwards compilation is sometimes lazy. I recently added a structured_trace for AOTAutograd backwards compilation and we can also populate the scuba table with it too, although it must be separate...

Right, we should populate the backward graph row in Scuba table. But I think we also need to find out a key that can connect the fwd and bwd graph, to tell us graph X's fwd or bwd is the problematic one.

@yanboliang yanboliang self-assigned this May 1, 2024
@ezyang
Copy link
Contributor

ezyang commented May 1, 2024

That's the compile id ;)

@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
petrex pushed a commit to petrex/pytorch that referenced this issue Jun 5, 2024
Fixes pytorch#125313

Compilation metric logs for the code example at pytorch#125313:
```
%s CompilationMetrics(compile_id='0/0', frame_key='1', co_name='forward', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=10, cache_size=0, accumulated_cache_size=0, guard_count=11, shape_env_guard_count=0, graph_op_count=1, graph_node_count=3, graph_input_count=1, start_time=1716247236.6165977, entire_frame_compile_time_s=7.926939964294434, backend_compile_time_s=7.887059926986694, inductor_compile_time_s=4.108498811721802, code_gen_time_s=3.97833514213562, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons={"'skip function graph_break in file /home/ybliang/local/pytorch/torch/_dynamo/decorators.py'"}, dynamo_time_before_restart_s=0.025330543518066406, has_guarded_code=True, is_fwd=True)
%s CompilationMetrics(compile_id='1/0', frame_key='2', co_name='torch_dynamo_resume_in_forward_at_12', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=12, cache_size=0, accumulated_cache_size=0, guard_count=10, shape_env_guard_count=0, graph_op_count=2, graph_node_count=5, graph_input_count=1, start_time=1716247244.544928, entire_frame_compile_time_s=0.10148310661315918, backend_compile_time_s=0.08753013610839844, inductor_compile_time_s=0.03691983222961426, code_gen_time_s=0.022417306900024414, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons=set(), dynamo_time_before_restart_s=0.0, has_guarded_code=True, is_fwd=True)
tensor([[-0.1622, -0.0000, -0.0000,  0.5643, -0.0000,  0.0000, -0.5087,  0.0914,
         -0.0000, -0.0421]], grad_fn=<CompiledFunctionBackward>)
%s CompilationMetrics(compile_id='1/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.026738643646240234, code_gen_time_s=0.016446352005004883, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False)
%s CompilationMetrics(compile_id='0/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.14563536643981934, code_gen_time_s=0.08652091026306152, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False)
```

Pull Request resolved: pytorch#126629
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants