Skip to content

Commit

Permalink
Added flexibility to plot_gp_dist (#3013)
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck authored and Junpeng Lao committed Jun 15, 2018
1 parent 7f11149 commit 260a0f7
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions pymc3/gp/util.py
Expand Up @@ -69,10 +69,42 @@ def setter(self, val):
return gp_wrapper


def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds"):
""" A helper function for plotting 1D GP posteriors from trace """
def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
""" A helper function for plotting 1D GP posteriors from trace
Parameters
----------
ax : axes
Matplotlib axes.
samples : trace or list of traces
Trace(s) or posterior predictive sample from a GP.
x : array
Grid of X values corresponding to the samples.
plot_samples: bool
Plot the GP samples along with posterior (defaults True).
palette: str
Palette for coloring output (defaults to "Reds").
fill_alpha : float
Alpha value for the posterior interval fill (defaults to 0.8).
samples_alpha : float
Alpha value for the sample lines (defaults to 0.1).
fill_kwargs : dict
Additional arguments for posterior interval fill (fill_between).
samples_kwargs : dict
Additional keyword arguments for samples plot.
Returns
-------
ax : Matplotlib axes
"""
import matplotlib.pyplot as plt

if fill_kwargs is None:
fill_kwargs = {}
if samples_kwargs is None:
samples_kwargs = {}

cmap = plt.get_cmap(palette)
percs = np.linspace(51, 99, 40)
colors = (percs - np.min(percs)) / (np.max(percs) - np.min(percs))
Expand All @@ -82,8 +114,11 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds"):
upper = np.percentile(samples, p, axis=1)
lower = np.percentile(samples, 100-p, axis=1)
color_val = colors[i]
ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=0.8)
ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=fill_alpha, **fill_kwargs)
if plot_samples:
# plot a few samples
idx = np.random.randint(0, samples.shape[1], 30)
ax.plot(x, samples[:,idx], color=cmap(0.9), lw=1, alpha=0.1)
ax.plot(x, samples[:,idx], color=cmap(0.9), lw=1, alpha=samples_alpha,
**samples_kwargs)

return ax

0 comments on commit 260a0f7

Please sign in to comment.