Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speeding up NUTS and MCMC in tests #1740

Closed
Jacob-Stevens-Haas opened this issue Feb 20, 2024 · 3 comments
Closed

Speeding up NUTS and MCMC in tests #1740

Jacob-Stevens-Haas opened this issue Feb 20, 2024 · 3 comments
Labels
jax This issue is specific to JAX

Comments

@Jacob-Stevens-Haas
Copy link

Jacob-Stevens-Haas commented Feb 20, 2024

In a package that builds regression problems with a variety of solvers, we recently added numpyro and a regularized horseshoe prior. The tests however, take longer than any other method, and I'm trying to speed them up. Even a small test (num_warmup=1, num_samples=4, data shape =(10, 2)) takes around 5 seconds - substantially longer than our tests for discrete optimizations with gurobipy. Fixing the geometry would probably help in this specific case, but in general, are there any kwargs that will let NUTS/MCMC find a very quick (and bad) solution? Like setting max_iter=1 in a gradient descent algorithm. I tried setting max_tree_depth to 1 and target_accept_prob to .1, but that didn't change timing appreciably. This would be useful in tests like "make sure that we can pickle our models after fitting" and "make sure that a fit model has certain attributes".

In profiling, the test spends about 41% the time initializing,:
10% hmc:186:initialize kernel
31% util:303:find_valid_initial_params
and
44% util:266:fori_collect (I assume actual execution)

Totally understand if it's not available, though! - it's a great package as is.

# %%
import numpy as np

import jax.numpy as jnp
from jax import random
import numpyro
from numpyro.diagnostics import summary
from numpyro.distributions import HalfCauchy, InverseGamma, Normal
from numpyro.infer import MCMC, NUTS
from pysindy.optimizers.sbr import _sample_reg_horseshoe


def model(x, y):

    # beta = reg_horseshoe_prior(1e-1, 5, 3, (1, x.shape[1]))
    beta = numpyro.sample("beta", Normal())
    preds = jnp.dot(x, beta.T)
    error = numpyro.sample("obs", Normal(preds, 1e-1), obs=y)

def reg_horseshoe_prior(
    global_sparsity: float,
    degrees_of_freedom: float,
    slab_var: float,
    shape: tuple[int, ...]
):
    tau = numpyro.sample("tau", HalfCauchy(global_sparsity))
    c_sq = numpyro.sample(
        "c_sq",
        InverseGamma(
            degrees_of_freedom / 2,
            degrees_of_freedom / 2 * slab_var**2
        ),
    )
    lamb = numpyro.sample("lambda", HalfCauchy(1.0), sample_shape=shape)
    lamb_squiggle = jnp.sqrt(c_sq) * lamb / jnp.sqrt(c_sq + tau**2 * lamb**2)
    beta = numpyro.sample(
        "beta",
        Normal(jnp.zeros_like(lamb_squiggle), jnp.sqrt(lamb_squiggle**2 * tau**2)),
    )
    return beta

# %%

x = np.random.normal(size=(10, 2))
y = x[:, :1]
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1, num_samples=4)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)
summary_dict = summary(mcmc.get_samples(), group_by_chain=False)
Example code:
@martinjankowiak
Copy link
Collaborator

don't think there's much that you can do since presumably much of that time is going into jax compilation

@fehiepsi fehiepsi added the jax This issue is specific to JAX label Feb 26, 2024
@fehiepsi
Copy link
Member

Closed because this is more specific to jax. Please checkout https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html. As far as I know, it supports TPU and GPU for now and will support CPU in near future.

@Jacob-Stevens-Haas
Copy link
Author

Jacob-Stevens-Haas commented May 13, 2024

Thanks! I'll check it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jax This issue is specific to JAX
Projects
None yet
Development

No branches or pull requests

3 participants