Skip to content

retracing set_grad HOO creates an empty submod #163294

@yushangdi

Description

@yushangdi

🐛 Describe the bug

ep2 is an empty submod_1.

import torch

class SetGradCase(torch.nn.Module):
    def forward(self, x):
        with torch.no_grad():
            y = x * 4
        return y

ep = torch.export.export(
    SetGradCase(),
    (torch.randn(6),),
    strict=False,
)
print(ep)

ep2 = torch.export.export(ep.module(), (torch.randn(6),))
print(ep2)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[6]"):
            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            mul = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, x);  submod_3 = x = None
            
             # File: /data/users/shangdiy/torchnative/tmp.py:22 in forward, code: y = x * 4
            getitem: "f32[6]" = mul[0];  mul = None
            return (getitem,)
            
        class submod_1(torch.nn.Module):
            def forward(self, x: "f32[6]"):
                 # File: /data/users/shangdiy/torchnative/tmp.py:22 in forward, code: y = x * 4
                mul: "f32[6]" = torch.ops.aten.mul.Tensor(x, 4);  x = None
                return (mul,)
                
Graph signature: 
    # inputs
    x: USER_INPUT
    
    # outputs
    getitem: USER_OUTPUT
    
Range constraints: {}

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[6]"):
            # No stacktrace found for following nodes
            submod_4 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4);  submod_4 = wrap_with_set_grad_enabled = None
            submod_5 = self.submod_2
            mul = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_5, x);  submod_5 = x = None
            
             # File: <eval_with_key>.12:5 in forward, code: mul = torch.ops.aten.mul.Tensor(x, 4);  x = None
            getitem: "f32[6]" = mul[0];  mul = None
            return (getitem,)
            
        class submod_1(torch.nn.Module):
            def forward(self):
                return ()
                
        class submod_2(torch.nn.Module):
            def forward(self, x: "f32[6]"):
                 # File: <eval_with_key>.12:5 in forward, code: mul = torch.ops.aten.mul.Tensor(x, 4);  x = None
                mul: "f32[6]" = torch.ops.aten.mul.Tensor(x, 4);  x = None
                return (mul,)
                
Graph signature: 
    # inputs
    x: USER_INPUT
    
    # outputs
    getitem: USER_OUTPUT
    
Range constraints: {}

Some initial investigation:

the graph after re-tracing and before replacing with hop looks like:

def forward(self, arg0_1):
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
    mul = torch.ops.aten.mul.Tensor(arg0_1, 4);  arg0_1 = None
    _set_grad_enabled_2 = torch._C._set_grad_enabled(True);  _set_grad_enabled_2 = None
    return (mul,)

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Versions

main

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions