diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst
index 4016eba65cd..a085970ee09 100644
--- a/doc/api-hidden.rst
+++ b/doc/api-hidden.rst
@@ -143,3 +143,8 @@
ufuncs.tan
ufuncs.tanh
ufuncs.trunc
+
+ plot.FacetGrid.map_dataarray
+ plot.FacetGrid.set_titles
+ plot.FacetGrid.set_ticks
+ plot.FacetGrid.map
diff --git a/doc/api.rst b/doc/api.rst
index f141ee43378..582df933d98 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -411,3 +411,5 @@ Plotting
plot.imshow
plot.line
plot.pcolormesh
+ plot.FacetGrid
+
diff --git a/doc/plotting.rst b/doc/plotting.rst
index 5ccc3dd855a..78de158e2f5 100644
--- a/doc/plotting.rst
+++ b/doc/plotting.rst
@@ -16,6 +16,9 @@ Xray's plotting capabilities are centered around
:py:class:`xray.DataArray` objects.
To plot :py:class:`xray.Dataset` objects
simply access the relevant DataArrays, ie ``dset['var1']``.
+Here we focus mostly on arrays 2d or larger. If your data fits
+nicely into a pandas DataFrame then you're better off using one of the more
+developed tools there.
Xray plotting functionality is a thin wrapper around the popular
`matplotlib `_ library.
@@ -51,10 +54,11 @@ The following imports are necessary for all of the examples.
.. ipython:: python
import numpy as np
+ import pandas as pd
import matplotlib.pyplot as plt
import xray
-We'll use the North American air temperature dataset.
+For these examples we'll use the North American air temperature dataset.
.. ipython:: python
@@ -306,6 +310,125 @@ since levels are chosen automatically).
air2d.plot(levels=10, cmap='husl')
+Faceting
+--------
+
+Faceting here refers to splitting an array along one or two dimensions and
+plotting each group.
+Xray's basic plotting is useful for plotting two dimensional arrays. What
+about three or four dimensional arrays? That's where facets become helpful.
+
+Consider the temperature data set. There are 4 observations per day for two
+years which makes for 2920 values along the time dimension.
+One way to visualize this data is to make a
+seperate plot for each time period.
+
+The faceted dimension should not have too many values;
+faceting on the time dimension will produce 2920 plots. That's
+too much to be helpful. To handle this situation try performing
+an operation that reduces the size of the data in some way. For example, we
+could compute the average air temperature for each month and reduce the
+size of this dimension from 2920 -> 12. A simpler way is
+to just take a slice on that dimension.
+So let's use a slice to pick 6 times throughout the first year.
+
+.. ipython:: python
+
+ t = air.isel(time=slice(0, 365 * 4, 250))
+ t.coords
+
+Simple Example
+~~~~~~~~~~~~~~
+
+TODO - replace with the convenience method from plot
+
+We can use :py:meth:`xray.plot.FacetGrid.map_dataarray` on a DataArray:
+
+.. ipython:: python
+
+ g = xray.plot.FacetGrid(t, col='time', col_wrap=3)
+
+ @savefig plot_facet_dataarray.png height=12in
+ g.map_dataarray(xray.plot.imshow, 'lon', 'lat')
+
+FacetGrid Objects
+~~~~~~~~~~~~~~~~~
+
+:py:class:`xray.plot.FacetGrid` is used to control the behavior of the
+multiple plots.
+It borrows an API and code from `Seaborn
+`_.
+The structure is contained within the ``axes`` and ``name_dicts``
+attributes, both 2d Numpy object arrays.
+
+.. ipython:: python
+
+ g.axes
+
+ g.name_dicts
+
+It's possible to select the :py:class:`xray.DataArray` corresponding to the FacetGrid
+through the ``name_dicts``.
+
+.. ipython:: python
+
+ g.data.loc[g.name_dicts[0, 0]]
+
+Here is an example of modifying the axes after they have been plotted.
+
+.. ipython:: python
+
+ g = (xray.plot
+ .FacetGrid(t, col='time', col_wrap=3)
+ .map_dataarray(xray.plot.imshow, 'lon', 'lat')
+ )
+
+ for i, ax in enumerate(g.axes.flat):
+ ax.set_title('Air Temperature %d' % i)
+
+ bottomright = g.axes[-1, -1]
+ bottomright.annotate('bottom right', (240, 40))
+
+ @savefig plot_facet_iterator.png height=12in
+ plt.show()
+
+4 dimensional
+~~~~~~~~~~~~~~
+
+For 4 dimensional arrays we can use the rows and columns of the grids.
+Here we create a 4 dimensional array by taking the original data and adding
+a fixed amount. Now we can see how the temperature maps would compare if
+one were much hotter.
+
+.. ipython:: python
+
+ t2 = t.isel(time=slice(0, 2))
+ t4d = xray.concat([t2, t2 + 40], pd.Index(['normal', 'hot'], name='fourth_dim'))
+ # This is a 4d array
+ t4d.coords
+
+ g = xray.plot.FacetGrid(t4d, col='time', row='fourth_dim')
+
+ @savefig plot_facet_4d.png height=12in
+ g.map_dataarray(xray.plot.imshow, 'lon', 'lat')
+
+Other features
+~~~~~~~~~~~~~~
+
+Faceted plotting supports other arguments common to xray 2d plots.
+
+.. ipython:: python
+
+ hasoutliers = t.isel(time=slice(0, 5)).copy()
+ hasoutliers[0, 0, 0] = -100
+ hasoutliers[-1, -1, -1] = 400
+
+ g = xray.plot.FacetGrid(hasoutliers, col='time', col_wrap=3)
+
+ @savefig plot_facet_robust.png height=12in
+ g.map_dataarray(xray.plot.contourf, 'lon', 'lat', robust=True, cmap='viridis')
+
+
Maps
----
diff --git a/xray/plot/__init__.py b/xray/plot/__init__.py
index ae2f6ad8f6f..3cd075b6b90 100644
--- a/xray/plot/__init__.py
+++ b/xray/plot/__init__.py
@@ -1,2 +1,4 @@
from .plot import (plot, line, contourf, contour,
hist, imshow, pcolormesh)
+
+from .facetgrid import FacetGrid
diff --git a/xray/plot/facetgrid.py b/xray/plot/facetgrid.py
new file mode 100644
index 00000000000..4e335f3b1d1
--- /dev/null
+++ b/xray/plot/facetgrid.py
@@ -0,0 +1,389 @@
+from __future__ import division
+
+import warnings
+import itertools
+import functools
+
+import numpy as np
+
+from ..core.formatting import format_item
+from .plot import _determine_cmap_params
+
+
+# Overrides axes.labelsize, xtick.major.size, ytick.major.size
+# from mpl.rcParams
+_FONTSIZE = 'small'
+# For major ticks on x, y axes
+_NTICKS = 5
+
+
+def _nicetitle(coord, value, maxchar, template):
+ """
+ Put coord, value in template and truncate at maxchar
+ """
+ prettyvalue = format_item(value)
+ title = template.format(coord=coord, value=prettyvalue)
+
+ if len(title) > maxchar:
+ title = title[:(maxchar - 3)] + '...'
+
+ return title
+
+
+class FacetGrid(object):
+ """
+ Initialize the matplotlib figure and FacetGrid object.
+
+ The :class:`FacetGrid` is an object that links a xray DataArray to
+ a matplotlib figure with a particular structure.
+
+ In particular, :class:`FacetGrid` is used to draw plots with multiple
+ Axes where each Axes shows the same relationship conditioned on
+ different levels of some dimension. It's possible to condition on up to
+ two variables by assigning variables to the rows and columns of the
+ grid.
+
+ The general approach to plotting here is called "small multiples",
+ where the same kind of plot is repeated multiple times, and the
+ specific use of small multiples to display the same relationship
+ conditioned on one ore more other variables is often called a "trellis
+ plot".
+
+ The basic workflow is to initialize the :class:`FacetGrid` object with
+ the DataArray and the variable names that are used to structure the grid.
+ Then plotting functions can be applied to each subset by calling
+ :meth:`FacetGrid.map_dataarray` or :meth:`FacetGrid.map`.
+
+ Attributes
+ ----------
+ axes : numpy object array
+ Contains axes in corresponding position, as returned from
+ plt.subplots
+ fig : matplotlib.Figure
+ The figure containing all the axes
+ name_dicts : numpy object array
+ Contains dictionaries mapping coordinate names to values. None is
+ used as a sentinel value for axes which should remain empty, ie.
+ sometimes the bottom right grid
+
+ """
+
+ def __init__(self, data, col=None, row=None, col_wrap=None,
+ aspect=1, size=3):
+ """
+ Parameters
+ ----------
+ data : DataArray
+ xray DataArray to be plotted
+ row, col : strings
+ Dimesion names that define subsets of the data, which will be drawn
+ on separate facets in the grid.
+ col_wrap : int, optional
+ "Wrap" the column variable at this width, so that the column facets
+ aspect : scalar, optional
+ Aspect ratio of each facet, so that ``aspect * size`` gives the
+ width of each facet in inches
+ size : scalar, optional
+ Height (in inches) of each facet. See also: ``aspect``
+
+ """
+
+ import matplotlib.pyplot as plt
+
+ # Handle corner case of nonunique coordinates
+ rep_col = col is not None and not data[col].to_index().is_unique
+ rep_row = row is not None and not data[row].to_index().is_unique
+ if rep_col or rep_row:
+ raise ValueError('Coordinates used for faceting cannot '
+ 'contain repeated (nonunique) values.')
+
+ # single_group is the grouping variable, if there is exactly one
+ if col and row:
+ single_group = False
+ nrow = len(data[row])
+ ncol = len(data[col])
+ nfacet = nrow * ncol
+ if col_wrap is not None:
+ warnings.warn('Ignoring col_wrap since both col and row '
+ 'were passed')
+ elif row and not col:
+ single_group = row
+ elif not row and col:
+ single_group = col
+ else:
+ raise ValueError(
+ 'Pass a coordinate name as an argument for row or col')
+
+ # Compute grid shape
+ if single_group:
+ nfacet = len(data[single_group])
+ if col:
+ # idea - could add heuristic for nice shapes like 3x4
+ ncol = nfacet
+ if row:
+ ncol = 1
+ if col_wrap is not None:
+ # Overrides previous settings
+ ncol = col_wrap
+ nrow = int(np.ceil(nfacet / ncol))
+
+ # Calculate the base figure size with extra horizontal space for a
+ # colorbar
+ cbar_space = 1
+ figsize = (ncol * size * aspect +
+ cbar_space, nrow * size)
+
+ fig, axes = plt.subplots(nrow, ncol,
+ sharex=True, sharey=True,
+ squeeze=False, figsize=figsize)
+
+ # Set up the lists of names for the row and column facet variables
+ col_names = list(data[col].values) if col else []
+ row_names = list(data[row].values) if row else []
+
+ if single_group:
+ full = [{single_group: x} for x in
+ data[single_group].values]
+ empty = [None for x in range(nrow * ncol - len(full))]
+ name_dicts = full + empty
+ else:
+ rowcols = itertools.product(row_names, col_names)
+ name_dicts = [{row: r, col: c} for r, c in rowcols]
+
+ name_dicts = np.array(name_dicts).reshape(nrow, ncol)
+
+ # Set up the class attributes
+ # ---------------------------
+
+ # First the public API
+ self.data = data
+ self.name_dicts = name_dicts
+ self.fig = fig
+ self.axes = axes
+ self.row_names = row_names
+ self.col_names = col_names
+
+ # Next the private variables
+ self._single_group = single_group
+ self._nrow = nrow
+ self._row_var = row
+ self._ncol = ncol
+ self._col_var = col
+ self._col_wrap = col_wrap
+ self._x_var = None
+ self._y_var = None
+
+ self.set_titles()
+
+ def map_dataarray(self, func, x, y, **kwargs):
+ """
+ Apply a plotting function to a 2d facet's subset of the data.
+
+ This is more convenient and less general than ``FacetGrid.map``
+
+ Parameters
+ ----------
+ func : callable
+ A plotting function with the same signature as a 2d xray
+ plotting method such as `xray.plot.imshow`
+ x, y : string
+ Names of the coordinates to plot on x, y axes
+ kwargs :
+ additional keyword arguments to func
+
+ Returns
+ -------
+ self : FacetGrid object
+
+ """
+
+ # These should be consistent with xray.plot._plot2d
+ cmap_kwargs = {'plot_data': self.data.values,
+ 'vmin': None,
+ 'vmax': None,
+ 'cmap': None,
+ 'center': None,
+ 'robust': False,
+ 'extend': None,
+ # MPL default
+ 'levels': 7 if 'contour' in func.__name__ else None,
+ 'filled': func.__name__ != 'contour',
+ }
+
+ # Allow kwargs to override these defaults
+ for param in kwargs:
+ if param in cmap_kwargs:
+ cmap_kwargs[param] = kwargs[param]
+
+ # colormap inference has to happen here since all the data in
+ # self.data is required to make the right choice
+ cmap_params = _determine_cmap_params(**cmap_kwargs)
+
+ if 'contour' in func.__name__:
+ # extend is a keyword argument only for contour and contourf, but
+ # passing it to the colorbar is sufficient for imshow and
+ # pcolormesh
+ kwargs['extend'] = cmap_params['extend']
+ kwargs['levels'] = cmap_params['levels']
+
+ defaults = {
+ 'add_colorbar': False,
+ 'add_labels': False,
+ 'norm': cmap_params.pop('cnorm'),
+ }
+
+ # Order is important
+ defaults.update(cmap_params)
+ defaults.update(kwargs)
+
+ for d, ax in zip(self.name_dicts.flat, self.axes.flat):
+ # None is the sentinel value
+ if d is not None:
+ subset = self.data.loc[d]
+ mappable = func(subset, x, y, ax=ax, **defaults)
+
+ # Left side labels
+ for ax in self.axes[:, 0]:
+ ax.set_ylabel(y)
+
+ # Bottom labels
+ for ax in self.axes[-1, :]:
+ ax.set_xlabel(x)
+
+ self.fig.tight_layout()
+
+ if self._single_group:
+ for d, ax in zip(self.name_dicts.flat, self.axes.flat):
+ if d is None:
+ ax.set_visible(False)
+
+ # colorbar
+ if kwargs.get('add_colorbar', True):
+ cbar = self.fig.colorbar(mappable,
+ ax=list(self.axes.flat),
+ extend=cmap_params['extend'])
+
+ if self.data.name:
+ cbar.set_label(self.data.name, rotation=270,
+ verticalalignment='bottom')
+
+ self._x_var = x
+ self._y_var = y
+
+ return self
+
+ def set_titles(self, template="{coord} = {value}", maxchar=30,
+ fontsize=_FONTSIZE, **kwargs):
+ """
+ Draw titles either above each facet or on the grid margins.
+
+ Parameters
+ ----------
+ template : string
+ Template for plot titles containing {coord} and {value}
+ maxchar : int
+ Truncate titles at maxchar
+ fontsize : string or int
+ Passed to matplotlib.text
+ kwargs : keyword args
+ additional arguments to matplotlib.text
+
+ Returns
+ -------
+ self: FacetGrid object
+
+ """
+
+ kwargs['fontsize'] = fontsize
+
+ nicetitle = functools.partial(_nicetitle, maxchar=maxchar,
+ template=template)
+
+ if self._single_group:
+ for d, ax in zip(self.name_dicts.flat, self.axes.flat):
+ # Only label the ones with data
+ if d is not None:
+ coord, value = list(d.items()).pop()
+ title = nicetitle(coord, value, maxchar=maxchar)
+ ax.set_title(title, **kwargs)
+ else:
+ # The row titles on the right edge of the grid
+ for ax, row_name in zip(self.axes[:, -1], self.row_names):
+ title = nicetitle(coord=self._row_var, value=row_name,
+ maxchar=maxchar)
+ ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction",
+ rotation=270, ha="left", va="center", **kwargs)
+
+ # The column titles on the top row
+ for ax, col_name in zip(self.axes[0, :], self.col_names):
+ title = nicetitle(coord=self._col_var, value=col_name,
+ maxchar=maxchar)
+ ax.set_title(title, **kwargs)
+
+ return self
+
+ def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS,
+ fontsize=_FONTSIZE):
+ """
+ Set and control tick behavior
+
+ Parameters
+ ----------
+ max_xticks, max_yticks : int, optional
+ Maximum number of labeled ticks to plot on x, y axes
+ fontsize : string or int
+ Font size as used by matplotlib text
+
+ Returns
+ -------
+ self : FacetGrid object
+
+ """
+ from matplotlib.ticker import MaxNLocator
+
+ # Both are necessary
+ x_major_locator = MaxNLocator(nbins=max_xticks)
+ y_major_locator = MaxNLocator(nbins=max_yticks)
+
+ for ax in self.axes.flat:
+ ax.xaxis.set_major_locator(x_major_locator)
+ ax.yaxis.set_major_locator(y_major_locator)
+ for tick in itertools.chain(ax.xaxis.get_major_ticks(),
+ ax.yaxis.get_major_ticks()):
+ tick.label.set_fontsize(fontsize)
+
+ return self
+
+ def map(self, func, *args, **kwargs):
+ """
+ Apply a plotting function to each facet's subset of the data.
+
+ Parameters
+ ----------
+ func : callable
+ A plotting function that takes data and keyword arguments. It
+ must plot to the currently active matplotlib Axes and take a
+ `color` keyword argument. If faceting on the `hue` dimension,
+ it must also take a `label` keyword argument.
+ args : strings
+ Column names in self.data that identify variables with data to
+ plot. The data for each variable is passed to `func` in the
+ order the variables are specified in the call.
+ kwargs : keyword arguments
+ All keyword arguments are passed to the plotting function.
+
+ Returns
+ -------
+ self : FacetGrid object
+
+ """
+ import matplotlib.pyplot as plt
+
+ for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
+ if namedict is not None:
+ data = self.data[namedict]
+ plt.sca(ax)
+ innerargs = [data[a].values for a in args]
+ func(*innerargs, **kwargs)
+
+ return self
diff --git a/xray/plot/plot.py b/xray/plot/plot.py
index cc65a375e10..ab18cae878f 100644
--- a/xray/plot/plot.py
+++ b/xray/plot/plot.py
@@ -6,8 +6,12 @@
DataArray.plot._____
"""
+from __future__ import division
import pkg_resources
import functools
+from textwrap import dedent
+from itertools import cycle
+from distutils.version import LooseVersion
import warnings
import numpy as np
@@ -17,11 +21,6 @@
from ..core.pycompat import basestring
-# TODO - implement this
-class FacetGrid():
- pass
-
-
# Maybe more appropriate to keep this in .utils
def _right_dtype(arr, types):
"""
@@ -56,6 +55,47 @@ def _load_default_cmap(fname='default_colormap.csv'):
return LinearSegmentedColormap.from_list('viridis', cm_data)
+def _infer_xy_labels(plotfunc, darray, x, y):
+ """
+ Determine x and y labels when some are missing. For use in _plot2d
+
+ darray is a 2 dimensional data array.
+ """
+ dims = list(darray.dims)
+
+ if len(dims) != 2:
+ raise ValueError('{type} plots are for 2 dimensional DataArrays. '
+ 'Passed DataArray has {ndim} dimensions'
+ .format(type=plotfunc.__name__, ndim=len(dims)))
+
+ if x and x not in dims:
+ raise KeyError('{0} is not a dimension of this DataArray. Use '
+ '{1} or {2} for x'
+ .format(x, *dims))
+
+ if y and y not in dims:
+ raise KeyError('{0} is not a dimension of this DataArray. Use '
+ '{1} or {2} for y'
+ .format(y, *dims))
+
+ # Get label names
+ if x and y:
+ xlab = x
+ ylab = y
+ elif x and not y:
+ xlab = x
+ del dims[dims.index(x)]
+ ylab = dims.pop()
+ elif y and not x:
+ ylab = y
+ del dims[dims.index(y)]
+ xlab = dims.pop()
+ else:
+ ylab, xlab = dims
+
+ return xlab, ylab
+
+
def plot(darray, ax=None, rtol=0.01, **kwargs):
"""
Default plot of DataArray using matplotlib / pylab.
@@ -218,11 +258,22 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
Adapted from Seaborn:
https://github.com/mwaskom/seaborn/blob/v0.6/seaborn/matrix.py#L158
+
+ Parameters
+ ==========
+ plot_data: Numpy array
+ Doesn't handle xray objects
+
+ Returns
+ =======
+ cmap_params : dict
+ Use depends on the type of the plotting function
"""
ROBUST_PERCENTILE = 2.0
import matplotlib as mpl
- calc_data = plot_data[~pd.isnull(plot_data)]
+ calc_data = np.ravel(plot_data[~pd.isnull(plot_data)])
+
if vmin is None:
vmin = np.percentile(calc_data, ROBUST_PERCENTILE) if robust else calc_data.min()
if vmax is None:
@@ -354,10 +405,10 @@ def _build_discrete_cmap(cmap, levels, extend, filled):
# MUST run before any 2d plotting functions are defined since
# _plot2d decorator adds them as methods here.
class _PlotMethods(object):
- '''
+ """
Enables use of xray.plot functions as attributes on a DataArray.
For example, DataArray.plot.imshow
- '''
+ """
def __init__(self, DataArray_instance):
self._da = DataArray_instance
@@ -380,11 +431,15 @@ def _plot2d(plotfunc):
Also adds the 2d plot method to class _PlotMethods
"""
- commondoc = '''
+ commondoc = """
Parameters
----------
darray : DataArray
- must be 2 dimensional.
+ Must be 2 dimensional
+ x : string, optional
+ Coordinate for x axis. If None use darray.dims[1]
+ y : string, optional
+ Coordinate for y axis. If None use darray.dims[0]
ax : matplotlib axes object, optional
If None, uses the current axis
xincrease : None, True, or False, optional
@@ -431,13 +486,13 @@ def _plot2d(plotfunc):
artist :
The same type of primitive artist that the wrapped matplotlib
function returns
- '''
+ """
# Build on the original docstring
plotfunc.__doc__ = '\n'.join((plotfunc.__doc__, commondoc))
@functools.wraps(plotfunc)
- def newplotfunc(darray, ax=None, xincrease=None, yincrease=None,
+ def newplotfunc(darray, x=None, y=None, ax=None, xincrease=None, yincrease=None,
add_colorbar=True, add_labels=True, vmin=None, vmax=None, cmap=None,
center=None, robust=False, extend=None, levels=None, colors=None,
**kwargs):
@@ -463,28 +518,35 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None,
if ax is None:
ax = plt.gca()
- # Handle the dimensions
- try:
- ylab, xlab = darray.dims
- except ValueError:
- raise ValueError('{name} plots are for 2 dimensional DataArrays. '
- 'Passed DataArray has {ndim} dimensions'
- .format(name=plotfunc.__name__, ndim=len(darray.dims)))
+ xlab, ylab = _infer_xy_labels(plotfunc=plotfunc, darray=darray, x=x, y=y)
+
+ # better to pass the ndarrays directly to plotting functions
+ xval = darray[xlab].values
+ yval = darray[ylab].values
+ zval = darray.to_masked_array(copy=False)
- # some plotting functions only know how to handle ndarrays
- x = darray[xlab].values
- y = darray[ylab].values
- z = darray.to_masked_array(copy=False)
+ # May need to transpose for correct x, y labels
+ if xlab == darray.dims[0]:
+ zval = zval.T
- _ensure_plottable(x, y)
+ _ensure_plottable(xval, yval)
if 'contour' in plotfunc.__name__ and levels is None:
levels = 7 # this is the matplotlib default
- filled = plotfunc.__name__ != 'contour'
- cmap = colors if colors else cmap
- cmap_params = _determine_cmap_params(z.data, vmin, vmax, cmap, center,
- robust, extend, levels, filled)
+ cmap_kwargs = {'plot_data': zval.data,
+ 'vmin': vmin,
+ 'vmax': vmax,
+ 'cmap': colors if colors else cmap,
+ 'center': center,
+ 'robust': robust,
+ 'extend': extend,
+ 'levels': levels,
+ 'filled': plotfunc.__name__ != 'contour',
+ }
+
+ cmap_params = _determine_cmap_params(**cmap_kwargs)
+
if 'contour' in plotfunc.__name__:
# extend is a keyword argument only for contour and contourf, but
# passing it to the colorbar is sufficient for imshow and
@@ -495,7 +557,7 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None,
# This allows the user to pass in a custom norm coming via kwargs
kwargs.setdefault('norm', cmap_params['cnorm'])
- ax, primitive = plotfunc(x, y, z, ax=ax,
+ ax, primitive = plotfunc(xval, yval, zval, ax=ax,
cmap=cmap_params['cmap'],
vmin=cmap_params['vmin'],
vmax=cmap_params['vmax'],
@@ -518,16 +580,16 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None,
# For use as DataArray.plot.plotmethod
@functools.wraps(newplotfunc)
- def plotmethod(_PlotMethods_obj, ax=None, xincrease=None, yincrease=None,
+ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, xincrease=None, yincrease=None,
add_colorbar=True, add_labels=True, vmin=None, vmax=None, cmap=None,
colors=None, center=None, robust=False, extend=None, levels=None,
**kwargs):
- '''
+ """
The method should have the same signature as the function.
This just makes the method work on Plotmethods objects,
and passes all the other arguments straight through.
- '''
+ """
allargs = locals()
allargs['darray'] = _PlotMethods_obj._da
allargs.update(kwargs)
diff --git a/xray/test/test_plot.py b/xray/test/test_plot.py
index a6c56e34867..c533b8e9bba 100644
--- a/xray/test/test_plot.py
+++ b/xray/test/test_plot.py
@@ -21,6 +21,36 @@
pass
+def text_in_fig():
+ '''
+ Return the set of all text in the figure
+ '''
+ alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
+ # Set comprehension not compatible with Python 2.6
+ return set(alltxt)
+
+
+def substring_in_axes(substring, ax):
+ '''
+ Return True if a substring is found anywhere in an axes
+ '''
+ alltxt = set([t.get_text() for t in ax.findobj(mpl.text.Text)])
+ for txt in alltxt:
+ if substring in txt:
+ return True
+ return False
+
+
+def easy_array(shape, start=0, stop=1):
+ '''
+ Make an array with desired shape using np.linspace
+
+ shape is a tuple like (2, 3)
+ '''
+ a = np.linspace(start, stop, num=np.prod(shape))
+ return a.reshape(shape)
+
+
@requires_matplotlib
class PlotTestCase(TestCase):
@@ -47,13 +77,13 @@ def contourf_called(self, plotmethod):
class TestPlot(PlotTestCase):
def setUp(self):
- self.darray = DataArray(np.random.randn(2, 3, 4))
+ self.darray = DataArray(easy_array((2, 3, 4)))
def test1d(self):
self.darray[:, 0, 0].plot()
def test_2d_before_squeeze(self):
- a = DataArray(np.arange(5).reshape(1, 5))
+ a = DataArray(easy_array((1, 5)))
a.plot()
def test2d_uniform_calls_imshow(self):
@@ -98,7 +128,7 @@ def test_ylabel_is_data_name(self):
self.assertEqual(self.darray.name, plt.gca().get_ylabel())
def test_wrong_dims_raises_valueerror(self):
- twodims = DataArray(np.arange(10).reshape(2, 5))
+ twodims = DataArray(easy_array((2, 5)))
with self.assertRaises(ValueError):
twodims.plot.line()
@@ -138,7 +168,7 @@ def test_slice_in_title(self):
class TestPlotHistogram(PlotTestCase):
def setUp(self):
- self.darray = DataArray(np.random.randn(2, 3, 4))
+ self.darray = DataArray(easy_array((2, 3, 4)))
def test_3d_array(self):
self.darray.plot.hist()
@@ -148,7 +178,7 @@ def test_title_no_name(self):
self.assertEqual('', plt.gca().get_title())
def test_title_uses_name(self):
- self.darray.name = 'randompoints'
+ self.darray.name = 'testpoints'
self.darray.plot.hist()
self.assertIn(self.darray.name, plt.gca().get_title())
@@ -175,19 +205,21 @@ def test_plot_nans(self):
@requires_matplotlib
class TestDetermineCmapParams(TestCase):
+
+ def setUp(self):
+ self.data = np.linspace(0, 1, num=100)
+
def test_robust(self):
- data = np.random.RandomState(1).rand(100)
- cmap_params = _determine_cmap_params(data, robust=True)
- self.assertEqual(cmap_params['vmin'], np.percentile(data, 2))
- self.assertEqual(cmap_params['vmax'], np.percentile(data, 98))
+ cmap_params = _determine_cmap_params(self.data, robust=True)
+ self.assertEqual(cmap_params['vmin'], np.percentile(self.data, 2))
+ self.assertEqual(cmap_params['vmax'], np.percentile(self.data, 98))
self.assertEqual(cmap_params['cmap'].name, 'viridis')
self.assertEqual(cmap_params['extend'], 'both')
self.assertIsNone(cmap_params['levels'])
self.assertIsNone(cmap_params['cnorm'])
def test_center(self):
- data = np.random.RandomState(2).rand(100)
- cmap_params = _determine_cmap_params(data, center=0.5)
+ cmap_params = _determine_cmap_params(self.data, center=0.5)
self.assertEqual(cmap_params['vmax'] - 0.5, 0.5 - cmap_params['vmin'])
self.assertEqual(cmap_params['cmap'], 'RdBu_r')
self.assertEqual(cmap_params['extend'], 'neither')
@@ -195,7 +227,7 @@ def test_center(self):
self.assertIsNone(cmap_params['cnorm'])
def test_integer_levels(self):
- data = 1 + np.random.RandomState(3).rand(100)
+ data = self.data + 1
cmap_params = _determine_cmap_params(data, levels=5, vmin=0, vmax=5,
cmap='Blues')
self.assertEqual(cmap_params['vmin'], cmap_params['levels'][0])
@@ -211,7 +243,7 @@ def test_integer_levels(self):
self.assertEqual(cmap_params['extend'], 'max')
def test_list_levels(self):
- data = 1 + np.random.RandomState(3).rand(100)
+ data = self.data + 1
orig_levels = [0, 1, 2, 3, 4, 5]
# vmin and vmax should be ignored if levels are explicitly provided
@@ -230,6 +262,7 @@ def test_list_levels(self):
@requires_matplotlib
class TestDiscreteColorMap(TestCase):
+
def setUp(self):
x = np.arange(start=0, stop=10, step=2)
y = np.arange(start=9, stop=-7, step=-3)
@@ -314,9 +347,10 @@ class Common2dMixin:
These tests assume that a staticmethod for `self.plotfunc` exists.
Should have the same name as the method.
"""
+
def setUp(self):
- rs = np.random.RandomState(123)
- self.darray = DataArray(rs.randn(10, 15), dims=['y', 'x'])
+ self.darray = DataArray(easy_array(
+ (10, 15), start=-1), dims=['y', 'x'])
self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)
def test_label_names(self):
@@ -329,12 +363,12 @@ def test_1d_raises_valueerror(self):
self.plotfunc(self.darray[0, :])
def test_3d_raises_valueerror(self):
- a = DataArray(np.random.randn(2, 3, 4))
+ a = DataArray(easy_array((2, 3, 4)))
with self.assertRaisesRegexp(ValueError, r'[Dd]im'):
self.plotfunc(a)
def test_nonnumeric_index_raises_typeerror(self):
- a = DataArray(np.random.randn(3, 2),
+ a = DataArray(easy_array((3, 2)),
coords=[['a', 'b', 'c'], ['d', 'e']])
with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
self.plotfunc(a)
@@ -380,7 +414,7 @@ def test_seaborn_palette_as_cmap(self):
try:
import seaborn
cmap_name = self.plotmethod(
- levels=2, cmap='husl').get_cmap().name
+ levels=2, cmap='husl').get_cmap().name
self.assertEqual('husl', cmap_name)
except ImportError:
pass
@@ -394,8 +428,34 @@ def test_diverging_color_limits(self):
vmin, vmax = artist.get_clim()
self.assertAlmostEqual(-vmin, vmax)
+ def test_xy_strings(self):
+ self.plotmethod('y', 'x')
+ ax = plt.gca()
+ self.assertEqual('y', ax.get_xlabel())
+ self.assertEqual('x', ax.get_ylabel())
+
+ def test_positional_x_string(self):
+ self.plotmethod('y')
+ ax = plt.gca()
+ self.assertEqual('y', ax.get_xlabel())
+ self.assertEqual('x', ax.get_ylabel())
+
+ def test_y_string(self):
+ self.plotmethod(y='x')
+ ax = plt.gca()
+ self.assertEqual('y', ax.get_xlabel())
+ self.assertEqual('x', ax.get_ylabel())
+
+ def test_bad_x_string_exception(self):
+ with self.assertRaisesRegexp(KeyError, r'y'):
+ self.plotmethod('not_a_real_dim')
+
+ self.darray.coords['z'] = 100
+ with self.assertRaisesRegexp(KeyError, r'y'):
+ self.plotmethod('z')
+
def test_default_title(self):
- a = DataArray(np.random.randn(4, 3, 2), dims=['a', 'b', 'c'])
+ a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c'])
a.coords['d'] = 10
self.plotfunc(a.isel(c=1))
title = plt.gca().get_title()
@@ -404,19 +464,24 @@ def test_default_title(self):
def test_colorbar_label(self):
self.darray.name = 'testvar'
self.plotmethod()
- alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
- # Set comprehension not compatible with Python 2.6
- alltxt = set(alltxt)
- self.assertIn(self.darray.name, alltxt)
+ self.assertIn(self.darray.name, text_in_fig())
def test_no_labels(self):
self.darray.name = 'testvar'
self.plotmethod(add_labels=False)
- alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
- alltxt = set(alltxt)
+ alltxt = text_in_fig()
for string in ['x', 'y', 'testvar']:
self.assertNotIn(string, alltxt)
+ def test_facetgrid(self):
+ a = easy_array((10, 15, 3))
+ d = DataArray(a, dims=['y', 'x', 'z'])
+ g = xplt.FacetGrid(d, col='z')
+ g.map_dataarray(self.plotfunc, 'x', 'y')
+ for ax in g.axes.flat:
+ self.assertTrue(ax.has_data())
+
+
class TestContourf(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.contourf)
@@ -434,9 +499,13 @@ def test_extend(self):
artist = self.plotmethod()
self.assertEqual(artist.extend, 'neither')
+ self.darray[0, 0] = -100
+ self.darray[-1, -1] = 100
artist = self.plotmethod(robust=True)
self.assertEqual(artist.extend, 'both')
+ self.darray[0, 0] = 0
+ self.darray[-1, -1] = 0
artist = self.plotmethod(vmin=-0, vmax=10)
self.assertEqual(artist.extend, 'min')
@@ -462,21 +531,22 @@ def _color_as_tuple(c):
return tuple(c[:3])
artist = self.plotmethod(colors='k')
self.assertEqual(
- _color_as_tuple(artist.cmap.colors[0]),
- (0.0,0.0,0.0))
+ _color_as_tuple(artist.cmap.colors[0]),
+ (0.0, 0.0, 0.0))
- artist = self.plotmethod(colors=['k','b'])
+ artist = self.plotmethod(colors=['k', 'b'])
self.assertEqual(
- _color_as_tuple(artist.cmap.colors[1]),
- (0.0,0.0,1.0))
+ _color_as_tuple(artist.cmap.colors[1]),
+ (0.0, 0.0, 1.0))
def test_cmap_and_color_both(self):
- with self.assertRaises(ValueError):
+ with self.assertRaises(ValueError):
self.plotmethod(colors='k', cmap='RdBu')
def list_of_colors_in_cmap_deprecated(self):
- with self.assertRaises(DeprecationError):
- self.plotmethod(cmap=['k','b'])
+ with self.assertRaises(Exception):
+ self.plotmethod(cmap=['k', 'b'])
+
class TestPcolormesh(Common2dMixin, PlotTestCase):
@@ -524,3 +594,176 @@ def test_seaborn_palette_needs_levels(self):
self.plotmethod(cmap='husl')
except ImportError:
pass
+
+
+class TestFacetGrid(PlotTestCase):
+
+ def setUp(self):
+ d = easy_array((10, 15, 3))
+ self.darray = DataArray(d, dims=['y', 'x', 'z'])
+ self.g = xplt.FacetGrid(self.darray, col='z')
+
+ def test_no_args(self):
+ self.g.map_dataarray(xplt.contourf, 'x', 'y')
+
+ # Don't want colorbar labeled with 'None'
+ alltxt = text_in_fig()
+ self.assertNotIn('None', alltxt)
+
+ for ax in self.g.axes.flat:
+ self.assertTrue(ax.has_data())
+
+ # default font size should be small
+ fontsize = ax.title.get_size()
+ self.assertLessEqual(fontsize, 12)
+
+ def test_names_appear_somewhere(self):
+ self.darray.name = 'testvar'
+ self.g.map_dataarray(xplt.contourf, 'x', 'y')
+ for i, ax in enumerate(self.g.axes.flat):
+ self.assertEqual('z = {0}'.format(i), ax.get_title())
+
+ alltxt = text_in_fig()
+ self.assertIn(self.darray.name, alltxt)
+ for label in ['x', 'y']:
+ self.assertIn(label, alltxt)
+
+ def test_text_not_super_long(self):
+ self.darray.coords['z'] = [100 * letter for letter in 'abc']
+ g = xplt.FacetGrid(self.darray, col='z')
+ g.map_dataarray(xplt.contour, 'x', 'y')
+ alltxt = text_in_fig()
+ maxlen = max(len(txt) for txt in alltxt)
+ self.assertLess(maxlen, 50)
+
+ t0 = g.axes[0, 0].get_title()
+ self.assertTrue(t0.endswith('...'))
+
+ def test_colorbar(self):
+ vmin = self.darray.values.min()
+ vmax = self.darray.values.max()
+ expected = np.array((vmin, vmax))
+
+ self.g.map_dataarray(xplt.imshow, 'x', 'y')
+
+ for image in plt.gcf().findobj(mpl.image.AxesImage):
+ clim = np.array(image.get_clim())
+ self.assertTrue(np.allclose(expected, clim))
+
+ # There's only one colorbar
+ cbar = plt.gcf().findobj(mpl.collections.QuadMesh)
+ self.assertEqual(1, len(cbar))
+
+ def test_empty_cell(self):
+ g = xplt.FacetGrid(self.darray, col='z', col_wrap=2)
+ g.map_dataarray(xplt.imshow, 'x', 'y')
+
+ bottomright = g.axes[-1, -1]
+ self.assertFalse(bottomright.has_data())
+ self.assertFalse(bottomright.get_visible())
+
+ def test_norow_nocol_error(self):
+ with self.assertRaisesRegexp(ValueError, r'[Rr]ow'):
+ xplt.FacetGrid(self.darray)
+
+ def test_groups(self):
+ self.g.map_dataarray(xplt.imshow, 'x', 'y')
+ upperleft_dict = self.g.name_dicts[0, 0]
+ upperleft_array = self.darray[upperleft_dict]
+ z0 = self.darray.isel(z=0)
+
+ self.assertDataArrayEqual(upperleft_array, z0)
+
+ def test_float_index(self):
+ self.darray.coords['z'] = [0.1, 0.2, 0.4]
+ g = xplt.FacetGrid(self.darray, col='z')
+ g.map_dataarray(xplt.imshow, 'x', 'y')
+
+ def test_nonunique_index_error(self):
+ self.darray.coords['z'] = [0.1, 0.2, 0.2]
+ with self.assertRaisesRegexp(ValueError, r'[Uu]nique'):
+ xplt.FacetGrid(self.darray, col='z')
+
+ def test_robust(self):
+ z = np.zeros((20, 20, 2))
+ darray = DataArray(z, dims=['y', 'x', 'z'])
+ darray[:, :, 1] = 1
+ darray[2, 0, 0] = -1000
+ darray[3, 0, 0] = 1000
+ g = xplt.FacetGrid(darray, col='z')
+ g.map_dataarray(xplt.imshow, 'x', 'y', robust=True)
+
+ # Color limits should be 0, 1
+ # The largest number displayed in the figure should be less than 21
+ numbers = set()
+ alltxt = text_in_fig()
+ for txt in alltxt:
+ try:
+ numbers.add(float(txt))
+ except ValueError:
+ pass
+ largest = max(abs(x) for x in numbers)
+ self.assertLess(largest, 21)
+
+ def test_can_set_vmin_vmax(self):
+ vmin, vmax = 50.0, 1000.0
+ expected = np.array((vmin, vmax))
+ self.g.map_dataarray(xplt.imshow, 'x', 'y', vmin=vmin, vmax=vmax)
+
+ for image in plt.gcf().findobj(mpl.image.AxesImage):
+ clim = np.array(image.get_clim())
+ self.assertTrue(np.allclose(expected, clim))
+
+ def test_figure_size(self):
+
+ self.assertArrayEqual(self.g.fig.get_size_inches(), (10, 3))
+
+ g = xplt.FacetGrid(self.darray, col='z', size=6)
+ self.assertArrayEqual(g.fig.get_size_inches(), (19, 6))
+
+ g = xplt.FacetGrid(self.darray, col='z', size=4, aspect=0.5)
+ self.assertArrayEqual(g.fig.get_size_inches(), (7, 4))
+
+ def test_num_ticks(self):
+ nticks = 100
+ maxticks = nticks + 1
+ self.g.map_dataarray(xplt.imshow, 'x', 'y')
+ self.g.set_ticks(max_xticks=nticks, max_yticks=nticks)
+
+ for ax in self.g.axes.flat:
+ xticks = len(ax.get_xticks())
+ yticks = len(ax.get_yticks())
+ self.assertLessEqual(xticks, maxticks)
+ self.assertLessEqual(yticks, maxticks)
+ self.assertGreaterEqual(xticks, nticks / 2.0)
+ self.assertGreaterEqual(yticks, nticks / 2.0)
+
+ def test_map(self):
+ self.g.map(plt.contourf, 'x', 'y', Ellipsis)
+
+
+class TestFacetGrid4d(PlotTestCase):
+
+ def setUp(self):
+ a = easy_array((10, 15, 3, 2))
+ darray = DataArray(a, dims=['y', 'x', 'col', 'row'])
+ darray.coords['col'] = np.array(['col' + str(x) for x in
+ darray.coords['col'].values])
+ darray.coords['row'] = np.array(['row' + str(x) for x in
+ darray.coords['row'].values])
+
+ self.darray = darray
+
+ def test_default_labels(self):
+ g = xplt.FacetGrid(self.darray, col='col', row='row')
+ self.assertEqual((2, 3), g.axes.shape)
+
+ g.map_dataarray(xplt.imshow, 'x', 'y')
+
+ # Rightmost column should be labeled
+ for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]):
+ self.assertTrue(substring_in_axes(label, ax))
+
+ # Top row should be labeled
+ for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]):
+ self.assertTrue(substring_in_axes(label, ax))