Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add svi part iv tutorial #2770

Merged
merged 7 commits into from
Mar 18, 2021
Merged

add svi part iv tutorial #2770

merged 7 commits into from
Mar 18, 2021

Conversation

martinjankowiak
Copy link
Collaborator

@martinjankowiak martinjankowiak commented Feb 20, 2021

as long promised... nbviewer link

@martinjankowiak
Copy link
Collaborator Author

this is very much a work in progress rough draft, but if anyone has suggestions for tips/tricks i should cover please feel free to make a suggestion @fritzo @eb8680 @fehiepsi @ordabayevy

@fritzo
Copy link
Member

fritzo commented Feb 22, 2021

Looks great! Here are some comments:

Comments on existing sections

  1. Consider mentioning nan checks like assert loss < math.inf?
    Also mention learning rate decay ClippedAdam({'lrd': 0.1 ** (1 / num_steps)}).

  2. 👍

  3. To ensure all constraints are available, let's encourage:

    - from torch.distributions import constraints
    + from pyro.distributions import constraints
  4. Consider leaning a little farther towards AutoGuide, e.g. label this section

    Before you build a custom guide, try a simple AutoDelta or AutoNormal
    As a rule of thumb, start with numerically stable and progress towards more accurate guides. We recommend the sequence: AutoDelta -> AutoNormal -> AutoLowRankMultivariateNormal -> a custom guide.

    (Also nit: combine the two section 4.s)

    • nit: "catastrophic failure" -> "failing catastrophically"
    • mention the init_scale and init_loc_fn parameters to autoguides
    • mention that in simple models, data-dependent initialization can dramatically speed up convergence and improve stability.
  5. Aside: great point on normalization; I often normalize the final loss simply for printing purposes even though it doesn't aid numerical stability. I wonder if we should add a loss_scale option to SVI to make this easier?

  6. Consider making this a tad more precise: "Scales matter" -> "Scales of numbers matter". I mention this because many distributions also have scale parameters with technical meaning.

Additional material

  • Tensor shapes are hard.
    If you see shape errors, first read the tensor shapes tutorial.
  • ClippedAdam can be more stable than Adam (if this hasn't already been covered)
  • Consider using poutine.reparam to improve geometry
    Coordinate-wise optimizers like Adam can have difficulty with the geometry of some models, especially models with chains of dependencies among latent variables. Pyro offers an effect handler poutine.reparam together with a number of reparametrization strategies to rewrite models in manners that preserve their posterior distributions but change geometry in ways that can ease optimization. For example, to improve optimization in time-series models, we recommend experimenting with HaarReparam. These model transforms change the latent variables and so require changes to your guide; however if you're already using an AutoGuide, no change will be necessary.

@eb8680
Copy link
Member

eb8680 commented Feb 26, 2021

This is great!

Some comments on some existing sections:

  • Introduction: add a plug and link for the forum e.g. "If you’re having trouble finding or understanding anything here, please don’t hesitate to ask a question on our forum!"
  • Learning rates: add a sentence explicitly justifying/recommending Adam ("try optim.Adam by default because ..."), point to schedulers
  • Autoguides/custom guides: add a sentence at the end of one of these sections mentioning easyguide with a link to the easyguide tutorial

Some suggestions for new sections:

  • as @fritzo mentioned, enabling validation (pyro.enable_validation(True) and *ELBO(strict_enumeration_warning=True)) to help catch silent shape/support errors or track down nans - this is covered somewhat in the tensor shape tutorial but probably worth teasing here and telling people to go read the tensor shape and enumeration tutorials
  • basic tools for reducing variance: use more particles via num_particles=..., increase batch sizes, use gradient clipping or learning rate schedulers/decay, use enumeration or TraceMeanField_ELBO if applicable, try annealing the prior (with link to DMM example), etc.
  • perhaps out of scope for this tutorial, but you could suggest a high-level pattern for code organization e.g. making model and guide separate PyroModules and combining them in a top-level nn.Module with methods for training and prediction

@ordabayevy
Copy link
Member

It looks great!

Some practical suggestions:

  1. Checkpointing. This is mentioned in https://pyro.ai/examples/dmm.html for nn.Module. The snippet of code on how to do it using ParamStore:
def save_checkpoint(path_param, path_optim):
    # save only if there are no NaN or Inf values
    for k, v in pyro.get_param_store().items():
        if torch.isnan(v).any() or torch.isinf(v).any():
            raise ValueError(f"Detected NaN or Inf values in {k}")

    # save parameters and optimizer state
    pyro.get_param_store().save(path_param)
    self.optim.save(path_optim)

def load_checkpoint(path_param, path_optim):
    pyro.clear_param_store()
    pyro.get_param_store().load(path_param, map_location=self.device)
    self.optim.load(path_optim)
  1. Brief guide on how to implement new distributions. This might deserve its own tutorial on how to implement & test new distributions. But maybe at least mention that Pyro distributions need to subclass TorchDistribution or TransformedDistribution and point to TorchDistribution which has a small section on how to implement Pyro compatible distributions.

@fritzo
Copy link
Member

fritzo commented Feb 26, 2021

Brief guide on how to implement new distributions

I'm vaguely planning to created a dedicated tutorial on implementing custom distributions. Before that, I'd like to create more automated testing tools along the lines of goftests #2658.

@martinjankowiak
Copy link
Collaborator Author

@fritzo @eb8680 @ordabayevy thanks for your feedback! i've tried to incorporate what seemed in scope (basically things related to optimization).

this is now ready for review. please check links and let me know if you can think of additional content that might be worth adding that's in scope

@fritzo
Copy link
Member

fritzo commented Mar 9, 2021

Final comments:

  • 3. Add a comment on gamma like
    gamm = 0.1  # final learning rate will be 0.1 * initial learning rate
  • 4. Add a line comment on bad_guide() like
    pyro.sample("x", dist.Normal(loc, 1.0))  # Normal may sample x < 0
  • 8. Mention that "To use vectorized_particles=True you may need to ensure your model and guide support batching. See the tensor shapes tutorial for best practices."

@martinjankowiak martinjankowiak changed the title wip: add svi part iv tutorial add svi part iv tutorial Mar 9, 2021
@martinjankowiak
Copy link
Collaborator Author

thanks @fritzo addressed your comments. this should be good to merge unless @eb8680 has additional feedback

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we should push this to the website as soon as it's merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants