Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT tscan and HMC warmup steps #115

Merged
merged 11 commits into from
Apr 19, 2019
Merged

JIT tscan and HMC warmup steps #115

merged 11 commits into from
Apr 19, 2019

Conversation

neerajprad
Copy link
Member

@neerajprad neerajprad commented Apr 18, 2019

Fixes #114.

This fixes the issue with tscan not being jittable. With this change and jitting the warmup_update step, we get significantly faster run times, specially for NUTS. e.g. for test_beta_bernoulli, the run time goes down from 21s to 15s.

I have also added a minimal version of @fehiepsi's #92 as an example to test how this change fares on the benchmark, and also so that it is easy to make changes and see how the benchmark moves (which takes a few more steps in a jupyter notebook). It seems that this is competitive for HMC, but for NUTS, this is still quite slow. @fehiepsi - is it important to start with the step size and initial param values as in the notebook?

TODO:

  • Figure out why NUTS is still slow wrt the original benchmark.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for making the benchmark script! We can keep the notebook for a while to compare different strategy, then we should remove it. If we make script, it would be nice to store the benchmark result somewhere (e.g. in wiki) so we can keep track of performance.

About the slowness of NUTS, if you let num_samples=100, then it will take sometime to finish because NUTS have larger trajectory length than HMC. In my run (with the version in the notebook, it took 65322 leapfrog steps to get 100 samples), so basically will take 65x longer than HMC. It took me 5m in GPU to finish getting 100 NUTS samples (with HMC init param/step_size) so in CPU, I expect it will take more than half an hour. ^^!

# TODO: Remove with jax v0.1.26
@patch_dependency('jax.interpreters.partial_eval.trace_unwrapped_to_jaxpr', jax)
def _trace_unwrapped_to_jaxpr(fun, pvals, **kwargs):
return pe.trace_to_jaxpr(lu.wrap_init(fun, kwargs), pvals)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

test/conftest.py Outdated Show resolved Hide resolved
numpyro/mcmc.py Show resolved Hide resolved
numpyro/examples/covtype.py Outdated Show resolved Hide resolved
@neerajprad
Copy link
Member Author

We can keep the notebook for a while to compare different strategy, then we should remove it. If we make script, it would be nice to store the benchmark result somewhere (e.g. in wiki) so we can keep track of performance.

We can retain both. I think its nice to have a notebook that gives more details, but a script is more handy for comparison on a day to day basis.

About the slowness of NUTS, if you let num_samples=100, then it will take sometime to finish because NUTS have larger trajectory length than HMC. In my run (with the version in the notebook, it took 65322 leapfrog steps to get 100 samples), so basically will take 65x longer than HMC. It took me 5m in GPU to finish getting 100 NUTS samples (with HMC init param/step_size) so in CPU, I expect it will take more than half an hour. ^^!

I see. From one of the runs, the time per leapfrog step was also quite a bit higher, let me let it run for half an hour then and see what I get.

@neerajprad
Copy link
Member Author

it took 65322 leapfrog steps to get 100 samples

This seems a bit surprising. With the benchmark scripts NUTS terminates within 1138 leapfrog steps but it takes around 0.1 s for each step, which is much higher than 0.06 for HMC. Could you disable fast math mode, and try the benchmark on your system?

@fehiepsi
Copy link
Member

fehiepsi commented Apr 18, 2019

@neerajprad That's the run in GPU, where precision is much better than CPU. For CPU with fast math disable, in my system NUTS gives:

100%|██████████| 100/100 [01:08<00:00,  1.92it/s]

number of leapfrog steps: 1138
avg. time for each step : 0.06028635409678642

while in HMC,

100%|██████████| 100/100 [00:59<00:00,  1.68it/s]

number of leapfrog steps: 1000
avg. time for each step : 0.05941483545303345

If you disable fast math and plot the sample of first coef, you will see that HMC samples are zeros constant and NUTS samples are highly correlated.

@neerajprad
Copy link
Member Author

neerajprad commented Apr 18, 2019

It is interesting that you get the same number of leapfrog steps for NUTS but your timings look better (0.06 vs 0.1 sec. per step). Just to confirm, this is with covtype.py using the same initialization for both?

@fehiepsi
Copy link
Member

Yes, both with step_size = np.sqrt(0.5 / N) and init_params = {"coefs": np.zeros(dim)}.

@fehiepsi
Copy link
Member

@neerajprad All the above info from my side is from modifications (step size, init params) of the notebook (which I didn’t use tscan). I have not used your script.

@neerajprad
Copy link
Member Author

@neerajprad All the above info from my side is from modifications (step size, init params) of the notebook (which I didn’t use tscan). I have not used your script.

Then the script seems to be slower than your notebook for NUTS. I'll need to take a look at why that's the case. Hopefully if this (or a similar approach works), we can have a fast implementation by default instead of relying on users to bypass warmup and compile sample_kernel.

@fehiepsi
Copy link
Member

I hope so too. Let me play around with your idea too to see if we miss something.

@neerajprad
Copy link
Member Author

Feel free to push to this PR itself; I might not be able to get to this until tomorrow.

@neerajprad
Copy link
Member Author

neerajprad commented Apr 19, 2019

@fehiepsi - With fast-math disabled, I'm getting 0.1 sec per time step for HMC (0.06 with fast-math), and 0.09 for NUTS. I have also changed the initialization so that HMC is actually not just giving 0s, which may have been the reason for the faster 0.06 secs run time that I was getting earlier.

I'm happy with the benchmark in that both HMC and NUTS have similar times. I haven't compared this with your original benchmark on my system, which I think will be similar to this. All the benchmark numbers are uniformly better on your system though (probably because you have a more powerful CPU). In any case, this is ready to merge, unless you have further comments.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great overall! I guess we still have not solve the problem of compiling 2 times #88 yet? I just have a few small comments:

numpyro/examples/covtype.py Outdated Show resolved Hide resolved
numpyro/examples/covtype.py Show resolved Hide resolved
from numpyro.mcmc import hmc
from numpyro.util import tscan


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lint for 5 blank lines

@fehiepsi
Copy link
Member

I'm happy with the benchmark in that both HMC and NUTS have similar times. I haven't compared this with your original benchmark on my system, which I think will be similar to this. All the benchmark numbers are uniformly better on your system though (probably because you have a more powerful CPU).

I just want to verify something to make sure we don't have different interpretations of the problem:

  • With the current init_params, first num_steps of NUTS is 1. So the hack step_size=1. is not important. In addition, we run NUTS with num_samples=100, so the effect of first sampling step would be small. When first num_steps is large, compiling time will contribute to the benchmark a lot. I think that is the reason you didn't use the NUTS init_params and step_size as in the notebook's benchmark (NUTS will take a bunch of time to finish its first step). If that is the case, then I would prefer to keep init_pamams and step_size as in the notebook for further enhancement (in later PRs) and change it back to the default init_params when problems are resolved.
  • The following timing in my system (cpu, fastmath disabled) shows that first step is important.
HMC script (n=10): 122ms
HMC notebook (n=10): 61ms
HMC script (n=100): 64ms
HMC notebook (n=100): 61ms

@neerajprad
Copy link
Member Author

When first num_steps is large, compiling time will contribute to the benchmark a lot. I think that is the reason you didn't use the NUTS init_params and step_size as in the notebook's benchmark (NUTS will take a bunch of time to finish its first step). If that is the case, then I would prefer to keep init_pamams and step_size as in the notebook for further enhancement (in later PRs) and change it back to the default init_params when problems are resolved.

Good point. I didn't use init_params from the notebook because its hard to reliably compare our benchmark against the paper due to a variety of system level differences. I think we can use any reasonable initialization and run the benchmarks for other implementations, so I tried to keep it simple. But if that is a particularly problematic initialization, let us just use that for the time being and later move to a simpler one when that problem is fixed.

@neerajprad
Copy link
Member Author

neerajprad commented Apr 19, 2019

@fehiepsi - let me address all your comments. I think it would be best to just use the notebook's initial params for both HMC and NUTS.

@neerajprad
Copy link
Member Author

Okay, I think this should address all your comments. I just had a couple of questions based on your last comment:

The following timing in my system (cpu, fastmath disabled) shows that first step is important.

Are these numbers with the earlier initialization of 0s or the one that is in the notebook? I changed to the one in the notebook and the performance per leapfrog step remains the same. It is possible that the initial compilation time is large and is getting amortized over the larger number of leapfrog steps. But if that's the case, I think we probably should not overly focus our efforts on fixed costs (like compiling the sample_kernel twice) that are likely to get amortized away on larger models / larger number of samples. We can probably rely on JAX fixing that over the longer term. What would be nice is if we are very competitive on large models (I'm not sure if we are there yet), even if we are 10 seconds slower on smaller models (this is important too, but it is kind of optimizing for the tail so early in the project). What do you think?

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could you address the following two nits?

numpyro/examples/covtype.py Outdated Show resolved Hide resolved
numpyro/examples/covtype.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member

Are these numbers with the earlier initialization of 0s or the one that is in the notebook?

Those are numbers from init_params = {'coefs': random.normal(key=random.PRNGKey(0), shape=(dim,))}.

But if that's the case, I think we probably should not overly focus our efforts on fixed costs (like compiling the sample_kernel twice) that are likely to get amortized away on larger models / larger number of samples. We can probably rely on JAX fixing that over the longer term.

Yes, it would be great if this issue is fixed, so we just need to jit sample_kernel in hmc implementation and use it across init and sample stages. But right now, compiling time is a big problem:

  • when first step of NUTS takes hundreds of steps
  • when there are a bunch of sample statements in models

To make a fair benchmark (while waiting for the issue is fixed upstream), I think we can use non-prim version for a while with a trade-off that we have to concatenate samples without compiling. The other option is to go aggressive and make mcmc work as follows

mcmc = MCMC(sample_kernel, state, num_samples=1000)  # make a jit version of tscan here with num_samples is known, in other works, `bs` length is known
mcmc.compile()  # run tscan with 1 step; we can also trigger the hack for fast compiling here
mcmc.run()  # run tscan with 1000 steps

But I don't like this idea because it is just suitable for benchmarking. We have to recompile to draw another 1000 samples... So for the time being, I would stick with non-prim version and wait for the issue is fixed upstream. ^^

What would be nice is if we are very competitive on large models (I'm not sure if we are there yet), even if we are 10 seconds slower on smaller models

Yes, I also think so.

@fehiepsi fehiepsi merged commit 45eac0a into master Apr 19, 2019
@neerajprad neerajprad deleted the scan-fix branch November 19, 2019 19:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make tscan jittable
2 participants