Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions pymc3/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def traceplot(trace, vars=None, figsize=None,
lines=None, combined=False, grid=True):
lines=None, combined=False, grid=True, ax=None):
"""Plot samples histograms and values

Parameters
Expand All @@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
(default), chains will be plotted separately.
grid : bool
Flag for adding gridlines to histogram. Defaults to True.
ax : axes
Matplotlib axes. Defaults to None.

Returns
-------

fig : figure object
ax : matplotlib axes

"""
import matplotlib.pyplot as plt
Expand All @@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
if figsize is None:
figsize = (12, n*2)

fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
if ax is None:
fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
elif ax.shape != (n,2):
print('traceplot requires n*2 subplots')
return None

for i, v in enumerate(vars):
for d in trace.get_values(v, combine=combined, squeeze=False):
Expand All @@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
pass

plt.tight_layout()
return fig
return ax

def histplot_op(ax, data):
for i in range(data.shape[1]):
Expand Down Expand Up @@ -128,23 +134,45 @@ def kde2plot_op(ax, x, y, grid=200):
extent=[xmin, xmax, ymin, ymax])


def kdeplot(data):
f, ax = subplots(1, 1, squeeze=True)
def kdeplot(data, ax=None):
if ax is None:
f, ax = subplots(1, 1, squeeze=True)
kdeplot_op(ax, data)
return f
return ax


def kde2plot(x, y, grid=200):
f, ax = subplots(1, 1, squeeze=True)
def kde2plot(x, y, grid=200, ax=None):
if ax is None:
f, ax = subplots(1, 1, squeeze=True)
kde2plot_op(ax, x, y, grid)
return f
return ax


def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
"""Bar plot of the autocorrelation function for a trace"""
def autocorrplot(trace, vars=None, max_lag=100, burn=0, ax=None):
"""Bar plot of the autocorrelation function for a trace

Parameters
----------

trace : result of MCMC run
vars : list of variable names
Variables to be plotted, if None all variable are plotted
max_lag : int
Maximum lag to calculate autocorrelation. Defaults to 100.
burn : int
Number of samples to discard from the beginning of the trace.
Defaults to 0.
ax : axes
Matplotlib axes. Defaults to None.

Returns
-------

ax : matplotlib axes

"""

import matplotlib.pyplot as plt
if fontmap is None:
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

if vars is None:
vars = trace.varnames
Expand All @@ -153,13 +181,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):

chains = trace.nchains

f, ax = plt.subplots(len(vars), chains, squeeze=False)
fig, ax = plt.subplots(len(vars), chains, squeeze=False)

max_lag = min(len(trace) - 1, max_lag)

for i, v in enumerate(vars):
for j in range(chains):
d = np.squeeze(trace.get_values(v, chains=[j],burn=burn,thin=thin))
d = np.squeeze(trace.get_values(v, chains=[j], burn=burn))

ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)

Expand All @@ -169,13 +197,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):

if chains > 1:
ax[i, j].set_title("chain {0}".format(j+1))

# Smaller tick labels
tlabels = plt.gca().get_xticklabels()
plt.setp(tlabels, 'fontsize', fontmap[1])

tlabels = plt.gca().get_yticklabels()
plt.setp(tlabels, 'fontsize', fontmap[1])

return (fig, ax)


def var_str(name, shape):
Expand All @@ -200,7 +223,7 @@ def var_str(name, shape):

def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
main=None, xtitle=None, xrange=None, ylabels=None,
chain_spacing=0.05, vline=0):
chain_spacing=0.05, vline=0, gs=None):
""" Forest plot (model summary plot)

Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
Expand Down Expand Up @@ -245,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,

vline (optional): numeric
Location of vertical reference line (defaults to 0).

gs : GridSpec
Matplotlib GridSpec object. Defaults to None.

Returns
-------

gs : matplotlib GridSpec

"""
import matplotlib.pyplot as plt
Expand All @@ -270,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
# Number of chains
chains = None

# Gridspec
gs = None

# Subplots
interval_plot = None
rhat_plot = None
Expand Down