In [1]:
import blackjax
import jax
import jax.numpy as jnp
from blackjax.util import run_inference_algorithm
from typing import NamedTuple

logdensity_fn = lambda x: -jnp.sum(x**2)/2
integrator = blackjax.mcmc.integrators.isokinetic_velocity_verlet
target_acc_rate = 0.8
rng_key = jax.random.PRNGKey(0)
initial_position = jnp.ones(2)
state = blackjax.nuts.init(initial_position, logdensity_fn)
num_steps = 10000
num_tuning_steps = 10000
return_only_final = True







In [2]:
class RobnikStepSizeTuningState(NamedTuple):
    time : jnp.ndarray
    step_size: float
    x_average: float
    step_size_max: float
    num_dimensions: int

def robnik_step_size_tuning(desired_energy_var, trust_in_estimate=1.5, num_effective_samples=150, step_size_max=jnp.inf):
      
    decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

    def init(initial_step_size, num_dimensions):
        return RobnikStepSizeTuningState(time=0.0, x_average=0.0, step_size=initial_step_size, step_size_max=step_size_max, num_dimensions=num_dimensions)
      
    def update(robnik_state, energy_change):

        xi = (
            jnp.square(energy_change) / (robnik_state.num_dimensions * desired_energy_var)
        ) + 1e-8  # 1e-8 is added to avoid divergences in log xi
        weight = jnp.exp(
            -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate))
        )  # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one.

        x_average = decay_rate * robnik_state.x_average + weight * (
            xi / jnp.power(robnik_state.step_size, 6.0)
        )
        
        time = decay_rate * robnik_state.time + weight
        step_size = jnp.power(
            x_average / time, -1.0 / 6.0
        )  # We use the Var[E] = O(eps^6) relation here.
        step_size = (step_size < robnik_state.step_size_max) * step_size + (
            step_size > robnik_state.step_size_max
        ) * robnik_state.step_size_max  # if the proposed stepsize is above the stepsize where we have seen divergences

        return RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions)


    def final(robnik_state):
        return robnik_state.step_size

    return init, update, final


In [3]:
import jax
import jax.numpy as jnp

from typing import Callable, NamedTuple

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.mass_matrix import (
    MassMatrixAdaptationState,
    mass_matrix_adaptation,
)
from blackjax.optimizers.dual_averaging import dual_averaging

from blackjax.base import AdaptationAlgorithm
from blackjax.progress_bar import gen_scan_fn
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size
from blackjax.adaptation.window_adaptation import build_schedule
from jax.flatten_util import ravel_pytree
from blackjax.diagnostics import effective_sample_size
# class DualAveragingAdaptationState(NamedTuple):
#     log_step_size: float
#     log_step_size_avg: float
#     step: int
#     avg_error: float
#     mu: float

# def dual_averaging_adaptation(
#     target: float, t0: int = 10, gamma: float = 0.05, kappa: float = 0.75
# ) -> tuple[Callable, Callable, Callable]:
    
#     da_init, da_update, da_final = dual_averaging(t0, gamma, kappa)

#     def init(inital_step_size: float) -> DualAveragingAdaptationState:
        
#         return DualAveragingAdaptationState(*da_init(inital_step_size))

#     def update(
#         da_state: DualAveragingAdaptationState, value: float
#     ) -> DualAveragingAdaptationState:
        
#         gradient = target - value
#         return DualAveragingAdaptationState(*da_update(da_state, gradient))

#     def final(da_state: DualAveragingAdaptationState) -> float:
#         return jnp.exp(da_state.log_step_size_avg)

#     return init, update, final

class AlbaAdaptationState(NamedTuple):
    ss_state: RobnikStepSizeTuningState  # step size
    imm_state: MassMatrixAdaptationState  # inverse mass matrix
    step_size: float
    inverse_mass_matrix: Array
    L : float

