In [None]:
import blackjax
import jax
import jax.numpy as jnp
from blackjax.util import run_inference_algorithm

logdensity_fn = lambda x: -jnp.sum(x**2)
integrator = blackjax.mcmc.integrators.velocity_verlet
target_acc_rate = 0.8
rng_key = jax.random.PRNGKey(0)
initial_position = jnp.ones(1)
state = blackjax.nuts.init(initial_position, logdensity_fn)
num_steps = 10000
num_tuning_steps = 10000
return_only_final = True


warmup_key = jax.random.PRNGKey(1)


warmup = blackjax.window_adaptation(
                blackjax.nuts, logdensity_fn, integrator=integrator, target_acceptance_rate=target_acc_rate,
                #  cos_angle_termination=cos_angle_termination
            )

(state, params), adaptation_info = warmup.run(
                warmup_key, initial_position, num_tuning_steps
            )

print("Reported acceptance rate: ", adaptation_info.info.acceptance_rate.mean())

alg = blackjax.nuts(
            logdensity_fn=logdensity_fn,
            step_size=params["step_size"],
            inverse_mass_matrix=params["inverse_mass_matrix"],
            integrator=integrator,
            # cos_angle_termination=cos_angle_termination,
        )

final_output, (_, info) = run_inference_algorithm(
            rng_key=rng_key,
            initial_state=state,
            inference_algorithm=alg,
            num_steps=num_steps,
            transform=lambda x, i:(x,i),
        )

print("Observed acceptance rate: ", info.acceptance_rate.mean())

## Now with just DA

In [None]:
from blackjax.adaptation.step_size import (
    dual_averaging_adaptation,
)
from blackjax.util import pytree_size

def da_adaptation(
    rng_key,
    initial_position,
    algorithm,
    logdensity_fn,
    num_steps: int = 1000,
    initial_step_size: float = 1.0,
    target_acceptance_rate: float = 0.80,
    integrator=blackjax.mcmc.integrators.velocity_verlet,
    # cos_angle_termination: float = 0.0,
):

    da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate)

    kernel = algorithm.build_kernel(integrator=integrator, 
                                    # cos_angle_termination=cos_angle_termination
                                    )
    init_kernel_state = algorithm.init(initial_position, logdensity_fn)
    inverse_mass_matrix = jnp.ones(pytree_size(initial_position))

    def step(state, key):


        adaptation_state, kernel_state = state
        # jax.debug.print("step {x}", x=jnp.exp(adaptation_state.log_step_size))

        # print("step size", jnp.exp(adaptation_state.log_step_size))

        new_kernel_state, info = kernel(
            key,
            kernel_state,
            logdensity_fn,
            jnp.exp(adaptation_state.log_step_size),
            inverse_mass_matrix,
        )

        new_adaptation_state = da_update(
            adaptation_state,
            info.acceptance_rate,
        )

        return (
            (new_adaptation_state, new_kernel_state),
            info,
        )

    keys = jax.random.split(rng_key, num_steps)
    init_state = da_init(initial_step_size), init_kernel_state
    (adaptation_state, kernel_state), info = jax.lax.scan(
        step,
        init_state,
        keys,
    )
    return (
        kernel_state,
        {
            "step_size": da_final(adaptation_state),
            "inverse_mass_matrix": inverse_mass_matrix,
        },
        info,
    )


state = blackjax.nuts.init(initial_position, logdensity_fn)


warmup_key = jax.random.PRNGKey(1)

state, params, adaptation_info = da_adaptation(
                rng_key=warmup_key,
                initial_position=initial_position,
                algorithm=blackjax.nuts,
                integrator=integrator,
                logdensity_fn=logdensity_fn,
                num_steps=num_tuning_steps,
                target_acceptance_rate=target_acc_rate,
            )

print("Reported acceptance rate: ", adaptation_info.acceptance_rate.mean())

alg = blackjax.nuts(
            logdensity_fn=logdensity_fn,
            step_size=params["step_size"],
            inverse_mass_matrix=params["inverse_mass_matrix"],
            integrator=integrator,
            # cos_angle_termination=cos_angle_termination,
        )

final_output, (_, info) = run_inference_algorithm(
            rng_key=rng_key,
            initial_state=state,
            inference_algorithm=alg,
            num_steps=num_steps,
            transform=lambda x, i:(x,i),
        )

print("Observed acceptance rate: ", info.acceptance_rate.mean())