Skip to content

Commit

Permalink
Merge 092faf2 into 90c7286
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Oct 12, 2016
2 parents 90c7286 + 092faf2 commit 9a97fe8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
27 changes: 20 additions & 7 deletions pymc3/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,28 @@ def set_key_if_doesnt_exist(d, key, value):
if rope is not None:
display_rope(rope)

def create_axes_grid(figsize, varnames):
n = np.ceil(len(varnames) / 2.0).astype(int)
def create_axes_grid(figsize, traces):
n = np.ceil(len(traces) / 2.0).astype(int)
if figsize is None:
figsize = (12, n * 2.5)
fig, ax = plt.subplots(n, 2, figsize=figsize)
ax = ax.reshape(2 * n)
if len(varnames) % 2 == 1:
if len(traces) % 2 == 1:
ax[-1].set_axis_off()
ax = ax[:-1]
return ax, fig

def get_trace_dict(tr, varnames):
traces = {}
for v in varnames:
vals = tr.get_values(v, combine=True, squeeze=True)
if vals.ndim>1:
vals_flat = vals.reshape(vals.shape[0], -1).T
for i,vi in enumerate(vals_flat):
traces['_'.join([v,str(i)])] = vi
else:
traces[v] = vals
return traces

if isinstance(trace, np.ndarray):
if figsize is None:
Expand All @@ -770,12 +782,13 @@ def create_axes_grid(figsize, varnames):
else:
varnames = [name for name in trace.varnames if not name.endswith('_')]

trace_dict = get_trace_dict(trace, varnames)

if ax is None:
ax, fig = create_axes_grid(figsize, varnames)
ax, fig = create_axes_grid(figsize, trace_dict)

for a, v in zip(ax, varnames):
tr_values = transform(trace.get_values(
v, combine=True, squeeze=True))
for a, v in zip(ax, trace_dict):
tr_values = transform(trace_dict[v])
plot_posterior_op(tr_values, ax=a)
a.set_title(v)

Expand Down
3 changes: 1 addition & 2 deletions pymc3/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def test_plots_multidimensional():
trace = sample(3000, step, start)

traceplot(trace)
# forestplot(trace)
# autocorrplot(trace)
plot_posterior(trace)


def test_multichain_plots():
Expand Down

0 comments on commit 9a97fe8

Please sign in to comment.