Skip to content

[pt2d] module register_pre_forward_hook and register_forward_hook triggered graph break when it's root module #117584

@wanchaol

Description

@wanchaol

🐛 Describe the bug

The issue pops when the hook registered as the "root" module, where we did not inline the hook calls into the graph.
But if the hooks not registered as a root module, the dynamo graph inlining works

import torch
import torch.nn as nn

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(18, 18, bias=False)

    def forward(self, x):
        return self.net1(x)

mod = ToyModel()
mod.register_forward_pre_hook(
    lambda mod, input: input[0] + 1
)

compiled_mod = torch.compile(mod)
compiled_mod(torch.rand(18, 18))

TORCH_LOGS="graph" python test_compile.py

outputs:

[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_0 <eval_with_key>.0 opcode         name        target                   args             kwargs
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] -------------  ----------  -----------------------  ---------------  --------
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_input_0_  L_input_0_               ()               {}
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function  add         <built-in function add>  (l_input_0_, 1)  {}
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] output         output      output                   ((add,),)        {}
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] 
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_1 <eval_with_key>.36 opcode       name            target          args                  kwargs
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] -----------  --------------  --------------  --------------------  --------
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder  l_x_            L_x_            ()                    {}
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] call_module  l__self___net1  L__self___net1  (l_x_,)               {}
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] output       output          output          ((l__self___net1,),)  {}

Versions

main branch

cc @H-Huang @awgu @kwen2501 @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @XilunWu @tianyu-l @yf225 @ezyang @anijain2305 @rohan-varma @msaroufim @bdhirsh @zou3519 @aakhundov

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions