New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feedback about PyTorch register_backward_hook #12331

Open
ezyang opened this Issue Oct 4, 2018 · 8 comments

Comments

Projects
None yet
4 participants
@ezyang
Copy link
Contributor

ezyang commented Oct 4, 2018

Maxim Naumov @mnaumovfb sends in the following feedback:


I thought I would share some feedback from my recent work with register_backward_hook function. I will describe some some inconsistencies in it that may be worth addressing in the future.

Background

PyTorch supports different hook functions, including register_hook, register_forward_hook and register_backward_hook. The former is applied to a tensor variable, while the latter two are applied to a layer module. I will discuss the latest form here. It has the following signature https://pytorch.org/docs/stable/nn.html func(layer, grad_input, grad_output)

Details

It is extremely important to understand the meaning of gradients to register_backwards_hook. Let a layer be defined as

 output z  (grad_output)
     ____|____
    |__layer__|
         |
 input x  (grad_input)

with overall loss E, error gradient dE/dz^(k) and weight gradient dE/dw^(k).
First, let us assume the simplest case: a layer with no bias. Then,

grad_output= [dE/dz^(k)]
grad_input = [dE/dz^(k-1), dE/dw^(k)]

Inconsistencies

It seems that there are some inconsistencies in how gradients for different layers are handled by this function.

  1. Shape
    • in convolution layers the weight gradient has the same shape as the weights
    • in fully connected layers the weight gradient is transpose of the weights
  2. Bias
    • in convolution layers bias gradient are appended: grad_input = [dE/dz^(k-1), dE/dw^(k), dE/db^(k)]
    • in fully connected layers bias gradient are prepended: grad_input = [dE/db^(k), dE/dz^(k-1), dE/dw^(k)]
  3. Batch size > 1
    • in convolution layers bias gradient corresponds to the gradient over the entire batch: grad_input = [dE/dz^(k-1), dE/dw^(k), dE/db^(k)]
    • in fully connected layers bias gradient corresponds to the gradient per data point j=1,...,r in the batch (therefore it needs to be added to get the gradient over the entire batch): grad_input = [[dE/db^(k,1),...,dE/db^(k,r)], dE/dz^(k-1), dE/dw^(k)]

These discrepancies can make handling of different layers, bias and batch sizes quite cumbersome in the code. It would help if they were done more consistently in the future.

@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Oct 5, 2018

I'm not entirely sure if anyone fixed it, but if not, I think that as per #598 , module backward hooks are more broken than it looks, as they just hook into the last operation's output.
The inconsistencies are the result of the modules being investigated consisting of "some simple input transformation" (like transpose, so you immediately recognize the values) and a single final autograd.

The main "fix" I can imagine would be to wrap the module in a single autograd function (similar to checkpointing but with enabling autograd during the first forward), so we can use the backward hook on that. As in the checkpointing case, there may still be limitations of how well it works, but I must admit that I cannot offer insight here without experimentation.

@ezyang

This comment has been minimized.

Copy link
Contributor

ezyang commented Oct 5, 2018

Yeah, I don't have a good enough handle on this part of the code to really know what a plausible technical approach is. If someone wants to look I'd be happy to advise.

@apaszke

This comment has been minimized.

Copy link
Member

apaszke commented Oct 5, 2018

The recursive wrapping is extremely hard to get right. If you take checkpointing as an example, it already doesn't work in grad mode, can possibly greatly increase the complexity of backward, and we've had a few bugs in the implementation anyway. I think we might just need to create a whitelist of modules that can use backward callbacks and forbid them otherwise.

@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Oct 5, 2018

Are you sure that the same limitations apply here? I was under the impression that part of the complexity of checkpointing is due to having two different forwards involved in the backward. While this is the case here in a sense, it is so much less.
I would think that the hooks are particularly interesting for debugging your own, new modules, so I'm not entirely sure that limiting it to certain modules is all that useful.

@apaszke

This comment has been minimized.

Copy link
Member

apaszke commented Oct 5, 2018

Yes, I'm sure the limitations apply here. Slicing autograd subgraphs is a very hard problem which took me over a month to solve for the JIT (and in the end we decided it's too complicated and chose a different approach).

Sure, they might be useful, but they are only harmful if we allow people to use them when they're broken.

@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Oct 13, 2018

How about removing them entirely?

@albanD

This comment has been minimized.

Copy link
Collaborator

albanD commented Oct 13, 2018

#12573 should be the first step in this direction (does not use the same approach as checkpoint).
It fixes the current backward_hook and make sure that non-supported case raise a proper error.

@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Oct 13, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment