Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMM model using HMC/NUTS is slow #1511

Closed
neerajprad opened this issue Nov 2, 2018 · 19 comments
Closed

HMM model using HMC/NUTS is slow #1511

neerajprad opened this issue Nov 2, 2018 · 19 comments

Comments

@neerajprad
Copy link
Member

neerajprad commented Nov 2, 2018

Since we started using einsum to evaluate the joint log density of the trace with discrete parameters, we can do discrete site enumeration across many more models in HMC/NUTS without going OOM.

For HMMs, however, NUTS / HMC is extremely slow to the point of being unusable beyond a few time steps, say 10. Refer to this test for profiling.

There are a few issues that I noticed:

  • NUTS crawls to a halt due to extremely small step size in the warmup phase itself. I suppose the issue might be with starting from a random initialization, but I am not sure. It might be worth initializing with the MAP estimate and then running NUTS to not hang with unworkably small step size. I think we have seen this issue with a few other models too. e.g. Significantly different answers for PyMC3 NUTS vs Pyro NUTS #1470. cc. @fehiepsi I don't observe this issue on pytorch 1.0.
  • Converting the values returned from the integrator into a model trace (_get_trace) also takes more than 3s. While comparatively small, I believe this can be optimized if we assume our models to be static by assuming a different data structure inside HMC, so that we do not need to run the model each time.
  • The trace log density evaluation (and gradient computation) takes the bulk of the time, as expected. It is not immediately clear how this can be improved given that we need to call this many times per integrator step to generate even a single sample, and it does seem like NUTS will continue to be slow until we can make this much faster. Profiling viz below.

screen shot 2018-11-01 at 5 24 41 pm

@fehiepsi
Copy link
Member

fehiepsi commented Nov 2, 2018

@neerajprad I tested with pytorch 0.4.0 and didn't catch the slowness. As I mentioned in #1487, trace + Gamma/Dirichlet + pytorch 1.0rc will give wrong grads in backward pass. And that bug is not related to HMC/NUTS. I don't know what I can do more with that issue so I skip it and implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.

@fehiepsi
Copy link
Member

fehiepsi commented Nov 2, 2018

About initialization, Stan does not use MAP. Previously, PyMC3 used MAP, but now they don't by default. There might be experimental reasons for their decisions (which I can't track back). In my opinion, allowing users set starting points is enough (from my experience, it is extremely useful for some models: when I got stuck with randomize initialization, I set initializations to the mean, then things go smoothly). These starting points can come from intuition or from mean of priors or from MAP.

@neerajprad
Copy link
Member Author

neerajprad commented Nov 2, 2018

@fehiepsi - This is a known issue that we would like to address or at least conclude that we cannot run HMM type models with the current constraints. I just cc'd you as an FYI - you shouldn't feel compelled to work on this! :)

I tested with pytorch 0.4.0 and didn't catch the slowness.

The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress. This is without using JIT. My hunch was that this could be the case with a bad initialization, and with transforms warping the potential energy surface in a way that trajectories are extremely unstable, and we keep lowering the step size, making progress extremely slow. This is just a guess though and needs to be investigated further.

Stan does not use MAP. Previously, PyMC3 used MAP, but now they don't by default.

Even if it is not available by default, I am interested in exploring if initializing with the MAP estimate does better on these kinds of models. If so, it will be useful to provide an optional kwarg initialize_with_map=False to the HMC/NUTS kernels.

implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.

That will be really useful! You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR arviz-devs/arviz#309 by @ColCarroll, which extends support for Pyro.

@fehiepsi
Copy link
Member

fehiepsi commented Nov 2, 2018

The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress.

You are right that this might another issue. I did the test for 100 num_samples and 100 warmup_steps in pytorch 1.0rc and pytorch 0.4. Pytorch 1.0rc is a bit slower than pytorch 0.4. Einsum is slower than not einsum. But I didn't observe the very small step_size problem. The step_size is around 0.0001-0.0004 in all the test.

For MAP, I discourage to use it with HMC. I have implemented dozens of models on various small dataset. Unless I specified good initial values, MAP gives very bad answers despite that I have set different learning rate and num_steps for MAP. For example, a simple linear regression with Gaussian likelihood: y = Normal(ax + b, sigma). When scale of a is large, we have to set learning rate to a large value unless we run SVI with dozens of thousands steps. But with large learning rate, sigma will tend to move to a very large value! At the end of MAP, I cann't get the answer I need. So the performance is heavily depending on the initial values of MAP, learning rate, num steps,...

I don't face problems (other than the nan issue which we have addressed) with initial trace so I can't say much. To avoid extreme values of initial trace (due to random initialization), Stan approach for initialization might be helpful. They initialize values randomly in the interval (-2, 2) of unconstrained space.

You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR arviz-devs/arviz#309 by @ColCarroll, which extends support for Pyro.

Thanks for your suggestion!!! I don't know about it. So if arviz already supports these diagnostics, then should we implement it? It might be better to explore how to combine things with arviz instead. What do you think? Edit: I took a look at their implementation. All calculation depends on numpy, which is a little bit uncomfortable to me. I will implement these diagnostics in pytorch instead.

@fritzo
Copy link
Member

fritzo commented Nov 2, 2018

@neerajprad could you send me a .prof file and/or the steps to reproduce your profile? I'd like to inspect the profiling numbers in contract_to_tensor().

@neerajprad
Copy link
Member Author

@neerajprad could you send me a .prof file and/or the steps to reproduce your profile? I'd like to inspect the profiling numbers in contract_to_tensor().

Thanks @fritzo, I will send you the profiling script and .prof file shortly.

I don't face problems (other than the nan issue which we have addressed) with initial trace so I can't say much. To avoid extreme values of initial trace (due to random initialization), Stan approach for initialization might be helpful. They initialize values randomly in the interval (-2, 2) of unconstrained space.

Thanks for all the suggestions, @fehiepsi. I will play around with different initializations first to see if it improves the performance.

I don't know about it. So if arviz already supports these diagnostics, then should we implement it? It might be better to explore how to combine things with arviz instead. What do you think?

I think if the integration is straightforward and arviz has all the diagnostics you were looking to implement, we could just suggest users to go with that (and even add it to our example). If you find the integration lacking in any way, and have ideas on what can be improved, feel free to open an issue to discuss! I think you might need to change the interface a bit (to preserve chain information). I am not too worried about numpy because that conversion will happen at the end of inference (unless you'd like to provide some online diagnostics) and converting a cpu tensor to numpy is low overhead.

@neerajprad
Copy link
Member Author

@fritzo - You can run the profiler using:

python -m cProfile -o hmm.prof tests/perf/prof_hmc_hmm.py

on the prof-hmm branch. Attaching the .prof file. I have turned off step size adaptation so as not to take too much time. Most of the time is actually just taken by einsum so I am not sure if there is much room for optimization here.

hmm.tar.gz

@ColCarroll
Copy link

The only thing I think to worry about in using ArviZ is that we are writing the library with no regard for Python 2. In particular, we use matplotlib 3.x, which is Python3 only, and the rest of the python data science infrastructure seems to be phasing python2 support out over the next year, so we did not want to start a new project with that technical debt. I understand this may hurt adoption in legacy stacks!

Beyond that, please tag one of us here or open an issue if we can help with the integration at all. We have found xarray and netcdf to be very natural ways of storing inference data.

@neerajprad
Copy link
Member Author

neerajprad commented Nov 2, 2018

@ColCarroll - Thanks for offering to help with the integration! Regarding the python 2 incompatibility, there is already another feature (CUDA with parallel MCMC chains) that isn't supported on Python 2. Given that Python 2 will not be supported in a year or so, I think it is fine if certain (non-critical) features are only available in Python 3 going forward, but this is worth discussing internally for sure.

EDIT: This however means that we cannot have arviz as an official dependency until we drop support for python 2.

@fritzo
Copy link
Member

fritzo commented Nov 2, 2018

@ColCarroll we plan for Pyro to support Python 2 as long as PyTorch supports Python 2.

@fehiepsi
Copy link
Member

fehiepsi commented Dec 2, 2018

@neerajprad Could you point me against the profiling test which is slow? It seems that the file tests/perf/prof_hmc_hmm.py is not available in dev branch.

@neerajprad
Copy link
Member Author

@fehiepsi - I updated the example in the prof-hmm branch. It should be in tests/perf/prof_hmc_hmm.py. I don't think there are any immediate TODOs for this one, and this is more of an enhancement issue than a perf issue. Some things we can experiment with in the future would be JITing the grad step itself (once PyTorch supports it).

@fehiepsi
Copy link
Member

fehiepsi commented Dec 13, 2018

@neerajprad Totally agree! I just do profiling with both jit and nojit (hmm_2.zip). Most of time is spent for computing _potential_energy and its grad, so the slowness is not related to hcm/nuts.

It is surprised to me that distributions' log_prob just take 40s in the total 250s to compute trace_log_prob. Lots of time is spent on sumproduct (140s) and pack_tensors (40s) stuffs. I guess this is expected? In addition, get_trace spends a lot of time on process_message stuffs and post_process_message stuffs (more than 100s). I believe to get samples for this small model, we just need 1s (in 100s) for sampling.

@fritzo
Copy link
Member

fritzo commented Dec 13, 2018

It is surprised to me that distributions' log_prob just take 40s

In the HMM example the distribution log_prob computation is merely a gather, i.e. memcopy; all actual computation is done by sumproduct when combining log_prob tensors from multiple sites, i.e matmul and einsum.

@fehiepsi
Copy link
Member

@fritzo That makes sense! So the slowness is expected for models with discrete variables.

@fehiepsi fehiepsi changed the title HMM model using HMC/NUTS is extremely slow HMM model using HMC/NUTS is slow Dec 14, 2018
@fehiepsi
Copy link
Member

fehiepsi commented Dec 28, 2018

I put here a profiling work https://gist.github.com/fehiepsi/75dfbea31b993f165f51524776185be6 for reference.

For the same model, Pyro took 66s while Stan took 33s. The inference results are quite comparable.

But the important point is: it took 32s to compile the Stan model, and only 1s for sampling! So compiling plays an important job here. Hope that PyTorch JIT will be improved in the future. :)

@eb8680
Copy link
Member

eb8680 commented Dec 28, 2018

@fehiepsi the difference may be related to this PyTorch issue about einsum performance: pytorch/pytorch#10661 oops, there's no enumeration happening in supervised_hmm, so this might not be a problem

@elbamos
Copy link
Contributor

elbamos commented Feb 13, 2019

I'm jumping in here because I've also seen serious performance issues with NUTS/HMC.

In my testing, performance starts out adequate, and then begins to drop precipitously after around 20 iterations. I have observed as this is occurring, the step size is increasing.

To me, the interesting thing is that the performance isn't constant. It declines after time. That suggests to me that the issue is not limited to the time it takes to calculate the energy, which shouldn't vary that much from iteration to iteration.

I have two suspicions: The first is that the HMC/NUTS implementation is trying too hard to increase the step size, and so that it ends up producing lots and lots of divergences. The second is that this has to do with memory fragmentation because of the very large number of tensors that are created as intermediate steps and then retained through gradient calculation.

@fehiepsi
Copy link
Member

I believe the slowness is expected when running Pyro MCMC on markov models. So I would like to close this issue. We can point users to numpyro if speed is needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants