-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Description
🚀 The feature, motivation and pitch
To get a top-down performance analysis, we'd like to know how the time is consumed in different components of the model. Currently, we are able to know the forward time of the layers, but have no separated time of the bwd time of different layers.
For example, I have a layer (MyLayer) which is the combination of some pytorch standard OPs, see blow. I can add something (for example, record_function, or events at the beginning/end) in the forward function, and I can get the time of forward function. But, I'm unable to know how much time is consumed by the bwd of MyLayer. I can get the OP level profiling from pytorch profiler, but it is not easy to collect them back into MyLayer (many reasons, for example, there are other layers fwd+bwd may call the same OPs).
class MyLayer(torch.nn.Module):
def __init__(self):
super(MyLayer, self).__init__()
self.mlp_up_proj = nn.Linear(args.hs, 4 * args.hs)
self.mlp_act = nn.GELU(approximate='none')
self.mlp_down_proj = nn.Linear(4*args.hs, args.hs)
def forward(self, x):
y1 = self.mlp_up_proj(x)
y2 = self.mlp_act(y1)
y3 = self.mlp_down_proj(y2)
return y3
So, I tried to change MyLayer as below to show my idea. I know it does not work, it even crashes in the runtime. If we can modify the original MyLayer with least change to make the backward function explicit, we can do all kinds of profiling on it.
# it does not work
class MyLayerImpl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, mlp_up_proj, mlp_act, mlp_down_proj):
with torch.enable_grad():
y1 = mlp_up_proj(x)
y2 = mlp_act(y1)
y3 = mlp_down_proj(y2)
y4 = y3.detach()
ctx.save_for_backward(x, y1, y2, y3)
return y4
@staticmethod
def backward(ctx, dout, *args):
x, y1, y2, y3 = ctx.saved_tensors
din = y3.backward(dout)
return din, None, None, None
class MyLayer(torch.nn.Module):
def __init__(self):
super(MyLayer, self).__init__()
self.mlp_up_proj = nn.Linear(args.hs, 4 * args.hs)
self.mlp_act = nn.GELU(approximate='none')
self.mlp_down_proj = nn.Linear(4*args.hs, args.hs)
def forward(self, x):
return MyLayerImpl.apply(x, self.mlp_up_proj, self.mlp_act, self.mlp_down_proj)
Alternatives
No response
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @robieta @chaekit @aaronenyeshi @nbcsm @guotuofeng @guyang3532 @gaoteng-git @tiffzhaofb @dzhulgakov @davidberard98