Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address google/jax#19885 for numpyro. #1743

Merged
merged 4 commits into from
Feb 28, 2024

Conversation

tillahoffmann
Copy link
Contributor

This should address performance issues for LowRankMultivariateNormal distributions with batch dimensions. Only a single change was required to fix the issue. There are other [matmul] + [identity] expressions in the codebase, but, for some reason, they don't cause any issues. The tests verify that no warning is emitted.

@@ -1860,7 +1860,7 @@ def _batch_capacitance_tril(W, D):
Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
K = jnp.matmul(Wt_Dinv, W)
# could be inefficient
return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1])))
return jnp.linalg.cholesky(jnp.subtract(K, -jnp.identity(K.shape[-1])))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. I couldn't come up with a good solution to compute K + eye. Switching to subtract is a bit unfortunate to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's not great. I'll have a think if K.at[..., <diag indices>].add(1) could do the job.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, I couldn't reproduce the issue locally. :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to only appear on some systems. For example, I can't reproduce the issue on my MacBook Pro with M1 chip, but it is reproducible in GitHub Actions (see here for an example run).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a comparison of different implementations. Using the at indexing seems to do a reasonable job. I've just pushed an update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would need to add .block_until_ready for a fair comparision: see https://jax.readthedocs.io/en/latest/tutorials/quickstart.html#using-jit-to-speed-up-functions. I guess add/substract is more performant than slice update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. Forgot about that. Using the block_until_ready call, I get

_original_batch_capacitance_tril
677 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
_sub_batch_capacitance_tril
5.97 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
_add_batch_capacitance_tril
5.44 ms ± 900 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

on the machine that's raising the warning. On my local M1 chip, I get

_original_batch_capacitance_tril
11.5 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
_sub_batch_capacitance_tril
12 ms ± 716 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
_add_batch_capacitance_tril
8.41 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Till!

@@ -71,7 +71,7 @@
validate_sample,
vec_to_tril_matrix,
)
from numpyro.util import is_prng_key
from numpyro.util import add_diag, is_prng_key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool!!

Could you move this helper to numpyro.distributions.util instead?

@fehiepsi fehiepsi merged commit e6c187c into pyro-ppl:master Feb 28, 2024
4 checks passed
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* Address google/jax#19885 for numpyro.

* Implement function to add constant or batch of vectors to diagonal.

* Use `add_diag` helper function in `distributions` module.

* Move `add_diag` to `distributions.util` module.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants