From 5763fc0c1f58a0f7723cf1fa503135582088678e Mon Sep 17 00:00:00 2001 From: karm-patel Date: Fri, 17 Jun 2022 17:23:22 +0530 Subject: [PATCH 1/2] Arviz trace method added in `blackjax_utils` --- probml_utils/__init__.py | 1 + probml_utils/blackjax_utils.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 probml_utils/blackjax_utils.py diff --git a/probml_utils/__init__.py b/probml_utils/__init__.py index bd6b933..85c819d 100644 --- a/probml_utils/__init__.py +++ b/probml_utils/__init__.py @@ -1,5 +1,6 @@ from ._version import version as __version__ from .plotting import savefig, latexify, _get_fig_name, is_latexify_enabled +from .blackjax_utils import arviz_trace_from_states from .pyprobml_utils import ( hinton_diagram, plot_ellipse, diff --git a/probml_utils/blackjax_utils.py b/probml_utils/blackjax_utils.py new file mode 100644 index 0000000..1b9c1d1 --- /dev/null +++ b/probml_utils/blackjax_utils.py @@ -0,0 +1,34 @@ +import arviz as az +import jax.numpy as jnp + +def arviz_trace_from_states(states, info, burn_in=0): + """ + args: + ........... + states: contains samples returned by blackjax model (i.e HMCState) + info: conatins the meta info returned by blackjax model (i.e HMCinfo) + + returns: + ........... + trace: arviz trace object + """ + if isinstance(states.position, jnp.DeviceArray): #if states.position is array of samples + samples = {"samples":jnp.swapaxes(states.position,0,1)} + divergence = jnp.swapaxes(info.is_divergent, 0, 1) + + else: # if states.position is dict + samples = {} + for param in states.position.keys(): + ndims = len(states.position[param].shape) + if ndims == 2: + samples[param] = jnp.swapaxes(states.position[param], 0, 1)[:, burn_in:] # swap n_samples and n_chains + divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1) + + if ndims == 1: + divergence = info.is_divergent + samples[param] = states.position[param] + + trace_posterior = az.convert_to_inference_data(samples) + trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats") + trace = az.concat(trace_posterior, trace_sample_stats) + return trace \ No newline at end of file From 7435012261d9b386a93f0505f2949768421ae352 Mon Sep 17 00:00:00 2001 From: karm-patel Date: Fri, 17 Jun 2022 17:46:09 +0530 Subject: [PATCH 2/2] Added `inference_loop_multiple_chains` in `blackjax_utils` --- probml_utils/blackjax_utils.py | 18 +++++++++++++++++- requirements.txt | 3 ++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/probml_utils/blackjax_utils.py b/probml_utils/blackjax_utils.py index 1b9c1d1..f0a05d6 100644 --- a/probml_utils/blackjax_utils.py +++ b/probml_utils/blackjax_utils.py @@ -1,5 +1,6 @@ import arviz as az import jax.numpy as jnp +import jax def arviz_trace_from_states(states, info, burn_in=0): """ @@ -31,4 +32,19 @@ def arviz_trace_from_states(states, info, burn_in=0): trace_posterior = az.convert_to_inference_data(samples) trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats") trace = az.concat(trace_posterior, trace_sample_stats) - return trace \ No newline at end of file + return trace + +def inference_loop_multiple_chains(rng_key, kernel, initial_states, num_samples, num_chains): + ''' + returns dict: {"states": states, "info": info} + Visit this page for more info: https://blackjax-devs.github.io/blackjax/examples/Introduction.html + ''' + def one_step(states, rng_key): + keys = jax.random.split(rng_key, num_chains) + states, info = jax.vmap(kernel)(keys, states) + return states, {"states": states, "info": info} + + keys = jax.random.split(rng_key, num_samples) + _, states_and_info = jax.lax.scan(one_step, initial_states, keys) + + return states_and_info diff --git a/requirements.txt b/requirements.txt index 2c4ae80..3733f7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ pandas TexSoup firebase_admin regex -umap-learn \ No newline at end of file +umap-learn +arviz \ No newline at end of file