diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..f315212 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,2 @@ +python: + version: 3.6 diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1be377f..aef88e7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,8 +1,18 @@ Changelog ========= -v1.8.0 (June 28, 2019) +v1.9.0 () --------- +- Added feature to results structure to store all parameters names, regardless of whether or not they are included in sampling. +- Added mcmcplot package to requirements. MCMCPlotting module is noted as deprecated. +- Added new module for uncertainty propagation. Aims to provide more flexible API for user to plot different combinations of credible and prediction intervals. +- Added a plotting routine so that you can plot a 2-D interval in 3-D space. + +v1.8.0 (June 28, 2019) +---------------------- +.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.3261287.svg + :target: https://doi.org/10.5281/zenodo.3261287 + - Added acceptance rate display feature when calling chain statistics - User can specify `skip` or `maxpoints` in pairwise correlation and chain panel plots in order to thin chain. - User can request item definitions when calling `chainstats`. diff --git a/README.rst b/README.rst index a92b17c..222f589 100644 --- a/README.rst +++ b/README.rst @@ -35,7 +35,7 @@ You can also clone the repository and run ``python setup.py install``. Getting Started =============== -- `Tutorial notebooks `_ +- `Tutorial notebooks `_ - `Documentation `_ - `Release history `_ - `Contributing guidelines `_ diff --git a/doc/source/pymcmcstat.rst b/doc/source/pymcmcstat.rst index 3d3b5e8..057a965 100644 --- a/doc/source/pymcmcstat.rst +++ b/doc/source/pymcmcstat.rst @@ -16,6 +16,14 @@ pymcmcstat.ParallelMCMC module :members: :undoc-members: :show-inheritance: + +pymcmcstat.propagation module +------------------------------ + +.. automodule:: pymcmcstat.propagation + :members: + :undoc-members: + :show-inheritance: Subpackages ----------- diff --git a/pymcmcstat/__init__.py b/pymcmcstat/__init__.py index 29654ee..78d7973 100644 --- a/pymcmcstat/__init__.py +++ b/pymcmcstat/__init__.py @@ -1 +1,3 @@ -__version__ = "1.8.0" +__version__ = "1.9.0" + +from mcmcplot import mcmcplot # noqa, flake8 issue diff --git a/pymcmcstat/plotting/MCMCPlotting.py b/pymcmcstat/plotting/MCMCPlotting.py index 983d811..87d86b9 100644 --- a/pymcmcstat/plotting/MCMCPlotting.py +++ b/pymcmcstat/plotting/MCMCPlotting.py @@ -13,6 +13,8 @@ from .utilities import generate_names, setup_plot_features, make_x_grid from .utilities import setup_subsample import warnings +from deprecated import deprecated + try: from statsmodels.nonparametric.kernel_density import KDEMultivariate @@ -21,7 +23,13 @@ - plot_density_panel will not work. {}".format(e))) +def deprecation(message): + warnings.warn(message, DeprecationWarning) + + # -------------------------------------------- +@deprecated(version='1.9.0', + reason='New function: "from pymcmcstat.mcmcplot import plot_density_panel"') def plot_density_panel(chains, names=None, hist_on=False, figsizeinches=None, return_kde=False): ''' @@ -33,6 +41,7 @@ def plot_density_panel(chains, names=None, hist_on=False, figsizeinches=None, * **hist_on** (:py:class:`bool`): Flag to include histogram on density plot * **figsizeinches** (:py:class:`list`): Specify figure size in inches [Width, Height] ''' + deprecation('Recommend using pymcmcstat.mcmcplot.plot_density_panel') nsimu, nparam = chains.shape # number of rows, number of columns ns1, ns2, names, figsizeinches = setup_plot_features( nparam=nparam, names=names, figsizeinches=figsizeinches) @@ -62,6 +71,8 @@ def plot_density_panel(chains, names=None, hist_on=False, figsizeinches=None, # -------------------------------------------- +@deprecated(version='1.9.0', + reason='New function: "from pymcmcstat.mcmcplot import plot_histogram_panel"') def plot_histogram_panel(chains, names=None, figsizeinches=None): """ Plot histogram from each parameter's sampling history @@ -90,6 +101,8 @@ def plot_histogram_panel(chains, names=None, figsizeinches=None): # -------------------------------------------- +@deprecated(version='1.9.0', + reason='New function:"from pymcmcstat.mcmcplot import plot_chain_panel"') def plot_chain_panel(chains, names=None, figsizeinches=None, skip=1, maxpoints=500): """ @@ -128,6 +141,8 @@ def plot_chain_panel(chains, names=None, figsizeinches=None, # -------------------------------------------- +@deprecated(version='1.9.0', + reason='New function: "from pymcmcstat.mcmcplot import plot_pairwise_correlation_panel"') def plot_pairwise_correlation_panel(chains, names=None, figsizeinches=None, skip=1, maxpoints=500): """ @@ -176,6 +191,8 @@ def plot_pairwise_correlation_panel(chains, names=None, figsizeinches=None, # -------------------------------------------- +@deprecated(version='1.9.0', + reason='New function: "from pymcmcstat.mcmcplot import plot_chain_metrics"') def plot_chain_metrics(chain, name=None, figsizeinches=None): ''' Plot chain metrics for individual chain diff --git a/pymcmcstat/propagation.py b/pymcmcstat/propagation.py new file mode 100644 index 0000000..7dfeb24 --- /dev/null +++ b/pymcmcstat/propagation.py @@ -0,0 +1,674 @@ +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- +""" +Created on Wed Nov 8 12:00:11 2017 + +@author: prmiles +""" + +import numpy as np +import sys +from .utilities.progressbar import progress_bar +from .utilities.general import check_settings +import matplotlib.pyplot as plt +from matplotlib import cm +from matplotlib import colors as mplcolor +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +from scipy.interpolate import interp1d + + +def calculate_intervals(chain, results, data, model, s2chain=None, + nsample=500, waitbar=True, sstype=0): + ''' + Calculate distribution of model response to form propagation intervals + + Samples values from chain, performs forward model evaluation, and + tabulates credible and prediction intervals (if obs. error var. included). + + Args: + * **chain** (:class:`~numpy.ndarray`): Parameter chains, expect + shape=(nsimu, npar). + * **results** (:py:class:`dict`): Results dictionary generated by + pymcmcstat. + * **data** (:class:`~.DataStructure`): Data + * **model**: User defined function. Note, if your model outputs + multiple quantities of interest (QoI) at the same time in a + multi-dimensional array, then make sure it is returned as a + (N, p) array where N is the number of evaluation points and + p is the number of QoI. + + Kwargs: + * **s2chain** (:py:class:`float`, :class:`~numpy.ndarray`, or None): + Observation error variance chain. + * **nsample** (:py:class:`int`): No. of samples drawn from posteriors. + * **waitbar** (:py:class:`bool`): Flag to display progress bar. + * **sstype** (:py:class:`int`): Sum-of-squares type. Can be 0 (normal), + 1 (sqrt), or 2 (log). + + Returns: + * :py:class:`dict` with two elements: 1) `credible` and 2) `prediction` + ''' + parind = results['parind'] + q = results['theta'] + nsimu, npar = chain.shape + s2chain = check_s2chain(s2chain, nsimu) + iisample, nsample = define_sample_points(nsample, nsimu) + if waitbar is True: + __wbarstatus = progress_bar(iters=int(nsample)) + + ci = [] + pi = [] + multiple = False + for kk, isa in enumerate(iisample): + # progress bar + if waitbar is True: + __wbarstatus.update(kk) + # extract chain set + q[parind] = chain[kk, :] + # evaluate model + y = model(q, data) + # check model output + if y.ndim == 2: + nrow, ncol = y.shape + if nrow != y.size and ncol != y.size: + multiple = True + if multiple is False: + # store model prediction in credible intervals + ci.append(y.reshape(y.size,)) # store model output + if s2chain is None: + continue + else: + # estimate prediction intervals + s2 = s2chain[kk] + obs = observation_sample(s2, y, sstype) + pi.append(obs.reshape(obs.size,)) + else: + # Model output contains multiple QoI + # Expect ncol = No. of QoI + if kk == 0: + cis = [] + pis = [] + for jj in range(ncol): + cis.append([]) + pis.append([]) + for jj in range(ncol): + # store model prediction in credible intervals + cis[jj].append(y[:, jj]) # store model output + if s2chain is None: + continue + else: + # estimate prediction intervals + if s2chain.ndim == 2: + if s2chain.shape[1] == ncol: + s2 = s2chain[kk, jj] + else: + s2 = s2chain[kk] + else: + s2 = s2chain[kk] + obs = observation_sample(s2, y[:, jj], sstype) + pis[jj].append(obs.reshape(obs.size,)) + + if multiple is False: + # Setup output + credible = np.array(ci) + if s2chain is None: + prediction = None + else: + prediction = np.array(pi) + return dict(credible=credible, + prediction=prediction) + else: + # Setup output for multiple QoI + out = [] + for jj in range(ncol): + credible = np.array(cis[jj]) + if s2chain is None: + prediction = None + else: + prediction = np.array(pis[jj]) + out.append(dict(credible=credible, + prediction=prediction)) + return out + + +# -------------------------------------------- +def plot_intervals(intervals, time, ydata=None, xdata=None, + limits=[95], + adddata=None, addmodel=True, addlegend=True, + addcredible=True, addprediction=True, + data_display={}, model_display={}, interval_display={}, + fig=None, figsize=None, legloc='upper left', + ciset=None, piset=None, + return_settings=False): + ''' + Plot propagation intervals in 2-D + + This routine takes the model distributions generated using the + :func:`~calculate_intervals` method and then plots specific + quantiles. The user can plot just the intervals, or also include the + median model response and/or observations. Specific settings for + credible intervals are controlled by defining the `ciset` dictionary. + Likewise, for prediction intervals, settings are defined using `piset`. + + The setting options available for each interval are as follows: + - `limits`: This should be a list of numbers between 0 and 100, e.g., + `limits=[50, 90]` will result in 50% and 90% intervals. + - `cmap`: The program is designed to "try" to choose colors that + are visually distinct. The user can specify the colormap to choose + from. + - `colors`: The user can specify the color they would like for each + interval in a list, e.g., ['r', 'g', 'b']. This list should have + the same number of elements as `limits` or the code will revert + back to its default behavior. + + Args: + * **intervals** (:py:class:`dict`): Interval dictionary generated + using :meth:`calculate_intervals` method. + * **time** (:class:`~numpy.ndarray`): Independent variable, i.e., + x-axis of plot + + Kwargs: + * **ydata** (:class:`~numpy.ndarray` or None): Observations, expect + 1-D array if defined. + * **xdata** (:class:`~numpy.ndarray` or None): Independent values + corresponding to observations. This is required if the observations + do not align with your times of generating the model response. + * **limits** (:py:class:`list`): Quantile limits that correspond to + percentage size of desired intervals. Note, this is the default + limits, but specific limits can be defined using the `ciset` and + `piset` dictionaries. + * **adddata** (:py:class:`bool`): Flag to include data + * **addmodel** (:py:class:`bool`): Flag to include median model + response + * **addlegend** (:py:class:`bool`): Flag to include legend + * **addcredible** (:py:class:`bool`): Flag to include credible + intervals + * **addprediction** (:py:class:`bool`): Flag to include prediction + intervals + * **model_display** (:py:class:`dict`): Display settings for median + model response + * **data_display** (:py:class:`dict`): Display settings for data + * **interval_display** (:py:class:`dict`): General display settings + for intervals. + * **fig**: Handle of previously created figure object + * **figsize** (:py:class:`tuple`): (width, height) in inches + * **legloc** (:py:class:`str`): Legend location - matplotlib help for + details. + * **ciset** (:py:class:`dict`): Settings for credible intervals + * **piset** (:py:class:`dict`): Settings for prediction intervals + * **return_settings** (:py:class:`bool`): Flag to return ciset and + piset along with fig and ax. + + Returns: + * (:py:class:`tuple`) with elements + 1) Figure handle + 2) Axes handle + 3) Dictionary with `ciset` and `piset` inside (only + outputted if `return_settings=True`) + ''' + # unpack dictionary + credible = intervals['credible'] + prediction = intervals['prediction'] + # Check user-defined settings + ciset = __setup_iset(ciset, + default_iset=dict( + limits=limits, + cmap=None, + colors=None)) + piset = __setup_iset(piset, + default_iset=dict( + limits=limits, + cmap=None, + colors=None)) + # Check limits + ciset['limits'] = _check_limits(ciset['limits'], limits) + piset['limits'] = _check_limits(piset['limits'], limits) + # convert limits to ranges + ciset['quantiles'] = _convert_limits(ciset['limits']) + piset['quantiles'] = _convert_limits(piset['limits']) + # setup display settings + interval_display, model_display, data_display = setup_display_settings( + interval_display, model_display, data_display) + # Define colors + ciset['colors'] = setup_interval_colors(ciset, inttype='ci') + piset['colors'] = setup_interval_colors(piset, inttype='pi') + # Define labels + ciset['labels'] = _setup_labels(ciset['limits'], inttype='CI') + piset['labels'] = _setup_labels(piset['limits'], inttype='PI') + if fig is None: + fig = plt.figure(figsize=figsize) + ax = fig.gca() + time = time.reshape(time.size,) + # add prediction intervals + if addprediction is True: + for ii, quantile in enumerate(piset['quantiles']): + pi = generate_quantiles(prediction, np.array(quantile)) + ax.fill_between(time, pi[0], pi[1], facecolor=piset['colors'][ii], + label=piset['labels'][ii], **interval_display) + # add credible intervals + if addcredible is True: + for ii, quantile in enumerate(ciset['quantiles']): + ci = generate_quantiles(credible, np.array(quantile)) + ax.fill_between(time, ci[0], ci[1], facecolor=ciset['colors'][ii], + label=ciset['labels'][ii], **interval_display) + # add model (median model response) + if addmodel is True: + ci = generate_quantiles(credible, np.array(0.5)) + ax.plot(time, ci, **model_display) + # add data to plot + if ydata is not None and adddata is None: + adddata = True + if adddata is True and ydata is not None: + if xdata is None: + ax.plot(time, ydata, **data_display) + else: + ax.plot(xdata, ydata, **data_display) + # add legend + if addlegend is True: + handles, labels = ax.get_legend_handles_labels() + ax.legend(handles, labels, loc=legloc) + if return_settings is True: + return fig, ax, dict(ciset=ciset, piset=piset) + else: + return fig, ax + + +# -------------------------------------------- +def plot_3d_intervals(intervals, time, ydata=None, xdata=None, + limits=[95], + adddata=False, addlegend=True, + addmodel=True, figsize=None, model_display={}, + data_display={}, interval_display={}, + addcredible=True, addprediction=True, + fig=None, legloc='upper left', + ciset=None, piset=None, + return_settings=False): + ''' + Plot propagation intervals in 3-D + + This routine takes the model distributions generated using the + :func:`~calculate_intervals` method and then plots specific + quantiles. The user can plot just the intervals, or also include the + median model response and/or observations. Specific settings for + credible intervals are controlled by defining the `ciset` dictionary. + Likewise, for prediction intervals, settings are defined using `piset`. + + The setting options available for each interval are as follows: + - `limits`: This should be a list of numbers between 0 and 100, e.g., + `limits=[50, 90]` will result in 50% and 90% intervals. + - `cmap`: The program is designed to "try" to choose colors that + are visually distinct. The user can specify the colormap to choose + from. + - `colors`: The user can specify the color they would like for each + interval in a list, e.g., ['r', 'g', 'b']. This list should have + the same number of elements as `limits` or the code will revert + back to its default behavior. + + Args: + * **intervals** (:py:class:`dict`): Interval dictionary generated + using :meth:`calculate_intervals` method. + * **time** (:class:`~numpy.ndarray`): Independent variable, i.e., + x- and y-axes of plot. Note, it must be a 2-D array with + shape=(N, 2), where N is the number of evaluation points. + + Kwargs: + * **ydata** (:class:`~numpy.ndarray` or None): Observations, expect + 1-D array if defined. + * **xdata** (:class:`~numpy.ndarray` or None): Independent values + corresponding to observations. This is required if the observations + do not align with your times of generating the model response. + * **limits** (:py:class:`list`): Quantile limits that correspond to + percentage size of desired intervals. Note, this is the default + limits, but specific limits can be defined using the `ciset` and + `piset` dictionaries. + * **adddata** (:py:class:`bool`): Flag to include data + * **addmodel** (:py:class:`bool`): Flag to include median model + response + * **addlegend** (:py:class:`bool`): Flag to include legend + * **addcredible** (:py:class:`bool`): Flag to include credible + intervals + * **addprediction** (:py:class:`bool`): Flag to include prediction + intervals + * **model_display** (:py:class:`dict`): Display settings for median + model response + * **data_display** (:py:class:`dict`): Display settings for data + * **interval_display** (:py:class:`dict`): General display settings + for intervals. + * **fig**: Handle of previously created figure object + * **figsize** (:py:class:`tuple`): (width, height) in inches + * **legloc** (:py:class:`str`): Legend location - matplotlib help for + details. + * **ciset** (:py:class:`dict`): Settings for credible intervals + * **piset** (:py:class:`dict`): Settings for prediction intervals + * **return_settings** (:py:class:`bool`): Flag to return ciset and + piset along with fig and ax. + + Returns: + * (:py:class:`tuple`) with elements + 1) Figure handle + 2) Axes handle + 3) Dictionary with `ciset` and `piset` inside (only + outputted if `return_settings=True`) + ''' + # unpack dictionary + credible = intervals['credible'] + prediction = intervals['prediction'] + # Check user-defined settings + ciset = __setup_iset(ciset, + default_iset=dict( + limits=limits, + cmap=None, + colors=None)) + piset = __setup_iset(piset, + default_iset=dict( + limits=limits, + cmap=None, + colors=None)) + # Check limits + ciset['limits'] = _check_limits(ciset['limits'], limits) + piset['limits'] = _check_limits(piset['limits'], limits) + # convert limits to ranges + ciset['quantiles'] = _convert_limits(ciset['limits']) + piset['quantiles'] = _convert_limits(piset['limits']) + # setup display settings + interval_display, model_display, data_display = setup_display_settings( + interval_display, model_display, data_display) + # Define colors + ciset['colors'] = setup_interval_colors(ciset, inttype='ci') + piset['colors'] = setup_interval_colors(piset, inttype='pi') + # Define labels + ciset['labels'] = _setup_labels(ciset['limits'], inttype='CI') + piset['labels'] = _setup_labels(piset['limits'], inttype='PI') + if fig is None: + fig = plt.figure(figsize=figsize) + ax = Axes3D(fig) + ax = fig.gca() + time1 = time[:, 0] + time2 = time[:, 1] + # add prediction intervals + if addprediction is True: + for ii, quantile in enumerate(piset['quantiles']): + pi = generate_quantiles(prediction, np.array(quantile)) + # Add a polygon instead of fill_between + rev = np.arange(time1.size - 1, -1, -1) + x = np.concatenate((time1, time1[rev])) + y = np.concatenate((time2, time2[rev])) + z = np.concatenate((pi[0], pi[1][rev])) + verts = [list(zip(x, y, z))] + surf = Poly3DCollection(verts, + color=piset['colors'][ii], + label=piset['labels'][ii]) + # Add fix for legend compatibility + surf._facecolors2d = surf._facecolors3d + surf._edgecolors2d = surf._edgecolors3d + ax.add_collection3d(surf) + # add credible intervals + if addcredible is True: + for ii, quantile in enumerate(ciset['quantiles']): + ci = generate_quantiles(credible, np.array(quantile)) + # Add a polygon instead of fill_between + rev = np.arange(time1.size - 1, -1, -1) + x = np.concatenate((time1, time1[rev])) + y = np.concatenate((time2, time2[rev])) + z = np.concatenate((ci[0], ci[1][rev])) + verts = [list(zip(x, y, z))] + surf = Poly3DCollection(verts, + color=ciset['colors'][ii], + label=ciset['labels'][ii]) + # Add fix for legend compatibility + surf._facecolors2d = surf._facecolors3d + surf._edgecolors2d = surf._edgecolors3d + ax.add_collection3d(surf) + # add model (median model response) + if addmodel is True: + ci = generate_quantiles(credible, np.array(0.5)) + ax.plot(time1, time2, ci, **model_display) + # add data to plot + if ydata is not None and adddata is None: + adddata = True + if adddata is True: + if xdata is None: + ax.plot(time1, time2, ydata.reshape(time1.shape), **data_display) + else: # User provided xdata array for observation points + ax.plot(xdata[:, 0], xdata[:, 1], + ydata.reshape(time1.shape), **data_display) + # add legend + if addlegend is True: + handles, labels = ax.get_legend_handles_labels() + ax.legend(handles, labels, loc=legloc) + if return_settings is True: + return fig, ax, dict(ciset=ciset, piset=piset) + else: + return fig, ax + + +def check_s2chain(s2chain, nsimu): + ''' + Check size of s2chain + + Args: + * **s2chain** (:py:class:`float`, :class:`~numpy.ndarray`, or `None`): + Observation error variance chain or value + * **nsimu** (:py:class:`int`): No. of elements in chain + + Returns: + * **s2chain** (:class:`~numpy.ndarray` or `None`) + ''' + if s2chain is None: + return None + else: + if isinstance(s2chain, float): + s2chain = np.ones((nsimu,))*s2chain + if s2chain.ndim == 2: + if s2chain.shape[0] != nsimu: + s2chain = s2chain * np.ones((nsimu, s2chain.size)) + else: + if s2chain.size != nsimu: # scalars provided for multiple QoI + s2chain = s2chain * np.ones((nsimu, s2chain.size)) + return s2chain + + +# -------------------------------------------- +def observation_sample(s2, y, sstype): + ''' + Calculate model response with observation errors. + + Args: + * **s2** (:class:`~numpy.ndarray`): Observation error(s). + * **y** (:class:`~numpy.ndarray`): Model responses. + * **sstype** (:py:class:`int`): Flag to specify sstype. + + Returns: + * **opred** (:class:`~numpy.ndarray`): Model responses with observation errors. + ''' + if sstype == 0: + opred = y + np.random.standard_normal(y.shape) * np.sqrt(s2) + elif sstype == 1: # sqrt + opred = (np.sqrt(y) + np.random.standard_normal(y.shape) * np.sqrt(s2))**2 + elif sstype == 2: # log + opred = y*np.exp(np.random.standard_normal(y.shape) * np.sqrt(s2)) + else: + sys.exit('Unknown sstype') + return opred + + +# -------------------------------------------- +def define_sample_points(nsample, nsimu): + ''' + Define indices to sample from posteriors. + + Args: + * **nsample** (:py:class:`int`): Number of samples to draw from posterior. + * **nsimu** (:py:class:`int`): Number of MCMC simulations. + + Returns: + * **iisample** (:class:`~numpy.ndarray`): Array of indices in posterior set. + * **nsample** (:py:class:`int`): Number of samples to draw from posterior. + ''' + # define sample points + if nsample >= nsimu: + iisample = range(nsimu) # sample all points from chain + nsample = nsimu + else: + # randomly sample from chain + iisample = np.ceil(np.random.rand(nsample)*nsimu) - 1 + iisample = iisample.astype(int) + return iisample, nsample + + +# -------------------------------------------- +def generate_quantiles(x, p=np.array([0.25, 0.5, 0.75])): + ''' + Calculate empirical quantiles. + + Args: + * **x** (:class:`~numpy.ndarray`): Observations from which to generate quantile. + * **p** (:class:`~numpy.ndarray`): Quantile limits. + + Returns: + * (:class:`~numpy.ndarray`): Interpolated quantiles. + ''' + # extract number of rows/cols from np.array + n = x.shape[0] + # define vector valued interpolation function + xpoints = np.arange(0, n, 1) + interpfun = interp1d(xpoints, np.sort(x, 0), axis=0) + # evaluation points + itpoints = (n - 1)*p + return interpfun(itpoints) + + +def setup_display_settings(interval_display, model_display, data_display): + ''' + Compare user defined display settings with defaults and merge. + + Args: + * **interval_display** (:py:class:`dict`): User defined settings for interval display. + * **model_display** (:py:class:`dict`): User defined settings for model display. + * **data_display** (:py:class:`dict`): User defined settings for data display. + + Returns: + * **interval_display** (:py:class:`dict`): Settings for interval display. + * **model_display** (:py:class:`dict`): Settings for model display. + * **data_display** (:py:class:`dict`): Settings for data display. + ''' + # Setup interval display + default_interval_display = dict( + linestyle=':', + linewidth=1, + alpha=1.0, + edgecolor='k') + interval_display = check_settings(default_interval_display, interval_display) + # Setup model display + default_model_display = dict( + linestyle='-', + color='k', + marker='', + linewidth=2, + markersize=5, + label='Model') + model_display = check_settings(default_model_display, model_display) + # Setup data display + default_data_display = dict( + linestyle='', + color='b', + marker='.', + linewidth=1, + markersize=5, + label='Data') + data_display = check_settings(default_data_display, data_display) + return interval_display, model_display, data_display + + +def setup_interval_colors(iset, inttype='CI'): + ''' + Setup colors for empirical intervals + + This routine attempts to distribute the color of the UQ intervals + based on a normalize color map. Or, it will assign user-defined + colors; however, this only happens if the correct number of colors + are specified. + + Args: + * **iset** (:py:class:`dict`): This dictionary should contain the + following keys - `limits`, `cmap`, and `colors`. + + Kwargs: + * **inttype** (:py:class:`str`): Type of uncertainty interval + + Returns: + * **ic** (:py:class:`list`): List containing color for each interval + ''' + limits, cmap, colors = iset['limits'], iset['cmap'], iset['colors'] + norm = __setup_cmap_norm(limits) + cmap = __setup_default_cmap(cmap, inttype) + # assign colors using color map or using colors defined by user + ic = [] + if colors is None: # No user defined colors + for limits in limits: + ic.append(cmap(norm(limits))) + else: + if len(colors) == len(limits): # correct number of colors defined + for color in colors: + ic.append(color) + else: # User defined the wrong number of colors + print('Note, user-defined colors were ignored. Using color map. ' + + 'Expected a list of length {}, but received {}'.format( + len(limits), len(colors))) + for limits in limits: + ic.append(cmap(norm(limits))) + return ic + + +# -------------------------------------------- +def _setup_labels(limits, inttype='CI'): + ''' + Setup labels for prediction/credible intervals. + ''' + labels = [] + for limit in limits: + labels.append(str('{}% {}'.format(limit, inttype))) + return labels + + +def _check_limits(limits, default_limits): + if limits is None: + limits = default_limits + limits.sort(reverse=True) + return limits + + +def _convert_limits(limits): + rng = [] + for limit in limits: + limit = limit/100 + rng.append([0.5 - limit/2, 0.5 + limit/2]) + return rng + + +def __setup_iset(iset, default_iset): + ''' + Setup interval settings by comparing user input to default + ''' + if iset is None: + iset = {} + iset = check_settings(default_iset, iset) + return iset + + +def __setup_cmap_norm(limits): + if len(limits) == 1: + norm = mplcolor.Normalize(vmin=0, vmax=100) + else: + norm = mplcolor.Normalize(vmin=min(limits), vmax=max(limits)) + return norm + + +def __setup_default_cmap(cmap, inttype): + if cmap is None: + if inttype.upper() == 'CI': + cmap = cm.autumn + else: + cmap = cm.winter + return cmap diff --git a/pymcmcstat/structures/ResultsStructure.py b/pymcmcstat/structures/ResultsStructure.py index d519bb2..7c1ae9e 100644 --- a/pymcmcstat/structures/ResultsStructure.py +++ b/pymcmcstat/structures/ResultsStructure.py @@ -161,6 +161,7 @@ def add_basic(self, nsimu, covariance, parameters, rejected, simutime, theta): self.results['qcov_scale'] = covariance._qcov_scale self.results['mean'] = covariance._meanchain self.results['names'] = [parameters._names[ii] for ii in parameters._parind] + self.results['allnames'] = [name for name in parameters._names] self.results['limits'] = [parameters._lower_limits[parameters._parind[:]], parameters._upper_limits[parameters._parind[:]]] self.results['nsimu'] = nsimu diff --git a/pymcmcstat/utilities/general.py b/pymcmcstat/utilities/general.py index 4bc386e..91a6c42 100644 --- a/pymcmcstat/utilities/general.py +++ b/pymcmcstat/utilities/general.py @@ -39,3 +39,37 @@ def removekey(d, key): r = dict(d) del r[key] return r + + +def check_settings(default_settings, user_settings=None): + ''' + Check user settings with default. + + Recursively checks elements of user settings against the defaults and updates settings + as it goes. If a user setting does not exist in the default, then the user setting + is added to the settings. If the setting is defined in both the user and default + settings, then the user setting overrides the default. Otherwise, the default + settings persist. + + Args: + * **default_settings** (:py:class:`dict`): Default settings for particular method. + * **user_settings** (:py:class:`dict`): User defined settings. + + Returns: + * (:py:class:`dict`): Updated settings. + ''' + settings = default_settings.copy() # initially define settings as default + options = list(default_settings.keys()) # get default settings + if user_settings is None: # convert to empty dict + user_settings = {} + user_options = list(user_settings.keys()) # get user settings + for uo in user_options: # iterate through settings + if uo in options: + # check if checking a dictionary + if isinstance(settings[uo], dict): + settings[uo] = check_settings(settings[uo], user_settings[uo]) + else: + settings[uo] = user_settings[uo] + if uo not in options: + settings[uo] = user_settings[uo] + return settings diff --git a/requirements.txt b/requirements.txt index d15cc24..868fd3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ sphinxcontrib-bibtex sphinx_rtd_theme -matplotlib>=2.2.0 +mcmcplot>=1.0.1 coveralls h5py>=2.7.0 statsmodels>=0.9.0 numpy>=1.14 scipy>=1.0 +deprecated>=1.2.6 \ No newline at end of file diff --git a/setup.py b/setup.py index 0867da8..19359ba 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,9 @@ def get_version(): package_dir={'pymcmcstat': 'pymcmcstat'}, packages=find_packages(), zip_safe=False, - install_requires=['numpy>=1.14', 'scipy>=1.0', 'matplotlib>=2.2.0', 'h5py>=2.7.0', 'statsmodels>=0.9.0'], - extras_require = {'docs':['sphinx'], 'plotting':['matplotlib', 'plotly'],}, + install_requires=['numpy>=1.14', 'scipy>=1.0', 'mcmcplot>=1.0.1', + 'h5py>=2.7.0', 'statsmodels>=0.9.0', 'deprecated>=1.2.6'], + extras_require = {'docs':['sphinx'], 'plotting':['matplotlib', 'seaborn'],}, classifiers=['License :: OSI Approved :: MIT License', 'Natural Language :: English', 'Operating System :: MacOS :: MacOS X', diff --git a/test/structures/test_ResultsStructure.py b/test/structures/test_ResultsStructure.py index 4101b55..29f1a1d 100644 --- a/test/structures/test_ResultsStructure.py +++ b/test/structures/test_ResultsStructure.py @@ -75,6 +75,25 @@ def test_addbasic_covariance(self): self.assertTrue(np.array_equal(RS.results['R'], covariance._R), msg = 'Cholesky matches') self.assertTrue(np.array_equal(RS.results['qcov'], np.dot(covariance._R.transpose(),covariance._R)), msg = 'Covariance matches') + def test_allnames(self): + model, options, parameters, data, covariance, rejected, chain, s2chain, sschain = gf.setup_mcmc_case_dr() + RS = ResultsStructure() + RS.add_basic(nsimu=options.nsimu, + covariance=covariance, + parameters=parameters, + rejected=rejected, + simutime=0.001, + theta=chain[-1,:]) + allnames = RS.results['allnames'] + names = RS.results['names'] + print(allnames) + print(names) + self.assertEqual(len(allnames), len(names) + 1, + msg='Expect extra name in allnames') + self.assertEqual(allnames[-1], 'b2', + msg='Expect final parameter is b2') + + # ------------------- class AddDRAM(unittest.TestCase): def test_addbasic_false(self): diff --git a/test/test_init.py b/test/test_init.py index 431dced..0306e8e 100644 --- a/test/test_init.py +++ b/test/test_init.py @@ -2,7 +2,7 @@ import unittest import pymcmcstat - +import re class ImportPymcmcstat(unittest.TestCase): @@ -10,5 +10,16 @@ def test_version_attribute(self): version = pymcmcstat.__version__ self.assertTrue(isinstance(version, str), msg='Expect string output') - self.assertEqual(len(version.split('.')), 3, - msg='Expect #.#.# format') + pattern = '\d+\.\d+\.\d.*' + mo = re.search(pattern, version, re.M) + divs = mo.group().split('.') + ndot = len(divs) + self.assertGreaterEqual(ndot, 3, + msg='Expect min of two dots to match symver.') + self.assertTrue(isinstance(float(divs[0]), float), + msg='Major version should be float') + self.assertTrue(isinstance(float(divs[1]), float), + msg='Minor version should be float') + self.assertTrue(isinstance(float(divs[2][0]), float), + msg='First element of bug version should be float') + diff --git a/test/test_propagation.py b/test/test_propagation.py new file mode 100644 index 0000000..6a6d35f --- /dev/null +++ b/test/test_propagation.py @@ -0,0 +1,417 @@ +import unittest +from mock import patch +import pymcmcstat.propagation as uqp +from pymcmcstat.MCMC import DataStructure +import general_functions as gf +import numpy as np +import matplotlib.pyplot as plt + +# -------------------------- +class CheckLimits(unittest.TestCase): + + def test_check(self): + limits = uqp._check_limits(None, [50, 90]) + self.assertEqual(limits, [90, 50], + msg='Expect default return') + limits = uqp._check_limits([75, 95], [50, 90]) + self.assertEqual(limits, [95, 75], + msg='Expect non-default return') + + +# -------------------------- +class ConvertLimits(unittest.TestCase): + + def test_conversion(self): + limits = uqp._convert_limits([90, 50]) + rng = [] + rng.append([0.05, 0.95]) + rng.append([0.25, 0.75]) + self.assertTrue(np.allclose(np.array(limits), np.array(rng)), + msg='Expect matching lists') + limits = uqp._convert_limits([90, 50]) + rng = [] + rng.append([0.05, 0.95]) + rng.append([0.25, 0.75]) + self.assertTrue(np.allclose(np.array(limits), np.array(rng)), + msg='Expect matching lists') + + +# -------------------------------------------- +class DefineSamplePoints(unittest.TestCase): + + def test_define_sample_points_nsample_gt_nsimu(self): + iisample, nsample = uqp.define_sample_points(nsample=1000, + nsimu=500) + self.assertEqual(iisample, range(500), + msg='Expect range(500)') + self.assertEqual(nsample, 500, + msg='Expect nsample updated to 500') + + @patch('numpy.random.rand') + def test_define_sample_points_nsample_lte_nsimu(self, mock_rand): + aa = np.random.rand([400, 1]) + mock_rand.return_value = aa + iisample, nsample = uqp.define_sample_points(nsample=400, + nsimu=500) + self.assertTrue(np.array_equal(iisample, np.ceil(aa*500) - 1), + msg='Expect range(500)') + self.assertEqual(nsample, 400, + msg='Expect nsample to stay 400') + + +# -------------------------------------------- +class Observation_Sample_Test(unittest.TestCase): + + def test_sstype(self): + s2elem = np.array([[2.0]]) + ypred = np.linspace(2.0, 3.0, num=5) + ypred = ypred.reshape(5, 1) + with self.assertRaises(SystemExit, msg='Unrecognized sstype'): + uqp.observation_sample(s2elem, ypred, 3) + opred = uqp.observation_sample(s2elem, ypred, 0) + self.assertEqual(opred.shape, ypred.shape, + msg='Shapes should match') + opred = uqp.observation_sample(s2elem, ypred, 1) + self.assertEqual(opred.shape, ypred.shape, + msg='Shapes should match') + opred = uqp.observation_sample(s2elem, ypred, 2) + self.assertEqual(opred.shape, ypred.shape, + msg='Shapes should match') + + +# -------------------------------------------- +class CheckS2Chain(unittest.TestCase): + + def test_checks2chain(self): + s2elem = np.array([2.0, 10.]) + s2chain = uqp.check_s2chain(s2elem, 5) + self.assertEqual(s2chain.shape, (5, 2), + msg='Expect (5, 2) array') + s2elem = None + s2chain = uqp.check_s2chain(s2elem, 5) + self.assertEqual(s2chain, None, msg='Expect None return') + s2elem = np.zeros(shape=(5,)) + s2chain = uqp.check_s2chain(s2elem, 5) + self.assertTrue(np.allclose(s2chain, s2elem), + msg='Expect as-is return') + s2elem = 0.01 + s2chain = uqp.check_s2chain(s2elem, 5) + self.assertTrue(np.allclose(s2chain, np.ones(5,)*s2elem), + msg='Expect float to be extended to array to match size of chain') + + +# -------------------------------------------- +class GenerateQuantiles(unittest.TestCase): + + def test_does_default_empirical_quantiles_return_3_element_array(self): + test_out = uqp.generate_quantiles(np.random.rand(10, 1)) + self.assertEqual(test_out.shape, (3, 1), + msg='Default output shape is (3,1)') + + def test_does_non_default_empirical_quantiles_return_2_element_array(self): + test_out = uqp.generate_quantiles(np.random.rand(10, 1), p=np.array([0.2, 0.5])) + self.assertEqual(test_out.shape, (2, 1), + msg='Non-default output shape should be (2, 1)') + + def test_empirical_quantiles_should_not_support_list_input(self): + with self.assertRaises(AttributeError): + uqp.generate_quantiles([-1, 0, 1]) + + def test_empirical_quantiles_vector(self): + out = uqp.generate_quantiles(np.linspace(10,20, num=10).reshape(10, 1), + p=np.array([0.22, 0.57345])) + exact = np.array([[12.2], [15.7345]]) + comp = np.linalg.norm(out - exact) + self.assertAlmostEqual(comp, 0) + + +# -------------------------------------------- +class SetupDisplaySettings(unittest.TestCase): + + def test_setup_display_settings(self): + model_display = {'label': 'hello'} + data_display = {'linewidth': 7} + interval_display = {'edgecolor': 'b'} + intd, modd, datd = uqp.setup_display_settings( + interval_display=interval_display, + model_display=model_display, + data_display=data_display) + self.assertEqual(modd['label'], model_display['label'], + msg='Expect label to match') + self.assertEqual(intd['edgecolor'], interval_display['edgecolor'], + msg='Expect edgecolor to match') + self.assertEqual(datd['linewidth'], data_display['linewidth'], + msg='Expect linewidth to match') + + +# -------------------------------------------- +class SetupLabels(unittest.TestCase): + + def test_label_setup(self): + limits = [90, 50] + labels = uqp._setup_labels(limits, inttype='HH') + self.assertEqual(labels[0], '90% HH', + msg='Expect matching strings') + self.assertEqual(labels[1], '50% HH', + msg='Expect matching strings') + + +class SetupIntervalColors(unittest.TestCase): + + def test_sic_colors_none(self): + iset = dict( + limits=[90, 50], + cmap=None, + colors=None) + ic = uqp.setup_interval_colors(iset) + self.assertEqual(len(ic), 2, + msg='Expect 2 colors') + + def test_sic_colors_not_none(self): + iset = dict( + limits=[90, 50], + cmap=None, + colors=['r', 'g']) + ic = uqp.setup_interval_colors(iset) + self.assertEqual(len(ic), 2, + msg='Expect 2 colors') + self.assertEqual(ic, ['r', 'g'], + msg='Expect matching lists') + + def test_sic_colors_not_none_but_wrong_size(self): + iset = dict( + limits=[90, 50], + cmap=None, + colors=['r', 'g', 'b']) + ic = uqp.setup_interval_colors(iset) + self.assertEqual(len(ic), 2, + msg='Expect 2 colors') + self.assertNotEqual(ic, ['r', 'g'], + msg='Expect non-matching lists') + + +def model(q, data): + m, b = q + return m*data.xdata[0] + b + + +def model3D(q, data): + m, b = q + return m*data.xdata[0][:, 0] + b*data.xdata[0][:, 1] + +def modelmultiple(q, data): + m, b = q + x = data.xdata[0] + y1 = m*x + b + y2 = m*x**2 + b + return np.stack((y1.reshape(y1.size,), y2.reshape(y2.size,)), axis=1) + + +class CalculateIntervals(unittest.TestCase): + + def test_credintcreation(self): + data = DataStructure() + data.add_data_set(x=np.linspace(0, 1), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + intervals = uqp.calculate_intervals( + chain, results, data, model, waitbar=True) + self.assertTrue('credible' in intervals.keys(), + msg='Expect credible intervals') + self.assertTrue('prediction' in intervals.keys(), + msg='Expect prediction intervals') + self.assertTrue(isinstance(intervals['credible'], np.ndarray), + msg='Expect numpy array') + self.assertEqual(intervals['prediction'], None, + msg='Expect None') + + def test_predintcreation(self): + data = DataStructure() + data.add_data_set(x=np.linspace(0, 1), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + s2chain = results['s2chain'] + intervals = uqp.calculate_intervals( + chain, results, data, model, s2chain=s2chain) + self.assertTrue('credible' in intervals.keys(), + msg='Expect credible intervals') + self.assertTrue('prediction' in intervals.keys(), + msg='Expect prediction intervals') + self.assertTrue(isinstance(intervals['credible'], np.ndarray), + msg='Expect numpy array') + self.assertTrue(isinstance(intervals['prediction'], np.ndarray), + msg='Expect numpy array') + intervals = uqp.calculate_intervals( + chain, results, data, model, s2chain=0.1) + self.assertTrue('credible' in intervals.keys(), + msg='Expect credible intervals') + self.assertTrue('prediction' in intervals.keys(), + msg='Expect prediction intervals') + self.assertTrue(isinstance(intervals['credible'], np.ndarray), + msg='Expect numpy array') + self.assertTrue(isinstance(intervals['prediction'], np.ndarray), + msg='Expect numpy array') + + def test_predintcreation_multimodel(self): + data = DataStructure() + data.add_data_set(x=np.linspace(0, 1), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + s2chain = results['s2chain'] + mintervals = uqp.calculate_intervals( + chain, results, data, modelmultiple, s2chain=s2chain) + for intervals in mintervals: + self.assertTrue('credible' in intervals.keys(), + msg='Expect credible intervals') + self.assertTrue('prediction' in intervals.keys(), + msg='Expect prediction intervals') + self.assertTrue(isinstance(intervals['credible'], np.ndarray), + msg='Expect numpy array') + self.assertTrue(isinstance(intervals['prediction'], np.ndarray), + msg='Expect numpy array') + mintervals = uqp.calculate_intervals( + chain, results, data, modelmultiple, s2chain=0.1) + for intervals in mintervals: + self.assertTrue('credible' in intervals.keys(), + msg='Expect credible intervals') + self.assertTrue('prediction' in intervals.keys(), + msg='Expect prediction intervals') + self.assertTrue(isinstance(intervals['credible'], np.ndarray), + msg='Expect numpy array') + self.assertTrue(isinstance(intervals['prediction'], np.ndarray), + msg='Expect numpy array') + mintervals = uqp.calculate_intervals( + chain, results, data, modelmultiple, + s2chain=np.hstack((s2chain, s2chain))) + for intervals in mintervals: + self.assertTrue('credible' in intervals.keys(), + msg='Expect credible intervals') + self.assertTrue('prediction' in intervals.keys(), + msg='Expect prediction intervals') + self.assertTrue(isinstance(intervals['credible'], np.ndarray), + msg='Expect numpy array') + self.assertTrue(isinstance(intervals['prediction'], np.ndarray), + msg='Expect numpy array') + + +# -------------------------------------------- +class Plot2DIntervals(unittest.TestCase): + + def test_plot_intervals_basic(self): + data = DataStructure() + data.add_data_set(x=np.linspace(0, 1), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + s2chain = results['s2chain'] + intervals = uqp.calculate_intervals( + chain, results, data, model, s2chain=s2chain) + fig, ax = uqp.plot_intervals(intervals, data.xdata[0], limits=[95]) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + plt.close() + fig, ax = uqp.plot_intervals(intervals, data.xdata[0], limits=[95], + adddata=True, ydata=data.xdata[0]) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', 'Data', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + plt.close() + + def test_check_settings_plot_intervals_basic(self): + data = DataStructure() + data.add_data_set(x=np.linspace(0, 1), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + s2chain = results['s2chain'] + intervals = uqp.calculate_intervals( + chain, results, data, model, s2chain=s2chain) + fig, ax, isets = uqp.plot_intervals( + intervals, data.xdata[0], limits=[95], + return_settings=True) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + self.assertTrue(isinstance(isets, dict), + msg='Expect dictionary') + self.assertEqual(isets['ciset']['limits'], [95]) + self.assertEqual(isets['piset']['limits'], [95]) + plt.close() + plt.close() + fig, ax, isets = uqp.plot_intervals( + intervals, data.xdata[0], limits=[95], + adddata=True, ydata=data.xdata[0], + return_settings=True) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', 'Data', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + self.assertTrue(isinstance(isets, dict), + msg='Expect dictionary') + self.assertEqual(isets['ciset']['limits'], [95]) + self.assertEqual(isets['piset']['limits'], [95]) + plt.close() + + +# -------------------------------------------- +class Plot3DIntervals(unittest.TestCase): + + def test_plot_intervals_basic(self): + data = DataStructure() + data.add_data_set(x=np.random.random_sample((100, 2)), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + s2chain = results['s2chain'] + intervals = uqp.calculate_intervals( + chain, results, data, model3D, s2chain=s2chain) + fig, ax = uqp.plot_3d_intervals(intervals, data.xdata[0], limits=[95]) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + plt.close() + fig, ax = uqp.plot_3d_intervals(intervals, data.xdata[0], limits=[95], + adddata=True, ydata=data.xdata[0][:, 0]) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', 'Data', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + plt.close() + + def test_check_settings_plot_intervals_basic(self): + data = DataStructure() + data.add_data_set(x=np.random.random_sample((100, 2)), y=None) + results = gf.setup_pseudo_results() + chain = results['chain'] + s2chain = results['s2chain'] + intervals = uqp.calculate_intervals( + chain, results, data, model3D, s2chain=s2chain) + fig, ax, isets = uqp.plot_3d_intervals( + intervals, data.xdata[0], limits=[95], + return_settings=True) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + self.assertTrue(isinstance(isets, dict), + msg='Expect dictionary') + self.assertEqual(isets['ciset']['limits'], [95]) + self.assertEqual(isets['piset']['limits'], [95]) + plt.close() + plt.close() + fig, ax, isets = uqp.plot_3d_intervals( + intervals, data.xdata[0], limits=[95], + adddata=True, ydata=data.xdata[0][:, 0], + return_settings=True) + self.assertEqual(ax.get_legend_handles_labels()[1], + ['Model', 'Data', '95% PI', '95% CI'], + msg=str('Strings should match: {}'.format( + ax.get_legend_handles_labels()[1]))) + self.assertTrue(isinstance(isets, dict), + msg='Expect dictionary') + self.assertEqual(isets['ciset']['limits'], [95]) + self.assertEqual(isets['piset']['limits'], [95]) + plt.close() + \ No newline at end of file diff --git a/test/utilities/test_general.py b/test/utilities/test_general.py index aaf1359..3fbb5f3 100644 --- a/test/utilities/test_general.py +++ b/test/utilities/test_general.py @@ -5,7 +5,7 @@ @author: prmiles """ -from pymcmcstat.utilities.general import message, removekey +from pymcmcstat.utilities.general import message, removekey, check_settings import unittest import io import sys @@ -36,4 +36,39 @@ def test_removekey(self): for ct in check_these: self.assertEqual(d[ct], r[ct], msg = str('Expect element agreement: {}'.format(ct))) - self.assertFalse('a2' in r, msg = 'Expect removal') \ No newline at end of file + self.assertFalse('a2' in r, msg = 'Expect removal') + + +# -------------------------- +class CheckSettings(unittest.TestCase): + + def test_settings_with_user_none(self): + user_settings = None + default_settings = dict(a = False, linewidth = 3, marker = dict(markersize = 5, color = 'g')) + settings = check_settings(default_settings = default_settings, user_settings = user_settings) + self.assertEqual(settings, default_settings, msg = str('Expect dictionaries to match: {} neq {}'.format(settings, default_settings))) + + def test_settings_with_subdict(self): + user_settings = dict(a = True, fontsize = 12) + default_settings = dict(a = False, linewidth = 3, marker = dict(markersize = 5, color = 'g')) + settings = check_settings(default_settings = default_settings, user_settings = user_settings) + self.assertEqual(settings['a'], user_settings['a'], msg = 'Expect user setting to overwrite') + self.assertEqual(settings['marker'], default_settings['marker'], msg = 'Expect default to persist') + + def test_settings_with_subdict_user_ow(self): + user_settings = dict(a = True, fontsize = 12, marker = dict(color = 'b')) + default_settings = dict(a = False, linewidth = 3, marker = dict(markersize = 5, color = 'g')) + settings = check_settings(default_settings = default_settings, user_settings = user_settings) + self.assertEqual(settings['a'], user_settings['a'], msg = 'Expect user setting to overwrite') + self.assertEqual(settings['marker']['color'], user_settings['marker']['color'], msg = 'Expect user to overwrite') + self.assertEqual(settings['marker']['markersize'], default_settings['marker']['markersize'], msg = 'Expect default to persist') + + def test_settings_with_subdict_user_has_new_setting(self): + user_settings = dict(a = True, fontsize = 12, marker = dict(color = 'b'), linestyle = '--') + default_settings = dict(a = False, linewidth = 3, marker = dict(markersize = 5, color = 'g')) + settings = check_settings(default_settings = default_settings, user_settings = user_settings) + self.assertEqual(settings['a'], user_settings['a'], msg = 'Expect user setting to overwrite') + self.assertEqual(settings['marker']['color'], user_settings['marker']['color'], msg = 'Expect user to overwrite') + self.assertEqual(settings['marker']['markersize'], default_settings['marker']['markersize'], msg = 'Expect default to persist') + self.assertEqual(settings['linestyle'], user_settings['linestyle'], msg = 'Expect user setting to be added') + diff --git a/tutorials/advanced_interval_plotting/advanced_interval_plotting.ipynb b/tutorials/advanced_interval_plotting/advanced_interval_plotting.ipynb index 9cb549c..1133682 100755 --- a/tutorials/advanced_interval_plotting/advanced_interval_plotting.ipynb +++ b/tutorials/advanced_interval_plotting/advanced_interval_plotting.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -324,7 +324,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/algae/Algae.ipynb b/tutorials/algae/Algae.ipynb index 6ffce90..7bb1bc8 100755 --- a/tutorials/algae/Algae.ipynb +++ b/tutorials/algae/Algae.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", diff --git a/tutorials/banana/Banana.ipynb b/tutorials/banana/Banana.ipynb index 718a700..9a602c1 100755 --- a/tutorials/banana/Banana.ipynb +++ b/tutorials/banana/Banana.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", diff --git a/tutorials/beetle/Beetle.ipynb b/tutorials/beetle/Beetle.ipynb index 5c3cba7..cc56c77 100755 --- a/tutorials/beetle/Beetle.ipynb +++ b/tutorials/beetle/Beetle.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", diff --git a/tutorials/calling_models_using_ctypes/Running_Model_Using_Ctypes.ipynb b/tutorials/calling_models_using_ctypes/Running_Model_Using_Ctypes.ipynb index aa336ec..accef83 100755 --- a/tutorials/calling_models_using_ctypes/Running_Model_Using_Ctypes.ipynb +++ b/tutorials/calling_models_using_ctypes/Running_Model_Using_Ctypes.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -362,7 +362,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/estimating_error_variance/estimating_error_variance_for_mutliple_data_sets.ipynb b/tutorials/estimating_error_variance/estimating_error_variance_for_mutliple_data_sets.ipynb index 88569e3..537b15a 100755 --- a/tutorials/estimating_error_variance/estimating_error_variance_for_mutliple_data_sets.ipynb +++ b/tutorials/estimating_error_variance/estimating_error_variance_for_mutliple_data_sets.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -382,7 +382,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/index.ipynb b/tutorials/index.ipynb index e9bd463..e0c9727 100755 --- a/tutorials/index.ipynb +++ b/tutorials/index.ipynb @@ -8,6 +8,8 @@ "\n", "Author(s): Paul Miles | Date Created: August 31, 2018\n", "\n", + "Note, the [pymcmcstat](https://github.com/prmiles/pymcmcstat/wiki) tutorials have moved to a new location. To switch to the new index, please follow this [link](https://github.com/prmiles/pymcmcstat_examples). Otherwise, selecting any of the tutorials listed below will take you to the appropriate new location.\n", + "\n", "# Introduction\n", "The [pymcmcstat](https://github.com/prmiles/pymcmcstat/wiki) package is a Python program for running Markov Chain Monte Carlo (MCMC) simulations. Included in this package is the abilitity to use different Metropolis based sampling techniques:\n", "\n", @@ -40,8 +42,8 @@ "There are many built-in features to [pymcmcstat](https://github.com/prmiles/pymcmcstat/wiki) that allow it to be tailored to suit your particular problem. Below we have outlined features through a set of examples.\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/monod/Monod.ipynb)\n", - "## [Monod](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/monod/Monod.ipynb)\n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Monod.ipynb)\n", + "## [Monod](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Monod.ipynb)\n", "Key Features:\n", "- Basic MCMC settings\n", "- Data structure initialization\n", @@ -54,8 +56,8 @@ "
\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/beetle/Beetle.ipynb)\n", - "## [Beetle](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/beetle/Beetle.ipynb) \n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Beetle.ipynb)\n", + "## [Beetle](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Beetle.ipynb) \n", "Key Features:\n", "- Sending objects within MCMC data structure.\n", "- Managing objects within sum-of-squares evaluation.\n", @@ -69,8 +71,8 @@ "
\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/banana/Banana.ipynb)\n", - "## [Banana](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/banana/Banana.ipynb) \n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Banana.ipynb)\n", + "## [Banana](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Banana.ipynb) \n", "Key Features:\n", "- Sending class objects in MCMC data structure.\n", "- Defining parameter covariance matrix.\n", @@ -84,8 +86,8 @@ "
\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/algae/Algae.ipynb)\n", - "## [Algae](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/algae/Algae.ipynb)\n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Algae.ipynb)\n", + "## [Algae](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Algae.ipynb)\n", "Key Features:\n", "- Using multiple data sets.\n", "- Solving system of ODE's as model response. \n", @@ -93,8 +95,8 @@ "- Generating prediction/credible intervals for multiple quantities of interest.\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/viscoelasticity/viscoelastic_analysis_using_ctypes.ipynb)\n", - "## [Viscoelasticity](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/viscoelasticity/viscoelastic_analysis_using_ctypes.ipynb)\n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Viscoelasticity.ipynb)\n", + "## [Viscoelasticity](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Viscoelasticity.ipynb)\n", "Key Features:\n", "- Loading data from `*.mat` file.\n", "- Calling C++ model using `ctypes` packages.\n", @@ -107,8 +109,8 @@ "
\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/landau_energy/Landau_Energy.ipynb)\n", - "## [Landau Energy](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/landau_energy/Landau_Energy.ipynb)\n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Landau_Energy.ipynb)\n", + "## [Landau Energy](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Landau_Energy.ipynb)\n", "Key Features:\n", "- Evaluating multidimensional functions (3-D polarization space).\n", "- Loading data from `*.mat` file.\n", @@ -122,8 +124,8 @@ "
\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/radiation_source_localization/radiation_source_localization.ipynb)\n", - "## [Radiation Source Localization](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/radiation_source_localization/radiation_source_localization.ipynb)\n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Radiation_Source_Localization.ipynb)\n", + "## [Radiation Source Localization](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Radiation_Source_Localization.ipynb)\n", "Key Features:\n", "- Embedding user defined objects in the data structure.\n", "- Enhanced visualization using [mcmcplot](https://github.com/prmiles/mcmcplot).\n", @@ -140,8 +142,8 @@ "
\n", "\n", "---\n", - "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/running_parallel_chains/running_parallel_chains.ipynb)\n", - "## [Running Parallel Chains](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/running_parallel_chains/running_parallel_chains.ipynb)\n", + "[](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Running_Parallel_Chains.ipynb)\n", + "## [Running Parallel Chains](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Running_Parallel_Chains.ipynb)\n", "Key Features:\n", "- Running multiple chains simultaneously.\n", "- Using Gelman-Rubin chain diagnostics.\n", @@ -156,39 +158,39 @@ "# Advanced Topics\n", "These tutorials address very specific features of using the package.\n", "\n", - "## [Using Chain Log Files](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/saving_to_log_files/Chain_Log_Files.ipynb)\n", + "## [Using Chain Log Files](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Chain_Log_Files.ipynb)\n", "Key Features:\n", "- Saving chain logs in binary and text formats.\n", "- Loading log files for post processing.\n", "- Assessing log history to ascertain status of simulation.\n", "\n", - "## [Setting the RNG Seed](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/setting_the_random_seed/setting_the_random_seed.ipynb)\n", + "## [Setting the RNG Seed](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Setting_Random_Seed.ipynb)\n", "Key Features:\n", "- Set seed for random number generator within [pymcmcstat](https://github.com/prmiles/pymcmcstat/wiki).\n", "- Produce repeatable simulation results.\n", "\n", - "## [Calling Models Written in C++](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/calling_models_using_ctypes/Running_Model_Using_Ctypes.ipynb)\n", + "## [Calling Models Written in C++](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Running_Model_Using_Ctypes.ipynb)\n", "Key Features:\n", "- Call arbitrarily complex models written in other languages (e.g., C++) using the `ctypes` package.\n", "- Generating credible/prediction intervals using C++ based model.\n", "\n", - "## [Specifying Sample Variables](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/specifying_sample_variables/specifying_sample_variables.ipynb)\n", + "## [Specifying Sample Variables](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Specifying_Sample_Variables.ipynb)\n", "Key Features:\n", "- Specify which model parameters should be included in sampling chain.\n", "\n", - "## [Estimating Error Variance for Multiple Data Sets](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/estimating_error_variance/estimating_error_variance_for_mutliple_data_sets.ipynb)\n", + "## [Estimating Error Variance for Multiple Data Sets](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Estimating_Error_Variance.ipynb)\n", "Key Features:\n", "- Setting up multiple data sets in the MCMC data structure.\n", "- Defining sum-of-squares function to accomodate multiple data sets.\n", "- Estimating a separate observation error variance for each data set.\n", "- Plotting prediction/credible intervals for each data set.\n", "\n", - "## [Using Normal Prior Distributions](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/prior_function/prior_function.ipynb)\n", + "## [Using Normal Prior Distributions](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Using_Normal_Prior_Functions.ipynb)\n", "Key Features:\n", "- Enforcing normally distributed prior functions.\n", "- Defining non-linear parameter constraints via custom prior functions.\n", "\n", - "## [Advanced Interval Plotting](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat/blob/master/tutorials/advanced_interval_plotting/advanced_interval_plotting.ipynb)\n", + "## [Advanced Interval Plotting](https://nbviewer.jupyter.org/github/prmiles/pymcmcstat_examples/blob/master/Advanced_Interval_Plotting.ipynb)\n", "Key Features:\n", "- Change model, data, and interval display options when plotting credible and prediction intervals.\n", "- This highlights available features as of version 1.5.0." @@ -211,7 +213,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/landau_energy/Landau_Energy.ipynb b/tutorials/landau_energy/Landau_Energy.ipynb index 65ca7d8..eacb5e2 100755 --- a/tutorials/landau_energy/Landau_Energy.ipynb +++ b/tutorials/landau_energy/Landau_Energy.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -542,7 +542,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/monod/Monod.ipynb b/tutorials/monod/Monod.ipynb index 351b83d..eeede02 100755 --- a/tutorials/monod/Monod.ipynb +++ b/tutorials/monod/Monod.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", diff --git a/tutorials/prior_function/prior_function.ipynb b/tutorials/prior_function/prior_function.ipynb index 20e1fe5..49a47ac 100755 --- a/tutorials/prior_function/prior_function.ipynb +++ b/tutorials/prior_function/prior_function.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -452,7 +452,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/radiation_source_localization/radiation_source_localization.ipynb b/tutorials/radiation_source_localization/radiation_source_localization.ipynb index 40c7f60..ca33b9e 100755 --- a/tutorials/radiation_source_localization/radiation_source_localization.ipynb +++ b/tutorials/radiation_source_localization/radiation_source_localization.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", diff --git a/tutorials/running_parallel_chains/running_parallel_chains.ipynb b/tutorials/running_parallel_chains/running_parallel_chains.ipynb index 009c880..98c44f9 100755 --- a/tutorials/running_parallel_chains/running_parallel_chains.ipynb +++ b/tutorials/running_parallel_chains/running_parallel_chains.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", diff --git a/tutorials/saving_to_log_files/Chain_Log_Files.ipynb b/tutorials/saving_to_log_files/Chain_Log_Files.ipynb index a83ce69..526be60 100755 --- a/tutorials/saving_to_log_files/Chain_Log_Files.ipynb +++ b/tutorials/saving_to_log_files/Chain_Log_Files.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -605,7 +605,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/setting_the_random_seed/setting_the_random_seed.ipynb b/tutorials/setting_the_random_seed/setting_the_random_seed.ipynb index e8a9ba8..bf9058f 100755 --- a/tutorials/setting_the_random_seed/setting_the_random_seed.ipynb +++ b/tutorials/setting_the_random_seed/setting_the_random_seed.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -259,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/specifying_sample_variables/specifying_sample_variables.ipynb b/tutorials/specifying_sample_variables/specifying_sample_variables.ipynb index ba6feaa..a2c9145 100755 --- a/tutorials/specifying_sample_variables/specifying_sample_variables.ipynb +++ b/tutorials/specifying_sample_variables/specifying_sample_variables.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n", @@ -320,7 +320,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tutorials/viscoelasticity/viscoelastic_analysis_using_ctypes.ipynb b/tutorials/viscoelasticity/viscoelastic_analysis_using_ctypes.ipynb index 3270b43..ceef8c8 100755 --- a/tutorials/viscoelasticity/viscoelastic_analysis_using_ctypes.ipynb +++ b/tutorials/viscoelasticity/viscoelastic_analysis_using_ctypes.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "
\n", + "\n", " \n", "
\n", "\n",