Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingMCMC class #2857

Merged
merged 5 commits into from
Jun 16, 2021
Merged

StreamingMCMC class #2857

merged 5 commits into from
Jun 16, 2021

Conversation

mtsokol
Copy link
Contributor

@mtsokol mtsokol commented Jun 3, 2021

Hi @fritzo!

Here's a workspace PR for StreamingMCMC, currently only with initial draft (I will rebase all pyro.ops.streaming changes).

I decided to introduce AbstractMCMC to extract a few lines. Initial test for StreamingMCMC only prints correct statistics so it runs.

Right now I'm wondering how to unify those two classes (their methods):

AbstractMCMC MCMC StreamingMCMC
run run run
-- get_samples --
-- diagnostics ? (r_hat is expected to be a statistics)
-- summary (only prints to the console) summary (returns statistics)

If we care about backward compatibility then I think summary in StreamingMCMC could be changed to a get_statistics. If not then summary can be unified to be pure and defined as abstract in AbstractMCMC and implemented by these classes.

Also right now I'm thinking how to rewrite test suite as current implementation is based on get_samples for each test (that's another argument to make summary return statistics instead of printing).

WDYT?

@fritzo
Copy link
Member

fritzo commented Jun 3, 2021

Looks great so far!

If we care about backward compatibility then I think summary in StreamingMCMC could be changed to a get_statistics.

We do care about backwards compatibility. I like the idea of a new method either .get_summary() or .get_statistics(), together with a refactoring that makes the old .summary() call .get_*() under the hood then print it.

Also right now I thinking how to rewrite test suite as current implementation is based on get_samples for each test.

Hmm... one option is to use the new StackStats which simply collects samples

# in a test file:
mcmc = StreamingMCMC(..., statistics=StatsOfDict(default=StackStats))

and use this in .get_samples(). Maybe .get_samples() can return only those samples for which the "samples" key is present in the stats dict returned by .get()? Another option might be to add an interface to get something like (count, mean, variance) statistics from the old non-streaming MCMC class. I defer to your refactoring judgement 😄

@mtsokol mtsokol mentioned this pull request Jun 3, 2021
3 tasks
@fritzo
Copy link
Member

fritzo commented Jun 3, 2021

? (r_hat is expected to be a statistics)

I'll need to think more about computing r_hat and n_eff = ess in streaming, so let's plan to add those to pyro.ops.streaming in a follow-up PR.

@mtsokol
Copy link
Contributor Author

mtsokol commented Jun 3, 2021

Thanks for (chain_id, site["name"]) idea! Now StreamingMCMC is chain-aware and get_statistics(self, group_by_chain=True) could be easily implemented with merge().

Draft test with @pytest.mark.parametrize('group_by_chain', [True, False]) parametrization prints get_statistics call as:

True:

{ (0, 'y'): {'count': 2000, 'mean': tensor([-0.0008]), 'variance': tensor([0.9980])}, 
  (1, 'y'): {'count': 2000, 'mean': tensor([-0.0336]), 'variance': tensor([0.9671])}  }

False:

{'y': {'count': 4000, 'mean': tensor([-0.0172]), 'variance': tensor([0.9826])}}

Right now I'm moving to implement tests.

When it comes to the refactor of the old MCMC to unify both classes with get_summary() I'm not so sure about it right now. summary() in MCMC which prints statistics uses summary(samples, prob=0.9, group_by_chain=True) from utils that is the exact implementation of desired get_summary(). But I can't move summary to MCMC as those utils are used directly in other places (like baseball tutorial) so it can potentially break compatibility.

