In [None]:
import blackjax
import jax
import jax.numpy as jnp
from blackjax.util import run_inference_algorithm

logdensity_fn = lambda x: -jnp.sum(x**2)
integrator = blackjax.mcmc.integrators.velocity_verlet
target_acc_rate = 0.8
rng_key = jax.random.PRNGKey(0)
initial_position = jnp.ones(1)
state = blackjax.nuts.init(initial_position, logdensity_fn)
num_steps = 10000
num_tuning_steps = 10000
return_only_final = True


warmup_key = jax.random.PRNGKey(1)


warmup = blackjax.window_adaptation(
                blackjax.nuts, logdensity_fn, integrator=integrator, target_acceptance_rate=target_acc_rate,
                #  cos_angle_termination=cos_angle_termination
            )

(state, params), adaptation_info = warmup.run(
                warmup_key, initial_position, num_tuning_steps
            )

print("Reported acceptance rate: ", adaptation_info.info.acceptance_rate.mean())

alg = blackjax.nuts(
            logdensity_fn=logdensity_fn,
            step_size=params["step_size"],
            inverse_mass_matrix=params["inverse_mass_matrix"],
            integrator=integrator,
            # cos_angle_termination=cos_angle_termination,
        )

final_output, (_, info) = run_inference_algorithm(
            rng_key=rng_key,
            initial_state=state,
            inference_algorithm=alg,
            num_steps=num_steps,
            transform=lambda x, i:(x,i),
        )

print("Observed acceptance rate: ", info.acceptance_rate.mean())