Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
95 lines (84 sloc) 3.01 KB
"""PyMC3 Plotting.
Plots are delegated to the ArviZ library, a general purpose library for
"exploratory analysis of Bayesian models." See
for details on plots.
import functools
import sys
import warnings
import arviz as az
def map_args(func):
swaps = [
('varnames', 'var_names')
def wrapped(*args, **kwargs):
for (old, new) in swaps:
if old in kwargs and new not in kwargs:
warnings.warn('Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8'.format(old=old, new=new))
kwargs[new] = kwargs.pop(old)
return func(*args, **kwargs)
return wrapped
# pymc3 custom plots: override these names for custom behavior
autocorrplot = map_args(az.plot_autocorr)
forestplot = map_args(az.plot_forest)
kdeplot = map_args(az.plot_kde)
plot_posterior = map_args(az.plot_posterior)
energyplot = map_args(az.plot_energy)
densityplot = map_args(az.plot_density)
pairplot = map_args(az.plot_pair)
# Use compact traceplot by default
def traceplot(*args, **kwargs):
kwargs.setdefault('compact', True)
return az.plot_trace(*args, **kwargs)
except TypeError:
return az.plot_trace(*args, **kwargs)
# addition arg mapping for compare plot
def compareplot(*args, **kwargs):
if 'comp_df' in kwargs:
comp_df = kwargs['comp_df'].copy()
args = list(args)
comp_df = args[0].copy()
if 'WAIC' in comp_df.columns:
comp_df = comp_df.rename(index=str,
columns={'WAIC': 'waic',
'pWAIC': 'p_waic',
'dWAIC': 'd_waic',
'SE': 'se',
'dSE': 'dse',
'var_warn': 'warning'})
elif 'LOO' in comp_df.columns:
comp_df = comp_df.rename(index=str,
columns={'LOO': 'loo',
'pLOO': 'p_loo',
'dLOO': 'd_loo',
'SE': 'se',
'dSE': 'dse',
'shape_warn': 'warning'})
if 'comp_df' in kwargs:
kwargs['comp_df'] = comp_df
args[0] = comp_df
return az.plot_compare(*args, **kwargs)
from .posteriorplot import plot_posterior_predictive_glm
# Access to arviz plots: base plots provided by arviz
for plot in az.plots.__all__:
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot)))
__all__ = tuple(az.plots.__all__) + (
You can’t perform that action at this time.