I can either:

  1. Leave MCMC as is, so both classes will be loosely connected with only run() method implemented from the base class.
  2. Do an attempt to refactor MCMC and utils a bit to also introduce get_summary() to common interface (but I'm not sure if this PR is the right place to do so).

WDYT?

@fritzo
Copy link
Member

fritzo commented Jun 3, 2021

Hmm... it looks like we could maybe refactor print_summary() since it is used only by MCMC and is not exposed in public docs. Maybe we could refactor print_summary() to input a stats dict similar to that output by StreamingStats.get(), and then move the samples -> stats dict computation in to either MCMC or utils.py? I defer to your judgement; the safest thing to do is duck type StreamingMCMC in this PR (your option 1.) and leave refactoring to a follow-up PR.

@mtsokol
Copy link
Contributor Author

mtsokol commented Jun 13, 2021

@fritzo Hi! After a short break I continued this PR.

In the latest commit I did a rebase to use merged pyro.ops.streaming. After inspecting MCMC implementation I managed to write StreamingMCMC so test_mcmc_api.py test suite could be parametrized with:

@pytest.mark.parametrize("run_mcmc_cls", [run_default_mcmc, run_streaming_mcmc])

But only for these tests that don't require rhat or ess diagnostics (and also one test that checks usage of _UnarySampler and _MultiSampler). The rest of the tests are passing.

When it comes to the implementation we discussed to what degree default MCMC should be refactored. Eventually I decided to leave MCMC intact and make as little changes to the existing codebase as possible.

It resulted in some forced workarounds, like in tests' parametrization - to make output of StreamingMCMC compatible with existing tests it required a bit of boilerplate and dull output manipulation (Let's discuss it in the review).

Profiling example

Here's a small memory profiling (via memory_profiler) that I did on modified baseball.py tutorial to compare MCMC and StreamingMCMC.

Gist containing modified script for reproducibility (commands are at the bottom):
https://gist.github.com/mtsokol/cc10c0d57ac0050d4cf8ea6774acbde5

And the results (former is MCMC and latter is StreamingMCMC (with --stream option)):

mcmc
streaming_mcmc

Questions

  1. So remaining tests in the test suite in question require rhat and ess diagnostics. I've started looking at the current implementation and at first sight rhat would (probably) require mean and variance streaming statistics to be present to get that diagnostic. For ess I don't have any idea (considering e.g. autocorrelation method internals). Is it possible to get both of them in streaming versions?

  2. While reading numpyro and Stan documentation I found that numpyro's MCMC implements "Thinning" which in Stan's documentation on ESS is described as a method for memory usage reduction - last paragraph (15.4.4):

https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html

It looks pretty easy (apply every nth sample to the provided streaming statistics).

Can I add it to this PR?

  1. Also some time ago I created a PR for [FR]: Example of Predictive with MCMC sampling #2803 here: Predictive and Deterministic tutorial #2852 which is ready for a review.

@mtsokol mtsokol marked this pull request as ready for review June 13, 2021 22:01
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation and tests look great, and thanks for adding plots to verify memory overhead 👍 I think you'll need to merge dev since we recently changed from travis-ci to github actions forCI. Answering your specific questions:

  1. remaining tests ... require rhat and ess

Yeah, it will take some effort to implement those; each is pretty complex and probably worth a separate PR. In the meantime I think it's fine to xfail in the tests.

  1. numpyro ... implements "Thinning"

Hmm I guess thinning could help the original MCMC class (or simply setting max_tree_depth to a large value). However our streaming statistics are already constant-memory, so I don't see how thinning would help (well it would reduce computational complexity but also increase statistical error of e.g. mean and variance). Maybe one way to implement thinning would be to add a thinned StreamingStats subclass, possibly

  • adding an optional thinning argument to StackStats, defaulting to 1;
  • creating a new ThinnedStackStats; or
  • creating a constant-memory ReservoirStackStats that implements reservoir sampling.

WDYT?

  1. Also some time ago I created a PR

Sorry I lost track of that, review sent!

fritzo
fritzo previously approved these changes Jun 15, 2021
@fritzo
Copy link
Member

fritzo commented Jun 15, 2021

@mtsokol is there anything else you'd like to add, or is this ready to merge? As mentioned above, I believe thinning would better fit into pyro.ops.streaming than in StreamingMCMC itself.

@mtsokol
Copy link
Contributor Author

mtsokol commented Jun 16, 2021

@fritzo I think now it's ready to be merged. What I just pushed is a minor docs fix (checked if html is correctly generated). Now is ready.

@fritzo fritzo merged commit 8fd0bf5 into pyro-ppl:dev Jun 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants