import torch from torch.autograd import Function from torch.nn import Module, Parameter class Tracer(Function): @staticmethod def forward(ctx, input, name): ctx.name = name return input @staticmethod def backward(ctx, grad_output): print(ctx.name) return grad_output, None tracer = Tracer.apply class Model(Module): def __init__(self): super(Model, self).__init__() self.a = Parameter(torch.cuda.FloatTensor([1])) self.b = Parameter(torch.cuda.FloatTensor([1])) self.c = Parameter(torch.cuda.FloatTensor([1])) self.d = Parameter(torch.cuda.FloatTensor([1])) self.d.register_hook(lambda grad: print("d hook")) def forward(self): e = tracer(self.a, 'a') + tracer(self.b, 'b') f = tracer(self.c, 'c') + tracer(self.d, 'd') g = tracer(e, 'e') + tracer(f, 'f') g.register_hook(lambda grad: print("g hook")) loss = tracer(g, 'g').sum() return loss model = Model() """ During the first iteration, "g hook" and "d hook" are printed in the expected places: the reverse of where g and d were used in the forward pass. During the second iteration, "g hook" is still printed in the expected place, because g is a temporary. However, "d hook" is deferred to the end of the backward pass, because d's accumulate has lower priority than all the most recent (second-iteration) autograd ops. """ for i in (0,1): print("\nIteration {}".format(i)) model.zero_grad() loss = model() loss.backward() """ Observed Output: Iteration 0 g g hook f e d d hook c b a Iteration 1 g g hook f e d c b a d hook """