Join GitHub today
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.Sign up
Feedback about PyTorch register_backward_hook #12331
Maxim Naumov @mnaumovfb sends in the following feedback:
I thought I would share some feedback from my recent work with register_backward_hook function. I will describe some some inconsistencies in it that may be worth addressing in the future.
PyTorch supports different hook functions, including register_hook, register_forward_hook and register_backward_hook. The former is applied to a tensor variable, while the latter two are applied to a layer module. I will discuss the latest form here. It has the following signature https://pytorch.org/docs/stable/nn.html
It is extremely important to understand the meaning of gradients to register_backwards_hook. Let a layer be defined as
with overall loss E, error gradient dE/dz^(k) and weight gradient dE/dw^(k).
It seems that there are some inconsistencies in how gradients for different layers are handled by this function.
These discrepancies can make handling of different layers, bias and batch sizes quite cumbersome in the code. It would help if they were done more consistently in the future.
I'm not entirely sure if anyone fixed it, but if not, I think that as per #598 , module backward hooks are more broken than it looks, as they just hook into the last operation's output.
The main "fix" I can imagine would be to wrap the module in a single autograd function (similar to checkpointing but with enabling autograd during the first forward), so we can use the backward hook on that. As in the checkpointing case, there may still be limitations of how well it works, but I must admit that I cannot offer insight here without experimentation.
The recursive wrapping is extremely hard to get right. If you take checkpointing as an example, it already doesn't work in
Are you sure that the same limitations apply here? I was under the impression that part of the complexity of checkpointing is due to having two different forwards involved in the backward. While this is the case here in a sense, it is so much less.
Yes, I'm sure the limitations apply here. Slicing autograd subgraphs is a very hard problem which took me over a month to solve for the JIT (and in the end we decided it's too complicated and chose a different approach).
Sure, they might be useful, but they are only harmful if we allow people to use them when they're broken.