-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Open
Labels
dynamo-must-fixThese bugs affect TorchDynamo reliability.These bugs affect TorchDynamo reliability.dynamo-nn-modulesdynamo-triage-june2024module: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2pt2d-triage-nov2024triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
The issue pops when the hook registered as the "root" module, where we did not inline the hook calls into the graph.
But if the hooks not registered as a root module, the dynamo graph inlining works
import torch
import torch.nn as nn
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(18, 18, bias=False)
def forward(self, x):
return self.net1(x)
mod = ToyModel()
mod.register_forward_pre_hook(
lambda mod, input: input[0] + 1
)
compiled_mod = torch.compile(mod)
compiled_mod(torch.rand(18, 18))
TORCH_LOGS="graph" python test_compile.py
outputs:
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] ------------- ---------- ----------------------- --------------- --------
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder l_input_0_ L_input_0_ () {}
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function add <built-in function add> (l_input_0_, 1) {}
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] output output output ((add,),) {}
[2024-01-16 14:27:41,383] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] __compiled_fn_1 <eval_with_key>.36 opcode name target args kwargs
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] ----------- -------------- -------------- -------------------- --------
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder l_x_ L_x_ () {}
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] call_module l__self___net1 L__self___net1 (l_x_,) {}
[2024-01-16 14:27:46,032] [1/0] torch._dynamo.output_graph.__graph: [DEBUG] output output output ((l__self___net1,),) {}
Versions
main branch
cc @H-Huang @awgu @kwen2501 @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @XilunWu @tianyu-l @yf225 @ezyang @anijain2305 @rohan-varma @msaroufim @bdhirsh @zou3519 @aakhundov
Metadata
Metadata
Assignees
Labels
dynamo-must-fixThese bugs affect TorchDynamo reliability.These bugs affect TorchDynamo reliability.dynamo-nn-modulesdynamo-triage-june2024module: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2pt2d-triage-nov2024triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module