Skip to content

Add an option to nn.Module.register_forward_hook/torch.nn.modules.module.register_module_forward_hook that guarantees a Module forward (post) hook is always run #103997

@mikaylagawarecki

Description

@mikaylagawarecki

🚀 The feature, motivation and pitch

nn.Module forward hooks are hooks that are run every time after nn.Module.forward() is run. There is also a global registration nn.Module forward hook registration function

There have been recent asks to add hooks that allow you to register a context manager to nn.Module.forward().

Context managers have __enter__ and __exit__ methods defined.

One solution here is for users to

  1. register a forward pre-hook that contains the definitions of __init__ and __enter__ of the context manager
  2. register a forward hook that contains the definition of __exit__ of the context manager

A gap with the above solution is that if nn.Module.forward() or something else before the post hook errors out, the forward (post) hook might not be called. In order to address this, we should add a new kwarg to nn.Module.register_foward_hook that guarantees that the post hook will be run, which ensures that the context will be properly exited.

In contrast to adding the ability to register a context manager hook to nn.Module.forward, this solution offers clearer granularity over the ordering of entering/exiting the context manager with other forward pre-hooks/forward hooks that are registered.

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr

Metadata

Metadata

Labels

module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions