Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bayesian neural nets #40

Closed
ngoodman opened this issue Jul 8, 2017 · 6 comments
Closed

bayesian neural nets #40

ngoodman opened this issue Jul 8, 2017 · 6 comments

Comments

@ngoodman
Copy link
Collaborator

ngoodman commented Jul 8, 2017

pyro should have a really beautiful, idiomatic way to describe bayesian neural nets.... currently the only way to do it is to construct the model net from raw tensors + tensor math (i.e. not using the predefined nn modules from pytorch). i've been thinking through some options.

here is the one i like best so far. This ony requires the addition of a pyro.random_module(module, prior) helper that intercepts the params of module and samples them with prior instead of registering them as parameters.

Comments on this approach, or alternatives, are welcome! If folks like this, then I or someone can add the helper and an example.

@karalets
Copy link
Collaborator

karalets commented Jul 8, 2017

I like this and agree we need a compact way to write that up, currently it is much too contrived.

Will spend some time on it to play with and see if I have any useful comments.

@eb8680
Copy link
Member

eb8680 commented Jul 8, 2017

I like the idea of pyro.random_module (but hopefully with a shorter, punchier name - pyro.nn.lift?) for lifting a nn.Module to a stochastic function that returns new nn.Modules with parameters sampled from a prior. Here's a slightly different way to generate a guide automatically that seems more Pyronic (?):

def make_guide(fn, sites=None):
    def guide(*args, **kwargs):
        model_trace = poutine.block(poutine.trace(fn))(*args, **kwargs)
        if sites is None:
            sites = {name: name for name in model_trace.keys()}
        for name in model_trace.keys():
            if model_trace[name]["type"] == "sample" and name in sites:
                pyro.sample(sites[name], make_site_posterior(model_trace[name], *args, **kwargs))

    return guide

guide_dist = make_guide(pyro.random_module(mod, prior))

As written this generates mean-field guides, but you can write more sophisticated guides in a similar style.

@eb8680
Copy link
Member

eb8680 commented Jul 8, 2017

Riffing on this some more because I quite like it: there's no reason the parameter-lifting operation has to be nn-specific. Imagine a poutine operation poutine.lift(fn, prior) that overrides each pyro.param call in fn with a pyro.sample call internally using the provided prior. Then for nn.Modules we can just write

pyro.random_module = lambda name, mod, prior: poutine.lift(pyro.module, prior)(name, mod)

but now the same principles, as well as guide generators like the one above, can be applied to any stochastic function that has pyro.param calls.

I'm not completely happy with this structure, though, because conceptually it would be nicer if, like the proposed pyro.random_module, lift(fn, prior) returned a distribution over fn-like callables that could be called to sample a single fn with new values but no more randomness at the original pyro.param sites. I'm not sure how to do this, since different execution traces may contain entirely different sets of pyro.param sites that can only be determined at runtime.

Edit: Ok, I thought about the last problem some more. I don't think there's a way to do that in general, and it's not even a probabilistically coherent request because the joint distribution over fns and traces doesn't factor that way when the appearance of a pyro.param in a trace is determined by pyro.sample or pyro.observe statements in the trace (or else, if it does, it's no longer guaranteed to have a density).

However, suppose we happen to know that all execution traces of a stochastic function fn will contain the same pyro.param sites. In that case the distribution does factor that way and we should in principle be able to create a nn.Module from fn and lift it with pyro.random_module:

class LiftableFunction(nn.Module):
    def __init__(self, fn, *args, **kwargs):
        self.fn = fn
        initial_trace = poutine.block(poutine.trace(fn))(*args, **kwargs)
        for name in initial_trace.keys():
            if initial_trace[name]["type"] == "param":
                # XXX something like this? not exactly correct
                setattr(self, "_weight_" + name, nn.Parameter(pyro.param(name, ...)))

    def forward(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

@ngoodman
Copy link
Collaborator Author

ngoodman commented Jul 10, 2017

i like the idea of a more general "lift" function that promotes params to samples! you're right that it couldn't cleanly separate the new randomness from the original randomness in the fn. it's not totally clear to me if this is an important separation. without that separation the bayesian nn example would look something like:

stoch_classifier = pyro.random_module("classifier", classify, prior) 
class_weights = stoch_classifier.forward(data) #use the net (as ordinary *stochastic* fn)

which is actually simpler! we've basically just upgraded the deterministic (but parameterized) function defined by the module to a stochastic function of the same signature.

btw. the make_guide function defined in the above comment makes sense only if the samples don't affect control flow. (which is true in the module case, but maybe not generally?)

@ngoodman ngoodman assigned ngoodman and jpchen and unassigned ngoodman Jul 12, 2017
@ngoodman
Copy link
Collaborator Author

for future reference, a nice but pretty straightforward use of bayesian rnns: https://arxiv.org/pdf/1704.02798.pdf

@eb8680
Copy link
Member

eb8680 commented Oct 7, 2017

Addressed by #121

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants