Skip to content

Commit

Permalink
Merge 6461c6b into 2debf77
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Miles committed Jul 17, 2019
2 parents 2debf77 + 6461c6b commit 2ebf2ed
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 58 deletions.
2 changes: 1 addition & 1 deletion mcmcplot/__init__.py
@@ -1 +1 @@
__version__ = "0.1.0rc1"
__version__ = "1.0.0rc1"
1 change: 0 additions & 1 deletion mcmcplot/__version__.py

This file was deleted.

112 changes: 73 additions & 39 deletions mcmcplot/mcmatplot.py
Expand Up @@ -8,12 +8,12 @@

# import required packages
from __future__ import division
import math
import matplotlib.pyplot as plt
from pylab import hist
from .utilities import generate_names, make_x_grid
from .utilities import check_settings, generate_subplot_grid
from .utilities import generate_ellipse_plot_points
from .utilities import setup_subsample
import numpy as np

import warnings
Expand All @@ -27,7 +27,9 @@


# --------------------------------------------
def plot_density_panel(chains, names=None, settings=None):
def plot_density_panel(chains, names=None, settings=None,
return_kde=False, hist_on=False,
return_settings=False):
'''
Plot marginal posterior densities
Expand All @@ -50,14 +52,14 @@ def plot_density_panel(chains, names=None, settings=None):
'plot': dict(color='k', marker=None, linestyle='-', linewidth=3),
'xlabel': {},
'ylabel': {},
'hist_on': False,
'hist': dict(density=True),
}
settings = check_settings(
default_settings=default_settings, user_settings=settings)
nsimu, nparam = chains.shape # number of rows, number of columns
ns1, ns2 = generate_subplot_grid(nparam)
names = generate_names(nparam, names)
kdehandle = []
f = plt.figure(**settings['fig']) # initialize figure
for ii in range(nparam):
# define chain
Expand All @@ -67,20 +69,30 @@ def plot_density_panel(chains, names=None, settings=None):
# Compute kernel density estimate
kde = KDEMultivariate(chain, **settings['kde'])
# plot density on subplot
plt.subplot(ns1, ns2, ii+1)
if settings['hist_on'] is True: # include histograms
plt.subplot(ns1, ns2, ii + 1)
if hist_on is True: # include histograms
hist(chain, **settings['hist'])
plt.plot(chain_grid, kde.pdf(chain_grid), **settings['plot'])
# format figure
plt.xlabel(names[ii], **settings['xlabel'])
plt.ylabel(str('$\\pi$({}$|M^{}$)'.format(names[ii], '{data}')),
**settings['ylabel'])
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f, settings
kdehandle.append(kde)
# setup output
if return_kde is True and return_settings is True:
return f, settings, kdehandle
elif return_kde is True and return_settings is False:
return f, kdehandle
elif return_kde is False and return_settings is True:
return f, settings
else:
return f


# --------------------------------------------
def plot_histogram_panel(chains, names=None, settings=None):
def plot_histogram_panel(chains, names=None,
settings=None, return_settings=False):
"""
Plot histogram from each parameter's sampling history
Expand Down Expand Up @@ -115,19 +127,23 @@ def plot_histogram_panel(chains, names=None, settings=None):
# define chain
chain = chains[:, ii].reshape(nsimu, 1) # check indexing
# plot density on subplot
ax = plt.subplot(ns1, ns2, ii+1)
ax = plt.subplot(ns1, ns2, ii + 1)
hist(chain, **settings['hist'])
# format figure
plt.xlabel(names[ii], **settings['xlabel'])
plt.ylabel(**settings['ylabel'])
if settings['turn_yticks_on'] is False:
ax.get_yaxis().set_ticks([])
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f, settings
if return_settings is True:
return f, settings
else:
return f


# --------------------------------------------
def plot_chain_panel(chains, names=None, settings=None):
def plot_chain_panel(chains, names=None, settings=None,
skip=1, maxpoints=500, return_settings=False):
"""
Plot sampling chain for each parameter
Expand All @@ -138,13 +154,16 @@ def plot_chain_panel(chains, names=None, settings=None):
of each parameter
* **settings** (:py:class:`dict`): Settings for features \
of this method.
* **skip** (:py:class:`int`): Indicates step size to be used when
plotting elements from the chain
* **maxpoints** (:py:class:`int`): Max number of display points
- keeps scatter plot from becoming overcrowded
Returns:
* (:py:class:`tuple`): (figure handle, settings actually \
used in program)
"""
default_settings = {
'maxpoints': 500,
'fig': dict(figsize=(5, 4), dpi=100),
'plot': dict(color='b', marker='.', linestyle='none'),
'xlabel': {'xlabel': 'Iteration'},
Expand All @@ -159,54 +178,61 @@ def plot_chain_panel(chains, names=None, settings=None):
nsimu, nparam = chains.shape # number of rows, number of columns
ns1, ns2 = generate_subplot_grid(nparam)
names = generate_names(nparam, names)
skip = 1
if nsimu > settings['maxpoints']:
skip = int(math.floor(nsimu/settings['maxpoints']))
# setup sample indices
inds = setup_subsample(skip, maxpoints, nsimu)
f = plt.figure(**settings['fig']) # initialize figure
for ii in range(nparam):
# define chain
chain = chains[:, ii].reshape(nsimu, 1) # check indexing
chain = chains[inds, ii] # check indexing
# plot chain on subplot
plt.subplot(ns1, ns2, ii+1)
plt.plot(range(0, nsimu, skip), chain[range(0, nsimu, skip), 0],
plt.subplot(ns1, ns2, ii + 1)
plt.plot(inds, chain,
**settings['plot'])
# format figure
plt.xlabel(**settings['xlabel'])
plt.ylabel(str('{}'.format(names[ii])), **settings['ylabel'])
if ii+1 <= ns1*ns2 - ns2:
if ii + 1 <= ns1*ns2 - ns2:
plt.xlabel('')
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
if settings['add_pm2std'] is True:
mu = np.mean(chain)
sig = np.std(chain)
plt.plot(range(0, nsimu), np.ones([nsimu, 1])*mu,
plt.plot(inds, np.ones(inds.shape)*mu,
**settings['mean'])
plt.plot(range(0, nsimu), np.ones([nsimu, 1])*mu + 2*sig,
plt.plot(inds, np.ones(inds.shape)*mu + 2*sig,
**settings['sig'])
plt.plot(range(0, nsimu), np.ones([nsimu, 1])*mu - 2*sig,
plt.plot(inds, np.ones(inds.shape)*mu - 2*sig,
**settings['sig'])
return f, settings
if return_settings is True:
return f, settings
else:
return f


# --------------------------------------------
def plot_pairwise_correlation_panel(chains, names=None, settings=None):
def plot_pairwise_correlation_panel(chains, names=None, settings=None,
skip=1, maxpoints=500,
return_settings=False):
"""
Plot pairwise correlation for each parameter
Args:
* **chains** (:class:`~numpy.ndarray`): Sampling chain \
for each parameter
for each parameter
* **names** (:py:class:`list`): List of strings - name \
of each parameter
* **settings** (:py:class:`dict`): Settings for features \
of this method.
of each parameter
* **settings** (:py:class:`dict`): Settings for figure \
features made by this method.
* **skip** (:py:class:`int`): Indicates step size to be used when
plotting elements from the chain
* **maxpoints** (py:class:`int`): Maximum allowable number of points
in plot.
Returns:
* (:py:class:`tuple`): (figure handle, settings actually \
used in program)
"""
default_settings = {
'skip': 1,
'fig': dict(figsize=(7, 5), dpi=100),
'plot': dict(color='b', marker='.', linestyle='none'),
'xlabel': {},
Expand All @@ -225,30 +251,31 @@ def plot_pairwise_correlation_panel(chains, names=None, settings=None):
nsimu, nparam = chains.shape # number of rows, number of columns
ns1, ns2 = generate_subplot_grid(nparam)
names = generate_names(nparam, names)
inds = range(0, nsimu, settings['skip'])
inds = setup_subsample(skip, maxpoints, nsimu)
f = plt.figure(**settings['fig']) # initialize figure
for jj in range(2, nparam+1):
for jj in range(2, nparam + 1):
for ii in range(1, jj):
chain1 = chains[inds, ii-1]
chain1 = chains[inds, ii - 1]
chain1 = chain1.reshape(nsimu, 1)
chain2 = chains[inds, jj-1]
chain2 = chains[inds, jj - 1]
chain2 = chain2.reshape(nsimu, 1)
# plot density on subplot
ax = plt.subplot(nparam-1, nparam-1, (jj-2)*(nparam-1)+ii)
ax = plt.subplot(nparam - 1, nparam - 1, (jj - 2)*(nparam - 1)+ii)
plt.plot(chain1, chain2, **settings['plot'])
# format figure
if jj != nparam: # rm xticks
ax.set_xticklabels([])
if ii != 1: # rm yticks
ax.set_yticklabels([])
if ii == 1: # add ylabels
plt.ylabel(str('{}'.format(names[jj-1])), **settings['ylabel'])
plt.ylabel(str('{}'.format(names[jj - 1])),
**settings['ylabel'])
if ii == jj - 1:
if nparam == 2: # add xlabels
plt.xlabel(str('{}'.format(names[ii-1])),
plt.xlabel(str('{}'.format(names[ii - 1])),
**settings['xlabel'])
else: # add title
plt.title(str('{}'.format(names[ii-1])),
plt.title(str('{}'.format(names[ii - 1])),
**settings['title'])
if settings['add_5095_contours'] is True:
contours = generate_ellipse_plot_points(
Expand All @@ -263,11 +290,15 @@ def plot_pairwise_correlation_panel(chains, names=None, settings=None):
ax = plt.gca()
h, labs = ax.get_legend_handles_labels()
plt.figlegend(h, labs, **settings['legend'])
return f, settings
if return_settings is True:
return f, settings
else:
return f


# --------------------------------------------
def plot_chain_metrics(chain, name=None, settings=None):
def plot_chain_metrics(chain, name=None, settings=None,
return_settings=False):
'''
Plot chain metrics for individual chain
Expand Down Expand Up @@ -311,7 +342,10 @@ def plot_chain_metrics(chain, name=None, settings=None):
plt.xlabel(name, **settings['xlabel'])
plt.ylabel(str('Histogram of {}-chain'.format(name)), **settings['ylabel'])
plt.tight_layout(rect=[0, 0.03, 1, 0.95], h_pad=1.0) # adjust spacing
return f, settings
if return_settings is True:
return f, settings
else:
return f


class Plot:
Expand Down
19 changes: 19 additions & 0 deletions mcmcplot/mcmcplot.py
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 17 11:21:41 2019
@author: prmiles
"""

from . import mcmatplot as _mcmpl
from . import mcseaborn as _mcsns

plot_density_panel = _mcmpl.plot_density_panel
plot_histogram_panel = _mcmpl.plot_histogram_panel
plot_chain_panel = _mcmpl.plot_chain_panel
plot_pairwise_correlation_panel = _mcmpl.plot_pairwise_correlation_panel
plot_chain_metrics = _mcmpl.plot_chain_metrics

plot_joint_distributions = _mcsns.plot_joint_distributions
plot_paired_density_matrix = _mcsns.plot_paired_density_matrix
18 changes: 13 additions & 5 deletions mcmcplot/mcseaborn.py
Expand Up @@ -9,10 +9,11 @@
import pandas as pd
import seaborn as sns
from .utilities import generate_names, check_settings
from .utilities import setup_subsample


def plot_joint_distributions(chains, names=None, sns_style='white',
settings=None):
settings=None, maxpoints=500, skip=1):
"""
Plot joint distribution for each parameter set.
Expand All @@ -27,6 +28,10 @@ def plot_joint_distributions(chains, names=None, sns_style='white',
Default is `white`.
* **settings** (:py:class:`dict`): Settings for features \
of this method.
* **skip** (:py:class:`int`): Indicates step size to be used when
plotting elements from the chain
* **maxpoints** (:py:class:`int`): Max number of display points
- keeps scatter plot from becoming overcrowded
Returns:
* (:py:class:`tuple`): (figure handle, settings actually \
Expand All @@ -43,12 +48,15 @@ def plot_joint_distributions(chains, names=None, sns_style='white',
sns.set_style(settings['sns_style'], settings['sns'])
nsimu, nparam = chains.shape # number of rows, number of columns
names = generate_names(nparam=nparam, names=names)
inds = range(0, nsimu, settings['skip'])
# setup sample indices
inds = setup_subsample(skip, maxpoints, nsimu)
g = []
for jj in range(2, nparam+1):
for jj in range(2, nparam + 1):
for ii in range(1, jj):
chain1 = pd.Series(chains[inds, ii-1], name=names[ii-1])
chain2 = pd.Series(chains[inds, jj-1], name=names[jj-1])
chain1 = pd.Series(chains[inds, ii - 1],
name=names[ii - 1])
chain2 = pd.Series(chains[inds, jj - 1],
name=names[jj - 1])
# Show the joint distribution using kernel density estimation
a = sns.jointplot(x=chain1, y=chain2, **settings['jointplot'])
g.append(a)
Expand Down
26 changes: 26 additions & 0 deletions mcmcplot/utilities.py
Expand Up @@ -335,3 +335,29 @@ def append_to_nrow_ncol_based_on_shape(sh, nrow, ncol):
nrow.append(sh[0])
ncol.append(sh[1])
return nrow, ncol


def setup_subsample(skip, maxpoints, nsimu):
'''
Setup subsampling from posterior.
When plotting the sampling chain, it is often beneficial to subsample
in order to avoid to dense of plots. This routine determines the
appropriate step size based on the size of the chain (nsimu) and maximum
points allowed to plot (maxpoints). The function checks if the
size of the chain exceeds the maximum number of points allowed in the
plot. If yes, skip is defined such that every the max number of points
are used and sampled evenly from the start to end of the chain. Otherwise
the value of skip is return as defined by the user. A subsample index
is then generated based on the value of skip and the number of simulations.
Args:
* **skip** (:py:class:`int`): User defined skip value.
* **maxpoints** (:py:class:`int`): Maximum points allowed in each plot.
Returns:
* (:py:class:`int`): Skip value.
'''
if nsimu > maxpoints:
skip = int(np.floor(nsimu/maxpoints))
return np.arange(0, nsimu, skip)

0 comments on commit 2ebf2ed

Please sign in to comment.