Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions probml_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._version import version as __version__
from .plotting import savefig, latexify, _get_fig_name, is_latexify_enabled
from .blackjax_utils import arviz_trace_from_states
from .pyprobml_utils import (
hinton_diagram,
plot_ellipse,
Expand Down
50 changes: 50 additions & 0 deletions probml_utils/blackjax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import arviz as az
import jax.numpy as jnp
import jax

def arviz_trace_from_states(states, info, burn_in=0):
"""
args:
...........
states: contains samples returned by blackjax model (i.e HMCState)
info: conatins the meta info returned by blackjax model (i.e HMCinfo)

returns:
...........
trace: arviz trace object
"""
if isinstance(states.position, jnp.DeviceArray): #if states.position is array of samples
samples = {"samples":jnp.swapaxes(states.position,0,1)}
divergence = jnp.swapaxes(info.is_divergent, 0, 1)

else: # if states.position is dict
samples = {}
for param in states.position.keys():
ndims = len(states.position[param].shape)
if ndims == 2:
samples[param] = jnp.swapaxes(states.position[param], 0, 1)[:, burn_in:] # swap n_samples and n_chains
divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)

if ndims == 1:
divergence = info.is_divergent
samples[param] = states.position[param]

trace_posterior = az.convert_to_inference_data(samples)
trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats")
trace = az.concat(trace_posterior, trace_sample_stats)
return trace

def inference_loop_multiple_chains(rng_key, kernel, initial_states, num_samples, num_chains):
'''
returns dict: {"states": states, "info": info}
Visit this page for more info: https://blackjax-devs.github.io/blackjax/examples/Introduction.html
'''
def one_step(states, rng_key):
keys = jax.random.split(rng_key, num_chains)
states, info = jax.vmap(kernel)(keys, states)
return states, {"states": states, "info": info}

keys = jax.random.split(rng_key, num_samples)
_, states_and_info = jax.lax.scan(one_step, initial_states, keys)

return states_and_info
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ pandas
TexSoup
firebase_admin
regex
umap-learn
umap-learn
arviz