def base(
    is_mass_matrix_diagonal: bool,
    v,
    target_eevpd,
) -> tuple[Callable, Callable, Callable]:
    
    mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal)

    # step_size_init, step_size_update, step_size_final = dual_averaging_adaptation(target_eevpd)
    step_size_init, step_size_update, step_size_final = robnik_step_size_tuning(desired_energy_var=target_eevpd)

    def init(
        position: ArrayLikeTree,
    ) -> AlbaAdaptationState:
        
        num_dimensions = pytree_size(position)
        imm_state = mm_init(num_dimensions)

        ss_state = step_size_init(initial_step_size=jnp.sqrt(num_dimensions)/5, num_dimensions=num_dimensions)

        return AlbaAdaptationState(
            ss_state,
            imm_state,
            ss_state.step_size,
            imm_state.inverse_mass_matrix,
            L = jnp.sqrt(num_dimensions)/v
        )

    def fast_update(
        position: ArrayLikeTree,
        value: float,
        warmup_state: AlbaAdaptationState,
    ) -> AlbaAdaptationState:
        """Update the adaptation state when in a "fast" window.

        Only the step size is adapted in fast windows. "Fast" refers to the fact
        that the optimization algorithms are relatively fast to converge
        compared to the covariance estimation with Welford's algorithm

        """

        del position


        new_ss_state =  step_size_update(warmup_state.ss_state, value)
        new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size)
        
        return AlbaAdaptationState(
            new_ss_state,
            warmup_state.imm_state,
            new_step_size,
            warmup_state.inverse_mass_matrix,
            L = warmup_state.L
        )

    def slow_update(
        position: ArrayLikeTree,
        value: float,
        warmup_state: AlbaAdaptationState,
    ) -> AlbaAdaptationState:
    
        new_imm_state = mm_update(warmup_state.imm_state, position)
        new_ss_state = step_size_update(warmup_state.ss_state, value)
        new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size)

        return AlbaAdaptationState(
            new_ss_state, new_imm_state, new_step_size, warmup_state.inverse_mass_matrix, L = warmup_state.L
        )

    def slow_final(warmup_state: AlbaAdaptationState) -> AlbaAdaptationState:

        new_imm_state = mm_final(warmup_state.imm_state)
        new_ss_state = step_size_init(step_size_final(warmup_state.ss_state), warmup_state.ss_state.num_dimensions)
        new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size)

        new_L = jnp.sqrt(warmup_state.ss_state.num_dimensions)/v # 

        return AlbaAdaptationState(
            new_ss_state,
            new_imm_state,
            new_step_size,
            new_imm_state.inverse_mass_matrix,
            L = new_L
        )

    def update(
        adaptation_state: AlbaAdaptationState,
        adaptation_stage: tuple,
        position: ArrayLikeTree,
        value: float,
    ) -> AlbaAdaptationState:
        """Update the adaptation state and parameter values.

        Parameters
        ----------
        adaptation_state
            Current adptation state.
        adaptation_stage
            The current stage of the warmup: whether this is a slow window,
            a fast window and if we are at the last step of a slow window.
        position
            Current value of the model parameters.
        value
            Value of the acceptance rate for the last mcmc step.

        Returns
        -------
        The updated adaptation state.

        """
        stage, is_middle_window_end = adaptation_stage

        warmup_state = jax.lax.switch(
            stage,
            (fast_update, slow_update),
            position,
            value,
            adaptation_state,
        )

        warmup_state = jax.lax.cond(
            is_middle_window_end,
            slow_final,
            lambda x: x,
            warmup_state,
        )

        return warmup_state

    def final(warmup_state: AlbaAdaptationState) -> tuple[float, Array]:
        """Return the final values for the step size and mass matrix."""
        step_size = warmup_state.ss_state.step_size 
        # step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg)
        inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix
        L = warmup_state.L
        return step_size, L, inverse_mass_matrix

    return init, update, final

def alba(
    algorithm,
    logdensity_fn: Callable,
    target_eevpd,
    v,
    is_mass_matrix_diagonal: bool = True,
    progress_bar: bool = False,
    adaptation_info_fn: Callable = return_all_adapt_info,
    integrator=mcmc.integrators.velocity_verlet,
    num_alba_steps: int = 500,
    alba_factor: float = 0.4,
    **extra_parameters,
) -> AdaptationAlgorithm:
    

    mcmc_kernel = algorithm.build_kernel(integrator)

    adapt_init, adapt_step, adapt_final = base(
        is_mass_matrix_diagonal,
        target_eevpd=target_eevpd,
        v=v,
    )

    def one_step(carry, xs):
        _, rng_key, adaptation_stage = xs
        state, adaptation_state = carry

        new_state, info = mcmc_kernel(
            rng_key=rng_key,
            state=state,
            logdensity_fn=logdensity_fn,
            step_size=adaptation_state.step_size,
            inverse_mass_matrix=adaptation_state.inverse_mass_matrix,
            L=adaptation_state.L,
            **extra_parameters,
        )
        new_adaptation_state = adapt_step(
            adaptation_state,
            adaptation_stage,
            new_state.position,
            info.energy_change,
        )

        return (
            (new_state, new_adaptation_state),
            adaptation_info_fn(new_state, info, new_adaptation_state),
        )

    def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
        init_key, rng_key, alba_key = jax.random.split(rng_key, 3)
        init_state = algorithm.init(position, logdensity_fn, init_key)
        init_adaptation_state = adapt_init(position)

        if progress_bar:
            print("Running window adaptation")
        scan_fn = gen_scan_fn(num_steps-num_alba_steps, progress_bar=progress_bar)
        start_state = (init_state, init_adaptation_state)
        keys = jax.random.split(rng_key, num_steps-num_alba_steps)
        schedule = build_schedule(num_steps-num_alba_steps)
        last_state, info = scan_fn(
            one_step,
            start_state,
            (jnp.arange(num_steps-num_alba_steps), keys, schedule),
        )

        last_chain_state, last_warmup_state, *_ = last_state
        step_size, L, inverse_mass_matrix = adapt_final(last_warmup_state)

        ###
        ### ALBA TUNING
        ###
        keys = jax.random.split(alba_key, num_alba_steps)
        mcmc_kernel = algorithm.build_kernel(integrator)
        def step(state, key):
            next_state, _ = mcmc_kernel(
                rng_key=key,
                state=state,
                logdensity_fn=logdensity_fn,
                L=L,
                step_size=step_size,
                inverse_mass_matrix=inverse_mass_matrix,
            )

            return next_state, next_state.position
        
        if num_alba_steps > 0:
            _, samples = jax.lax.scan(step, last_chain_state, keys)
            flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)
            ess = effective_sample_size(flat_samples[None, ...])

            L=alba_factor * step_size * jnp.mean(num_alba_steps / ess)
        

        parameters = {
            "step_size": step_size,
            "inverse_mass_matrix": inverse_mass_matrix,
            "L": L,
            **extra_parameters,
        }

        return (
            AdaptationResults(
                last_chain_state,
                parameters,
            ),
            info,
        )

    return AdaptationAlgorithm(run)





In [4]:
from blackjax.adaptation.step_size import (
    dual_averaging_adaptation,
)
from blackjax.mcmc.adjusted_mclmc_dynamic import rescale


def make_random_trajectory_length_fn(random_trajectory_length : bool):
    if random_trajectory_length:
        integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil(
            jax.random.uniform(k) * rescale(avg_num_integration_steps)
        ).astype('int32')
    else:
        integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(
            avg_num_integration_steps
        ).astype('int32')
    return integration_steps_fn

def da_adaptation(
    algorithm,
    logdensity_fn: Callable,
    integration_steps_fn: Callable,
    inverse_mass_matrix,
    initial_step_size: float = 1.0,
    target_acceptance_rate: float = 0.80,
    integrator=blackjax.mcmc.integrators.velocity_verlet, 
):
    
    da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate)
    kernel = algorithm.build_kernel(integrator=integrator)
    
    def step(state, key):

        adaptation_state, kernel_state = state

        new_kernel_state, info = kernel(
            rng_key=key,
            state=kernel_state,
            logdensity_fn=logdensity_fn,
            step_size=jnp.exp(adaptation_state.log_step_size),
            inverse_mass_matrix=inverse_mass_matrix,
            integration_steps_fn=integration_steps_fn,
        )

        new_adaptation_state = da_update(
            adaptation_state,
            info.acceptance_rate,
        )

        return (
            (new_adaptation_state, new_kernel_state),
            info,
        )

    def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):


        init_key, rng_key = jax.random.split(rng_key)
        
        init_kernel_state = algorithm.init(position, logdensity_fn, init_key)

        keys = jax.random.split(rng_key, num_steps)
        init_state = da_init(initial_step_size), init_kernel_state
        (adaptation_state, kernel_state), info = jax.lax.scan(
            step,
            init_state,
            keys,
        )
        step_size = da_final(adaptation_state)
        return (
            kernel_state,
            {
                "step_size": step_size,
                "inverse_mass_matrix": inverse_mass_matrix,
            },
            info,
        )

    return AdaptationAlgorithm(run)


In [5]:
def adjusted_alba(
    unadjusted_algorithm,
    logdensity_fn: Callable,
    target_eevpd,
    v,
    adjusted_algorithm,
    num_dimensions: int,
    integrator,
    target_acceptance_rate: float = 0.80,
    num_alba_steps: int = 500,
    alba_factor: float = 0.4,
    **extra_parameters,
    ):

    unadjusted_warmup = alba(
        algorithm= unadjusted_algorithm,
        logdensity_fn=logdensity_fn,
        target_eevpd=target_eevpd,
        v=v,
        integrator=integrator,
        num_alba_steps=num_alba_steps,
        alba_factor=alba_factor, **extra_parameters)
    
    def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
        
        unadjusted_warmup_key, adjusted_warmup_key = jax.random.split(rng_key)

        (state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_steps)

        avg_num_integration_steps = params["L"] / params["step_size"]

        integration_steps_fn = lambda k: jnp.ceil(
                    jax.random.uniform(k) * rescale(avg_num_integration_steps)
                )

        adjusted_warmup = da_adaptation(
            algorithm=adjusted_algorithm,
            logdensity_fn=logdensity_fn,
            integration_steps_fn=integration_steps_fn,
            initial_step_size=params["step_size"],
            target_acceptance_rate=target_acceptance_rate,
            inverse_mass_matrix=params["inverse_mass_matrix"],
            integrator=integrator, **extra_parameters)
        
        state, params, adaptation_info = adjusted_warmup.run(adjusted_warmup_key, state.position, num_steps)
        params["L"] = adaptation_info.num_integration_steps.mean()*params["step_size"]
        return state, params, adaptation_info
    
    return AdaptationAlgorithm(run)

    

In [1]:
from blackjax.adaptation.adjusted_abla import adjusted_alba
from blackjax.adaptation.unadjusted_step_size import robnik_step_size_tuning
from blackjax.adaptation.unadjusted_alba import unadjusted_alba

import blackjax
import jax
import jax.numpy as jnp
from blackjax.util import run_inference_algorithm
from typing import NamedTuple

logdensity_fn = lambda x: -jnp.sum(x**2)/2
integrator = blackjax.mcmc.integrators.isokinetic_velocity_verlet
target_acc_rate = 0.8
rng_key = jax.random.PRNGKey(0)
initial_position = jnp.ones(2)
state = blackjax.nuts.init(initial_position, logdensity_fn)
num_steps = 10000
num_tuning_steps = 10000
return_only_final = True



alg = blackjax.mclmc
warmup_key = jax.random.PRNGKey(0)
warmup = unadjusted_alba(algorithm=alg, logdensity_fn=logdensity_fn, integrator=integrator, target_eevpd=5e-4, v=1., num_alba_steps=5000)
(state, params), adaptation_info = warmup.run(warmup_key, initial_position, num_tuning_steps)

params



{'step_size': Array(0.636016, dtype=float32),
 'inverse_mass_matrix': Array([1.1305283 , 0.97562194], dtype=float32),
 'L': Array(1.373382, dtype=float32)}

In [2]:
warmup = adjusted_alba(
    unadjusted_algorithm=blackjax.mclmc,
    logdensity_fn=logdensity_fn,
    target_eevpd=5e-4,
    v=1.,
    adjusted_algorithm=blackjax.adjusted_mclmc_dynamic,
    target_acceptance_rate=0.8,
    num_dimensions=2,
    integrator=blackjax.mcmc.integrators.isokinetic_velocity_verlet
)

state, params, adaptation_info = warmup.run(warmup_key, initial_position, num_tuning_steps)
params

{'step_size': Array(1.722377, dtype=float32),
 'inverse_mass_matrix': Array([1.0370758, 1.0950992], dtype=float32),
 'L': Array(3.3546734, dtype=float32)}

In [3]:
warmup = adjusted_alba(
    unadjusted_algorithm=blackjax.langevin,
    logdensity_fn=logdensity_fn,
    target_eevpd=5e-4,
    v=jnp.sqrt(2),
    adjusted_algorithm=blackjax.dynamic_malt,
    target_acceptance_rate=0.8,
    num_dimensions=2,
    integrator=blackjax.mcmc.integrators.velocity_verlet
)

state, params, adaptation_info = warmup.run(warmup_key, initial_position, num_tuning_steps)
params

{'step_size': Array(1.2649875, dtype=float32),
 'inverse_mass_matrix': Array([1.0368719, 1.1122135], dtype=float32),
 'L': Array(2.4872184, dtype=float32)}

In [9]:

# adjusted_warmup_key = jax.random.PRNGKey(0)


# integration_steps_fn = lambda k: jnp.ceil(
#             jax.random.uniform(k) * rescale(avg_num_integration_steps)
#         )



# warmup = da_adaptation(
#                 # rng_key=adjusted_warmup_key,
#                 # initial_position=state.position,
#                 algorithm=blackjax.adjusted_mclmc_dynamic,
#                 integrator=integrator,
#                 logdensity_fn=logdensity_fn,
#                 num_dimensions=pytree_size(state.position),
#                 # num_steps=num_tuning_steps,
#                 target_acceptance_rate=target_acc_rate,
#                 initial_step_size=params["step_size"],
#                 integration_steps_fn=integration_steps_fn,
#             )

# state, params, adaptation_info = warmup.run(adjusted_warmup_key, state.position, num_tuning_steps)

# params["L"] = adaptation_info.num_integration_steps.mean()*params["step_size"]

RobnikStepSizeTuningState(time=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, step_size=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, x_average=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, step_size_max=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, num_dimensions=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>)
RobnikStepSizeTuningState(time=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, step_size=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, x_average=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, step_size_max=inf, num_dimensions=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>)
RobnikStepSizeTuningState(time=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, step_size=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, x_average=Traced<ShapedArray(float32[], weak_type=True)>with<Dyna

In [30]:
params

{'step_size': Array(1.722377, dtype=float32),
 'inverse_mass_matrix': Array([1.0370758, 1.0950992], dtype=float32),
 'L': Array(3.3546734, dtype=float32)}

In [23]:
params

{'step_size': Array(1.798047, dtype=float32),
 'inverse_mass_matrix': Array([1., 1.], dtype=float32),
 'L': Array(3.502056, dtype=float32)}