Skip to content

Optim-wip: Improve ModuleOutputsHook, testing coverage #834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 17, 2022

Conversation

ProGamerGov
Copy link
Contributor

@ProGamerGov ProGamerGov commented Dec 31, 2021

  • Added the _remove_all_forward_hooks function for easy cleanup and removal of hooks without requiring their handles.

  • Changed ModuleOutputHook's forward hook function name from forward_hook to module_outputs_forward_hook to allow for easier testing and detection of the hook function by name.

  • Added the _count_forward_hooks function for easy testing of hook creation & removal functionality.

  • Added tests for ModuleOutputsHook. Previously we had no tests for this module.

Currently there a sort of 'bug' where unused forward hooks aren't removed. Every time an instance of InputOptimization is created, global hooks are created in ModuleOutputsHook()'s initialization function for the target modules. These hooks remain until InputOptimization.optimze has finished running or InputOptimization.cleanup() has been run. This means that there is the potential for unused 'ghost' hooks to remain that can interfere with things like TorchScript / JIT as these unused hooks technically still exist. I removed these changes from this PR.

To resolve this issue, my fix checks for the presence of forward hooks and then removes them (without requiring hook handles), so that every time an instance of InputOptimization is created, and by extension ModuleOutputsHook is run, old hooks are removed.

I created a PyTorch feature request here for creating a proper API that can be used to solve this issue, and in the meantime the devs said that my solution is fine: pytorch/pytorch#70455

* Added the `_remove_all_forward_hooks` function for easy cleanup and removal of hooks without requiring their handles.

* Changed `ModuleOutputHook`'s forward hook function name from `forward_hook` to `module_outputs_forward_hook` to allow for easy removal of only hooks using that hook function.

* `ModuleOutputHook`'s initialization function now runs the `_remove_all_forward_hooks` function on targets, and only removes the hooks created by `ModuleOutputHook` to avoid breaking PyTorch.

* Added the `_count_forward_hooks` function for easy testing of hook creation & removal functionality.

* Added tests for verifying that the 'ghost hook' bug has been fixed, and that the new function is working correctly.

* Added tests for `ModuleOutputsHook`. Previously we had no tests for this module.
any hook handles. This lets us clean up & remove any hooks that weren't property
deleted.

Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
Copy link
Contributor

@NarineK NarineK Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned this is dangerous to do because we can remove hooks that we didn't set and the function name happen to be the same. Why don't we remove only the hooks that we set right after we are done with the hook ? I thought that all hooks we be removed here, won't they ?

self.hooks.remove_hooks()

If there is a bug and we miss some of them, then we should rather make sure that we removed them after optimization is finished.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was that the problem would be fixed if the individual ran InputOptimization a second time, but I now realize that this would break the ability to use multiple instances of InputOptimization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could somehow let the user decide when to run the cleanup code?

[_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in target_modules]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There does not appear to be a way to detect whether or not the hooks are still being used, so letting the user decide if they want to perform this fix is probably the best choice.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably add this function to the API for users:

def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]) -> None:
    """
    Remove any InputOptimization hooks from the specified modules. This may be useful
    in the event that something goes wrong in between creating the InputOptimization
    instance and running the optimization function, or if InputOptimization fails
    without properly removing it's hooks.

    Warning: This function will remove all the hooks placed by InputOptimization
    instances on the target modules, and thus can interfere with using multiple
    InputOptimization instances.

    Args:

        modules (nn.Module or list of nn.Module): Any module instances that contain
            hooks created by InputOptimization, for which the removal of the hooks is
            required.
    """
    if not hasattr(modules, "__iter__"):
        modules = [modules]
    # Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions
    [_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in modules]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, in your example: does it occur after we intialize opt.InputOptimization(model, loss_fn, image) second time ?

Do you have a notebook where we can debug the error ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK Just setup / install captum with the optim module, and run this snippet of code to reproduce it:

!git clone https://github.com/progamergov/captum
%cd captum
!git checkout "optim-wip-fix-hook-bug"
!pip3 install -e .
import sys
sys.path.append('/content/captum')
%cd ..
import torch
import captum.optim._core.output_hook as output_hook

def test_bug():
    model = torch.nn.Identity()
    for i in range(5):
        _ = output_hook.ModuleOutputsHook([model])
    print(model._forward_hooks.items()) # There will be 5 hooks

test_bug()

The InputOptimization init function just calls ModuleOutputsHook, so we can just do the same in order to make it easier to reproduce.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs because we often reuse the same target instance in notebooks for example, and thus the hooks attached to target instance are not removed. I don't think that it's something we can avoid, but we can make users aware of it and provide the option to mitigate it's effects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc-ing @vivekmig - Vivek can help with this issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK I can remove the hook removal functions for this PR, so we can merge it. Then in a later PR we can revisit this minor hook duplication bug as it's a very niche issue at the moment that won't interfere with anything at the moment.

@ProGamerGov ProGamerGov force-pushed the optim-wip-fix-hook-bug branch from e0b4588 to 1f6c2d5 Compare January 18, 2022 16:55
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Mar 31, 2022

These two warnings in the forward hook are run every iteration if their condition was met, and there doesn't appear to be a way to make the only show up once.

            if self.outputs[module] is None:
                self.outputs[module] = output
            else:
                warn(
                    f"Hook attached to {module} was called multiple times. "
                    "As of 2019-11-22 please don't reuse nn.Modules in your models."
                )
            if self.is_ready:
                warn(
                    "No outputs found from models. This can be ignored if you are "
                    "optimizing on inputs only, without models. Otherwise, check "
                    "that you are passing model layers in your losses."
                )

The first warning message is triggered if there are multiple loss objectives using the same target, and this problematic as NaturalImage and the transforms can sometimes have multiple objectives using them as a target. The second warning message is just pointing out redundant information.

Edit: The first warning was being repeated in the case of duplicate targets.

Second edit:

This PR should fix the issue with the warning messages being repeated every iteration: #919

@ProGamerGov ProGamerGov changed the title Optim-wip: Improve ModuleOutputsHook, testing coverage, & fix bug Optim-wip: Improve ModuleOutputsHook, testing coverage Apr 6, 2022
@ProGamerGov ProGamerGov requested a review from NarineK April 8, 2022 19:21
@ProGamerGov
Copy link
Contributor Author

The test_py36_pip_torch_1_6 test passed, but GitHub didn't update it to a green check mark for some reason.

@NarineK NarineK merged commit 4e5e50c into pytorch:optim-wip May 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants