Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interaction between PyroParam and torch.func.grad #3328

Merged
merged 9 commits into from
Feb 16, 2024

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Feb 15, 2024

Addresses a bug described downstream in BasisResearch/chirho#393

This PR adds a fix for compatibility of PyroModules and PyroParams with torch.func.grad and the other functional automatic differentiation transforms in torch.func. The fix is basically to replace each pyro.param statement or other interaction with the parameter store with a dummy version that does not store and retrieve a parameter tensor from a nonlocal state (which is invisible to the tracing machinery in torch.func).

Without this fix, gradient computations in torch.func.grad do not propagate to the unconstrained parameters behind constrained PyroParams even when using pyro.settings.set(module_local_param=True) and are always zero. After this fix, the functional AD system in torch.func behaves correctly with AutoGuides and other PyroModules when module_local_param=True, though it is still fundamentally incompatible with the global parameter store state when module_local_param=False.

Tested:

  • Added simple regression test that fails without the fix in this PR

@eb8680 eb8680 added this to the 1.9 release milestone Feb 15, 2024
@eb8680
Copy link
Member Author

eb8680 commented Feb 15, 2024

I think I'm missing some edge cases. Marking this WIP while I strengthen the test and see if it still works...

@eb8680
Copy link
Member Author

eb8680 commented Feb 15, 2024

Addressed the rest of the edge cases, should be ready for review

Copy link
Member

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. I just have a question about the effectful.

pyro/nn/module.py Show resolved Hide resolved
@eb8680 eb8680 merged commit 4a55960 into dev Feb 16, 2024
9 checks passed
@eb8680 eb8680 deleted the eb-fix-pyromodule-functorch branch February 16, 2024 20:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants