Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added dynamic rhat #77

Merged
merged 1 commit into from
Jul 29, 2021
Merged

Added dynamic rhat #77

merged 1 commit into from
Jul 29, 2021

Conversation

sidravi1
Copy link
Contributor

@sidravi1 sidravi1 commented Feb 7, 2021

Overview

Displays Rhat values in the progress bar. Partially addresses #8.

Details

Here it is in action:

rhat_dynamic.mov

To do

  1. Probably should add variable names to Rhat
  2. Probably move it to a new line else will get messy with a lot of vars
  3. Basic offline Rhat and warning if too high (like pymc3)
  4. More testing - especially for mvn 🙊

mcx/sample.py Outdated Show resolved Hide resolved
mcx/sample.py Outdated Show resolved Hide resolved
mcx/trace.py Show resolved Hide resolved
@rlouf
Copy link
Owner

rlouf commented Feb 8, 2021

The code looks great and I really like the result!

  1. Probably should add variable names to Rhat
  2. Probably move it to a new line else will get messy with a lot of vars

Since this is a rough indicator, I thought we could only display the "worst" value of Rhat among all variables (in terms of distance to 1). Other values can be shown in the inference summary. The ideal would be to update a graph with all the values of Rhat over time, but that's a project in itself.

  1. Basic offline Rhat and warning if too high (like pymc3)

As discussed, implementing the rank-normalized Rhat for the inference summary would be best. Adding a warning if a value is too high is a good idea, and it is even better if that warning is actionable: what can I do as a modeler with this information?

Where should it be displayed? After the progress bar or do we print (at least part of) the inference summary first?

  1. More testing - especially for mvn speak_no_evil

Indeed, multivariate random variables are more error-prone :) The best is probably to take examples from the paper and check that computing Rhat on these chains gives the expected result.

@rlouf
Copy link
Owner

rlouf commented Feb 9, 2021

Btw for the sake of making incremental changes it would be better to address (3) in a separate PR.

mcx/sample.py Outdated Show resolved Hide resolved
@rlouf rlouf force-pushed the master branch 3 times, most recently from f8f3e6b to 965f6dd Compare February 23, 2021 11:28
@rlouf
Copy link
Owner

rlouf commented Apr 12, 2021

Hey @sidravi1 what's the status on this PR?

@sidravi1
Copy link
Contributor Author

sidravi1 commented Apr 13, 2021 via email

@rlouf
Copy link
Owner

rlouf commented Apr 15, 2021

No problem, this is open source, not paid work 🙂

@sidravi1
Copy link
Contributor Author

Hi @rlouf - Got dynamic rhat working though using set_postfix does slow down performance by ~50%

Tested it with this mvnormal model as well.

@mcx.model
def linear_regression_mvn(x, lmbda=1.):
    sigma <~ dist.Exponential(lmbda)
    sigma2 <~ dist.Exponential(lmbda)
    rho <~ dist.Uniform(-1, 1)
    cov = jnp.array([[sigma**2, rho*sigma*sigma2],[rho*sigma*sigma2, sigma2**2]])
    coeffs_init = jnp.ones(x.shape[-1])
    coeffs <~ dist.MvNormal(coeffs_init, cov)
    y = jnp.dot(x, coeffs.T)
    preds <~ dist.Normal(y, sigma)
    return preds

sampler = mcx.sampler(
    rng_key,
    linear_regression_mvn,
    (x_data_mvn,),
    {'preds': y_data_mvn},
    HMC(10),
)
posterior = sampler.run()

If all good, i'll clean up the commit history before the merge.

@sidravi1 sidravi1 requested a review from rlouf June 13, 2021 20:33
@sidravi1 sidravi1 changed the title [WIP] Added dynamic rhat Added dynamic rhat Jun 13, 2021
@rlouf
Copy link
Owner

rlouf commented Jun 14, 2021

Great! Could you try using the mininterval flag and setting it to something like .5s or 1s and report the slowdown then? (https://github.com/tqdm/tqdm/blob/master/tqdm/std.py#L873-L880)

@sidravi1
Copy link
Contributor Author

sidravi1 commented Jun 14, 2021

Great! Could you try using the mininterval flag and setting it to something like .5s or 1s and report the slowdown then? (https://github.com/tqdm/tqdm/blob/master/tqdm/std.py#L873-L880)

mininterval doesn't seem to help much. The bottleneck is actually the rhat updating and not tqdm. What are your thoughts on making it optional? We could also use a pattern where you can register callbacks to run a bunch of other online stats

image

Should also point out that the bottleneck is most noticeable when the model is simple (the linear example), when it's more complex (multivariate example) then it doesn't really reduce it that much.

@rlouf
Copy link
Owner

rlouf commented Jun 15, 2021

It's all a question of user interface. The original idea was that, since we spend 99% of our time debugging models, the sample function would be interactive by default: it displays as much information as possible to see when issues arise and can be interrupted at any time to diagnose these issues. compile=True would show nothing but the progress bar and would correspond to situations where we need inference to be as fast as possible; we could also define a fast_sample function for that purpose.

Now, if you have to wait an extra few seconds for simple models but it does not affect large models, it is not really a problem.

Nevertheless, I like your idea of designing these online metrics as callbacks. This would allow users to customize the metrics being displayed and/or follow their own metrics. It is also cleaner from a code perspective. This way sample would be called with callbacks=[rhat, ess, divergences] by default.

PS: is the multiple progress bar a bug?

@sidravi1
Copy link
Contributor Author

Ok. Make sense.

Should we merge this in and switch to callback design pattern in another PR (when we implement ESS or divergences) or do you want me to update this one?

The multiple progress bars are because of the %%timeit cell magic on top. Just runs it multiple times to get the average run time.

@rlouf
Copy link
Owner

rlouf commented Jun 15, 2021

Would you mind updating this one?

@sidravi1
Copy link
Contributor Author

Yep! Can do :)

@sidravi1
Copy link
Contributor Author

Thanks for your patience @rlouf - I've made those changes. Let me know what you think.

mcx/sample.py Outdated
Comment on lines 532 to 533
call_backs:
The functions to run after each state update
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not appear in the function's signature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Fixed! I'll also squash all the commits so it's ready to merge in

Copy link
Owner

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from my small comment on a docstring, everything is perfect. Ready to merge once that's fixed.

allows online metrics to be passed to sample_loop
@sidravi1
Copy link
Contributor Author

@rlouf - Thanks for reviewing. I've made that one docstring fix and squashed all the commits.

@rlouf rlouf merged commit af1adc9 into rlouf:master Jul 29, 2021
@rlouf
Copy link
Owner

rlouf commented Jul 29, 2021

Great work, the code was really clean and self-explanatory!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants