Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model defined gradients in NUTS and HMC #3211

Open
ConnorStoneAstro opened this issue May 13, 2023 · 3 comments
Open

Model defined gradients in NUTS and HMC #3211

ConnorStoneAstro opened this issue May 13, 2023 · 3 comments

Comments

@ConnorStoneAstro
Copy link

For score based diffusion models in machine learning we have a function which computes the gradient of the potential (score) that is learned. It would be really nice if it was possible for HMC and NUTS to be able to optionally provide a score_fn or a potential_fn.

I imagine this may be nice for very simple models as well, when the user may be able to just write out the gradient manually. I think this would be a pretty simple fix (an extra option at instantiation of the sampler objects). I might be able to do it myself, but I'm creating an issue here to see if someone has a smarter idea. Also if there is opposition to such an option then I won't put in the work to make a PR.

@martinjankowiak
Copy link
Collaborator

see docs:

class HMC(model=None, potential_fn=None, ...)

@fritzo
Copy link
Member

fritzo commented May 13, 2023

Hi @ConnorStoneAstro, Assuming your nn.Module parameters are frozen, you should be able to wrap them in a custom torch.autograd.Function that you can pass as the potential_fn to Pyro's HMC:

class CustomPotential(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, forward_module, backward_module):
        ctx.backward_module = backward_module
        output = forward_module(input)
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors
        grad_input = ctx.backward_module(grad_output)
        return grad_input, None, None

potential_fn = CustomPotential(forward_module, backward_module)

@ConnorStoneAstro
Copy link
Author

Hi @fritzo thanks for the idea! That does seem easier!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants