Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement StaticSVI #1562

Closed
wants to merge 5 commits into from
Closed

Implement StaticSVI #1562

wants to merge 5 commits into from

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Nov 23, 2018

This pull request implements StaticSVI, which is an SVI interface for model & guide which does not create new params dynamically. Hence LBFGS works with this inference (address #1519).

@fehiepsi fehiepsi requested a review from eb8680 November 23, 2018 00:46
alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
pyro.sample("p_latent", dist.Beta(alpha_q, beta_q))

adam = optim.Adam({"lr": .001})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test using LBFGS?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo I have added another test for it. Using LBFGS for this model is quite flaky.

@eb8680
Copy link
Member

eb8680 commented Nov 23, 2018

I don't really think this is necessary (in fact, I think we should deprecate SVI entirely, but that's another discussion). In this situation, we should encourage users to write their models as nn.Modules and use Pyro losses and torch.optim optimizers together directly instead of going through the SVI and PyroOptim interfaces:

class MyModel(nn.Module):
    def __init__(self, ...):
        ...

    def model(self, batch):
        ....

    def guide(self, batch):
        ...

model = MyModel(...)

elbo = pyro.infer.Trace_ELBO()
optim = torch.optim.SGD(model.parameters())  # or LBFGS etc
for batch in data:
    optim.zero_grad()
    elbo.loss_and_grads(model.model, model.guide, batch)
    optim.step()

@fehiepsi
Copy link
Member Author

@eb8680 Do you mean that we'll have a Pyro nn.Module which will catch all parameters from model and guide when we call mymodel.parameters()? Without that, users have to know about poutine.trace(param_only=True) to capture all params for pytorch optimizer. In addition, SVI.run() is quite convenient to get samples from posterior. Otherwise, users will have to learn about poutine.trace(...), poutine.replay(...),...

@eb8680
Copy link
Member

eb8680 commented Nov 23, 2018

@fehiepsi module.parameters() is a method of torch.nn.Module that PyTorch users would be familiar with. Most of our tutorials, e.g. the VAE tutorial, already wrap the model and guide as methods of a single nn.Module and never call pyro.param explicitly in the model or guide, but they use pyro.module (which calls module.named_parameters() under the hood) to pass the parameters to the Pyro parameter store via SVI.step. I don't think that extra layer of indirection is buying us much, and I think the snippet above is less opaque and more consistent with PyTorch idioms.

We shouldn't try to fix leaky abstractions by introducing more indirection, we should just deprecate/remove them and do our best to help users write code that's easier to understand. SVI, Trace, ParamStoreDict, ELBO, and PyroOptim are some of the worst offenders. On that note, re: trace and replay for parameter capture and posterior sampling instead of SVI, I actually think it'd be better to do that. We can always provide a couple of tiny wrappers that remove boilerplate without obfuscation.

@fehiepsi
Copy link
Member Author

fehiepsi commented Nov 24, 2018

@eb8680 I took me a while to have a feeling that I understand what you mean. :) Did you mean that we define all the parameters ahead, so we don't need to use ParamStoreDict, PyroOptim? If that is the case, then I am in with the future deprecation of these wrappers (together with SVI of course). I'll try to think about that idiom for gp module btw. But please correct me if my understanding is incorrect.

About this PR, how about moving this to contrib for a while?

@eb8680
Copy link
Member

eb8680 commented Nov 24, 2018

Did you mean that we define all the parameters ahead, so we don't need to use ParamStoreDict, PyroOptim?

Yes, I mean that if all the parameters can be defined ahead of time, we should encourage users to write their models or model/guide pairs as nn.Modules and use torch.optim and standard PyTorch idioms, and not just if they want to use LBFGS but also more generally. We've slowly but steadily accumulated technical debt in our API design (much of it my fault, like making Trace a networkx.Digraph) and we should try to be diligent about adding more.

Here's a rewritten version of the example from your test in this PR to illustrate:

x = 1 + torch.randn(10)

class MyModel(nn.Module):
    def __init__(self):
        mu = nn.Parameter(torch.tensor(0.))
        sigma = nn.Parameter(torch.tensor(1.))

    def forward(self):
        with pyro.plate("plate"):
            return pyro.sample("x", dist.Normal(self.mu, torch.exp(self.sigma)), obs=x)

model = MyModel()

def closure():
    return Trace_ELBO().loss_and_grads(model, lambda: pass)

optim = torch.optim.LBFGS(model.parameters())
for _ in range(100):
    optim.step(closure)

Or version 2:

x = 1 + torch.randn(10)

class MyModel(nn.Module):
    def __init__(self):
        mu = nn.Parameter(torch.tensor(0.))
        sigma = nn.Parameter(torch.tensor(1.))

    def model(self):
        with pyro.plate("plate"):
            return pyro.sample("x", dist.Normal(self.mu, torch.exp(self.sigma)), obs=x)

    def guide(self):
        pass

model = MyModel()

def closure():
    return Trace_ELBO().loss_and_grads(model.model, model.guide)

optim = torch.optim.LBFGS(model.parameters())
for _ in range(100):
    optim.step(closure)

About this PR, how about moving this to contrib for a while?

I'm not sure that's the right type of thing to add to contrib. How about an example that uses torch.optim.LBFGS instead?

@fehiepsi
Copy link
Member Author

For that purpose, I vote to close this PR. I'll use that pattern in a GP tutorial instead. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants