New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with backward hook function #598

Open
ludc opened this Issue Jan 26, 2017 · 13 comments

Comments

10 participants
@ludc

ludc commented Jan 26, 2017

Hi,

there is something strange in the backward step (or maybe something I don't understand). If I define a Module that takes 3 inputs, the grad_input has to be of size 3, right ? But this is not the case here (from the backward_hook point of view):

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

def bh(m,go,gi):
    print("Grad Input")
    print(go)
    print("Grad Output")
    print(gi)

class M(nn.Module):
    def __init__(self):
        super(M,self).__init__()
        self.register_backward_hook(bh)


    def forward(self,x,y,z):
        return (x+y+z)

x=Variable(torch.randn(1,5),requires_grad=True)
y=Variable(torch.randn(1,5),requires_grad=True)
z=Variable(torch.randn(1,5),requires_grad=True)

criterion=nn.MSELoss()
mod=M()
out=mod(x,y,z)
loss=criterion(out,Variable(torch.randn(1,5)))
loss.backward()```

In that case, when I print grad_input throught the hook function, it is just composed of two elements... Could you tell me where am I wrong ? But `x.grad, y.grad and z.grad` seem correctly computed
@apaszke

This comment has been minimized.

Member

apaszke commented Jan 26, 2017

Ok, so the problem is that module hooks are actually registered on the last function that the module has created. In your case x + y + z is computed as ((x + y) + z) so the hook is registered on that (_ + z) operation, and this is why you're getting only two grad inputs.

We'll definitely have to resolve this but it will need a large change in the autograd internals. However, right now @colesbury is rewriting them to make it possible to have multiple functions dispatched in parallel, and they would heavily conflict with his work. For now use only Variable hooks (or module hooks, but not on containers). Sorry!

@apaszke apaszke added on hold and removed high priority labels Jan 26, 2017

@karandwivedi42

This comment has been minimized.

Contributor

karandwivedi42 commented Jun 28, 2017

@apaszke Is the refactor you referred to this one: #1016? Since that is merged, is this bug solvable now?

@apaszke

This comment has been minimized.

Member

apaszke commented Jun 29, 2017

Yes. We can remove the on hold label, but I don't have any good solutions in mind

@soumith soumith added this to Uncategorized in Issue Status Aug 23, 2017

@soumith soumith removed the on hold label Aug 31, 2017

@soumith soumith added this to correctness/stability in Issue Categories Sep 13, 2017

@shubhamjain0594

This comment has been minimized.

Contributor

shubhamjain0594 commented Dec 6, 2017

@apaszke Any progress or ideas on this issue, about how to solve. I may write a PR if required.

@apaszke

This comment has been minimized.

Member

apaszke commented Dec 12, 2017

No progress on this. It's actually a pretty hard problem to solve properly. I think I'd be ok with disallowing registering hooks on modules that have submodules as a partial workaround.

@whr94621

This comment has been minimized.

whr94621 commented Apr 22, 2018

@apaszke Any progress or ideas on this issue? I also encounter this problem we the output of forward function is the result of torch.chunk.

@apaszke

This comment has been minimized.

Member

apaszke commented Apr 23, 2018

Not yet. It's really hard to solve, and not such a common pitfall. I think we might just disable the function for now.

@t-vi

This comment has been minimized.

Contributor

t-vi commented Apr 27, 2018

So here is a copypaste from the duplicate I filed today.

The pytorch documentation for nn.Module.register_backward_hook says:

The grad_input and grad_output may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input in subsequent computation.

This is not quite accurate:

  • grad_output is indeed the gradient of the loss w.r.t. the layer output. So if you have a layer l and do, say, y = l(x) ; loss = y.sum(); loss.backward(), you get the gradient of loss w.r.t. y. So far so good.
  • What is unexpected to most users is that grad_input are the inputs to the last operation in the layer. For linear layers, this is fairly complete, as the last op is torch.addmm multiplying the input with the weight and adding the bias. For other layers (e.g. do a Sequential, it’ll be the last op of the last layer, the inputs not even remotely related to the sequential layer’s inputs). You can see what will be used by looking at y.grad_fn.

This may be confusing to users, e.g. https://discuss.pytorch.org/t/exact-meaning-of-grad-input-and-grad-output/14186

So improving the docs is one thing, but should we make the hook actually live up to the description?

I see three strategies:

  • Documentation only.
  • The straightforward way of providing input gradients: collect the grad_ins with variable hooks and call the module hook when we have all of them. We loose the ability to return a different gradient.
  • The somewhat convoluted way: If the module has hooks, wrap the module forward in a autograd function - similar to checkpointing. The the variable hook for the output would do the right thing.
@apaszke

This comment has been minimized.

Member

apaszke commented Apr 30, 2018

I think we should just remove this functionality for now (raise an error when register_backward_hook is called). It never worked correctly and is really hard to implement. I don't think the checkpointing way is good in this case, because it can cause certain backward hooks to trigger multiple times, and that's generally not what you want. It also doesn't work with .grad.

@karandwivedi42

This comment has been minimized.

Contributor

karandwivedi42 commented Apr 30, 2018

@apaszke Remove this functionality from all modules or only from Container type modules? If entirely removed, how should we obtain, for example, gradient norm?

@marcoancona

This comment has been minimized.

marcoancona commented Jun 2, 2018

I encountered the same issue. Is there any alternative to replace the gradient of a sequence of operations? Similarly to custom_gradient in Tensorflow 1.8?

What I am trying to do is override the gradient of an entire layer at once. The use-case is to implement some attribution methods as in https://github.com/marcoancona/DeepExplain, by modifying the gradient of the output with respect to the input features.

@tete1030

This comment has been minimized.

tete1030 commented Jul 4, 2018

Could this problem be properly noted in the doc before some solution? Since it has existed so long and misleads many people including me.

@zou3519

This comment has been minimized.

Contributor

zou3519 commented Jul 9, 2018

Could this problem be properly noted in the doc before some solution? Since it has existed so long and misleads many people including me.

Sure, let's do that for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment