-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Description
🐛 Describe the bug
ep2 is an empty submod_1
.
import torch
class SetGradCase(torch.nn.Module):
def forward(self, x):
with torch.no_grad():
y = x * 4
return y
ep = torch.export.export(
SetGradCase(),
(torch.randn(6),),
strict=False,
)
print(ep)
ep2 = torch.export.export(ep.module(), (torch.randn(6),))
print(ep2)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[6]"):
# No stacktrace found for following nodes
submod_3 = self.submod_1
mul = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, x); submod_3 = x = None
# File: /data/users/shangdiy/torchnative/tmp.py:22 in forward, code: y = x * 4
getitem: "f32[6]" = mul[0]; mul = None
return (getitem,)
class submod_1(torch.nn.Module):
def forward(self, x: "f32[6]"):
# File: /data/users/shangdiy/torchnative/tmp.py:22 in forward, code: y = x * 4
mul: "f32[6]" = torch.ops.aten.mul.Tensor(x, 4); x = None
return (mul,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[6]"):
# No stacktrace found for following nodes
submod_4 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4); submod_4 = wrap_with_set_grad_enabled = None
submod_5 = self.submod_2
mul = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_5, x); submod_5 = x = None
# File: <eval_with_key>.12:5 in forward, code: mul = torch.ops.aten.mul.Tensor(x, 4); x = None
getitem: "f32[6]" = mul[0]; mul = None
return (getitem,)
class submod_1(torch.nn.Module):
def forward(self):
return ()
class submod_2(torch.nn.Module):
def forward(self, x: "f32[6]"):
# File: <eval_with_key>.12:5 in forward, code: mul = torch.ops.aten.mul.Tensor(x, 4); x = None
mul: "f32[6]" = torch.ops.aten.mul.Tensor(x, 4); x = None
return (mul,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
Some initial investigation:
the graph after re-tracing and before replacing with hop looks like:
def forward(self, arg0_1):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None
mul = torch.ops.aten.mul.Tensor(arg0_1, 4); arg0_1 = None
_set_grad_enabled_2 = torch._C._set_grad_enabled(True); _set_grad_enabled_2 = None
return (mul,)
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
Versions
main