diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 4016eba65cd..a085970ee09 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -143,3 +143,8 @@ ufuncs.tan ufuncs.tanh ufuncs.trunc + + plot.FacetGrid.map_dataarray + plot.FacetGrid.set_titles + plot.FacetGrid.set_ticks + plot.FacetGrid.map diff --git a/doc/api.rst b/doc/api.rst index f141ee43378..582df933d98 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -411,3 +411,5 @@ Plotting plot.imshow plot.line plot.pcolormesh + plot.FacetGrid + diff --git a/doc/plotting.rst b/doc/plotting.rst index 5ccc3dd855a..78de158e2f5 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -16,6 +16,9 @@ Xray's plotting capabilities are centered around :py:class:`xray.DataArray` objects. To plot :py:class:`xray.Dataset` objects simply access the relevant DataArrays, ie ``dset['var1']``. +Here we focus mostly on arrays 2d or larger. If your data fits +nicely into a pandas DataFrame then you're better off using one of the more +developed tools there. Xray plotting functionality is a thin wrapper around the popular `matplotlib `_ library. @@ -51,10 +54,11 @@ The following imports are necessary for all of the examples. .. ipython:: python import numpy as np + import pandas as pd import matplotlib.pyplot as plt import xray -We'll use the North American air temperature dataset. +For these examples we'll use the North American air temperature dataset. .. ipython:: python @@ -306,6 +310,125 @@ since levels are chosen automatically). air2d.plot(levels=10, cmap='husl') +Faceting +-------- + +Faceting here refers to splitting an array along one or two dimensions and +plotting each group. +Xray's basic plotting is useful for plotting two dimensional arrays. What +about three or four dimensional arrays? That's where facets become helpful. + +Consider the temperature data set. There are 4 observations per day for two +years which makes for 2920 values along the time dimension. +One way to visualize this data is to make a +seperate plot for each time period. + +The faceted dimension should not have too many values; +faceting on the time dimension will produce 2920 plots. That's +too much to be helpful. To handle this situation try performing +an operation that reduces the size of the data in some way. For example, we +could compute the average air temperature for each month and reduce the +size of this dimension from 2920 -> 12. A simpler way is +to just take a slice on that dimension. +So let's use a slice to pick 6 times throughout the first year. + +.. ipython:: python + + t = air.isel(time=slice(0, 365 * 4, 250)) + t.coords + +Simple Example +~~~~~~~~~~~~~~ + +TODO - replace with the convenience method from plot + +We can use :py:meth:`xray.plot.FacetGrid.map_dataarray` on a DataArray: + +.. ipython:: python + + g = xray.plot.FacetGrid(t, col='time', col_wrap=3) + + @savefig plot_facet_dataarray.png height=12in + g.map_dataarray(xray.plot.imshow, 'lon', 'lat') + +FacetGrid Objects +~~~~~~~~~~~~~~~~~ + +:py:class:`xray.plot.FacetGrid` is used to control the behavior of the +multiple plots. +It borrows an API and code from `Seaborn +`_. +The structure is contained within the ``axes`` and ``name_dicts`` +attributes, both 2d Numpy object arrays. + +.. ipython:: python + + g.axes + + g.name_dicts + +It's possible to select the :py:class:`xray.DataArray` corresponding to the FacetGrid +through the ``name_dicts``. + +.. ipython:: python + + g.data.loc[g.name_dicts[0, 0]] + +Here is an example of modifying the axes after they have been plotted. + +.. ipython:: python + + g = (xray.plot + .FacetGrid(t, col='time', col_wrap=3) + .map_dataarray(xray.plot.imshow, 'lon', 'lat') + ) + + for i, ax in enumerate(g.axes.flat): + ax.set_title('Air Temperature %d' % i) + + bottomright = g.axes[-1, -1] + bottomright.annotate('bottom right', (240, 40)) + + @savefig plot_facet_iterator.png height=12in + plt.show() + +4 dimensional +~~~~~~~~~~~~~~ + +For 4 dimensional arrays we can use the rows and columns of the grids. +Here we create a 4 dimensional array by taking the original data and adding +a fixed amount. Now we can see how the temperature maps would compare if +one were much hotter. + +.. ipython:: python + + t2 = t.isel(time=slice(0, 2)) + t4d = xray.concat([t2, t2 + 40], pd.Index(['normal', 'hot'], name='fourth_dim')) + # This is a 4d array + t4d.coords + + g = xray.plot.FacetGrid(t4d, col='time', row='fourth_dim') + + @savefig plot_facet_4d.png height=12in + g.map_dataarray(xray.plot.imshow, 'lon', 'lat') + +Other features +~~~~~~~~~~~~~~ + +Faceted plotting supports other arguments common to xray 2d plots. + +.. ipython:: python + + hasoutliers = t.isel(time=slice(0, 5)).copy() + hasoutliers[0, 0, 0] = -100 + hasoutliers[-1, -1, -1] = 400 + + g = xray.plot.FacetGrid(hasoutliers, col='time', col_wrap=3) + + @savefig plot_facet_robust.png height=12in + g.map_dataarray(xray.plot.contourf, 'lon', 'lat', robust=True, cmap='viridis') + + Maps ---- diff --git a/xray/plot/__init__.py b/xray/plot/__init__.py index ae2f6ad8f6f..3cd075b6b90 100644 --- a/xray/plot/__init__.py +++ b/xray/plot/__init__.py @@ -1,2 +1,4 @@ from .plot import (plot, line, contourf, contour, hist, imshow, pcolormesh) + +from .facetgrid import FacetGrid diff --git a/xray/plot/facetgrid.py b/xray/plot/facetgrid.py new file mode 100644 index 00000000000..4e335f3b1d1 --- /dev/null +++ b/xray/plot/facetgrid.py @@ -0,0 +1,389 @@ +from __future__ import division + +import warnings +import itertools +import functools + +import numpy as np + +from ..core.formatting import format_item +from .plot import _determine_cmap_params + + +# Overrides axes.labelsize, xtick.major.size, ytick.major.size +# from mpl.rcParams +_FONTSIZE = 'small' +# For major ticks on x, y axes +_NTICKS = 5 + + +def _nicetitle(coord, value, maxchar, template): + """ + Put coord, value in template and truncate at maxchar + """ + prettyvalue = format_item(value) + title = template.format(coord=coord, value=prettyvalue) + + if len(title) > maxchar: + title = title[:(maxchar - 3)] + '...' + + return title + + +class FacetGrid(object): + """ + Initialize the matplotlib figure and FacetGrid object. + + The :class:`FacetGrid` is an object that links a xray DataArray to + a matplotlib figure with a particular structure. + + In particular, :class:`FacetGrid` is used to draw plots with multiple + Axes where each Axes shows the same relationship conditioned on + different levels of some dimension. It's possible to condition on up to + two variables by assigning variables to the rows and columns of the + grid. + + The general approach to plotting here is called "small multiples", + where the same kind of plot is repeated multiple times, and the + specific use of small multiples to display the same relationship + conditioned on one ore more other variables is often called a "trellis + plot". + + The basic workflow is to initialize the :class:`FacetGrid` object with + the DataArray and the variable names that are used to structure the grid. + Then plotting functions can be applied to each subset by calling + :meth:`FacetGrid.map_dataarray` or :meth:`FacetGrid.map`. + + Attributes + ---------- + axes : numpy object array + Contains axes in corresponding position, as returned from + plt.subplots + fig : matplotlib.Figure + The figure containing all the axes + name_dicts : numpy object array + Contains dictionaries mapping coordinate names to values. None is + used as a sentinel value for axes which should remain empty, ie. + sometimes the bottom right grid + + """ + + def __init__(self, data, col=None, row=None, col_wrap=None, + aspect=1, size=3): + """ + Parameters + ---------- + data : DataArray + xray DataArray to be plotted + row, col : strings + Dimesion names that define subsets of the data, which will be drawn + on separate facets in the grid. + col_wrap : int, optional + "Wrap" the column variable at this width, so that the column facets + aspect : scalar, optional + Aspect ratio of each facet, so that ``aspect * size`` gives the + width of each facet in inches + size : scalar, optional + Height (in inches) of each facet. See also: ``aspect`` + + """ + + import matplotlib.pyplot as plt + + # Handle corner case of nonunique coordinates + rep_col = col is not None and not data[col].to_index().is_unique + rep_row = row is not None and not data[row].to_index().is_unique + if rep_col or rep_row: + raise ValueError('Coordinates used for faceting cannot ' + 'contain repeated (nonunique) values.') + + # single_group is the grouping variable, if there is exactly one + if col and row: + single_group = False + nrow = len(data[row]) + ncol = len(data[col]) + nfacet = nrow * ncol + if col_wrap is not None: + warnings.warn('Ignoring col_wrap since both col and row ' + 'were passed') + elif row and not col: + single_group = row + elif not row and col: + single_group = col + else: + raise ValueError( + 'Pass a coordinate name as an argument for row or col') + + # Compute grid shape + if single_group: + nfacet = len(data[single_group]) + if col: + # idea - could add heuristic for nice shapes like 3x4 + ncol = nfacet + if row: + ncol = 1 + if col_wrap is not None: + # Overrides previous settings + ncol = col_wrap + nrow = int(np.ceil(nfacet / ncol)) + + # Calculate the base figure size with extra horizontal space for a + # colorbar + cbar_space = 1 + figsize = (ncol * size * aspect + + cbar_space, nrow * size) + + fig, axes = plt.subplots(nrow, ncol, + sharex=True, sharey=True, + squeeze=False, figsize=figsize) + + # Set up the lists of names for the row and column facet variables + col_names = list(data[col].values) if col else [] + row_names = list(data[row].values) if row else [] + + if single_group: + full = [{single_group: x} for x in + data[single_group].values] + empty = [None for x in range(nrow * ncol - len(full))] + name_dicts = full + empty + else: + rowcols = itertools.product(row_names, col_names) + name_dicts = [{row: r, col: c} for r, c in rowcols] + + name_dicts = np.array(name_dicts).reshape(nrow, ncol) + + # Set up the class attributes + # --------------------------- + + # First the public API + self.data = data + self.name_dicts = name_dicts + self.fig = fig + self.axes = axes + self.row_names = row_names + self.col_names = col_names + + # Next the private variables + self._single_group = single_group + self._nrow = nrow + self._row_var = row + self._ncol = ncol + self._col_var = col + self._col_wrap = col_wrap + self._x_var = None + self._y_var = None + + self.set_titles() + + def map_dataarray(self, func, x, y, **kwargs): + """ + Apply a plotting function to a 2d facet's subset of the data. + + This is more convenient and less general than ``FacetGrid.map`` + + Parameters + ---------- + func : callable + A plotting function with the same signature as a 2d xray + plotting method such as `xray.plot.imshow` + x, y : string + Names of the coordinates to plot on x, y axes + kwargs : + additional keyword arguments to func + + Returns + ------- + self : FacetGrid object + + """ + + # These should be consistent with xray.plot._plot2d + cmap_kwargs = {'plot_data': self.data.values, + 'vmin': None, + 'vmax': None, + 'cmap': None, + 'center': None, + 'robust': False, + 'extend': None, + # MPL default + 'levels': 7 if 'contour' in func.__name__ else None, + 'filled': func.__name__ != 'contour', + } + + # Allow kwargs to override these defaults + for param in kwargs: + if param in cmap_kwargs: + cmap_kwargs[param] = kwargs[param] + + # colormap inference has to happen here since all the data in + # self.data is required to make the right choice + cmap_params = _determine_cmap_params(**cmap_kwargs) + + if 'contour' in func.__name__: + # extend is a keyword argument only for contour and contourf, but + # passing it to the colorbar is sufficient for imshow and + # pcolormesh + kwargs['extend'] = cmap_params['extend'] + kwargs['levels'] = cmap_params['levels'] + + defaults = { + 'add_colorbar': False, + 'add_labels': False, + 'norm': cmap_params.pop('cnorm'), + } + + # Order is important + defaults.update(cmap_params) + defaults.update(kwargs) + + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + mappable = func(subset, x, y, ax=ax, **defaults) + + # Left side labels + for ax in self.axes[:, 0]: + ax.set_ylabel(y) + + # Bottom labels + for ax in self.axes[-1, :]: + ax.set_xlabel(x) + + self.fig.tight_layout() + + if self._single_group: + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + if d is None: + ax.set_visible(False) + + # colorbar + if kwargs.get('add_colorbar', True): + cbar = self.fig.colorbar(mappable, + ax=list(self.axes.flat), + extend=cmap_params['extend']) + + if self.data.name: + cbar.set_label(self.data.name, rotation=270, + verticalalignment='bottom') + + self._x_var = x + self._y_var = y + + return self + + def set_titles(self, template="{coord} = {value}", maxchar=30, + fontsize=_FONTSIZE, **kwargs): + """ + Draw titles either above each facet or on the grid margins. + + Parameters + ---------- + template : string + Template for plot titles containing {coord} and {value} + maxchar : int + Truncate titles at maxchar + fontsize : string or int + Passed to matplotlib.text + kwargs : keyword args + additional arguments to matplotlib.text + + Returns + ------- + self: FacetGrid object + + """ + + kwargs['fontsize'] = fontsize + + nicetitle = functools.partial(_nicetitle, maxchar=maxchar, + template=template) + + if self._single_group: + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + # Only label the ones with data + if d is not None: + coord, value = list(d.items()).pop() + title = nicetitle(coord, value, maxchar=maxchar) + ax.set_title(title, **kwargs) + else: + # The row titles on the right edge of the grid + for ax, row_name in zip(self.axes[:, -1], self.row_names): + title = nicetitle(coord=self._row_var, value=row_name, + maxchar=maxchar) + ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction", + rotation=270, ha="left", va="center", **kwargs) + + # The column titles on the top row + for ax, col_name in zip(self.axes[0, :], self.col_names): + title = nicetitle(coord=self._col_var, value=col_name, + maxchar=maxchar) + ax.set_title(title, **kwargs) + + return self + + def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, + fontsize=_FONTSIZE): + """ + Set and control tick behavior + + Parameters + ---------- + max_xticks, max_yticks : int, optional + Maximum number of labeled ticks to plot on x, y axes + fontsize : string or int + Font size as used by matplotlib text + + Returns + ------- + self : FacetGrid object + + """ + from matplotlib.ticker import MaxNLocator + + # Both are necessary + x_major_locator = MaxNLocator(nbins=max_xticks) + y_major_locator = MaxNLocator(nbins=max_yticks) + + for ax in self.axes.flat: + ax.xaxis.set_major_locator(x_major_locator) + ax.yaxis.set_major_locator(y_major_locator) + for tick in itertools.chain(ax.xaxis.get_major_ticks(), + ax.yaxis.get_major_ticks()): + tick.label.set_fontsize(fontsize) + + return self + + def map(self, func, *args, **kwargs): + """ + Apply a plotting function to each facet's subset of the data. + + Parameters + ---------- + func : callable + A plotting function that takes data and keyword arguments. It + must plot to the currently active matplotlib Axes and take a + `color` keyword argument. If faceting on the `hue` dimension, + it must also take a `label` keyword argument. + args : strings + Column names in self.data that identify variables with data to + plot. The data for each variable is passed to `func` in the + order the variables are specified in the call. + kwargs : keyword arguments + All keyword arguments are passed to the plotting function. + + Returns + ------- + self : FacetGrid object + + """ + import matplotlib.pyplot as plt + + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): + if namedict is not None: + data = self.data[namedict] + plt.sca(ax) + innerargs = [data[a].values for a in args] + func(*innerargs, **kwargs) + + return self diff --git a/xray/plot/plot.py b/xray/plot/plot.py index cc65a375e10..ab18cae878f 100644 --- a/xray/plot/plot.py +++ b/xray/plot/plot.py @@ -6,8 +6,12 @@ DataArray.plot._____ """ +from __future__ import division import pkg_resources import functools +from textwrap import dedent +from itertools import cycle +from distutils.version import LooseVersion import warnings import numpy as np @@ -17,11 +21,6 @@ from ..core.pycompat import basestring -# TODO - implement this -class FacetGrid(): - pass - - # Maybe more appropriate to keep this in .utils def _right_dtype(arr, types): """ @@ -56,6 +55,47 @@ def _load_default_cmap(fname='default_colormap.csv'): return LinearSegmentedColormap.from_list('viridis', cm_data) +def _infer_xy_labels(plotfunc, darray, x, y): + """ + Determine x and y labels when some are missing. For use in _plot2d + + darray is a 2 dimensional data array. + """ + dims = list(darray.dims) + + if len(dims) != 2: + raise ValueError('{type} plots are for 2 dimensional DataArrays. ' + 'Passed DataArray has {ndim} dimensions' + .format(type=plotfunc.__name__, ndim=len(dims))) + + if x and x not in dims: + raise KeyError('{0} is not a dimension of this DataArray. Use ' + '{1} or {2} for x' + .format(x, *dims)) + + if y and y not in dims: + raise KeyError('{0} is not a dimension of this DataArray. Use ' + '{1} or {2} for y' + .format(y, *dims)) + + # Get label names + if x and y: + xlab = x + ylab = y + elif x and not y: + xlab = x + del dims[dims.index(x)] + ylab = dims.pop() + elif y and not x: + ylab = y + del dims[dims.index(y)] + xlab = dims.pop() + else: + ylab, xlab = dims + + return xlab, ylab + + def plot(darray, ax=None, rtol=0.01, **kwargs): """ Default plot of DataArray using matplotlib / pylab. @@ -218,11 +258,22 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, Adapted from Seaborn: https://github.com/mwaskom/seaborn/blob/v0.6/seaborn/matrix.py#L158 + + Parameters + ========== + plot_data: Numpy array + Doesn't handle xray objects + + Returns + ======= + cmap_params : dict + Use depends on the type of the plotting function """ ROBUST_PERCENTILE = 2.0 import matplotlib as mpl - calc_data = plot_data[~pd.isnull(plot_data)] + calc_data = np.ravel(plot_data[~pd.isnull(plot_data)]) + if vmin is None: vmin = np.percentile(calc_data, ROBUST_PERCENTILE) if robust else calc_data.min() if vmax is None: @@ -354,10 +405,10 @@ def _build_discrete_cmap(cmap, levels, extend, filled): # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods(object): - ''' + """ Enables use of xray.plot functions as attributes on a DataArray. For example, DataArray.plot.imshow - ''' + """ def __init__(self, DataArray_instance): self._da = DataArray_instance @@ -380,11 +431,15 @@ def _plot2d(plotfunc): Also adds the 2d plot method to class _PlotMethods """ - commondoc = ''' + commondoc = """ Parameters ---------- darray : DataArray - must be 2 dimensional. + Must be 2 dimensional + x : string, optional + Coordinate for x axis. If None use darray.dims[1] + y : string, optional + Coordinate for y axis. If None use darray.dims[0] ax : matplotlib axes object, optional If None, uses the current axis xincrease : None, True, or False, optional @@ -431,13 +486,13 @@ def _plot2d(plotfunc): artist : The same type of primitive artist that the wrapped matplotlib function returns - ''' + """ # Build on the original docstring plotfunc.__doc__ = '\n'.join((plotfunc.__doc__, commondoc)) @functools.wraps(plotfunc) - def newplotfunc(darray, ax=None, xincrease=None, yincrease=None, + def newplotfunc(darray, x=None, y=None, ax=None, xincrease=None, yincrease=None, add_colorbar=True, add_labels=True, vmin=None, vmax=None, cmap=None, center=None, robust=False, extend=None, levels=None, colors=None, **kwargs): @@ -463,28 +518,35 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None, if ax is None: ax = plt.gca() - # Handle the dimensions - try: - ylab, xlab = darray.dims - except ValueError: - raise ValueError('{name} plots are for 2 dimensional DataArrays. ' - 'Passed DataArray has {ndim} dimensions' - .format(name=plotfunc.__name__, ndim=len(darray.dims))) + xlab, ylab = _infer_xy_labels(plotfunc=plotfunc, darray=darray, x=x, y=y) + + # better to pass the ndarrays directly to plotting functions + xval = darray[xlab].values + yval = darray[ylab].values + zval = darray.to_masked_array(copy=False) - # some plotting functions only know how to handle ndarrays - x = darray[xlab].values - y = darray[ylab].values - z = darray.to_masked_array(copy=False) + # May need to transpose for correct x, y labels + if xlab == darray.dims[0]: + zval = zval.T - _ensure_plottable(x, y) + _ensure_plottable(xval, yval) if 'contour' in plotfunc.__name__ and levels is None: levels = 7 # this is the matplotlib default - filled = plotfunc.__name__ != 'contour' - cmap = colors if colors else cmap - cmap_params = _determine_cmap_params(z.data, vmin, vmax, cmap, center, - robust, extend, levels, filled) + cmap_kwargs = {'plot_data': zval.data, + 'vmin': vmin, + 'vmax': vmax, + 'cmap': colors if colors else cmap, + 'center': center, + 'robust': robust, + 'extend': extend, + 'levels': levels, + 'filled': plotfunc.__name__ != 'contour', + } + + cmap_params = _determine_cmap_params(**cmap_kwargs) + if 'contour' in plotfunc.__name__: # extend is a keyword argument only for contour and contourf, but # passing it to the colorbar is sufficient for imshow and @@ -495,7 +557,7 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None, # This allows the user to pass in a custom norm coming via kwargs kwargs.setdefault('norm', cmap_params['cnorm']) - ax, primitive = plotfunc(x, y, z, ax=ax, + ax, primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], @@ -518,16 +580,16 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None, # For use as DataArray.plot.plotmethod @functools.wraps(newplotfunc) - def plotmethod(_PlotMethods_obj, ax=None, xincrease=None, yincrease=None, + def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, xincrease=None, yincrease=None, add_colorbar=True, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, extend=None, levels=None, **kwargs): - ''' + """ The method should have the same signature as the function. This just makes the method work on Plotmethods objects, and passes all the other arguments straight through. - ''' + """ allargs = locals() allargs['darray'] = _PlotMethods_obj._da allargs.update(kwargs) diff --git a/xray/test/test_plot.py b/xray/test/test_plot.py index a6c56e34867..c533b8e9bba 100644 --- a/xray/test/test_plot.py +++ b/xray/test/test_plot.py @@ -21,6 +21,36 @@ pass +def text_in_fig(): + ''' + Return the set of all text in the figure + ''' + alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)] + # Set comprehension not compatible with Python 2.6 + return set(alltxt) + + +def substring_in_axes(substring, ax): + ''' + Return True if a substring is found anywhere in an axes + ''' + alltxt = set([t.get_text() for t in ax.findobj(mpl.text.Text)]) + for txt in alltxt: + if substring in txt: + return True + return False + + +def easy_array(shape, start=0, stop=1): + ''' + Make an array with desired shape using np.linspace + + shape is a tuple like (2, 3) + ''' + a = np.linspace(start, stop, num=np.prod(shape)) + return a.reshape(shape) + + @requires_matplotlib class PlotTestCase(TestCase): @@ -47,13 +77,13 @@ def contourf_called(self, plotmethod): class TestPlot(PlotTestCase): def setUp(self): - self.darray = DataArray(np.random.randn(2, 3, 4)) + self.darray = DataArray(easy_array((2, 3, 4))) def test1d(self): self.darray[:, 0, 0].plot() def test_2d_before_squeeze(self): - a = DataArray(np.arange(5).reshape(1, 5)) + a = DataArray(easy_array((1, 5))) a.plot() def test2d_uniform_calls_imshow(self): @@ -98,7 +128,7 @@ def test_ylabel_is_data_name(self): self.assertEqual(self.darray.name, plt.gca().get_ylabel()) def test_wrong_dims_raises_valueerror(self): - twodims = DataArray(np.arange(10).reshape(2, 5)) + twodims = DataArray(easy_array((2, 5))) with self.assertRaises(ValueError): twodims.plot.line() @@ -138,7 +168,7 @@ def test_slice_in_title(self): class TestPlotHistogram(PlotTestCase): def setUp(self): - self.darray = DataArray(np.random.randn(2, 3, 4)) + self.darray = DataArray(easy_array((2, 3, 4))) def test_3d_array(self): self.darray.plot.hist() @@ -148,7 +178,7 @@ def test_title_no_name(self): self.assertEqual('', plt.gca().get_title()) def test_title_uses_name(self): - self.darray.name = 'randompoints' + self.darray.name = 'testpoints' self.darray.plot.hist() self.assertIn(self.darray.name, plt.gca().get_title()) @@ -175,19 +205,21 @@ def test_plot_nans(self): @requires_matplotlib class TestDetermineCmapParams(TestCase): + + def setUp(self): + self.data = np.linspace(0, 1, num=100) + def test_robust(self): - data = np.random.RandomState(1).rand(100) - cmap_params = _determine_cmap_params(data, robust=True) - self.assertEqual(cmap_params['vmin'], np.percentile(data, 2)) - self.assertEqual(cmap_params['vmax'], np.percentile(data, 98)) + cmap_params = _determine_cmap_params(self.data, robust=True) + self.assertEqual(cmap_params['vmin'], np.percentile(self.data, 2)) + self.assertEqual(cmap_params['vmax'], np.percentile(self.data, 98)) self.assertEqual(cmap_params['cmap'].name, 'viridis') self.assertEqual(cmap_params['extend'], 'both') self.assertIsNone(cmap_params['levels']) self.assertIsNone(cmap_params['cnorm']) def test_center(self): - data = np.random.RandomState(2).rand(100) - cmap_params = _determine_cmap_params(data, center=0.5) + cmap_params = _determine_cmap_params(self.data, center=0.5) self.assertEqual(cmap_params['vmax'] - 0.5, 0.5 - cmap_params['vmin']) self.assertEqual(cmap_params['cmap'], 'RdBu_r') self.assertEqual(cmap_params['extend'], 'neither') @@ -195,7 +227,7 @@ def test_center(self): self.assertIsNone(cmap_params['cnorm']) def test_integer_levels(self): - data = 1 + np.random.RandomState(3).rand(100) + data = self.data + 1 cmap_params = _determine_cmap_params(data, levels=5, vmin=0, vmax=5, cmap='Blues') self.assertEqual(cmap_params['vmin'], cmap_params['levels'][0]) @@ -211,7 +243,7 @@ def test_integer_levels(self): self.assertEqual(cmap_params['extend'], 'max') def test_list_levels(self): - data = 1 + np.random.RandomState(3).rand(100) + data = self.data + 1 orig_levels = [0, 1, 2, 3, 4, 5] # vmin and vmax should be ignored if levels are explicitly provided @@ -230,6 +262,7 @@ def test_list_levels(self): @requires_matplotlib class TestDiscreteColorMap(TestCase): + def setUp(self): x = np.arange(start=0, stop=10, step=2) y = np.arange(start=9, stop=-7, step=-3) @@ -314,9 +347,10 @@ class Common2dMixin: These tests assume that a staticmethod for `self.plotfunc` exists. Should have the same name as the method. """ + def setUp(self): - rs = np.random.RandomState(123) - self.darray = DataArray(rs.randn(10, 15), dims=['y', 'x']) + self.darray = DataArray(easy_array( + (10, 15), start=-1), dims=['y', 'x']) self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) def test_label_names(self): @@ -329,12 +363,12 @@ def test_1d_raises_valueerror(self): self.plotfunc(self.darray[0, :]) def test_3d_raises_valueerror(self): - a = DataArray(np.random.randn(2, 3, 4)) + a = DataArray(easy_array((2, 3, 4))) with self.assertRaisesRegexp(ValueError, r'[Dd]im'): self.plotfunc(a) def test_nonnumeric_index_raises_typeerror(self): - a = DataArray(np.random.randn(3, 2), + a = DataArray(easy_array((3, 2)), coords=[['a', 'b', 'c'], ['d', 'e']]) with self.assertRaisesRegexp(TypeError, r'[Pp]lot'): self.plotfunc(a) @@ -380,7 +414,7 @@ def test_seaborn_palette_as_cmap(self): try: import seaborn cmap_name = self.plotmethod( - levels=2, cmap='husl').get_cmap().name + levels=2, cmap='husl').get_cmap().name self.assertEqual('husl', cmap_name) except ImportError: pass @@ -394,8 +428,34 @@ def test_diverging_color_limits(self): vmin, vmax = artist.get_clim() self.assertAlmostEqual(-vmin, vmax) + def test_xy_strings(self): + self.plotmethod('y', 'x') + ax = plt.gca() + self.assertEqual('y', ax.get_xlabel()) + self.assertEqual('x', ax.get_ylabel()) + + def test_positional_x_string(self): + self.plotmethod('y') + ax = plt.gca() + self.assertEqual('y', ax.get_xlabel()) + self.assertEqual('x', ax.get_ylabel()) + + def test_y_string(self): + self.plotmethod(y='x') + ax = plt.gca() + self.assertEqual('y', ax.get_xlabel()) + self.assertEqual('x', ax.get_ylabel()) + + def test_bad_x_string_exception(self): + with self.assertRaisesRegexp(KeyError, r'y'): + self.plotmethod('not_a_real_dim') + + self.darray.coords['z'] = 100 + with self.assertRaisesRegexp(KeyError, r'y'): + self.plotmethod('z') + def test_default_title(self): - a = DataArray(np.random.randn(4, 3, 2), dims=['a', 'b', 'c']) + a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c']) a.coords['d'] = 10 self.plotfunc(a.isel(c=1)) title = plt.gca().get_title() @@ -404,19 +464,24 @@ def test_default_title(self): def test_colorbar_label(self): self.darray.name = 'testvar' self.plotmethod() - alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)] - # Set comprehension not compatible with Python 2.6 - alltxt = set(alltxt) - self.assertIn(self.darray.name, alltxt) + self.assertIn(self.darray.name, text_in_fig()) def test_no_labels(self): self.darray.name = 'testvar' self.plotmethod(add_labels=False) - alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)] - alltxt = set(alltxt) + alltxt = text_in_fig() for string in ['x', 'y', 'testvar']: self.assertNotIn(string, alltxt) + def test_facetgrid(self): + a = easy_array((10, 15, 3)) + d = DataArray(a, dims=['y', 'x', 'z']) + g = xplt.FacetGrid(d, col='z') + g.map_dataarray(self.plotfunc, 'x', 'y') + for ax in g.axes.flat: + self.assertTrue(ax.has_data()) + + class TestContourf(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.contourf) @@ -434,9 +499,13 @@ def test_extend(self): artist = self.plotmethod() self.assertEqual(artist.extend, 'neither') + self.darray[0, 0] = -100 + self.darray[-1, -1] = 100 artist = self.plotmethod(robust=True) self.assertEqual(artist.extend, 'both') + self.darray[0, 0] = 0 + self.darray[-1, -1] = 0 artist = self.plotmethod(vmin=-0, vmax=10) self.assertEqual(artist.extend, 'min') @@ -462,21 +531,22 @@ def _color_as_tuple(c): return tuple(c[:3]) artist = self.plotmethod(colors='k') self.assertEqual( - _color_as_tuple(artist.cmap.colors[0]), - (0.0,0.0,0.0)) + _color_as_tuple(artist.cmap.colors[0]), + (0.0, 0.0, 0.0)) - artist = self.plotmethod(colors=['k','b']) + artist = self.plotmethod(colors=['k', 'b']) self.assertEqual( - _color_as_tuple(artist.cmap.colors[1]), - (0.0,0.0,1.0)) + _color_as_tuple(artist.cmap.colors[1]), + (0.0, 0.0, 1.0)) def test_cmap_and_color_both(self): - with self.assertRaises(ValueError): + with self.assertRaises(ValueError): self.plotmethod(colors='k', cmap='RdBu') def list_of_colors_in_cmap_deprecated(self): - with self.assertRaises(DeprecationError): - self.plotmethod(cmap=['k','b']) + with self.assertRaises(Exception): + self.plotmethod(cmap=['k', 'b']) + class TestPcolormesh(Common2dMixin, PlotTestCase): @@ -524,3 +594,176 @@ def test_seaborn_palette_needs_levels(self): self.plotmethod(cmap='husl') except ImportError: pass + + +class TestFacetGrid(PlotTestCase): + + def setUp(self): + d = easy_array((10, 15, 3)) + self.darray = DataArray(d, dims=['y', 'x', 'z']) + self.g = xplt.FacetGrid(self.darray, col='z') + + def test_no_args(self): + self.g.map_dataarray(xplt.contourf, 'x', 'y') + + # Don't want colorbar labeled with 'None' + alltxt = text_in_fig() + self.assertNotIn('None', alltxt) + + for ax in self.g.axes.flat: + self.assertTrue(ax.has_data()) + + # default font size should be small + fontsize = ax.title.get_size() + self.assertLessEqual(fontsize, 12) + + def test_names_appear_somewhere(self): + self.darray.name = 'testvar' + self.g.map_dataarray(xplt.contourf, 'x', 'y') + for i, ax in enumerate(self.g.axes.flat): + self.assertEqual('z = {0}'.format(i), ax.get_title()) + + alltxt = text_in_fig() + self.assertIn(self.darray.name, alltxt) + for label in ['x', 'y']: + self.assertIn(label, alltxt) + + def test_text_not_super_long(self): + self.darray.coords['z'] = [100 * letter for letter in 'abc'] + g = xplt.FacetGrid(self.darray, col='z') + g.map_dataarray(xplt.contour, 'x', 'y') + alltxt = text_in_fig() + maxlen = max(len(txt) for txt in alltxt) + self.assertLess(maxlen, 50) + + t0 = g.axes[0, 0].get_title() + self.assertTrue(t0.endswith('...')) + + def test_colorbar(self): + vmin = self.darray.values.min() + vmax = self.darray.values.max() + expected = np.array((vmin, vmax)) + + self.g.map_dataarray(xplt.imshow, 'x', 'y') + + for image in plt.gcf().findobj(mpl.image.AxesImage): + clim = np.array(image.get_clim()) + self.assertTrue(np.allclose(expected, clim)) + + # There's only one colorbar + cbar = plt.gcf().findobj(mpl.collections.QuadMesh) + self.assertEqual(1, len(cbar)) + + def test_empty_cell(self): + g = xplt.FacetGrid(self.darray, col='z', col_wrap=2) + g.map_dataarray(xplt.imshow, 'x', 'y') + + bottomright = g.axes[-1, -1] + self.assertFalse(bottomright.has_data()) + self.assertFalse(bottomright.get_visible()) + + def test_norow_nocol_error(self): + with self.assertRaisesRegexp(ValueError, r'[Rr]ow'): + xplt.FacetGrid(self.darray) + + def test_groups(self): + self.g.map_dataarray(xplt.imshow, 'x', 'y') + upperleft_dict = self.g.name_dicts[0, 0] + upperleft_array = self.darray[upperleft_dict] + z0 = self.darray.isel(z=0) + + self.assertDataArrayEqual(upperleft_array, z0) + + def test_float_index(self): + self.darray.coords['z'] = [0.1, 0.2, 0.4] + g = xplt.FacetGrid(self.darray, col='z') + g.map_dataarray(xplt.imshow, 'x', 'y') + + def test_nonunique_index_error(self): + self.darray.coords['z'] = [0.1, 0.2, 0.2] + with self.assertRaisesRegexp(ValueError, r'[Uu]nique'): + xplt.FacetGrid(self.darray, col='z') + + def test_robust(self): + z = np.zeros((20, 20, 2)) + darray = DataArray(z, dims=['y', 'x', 'z']) + darray[:, :, 1] = 1 + darray[2, 0, 0] = -1000 + darray[3, 0, 0] = 1000 + g = xplt.FacetGrid(darray, col='z') + g.map_dataarray(xplt.imshow, 'x', 'y', robust=True) + + # Color limits should be 0, 1 + # The largest number displayed in the figure should be less than 21 + numbers = set() + alltxt = text_in_fig() + for txt in alltxt: + try: + numbers.add(float(txt)) + except ValueError: + pass + largest = max(abs(x) for x in numbers) + self.assertLess(largest, 21) + + def test_can_set_vmin_vmax(self): + vmin, vmax = 50.0, 1000.0 + expected = np.array((vmin, vmax)) + self.g.map_dataarray(xplt.imshow, 'x', 'y', vmin=vmin, vmax=vmax) + + for image in plt.gcf().findobj(mpl.image.AxesImage): + clim = np.array(image.get_clim()) + self.assertTrue(np.allclose(expected, clim)) + + def test_figure_size(self): + + self.assertArrayEqual(self.g.fig.get_size_inches(), (10, 3)) + + g = xplt.FacetGrid(self.darray, col='z', size=6) + self.assertArrayEqual(g.fig.get_size_inches(), (19, 6)) + + g = xplt.FacetGrid(self.darray, col='z', size=4, aspect=0.5) + self.assertArrayEqual(g.fig.get_size_inches(), (7, 4)) + + def test_num_ticks(self): + nticks = 100 + maxticks = nticks + 1 + self.g.map_dataarray(xplt.imshow, 'x', 'y') + self.g.set_ticks(max_xticks=nticks, max_yticks=nticks) + + for ax in self.g.axes.flat: + xticks = len(ax.get_xticks()) + yticks = len(ax.get_yticks()) + self.assertLessEqual(xticks, maxticks) + self.assertLessEqual(yticks, maxticks) + self.assertGreaterEqual(xticks, nticks / 2.0) + self.assertGreaterEqual(yticks, nticks / 2.0) + + def test_map(self): + self.g.map(plt.contourf, 'x', 'y', Ellipsis) + + +class TestFacetGrid4d(PlotTestCase): + + def setUp(self): + a = easy_array((10, 15, 3, 2)) + darray = DataArray(a, dims=['y', 'x', 'col', 'row']) + darray.coords['col'] = np.array(['col' + str(x) for x in + darray.coords['col'].values]) + darray.coords['row'] = np.array(['row' + str(x) for x in + darray.coords['row'].values]) + + self.darray = darray + + def test_default_labels(self): + g = xplt.FacetGrid(self.darray, col='col', row='row') + self.assertEqual((2, 3), g.axes.shape) + + g.map_dataarray(xplt.imshow, 'x', 'y') + + # Rightmost column should be labeled + for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]): + self.assertTrue(substring_in_axes(label, ax)) + + # Top row should be labeled + for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]): + self.assertTrue(substring_in_axes(label, ax))