Skip to content

make backward function explicit in a layer which is a combination of some ops #107435

@guoyejun

Description

@guoyejun

🚀 The feature, motivation and pitch

To get a top-down performance analysis, we'd like to know how the time is consumed in different components of the model. Currently, we are able to know the forward time of the layers, but have no separated time of the bwd time of different layers.

For example, I have a layer (MyLayer) which is the combination of some pytorch standard OPs, see blow. I can add something (for example, record_function, or events at the beginning/end) in the forward function, and I can get the time of forward function. But, I'm unable to know how much time is consumed by the bwd of MyLayer. I can get the OP level profiling from pytorch profiler, but it is not easy to collect them back into MyLayer (many reasons, for example, there are other layers fwd+bwd may call the same OPs).

class MyLayer(torch.nn.Module):
    def __init__(self):
        super(MyLayer, self).__init__()
        self.mlp_up_proj = nn.Linear(args.hs, 4 * args.hs)
        self.mlp_act = nn.GELU(approximate='none')
        self.mlp_down_proj = nn.Linear(4*args.hs, args.hs)
        
    def forward(self, x):
        y1 = self.mlp_up_proj(x)
        y2 = self.mlp_act(y1)
        y3 = self.mlp_down_proj(y2)
        return y3

So, I tried to change MyLayer as below to show my idea. I know it does not work, it even crashes in the runtime. If we can modify the original MyLayer with least change to make the backward function explicit, we can do all kinds of profiling on it.

# it does not work
class MyLayerImpl(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mlp_up_proj, mlp_act, mlp_down_proj):
        with torch.enable_grad():
            y1 = mlp_up_proj(x)
            y2 = mlp_act(y1)
            y3 = mlp_down_proj(y2)
        y4 = y3.detach()
        ctx.save_for_backward(x, y1, y2, y3)
        return y4

    @staticmethod
    def backward(ctx, dout, *args):
        x, y1, y2, y3 = ctx.saved_tensors
        din = y3.backward(dout)
        return din, None, None, None
    

class MyLayer(torch.nn.Module):
    def __init__(self):
        super(MyLayer, self).__init__()
        self.mlp_up_proj = nn.Linear(args.hs, 4 * args.hs)
        self.mlp_act = nn.GELU(approximate='none')
        self.mlp_down_proj = nn.Linear(4*args.hs, args.hs)
        
    def forward(self, x):
        return MyLayerImpl.apply(x, self.mlp_up_proj, self.mlp_act, self.mlp_down_proj)

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @robieta @chaekit @aaronenyeshi @nbcsm @guotuofeng @guyang3532 @gaoteng-git @tiffzhaofb @dzhulgakov @davidberard98

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nnRelated to torch.nnoncall: profilerprofiler-related issues (cpu, gpu, kineto)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions