diff --git a/pymc3/plots.py b/pymc3/plots.py index bc757267e9..7c1c8da645 100644 --- a/pymc3/plots.py +++ b/pymc3/plots.py @@ -7,7 +7,7 @@ def traceplot(trace, vars=None, figsize=None, - lines=None, combined=False, grid=True): + lines=None, combined=False, grid=True, ax=None): """Plot samples histograms and values Parameters @@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None, (default), chains will be plotted separately. grid : bool Flag for adding gridlines to histogram. Defaults to True. + ax : axes + Matplotlib axes. Defaults to None. Returns ------- - fig : figure object + ax : matplotlib axes """ import matplotlib.pyplot as plt @@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None, if figsize is None: figsize = (12, n*2) - fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize) + if ax is None: + fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize) + elif ax.shape != (n,2): + print('traceplot requires n*2 subplots') + return None for i, v in enumerate(vars): for d in trace.get_values(v, combine=combined, squeeze=False): @@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None, pass plt.tight_layout() - return fig + return ax def histplot_op(ax, data): for i in range(data.shape[1]): @@ -128,23 +134,45 @@ def kde2plot_op(ax, x, y, grid=200): extent=[xmin, xmax, ymin, ymax]) -def kdeplot(data): - f, ax = subplots(1, 1, squeeze=True) +def kdeplot(data, ax=None): + if ax is None: + f, ax = subplots(1, 1, squeeze=True) kdeplot_op(ax, data) - return f + return ax -def kde2plot(x, y, grid=200): - f, ax = subplots(1, 1, squeeze=True) +def kde2plot(x, y, grid=200, ax=None): + if ax is None: + f, ax = subplots(1, 1, squeeze=True) kde2plot_op(ax, x, y, grid) - return f + return ax -def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1): - """Bar plot of the autocorrelation function for a trace""" +def autocorrplot(trace, vars=None, max_lag=100, burn=0, ax=None): + """Bar plot of the autocorrelation function for a trace + + Parameters + ---------- + + trace : result of MCMC run + vars : list of variable names + Variables to be plotted, if None all variable are plotted + max_lag : int + Maximum lag to calculate autocorrelation. Defaults to 100. + burn : int + Number of samples to discard from the beginning of the trace. + Defaults to 0. + ax : axes + Matplotlib axes. Defaults to None. + + Returns + ------- + + ax : matplotlib axes + + """ + import matplotlib.pyplot as plt - if fontmap is None: - fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4} if vars is None: vars = trace.varnames @@ -153,13 +181,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1): chains = trace.nchains - f, ax = plt.subplots(len(vars), chains, squeeze=False) + fig, ax = plt.subplots(len(vars), chains, squeeze=False) max_lag = min(len(trace) - 1, max_lag) for i, v in enumerate(vars): for j in range(chains): - d = np.squeeze(trace.get_values(v, chains=[j],burn=burn,thin=thin)) + d = np.squeeze(trace.get_values(v, chains=[j], burn=burn)) ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag) @@ -169,13 +197,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1): if chains > 1: ax[i, j].set_title("chain {0}".format(j+1)) - - # Smaller tick labels - tlabels = plt.gca().get_xticklabels() - plt.setp(tlabels, 'fontsize', fontmap[1]) - - tlabels = plt.gca().get_yticklabels() - plt.setp(tlabels, 'fontsize', fontmap[1]) + + return (fig, ax) def var_str(name, shape): @@ -200,7 +223,7 @@ def var_str(name, shape): def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xrange=None, ylabels=None, - chain_spacing=0.05, vline=0): + chain_spacing=0.05, vline=0, gs=None): """ Forest plot (model summary plot) Generates a "forest plot" of 100*(1-alpha)% credible intervals for either @@ -245,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, vline (optional): numeric Location of vertical reference line (defaults to 0). + + gs : GridSpec + Matplotlib GridSpec object. Defaults to None. + + Returns + ------- + + gs : matplotlib GridSpec """ import matplotlib.pyplot as plt @@ -270,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, # Number of chains chains = None - # Gridspec - gs = None - # Subplots interval_plot = None rhat_plot = None