-
Notifications
You must be signed in to change notification settings - Fork 238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT tscan and HMC warmup steps #115
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for making the benchmark script! We can keep the notebook for a while to compare different strategy, then we should remove it. If we make script, it would be nice to store the benchmark result somewhere (e.g. in wiki) so we can keep track of performance.
About the slowness of NUTS, if you let num_samples=100, then it will take sometime to finish because NUTS have larger trajectory length than HMC. In my run (with the version in the notebook, it took 65322 leapfrog steps to get 100 samples), so basically will take 65x longer than HMC. It took me 5m in GPU to finish getting 100 NUTS samples (with HMC init param/step_size) so in CPU, I expect it will take more than half an hour. ^^!
# TODO: Remove with jax v0.1.26 | ||
@patch_dependency('jax.interpreters.partial_eval.trace_unwrapped_to_jaxpr', jax) | ||
def _trace_unwrapped_to_jaxpr(fun, pvals, **kwargs): | ||
return pe.trace_to_jaxpr(lu.wrap_init(fun, kwargs), pvals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
We can retain both. I think its nice to have a notebook that gives more details, but a script is more handy for comparison on a day to day basis.
I see. From one of the runs, the time per leapfrog step was also quite a bit higher, let me let it run for half an hour then and see what I get. |
This seems a bit surprising. With the benchmark scripts NUTS terminates within 1138 leapfrog steps but it takes around 0.1 s for each step, which is much higher than 0.06 for HMC. Could you disable fast math mode, and try the benchmark on your system? |
@neerajprad That's the run in GPU, where precision is much better than CPU. For CPU with fast math disable, in my system NUTS gives:
while in HMC,
If you disable fast math and plot the sample of first coef, you will see that HMC samples are zeros constant and NUTS samples are highly correlated. |
It is interesting that you get the same number of leapfrog steps for NUTS but your timings look better (0.06 vs 0.1 sec. per step). Just to confirm, this is with |
Yes, both with |
@neerajprad All the above info from my side is from modifications (step size, init params) of the notebook (which I didn’t use tscan). I have not used your script. |
Then the script seems to be slower than your notebook for NUTS. I'll need to take a look at why that's the case. Hopefully if this (or a similar approach works), we can have a fast implementation by default instead of relying on users to bypass warmup and compile |
I hope so too. Let me play around with your idea too to see if we miss something. |
Feel free to push to this PR itself; I might not be able to get to this until tomorrow. |
@fehiepsi - With fast-math disabled, I'm getting 0.1 sec per time step for HMC (0.06 with fast-math), and 0.09 for NUTS. I have also changed the initialization so that HMC is actually not just giving 0s, which may have been the reason for the faster 0.06 secs run time that I was getting earlier. I'm happy with the benchmark in that both HMC and NUTS have similar times. I haven't compared this with your original benchmark on my system, which I think will be similar to this. All the benchmark numbers are uniformly better on your system though (probably because you have a more powerful CPU). In any case, this is ready to merge, unless you have further comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great overall! I guess we still have not solve the problem of compiling 2 times #88 yet? I just have a few small comments:
from numpyro.mcmc import hmc | ||
from numpyro.util import tscan | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lint for 5 blank lines
I just want to verify something to make sure we don't have different interpretations of the problem:
|
Good point. I didn't use |
@fehiepsi - let me address all your comments. I think it would be best to just use the notebook's initial params for both HMC and NUTS. |
Okay, I think this should address all your comments. I just had a couple of questions based on your last comment:
Are these numbers with the earlier initialization of 0s or the one that is in the notebook? I changed to the one in the notebook and the performance per leapfrog step remains the same. It is possible that the initial compilation time is large and is getting amortized over the larger number of leapfrog steps. But if that's the case, I think we probably should not overly focus our efforts on fixed costs (like compiling the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Could you address the following two nits?
Those are numbers from
Yes, it would be great if this issue is fixed, so we just need to
To make a fair benchmark (while waiting for the issue is fixed upstream), I think we can use non-prim version for a while with a trade-off that we have to concatenate samples without compiling. The other option is to go aggressive and make mcmc work as follows
But I don't like this idea because it is just suitable for benchmarking. We have to recompile to draw another 1000 samples... So for the time being, I would stick with non-prim version and wait for the issue is fixed upstream. ^^
Yes, I also think so. |
Fixes #114.
This fixes the issue with
tscan
not being jittable. With this change and jitting thewarmup_update
step, we get significantly faster run times, specially for NUTS. e.g. fortest_beta_bernoulli
, the run time goes down from 21s to 15s.I have also added a minimal version of @fehiepsi's #92 as an example to test how this change fares on the benchmark, and also so that it is easy to make changes and see how the benchmark moves (which takes a few more steps in a jupyter notebook). It seems that this is competitive for HMC, but for NUTS, this is still quite slow. @fehiepsi - is it important to start with the step size and initial param values as in the notebook?
TODO: