Skip to content

Commit

Permalink
Fixed y-axis bug in forestplot; added transform argument to summary
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Mar 7, 2017
1 parent f2f82b5 commit 79f0174
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pymc3/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

# Define range of y-axis
interval_plot.set_ylim(-var + 0.5, -0.5)
interval_plot.set_ylim(-var-0.5, -0.5)

datarange = plotrange[1] - plotrange[0]
interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)
Expand Down
8 changes: 5 additions & 3 deletions pymc3/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ def _hpd_df(x, alpha):
return pd.DataFrame(hpd(x, alpha), columns=cnames)


def summary(trace, varnames=None, alpha=0.05, start=0, batches=None, roundto=3,
include_transformed=False, to_file=None):
def summary(trace, varnames=None, transform=lambda x: x, alpha=0.05, start=0,
batches=None, roundto=3, include_transformed=False, to_file=None):
R"""
Generate a pretty-printed summary of the node.
Expand All @@ -644,6 +644,8 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=None, roundto=3,
varnames : list of strings
List of variables to summarize. Defaults to None, which results
in all variables summarized.
transform : callable
Function to transform data (defaults to identity)
alpha : float
The alpha level for generating posterior intervals. Defaults to
0.05.
Expand Down Expand Up @@ -682,7 +684,7 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=None, roundto=3,

for var in varnames:
# Extract sampled values
sample = trace.get_values(var, burn=start, combine=True)
sample = transform(trace.get_values(var, burn=start, combine=True))

fh.write('\n%s:\n\n' % var)

Expand Down

0 comments on commit 79f0174

Please sign in to comment.