-
Notifications
You must be signed in to change notification settings - Fork 538
Optim-wip: Improve ModuleOutputsHook, testing coverage #834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optim-wip: Improve ModuleOutputsHook, testing coverage #834
Conversation
* Added the `_remove_all_forward_hooks` function for easy cleanup and removal of hooks without requiring their handles. * Changed `ModuleOutputHook`'s forward hook function name from `forward_hook` to `module_outputs_forward_hook` to allow for easy removal of only hooks using that hook function. * `ModuleOutputHook`'s initialization function now runs the `_remove_all_forward_hooks` function on targets, and only removes the hooks created by `ModuleOutputHook` to avoid breaking PyTorch. * Added the `_count_forward_hooks` function for easy testing of hook creation & removal functionality. * Added tests for verifying that the 'ghost hook' bug has been fixed, and that the new function is working correctly. * Added tests for `ModuleOutputsHook`. Previously we had no tests for this module.
captum/optim/_core/output_hook.py
Outdated
any hook handles. This lets us clean up & remove any hooks that weren't property | ||
deleted. | ||
|
||
Warning: Various PyTorch modules and systems make use of hooks, and thus extreme |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you mentioned this is dangerous to do because we can remove hooks that we didn't set and the function name happen to be the same. Why don't we remove only the hooks that we set right after we are done with the hook ? I thought that all hooks we be removed here, won't they ?
captum/captum/optim/_core/optimization.py
Line 97 in 6e7f0bd
self.hooks.remove_hooks() |
If there is a bug and we miss some of them, then we should rather make sure that we removed them after optimization is finished.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea was that the problem would be fixed if the individual ran InputOptimization
a second time, but I now realize that this would break the ability to use multiple instances of InputOptimization
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could somehow let the user decide when to run the cleanup code?
[_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in target_modules]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There does not appear to be a way to detect whether or not the hooks are still being used, so letting the user decide if they want to perform this fix is probably the best choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably add this function to the API for users:
def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]) -> None:
"""
Remove any InputOptimization hooks from the specified modules. This may be useful
in the event that something goes wrong in between creating the InputOptimization
instance and running the optimization function, or if InputOptimization fails
without properly removing it's hooks.
Warning: This function will remove all the hooks placed by InputOptimization
instances on the target modules, and thus can interfere with using multiple
InputOptimization instances.
Args:
modules (nn.Module or list of nn.Module): Any module instances that contain
hooks created by InputOptimization, for which the removal of the hooks is
required.
"""
if not hasattr(modules, "__iter__"):
modules = [modules]
# Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions
[_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in modules]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, in your example: does it occur after we intialize opt.InputOptimization(model, loss_fn, image)
second time ?
Do you have a notebook where we can debug the error ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK Just setup / install captum with the optim module, and run this snippet of code to reproduce it:
!git clone https://github.com/progamergov/captum
%cd captum
!git checkout "optim-wip-fix-hook-bug"
!pip3 install -e .
import sys
sys.path.append('/content/captum')
%cd ..
import torch
import captum.optim._core.output_hook as output_hook
def test_bug():
model = torch.nn.Identity()
for i in range(5):
_ = output_hook.ModuleOutputsHook([model])
print(model._forward_hooks.items()) # There will be 5 hooks
test_bug()
The InputOptimization
init function just calls ModuleOutputsHook
, so we can just do the same in order to make it easier to reproduce.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs because we often reuse the same target instance in notebooks for example, and thus the hooks attached to target instance are not removed. I don't think that it's something we can avoid, but we can make users aware of it and provide the option to mitigate it's effects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc-ing @vivekmig - Vivek can help with this issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK I can remove the hook removal functions for this PR, so we can merge it. Then in a later PR we can revisit this minor hook duplication bug as it's a very niche issue at the moment that won't interfere with anything at the moment.
e0b4588
to
1f6c2d5
Compare
These two warnings in the forward hook are run every iteration if their condition was met, and there doesn't appear to be a way to make the only show up once.
The first warning message is triggered if there are multiple loss objectives using the same target, and this problematic as Edit: The first warning was being repeated in the case of duplicate targets. Second edit: This PR should fix the issue with the warning messages being repeated every iteration: #919 |
The |
Added the_remove_all_forward_hooks
function for easy cleanup and removal of hooks without requiring their handles.Changed
ModuleOutputHook
's forward hook function name fromforward_hook
tomodule_outputs_forward_hook
to allow for easier testing and detection of the hook function by name.Added the
_count_forward_hooks
function for easy testing of hook creation & removal functionality.Added tests for
ModuleOutputsHook
. Previously we had no tests for this module.Currently there a sort of 'bug' where unused forward hooks aren't removed. Every time an instance ofI removed these changes from this PR.InputOptimization
is created, global hooks are created inModuleOutputsHook()
's initialization function for the target modules. These hooks remain untilInputOptimization.optimze
has finished running orInputOptimization.cleanup()
has been run. This means that there is the potential for unused 'ghost' hooks to remain that can interfere with things like TorchScript / JIT as these unused hooks technically still exist.To resolve this issue, my fix checks for the presence of forward hooks and then removes them (without requiring hook handles), so that every time an instance ofInputOptimization
is created, and by extensionModuleOutputsHook
is run, old hooks are removed.I created a PyTorch feature request here for creating a proper API that can be used to solve this issue, and in the meantime the devs said that my solution is fine: pytorch/pytorch#70455