New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Module hook fix and improvements #12573

Closed
wants to merge 6 commits into
base: master
from

Conversation

Projects
None yet
6 participants
@albanD
Copy link
Collaborator

albanD commented Oct 11, 2018

This PR adds the following:

  • Allow forward pre hooks and foward hooks to modify the input by returning a not-None value. Some checks are done here to make sure the returned value is similar to the input.
  • Fix the backward hooks for the case where all inputs and outputs are Tensors. Raise an error if any input or output is not a Tensor.

Do we actually want to check the returned values from the forward hooks?

Adding support for Module taking other things than Tensors (and giving None gradient for these) is left to be done later.
This is a breaking change with the current nn.Module backward hooks. But no one should be using them as they are heavily broken atm anyway.

@ezyang

This comment has been minimized.

Copy link
Contributor

ezyang commented Oct 11, 2018

Hey @albanD, could you rebase this on master? That will get CircleCI tests working for you.

@albanD albanD force-pushed the albanD:module_hook branch 2 times, most recently from 10d3448 to 8979085 Oct 11, 2018

@ezyang ezyang added the bc-breaking label Oct 17, 2018

@ezyang

This comment has been minimized.

Copy link
Contributor

ezyang commented Oct 17, 2018

Sorry, this might take a little time to review. @apaszke, do you think you have time to look at this? Otherwise I'll try to squeeze some in later.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot left a comment

ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

def hook(grad_input, _):
return self.user_hook(self.module, grad_input, self.grad_output)
# Make error message more user-friendly
hook.__name__ = self.user_hook.__name__

This comment has been minimized.

@SsnL

SsnL Oct 17, 2018

Contributor

use @functools.wraps :)

@@ -21,6 +21,16 @@ def backward(ctx, grad_output):
return grad_output.type(ctx.input_type), None


class Noop(Function):

This comment has been minimized.

@SsnL

SsnL Oct 17, 2018

Contributor

@colesbury How does Python autograd function work with our View logic? My guess is that this function is similar to an input[:] except that the output only shares storage but not has ViewImpl (a non-differentiable view if you use the terminology in #12502 ). I'm mainly worried that rebase_history of base will move this Function around so that the hooks registers on this is not the accurate grad_input/output. I'll need to think about this more.

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

I think it's fine. We make it into a new view and add this thing as its grad_fn. My only concern is that it will get dropped if someone writes to the base variable (the view's grad fn becomes AsStridedBackward in that case).

@albanD albanD force-pushed the albanD:module_hook branch from 3826d96 to 78eae14 Oct 17, 2018

@ezyang

This comment has been minimized.

Copy link
Contributor

ezyang commented Oct 17, 2018

I'm not sure why, but this is not triggering CircleCI tests. CC @yf225

@facebook-github-bot
Copy link
Contributor

facebook-github-bot left a comment

ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@albanD

This comment has been minimized.

Copy link
Collaborator

albanD commented Oct 17, 2018

@ezyang CircleCI tests were failing because weird git related issues when it was pulling the branch. So I rebased on top of master. Not sure what's happening :/

@yf225

This comment has been minimized.

Copy link
Contributor

yf225 commented Oct 17, 2018

@albanD there are some CircleCI changes since your rebase, and we might need to rebase again or pull master into your branch to have it working

@yf225

This comment has been minimized.

Copy link
Contributor

yf225 commented Oct 17, 2018

@ezyang it seems that CircleCI has history of it (https://circleci.com/gh/pytorch/workflows/pytorch/tree/pull%2F12573), but the statuses don't show up here. I just submitted a ticket to CircleCI.

@albanD albanD force-pushed the albanD:module_hook branch from 78eae14 to 1c2986a Oct 17, 2018

@albanD

This comment has been minimized.

Copy link
Collaborator

albanD commented Oct 17, 2018

@yf225 I did the rebase, seems to work now !

@apaszke

This comment has been minimized.

Copy link
Member

apaszke commented Oct 17, 2018

@ezyang yes, I'd like to take a look at this to make sure everything is sound. I should have some time tomorrow in the evening.

@apaszke
Copy link
Member

apaszke left a comment

I have some suggestions for improvements, but the PR generally looks ok. It's definitely a step in a good direction compared to the fully broken behavior we have today, but I'm worried that it might still fail in cases like this one (that's related to resetting of grad_fns of views):

input = torch.randn(2, 5, requires_grad=True)
module = nn.Linear(5, 10)
module.register_backward_hook(lambda *args: print('Hook!')
output = module(input)
# Now do either
# input.mul_(2)
# output.mul_(2)
output.sum().backward()
@@ -635,13 +635,26 @@ def bw_hook(inc, h_module, grad_input, grad_output):
test_fwd.remove()
test_bwd.remove()

def test_hook_backward_size(self):
module = nn.Linear(5, 10)

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

It would be better to test this with a module that's guaranteed to contain at least two autograd ops, to make sure that we slice whole subgraphs correctly

@@ -21,6 +21,16 @@ def backward(ctx, grad_output):
return grad_output.type(ctx.input_type), None


class Noop(Function):

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

I think it's fine. We make it into a new view and add this thing as its grad_fn. My only concern is that it will get dropped if someone writes to the base variable (the view's grad fn becomes AsStridedBackward in that case).


@staticmethod
def backward(ctx, *grad_outputs):
return Noop.apply(*grad_outputs)

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

We don't support double-backwards hooks, so is that really useful? Can't we just return grad_outputs?

@@ -20,6 +20,43 @@ def _addindent(s_, numSpaces):
return s


def _check_same_shape(base, new, caller_name):

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

This has pretty strict assumptions (most importantly that all inputs to a module are tensors or flat tuples of tensors). This is completely unnecessary in the case of e.g. forward hooks, and is likely to break backward compatibility.

This comment has been minimized.

@albanD

albanD Oct 19, 2018

Collaborator

This is less strict than that actually.
It only checks that each element for *args has the same type and for tensors the same size.
Maybe the function name is misleading. The goal was to make sure that arguments where not completely changed.
But we can drop this if you prefer?

var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
return var.grad_fn

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

This assumes that it will succeed, but it can actually end up throwing an weird error when you try to var[0] sth that doesn't support indexing.

@@ -37,6 +38,24 @@ def __exit__(self, type, value, tb):
self.remove()


class BackwardHook():

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

Inherit from object. Otherwise it will be an old-style class in Python 2.

if len(self._backward_hooks) > 0:
backward_hooks = self._get_backward_hooks()

self._validate_backward_hook_args("input", input)

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

This completely ignores kwargs. We should at least warn that they will be ignored if they are present.

backward_hooks = []
for user_hook in self._backward_hooks.values():
backward_hooks.append(hooks.BackwardHook(self, user_hook))
return backward_hooks

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

This is simply

return [hooks.BackwardHook(self, user_hook) for user_hook in self._backward_hooks.values()]
self._validate_backward_hook_args("input", input)
input = Noop.apply(*input)

noop_fn = _get_prev_function(input)

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

Can we write an internal alternative to .apply, which also returns the function object as the second value (e.g. _apply_ret_self)? That would be much more robust.

hook(self, input)
hook_result = hook(self, input)
if hook_result is not None:
_check_same_shape(input, hook_result, "forward pre hook '{}'".format(hook))

This comment has been minimized.

@apaszke

apaszke Oct 18, 2018

Member

I really think that those checks are unnecessary. If someone wants to modify those inputs, then let's let them do it. It's forward code so it should be relatively easy to debug.

This comment has been minimized.

@albanD

albanD Oct 19, 2018

Collaborator

I'm happy to drop these.

@albanD albanD force-pushed the albanD:module_hook branch from ea1ad8a to 59c5202 Oct 31, 2018

@albanD

This comment has been minimized.

Copy link
Collaborator

albanD commented Nov 3, 2018

This strategy is promising but cause a lot of problems if inputs or outputs are modified inplace. In particular, hooks are not called anymore.
The behavior of hooks will need to be cleaned up before we can make this possible, closing this PR in the meantime.

@albanD albanD closed this Nov 3, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment