Skip to content

Commit

Permalink
Normalisation for RGB imshow (#1819)
Browse files Browse the repository at this point in the history
* Normalisation for RGB imshow

* Add test for error checking
  • Loading branch information
Zac-HD authored and Joe Hamman committed Jan 19, 2018
1 parent f3deb2f commit 6aa225f
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/plotting.rst
Expand Up @@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius:
The Celsius data contain 0, so a diverging color map was used. The
Kelvins do not have 0, so the default color map was used.

.. _robust-plotting:

Robust
~~~~~~

Expand Down
1 change: 1 addition & 0 deletions doc/whats-new.rst
Expand Up @@ -41,6 +41,7 @@ Enhancements
- Support for using `Zarr`_ as storage layer for xarray.
By `Ryan Abernathey <https://github.com/rabernat>`_.
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- Experimental support for parsing ENVI metadata to coordinates and attributes
in :py:func:`xarray.open_rasterio`.
Expand Down
47 changes: 45 additions & 2 deletions xarray/plot/plot.py
Expand Up @@ -15,8 +15,8 @@
import pandas as pd
from datetime import datetime

from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis,
import_matplotlib_pyplot)
from .utils import (ROBUST_PERCENTILE, _determine_cmap_params,
_infer_xy_labels, get_axis, import_matplotlib_pyplot)
from .facetgrid import FacetGrid
from xarray.core.pycompat import basestring

Expand Down Expand Up @@ -326,6 +326,39 @@ def line(self, *args, **kwargs):
return line(self._da, *args, **kwargs)


def _rescale_imshow_rgb(darray, vmin, vmax, robust):
assert robust or vmin is not None or vmax is not None
# There's a cyclic dependency via DataArray, so we can't import from
# xarray.ufuncs in global scope.
from xarray.ufuncs import maximum, minimum
# Calculate vmin and vmax automatically for `robust=True`
if robust:
if vmax is None:
vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE)
if vmin is None:
vmin = np.nanpercentile(darray, ROBUST_PERCENTILE)
# If not robust and one bound is None, calculate the default other bound
# and check that an interval between them exists.
elif vmax is None:
vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1
if vmax < vmin:
raise ValueError(
'vmin=%r is less than the default vmax (%r) - you must supply '
'a vmax > vmin in this case.' % (vmin, vmax))
elif vmin is None:
vmin = 0
if vmin > vmax:
raise ValueError(
'vmax=%r is less than the default vmin (0) - you must supply '
'a vmin < vmax in this case.' % vmax)
# Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float
# to avoid precision loss, integer over/underflow, etc with extreme inputs.
# After scaling, downcast to 32-bit float. This substantially reduces
# memory usage after we hand `darray` off to matplotlib.
darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4')
return minimum(maximum(darray, 0), 1)


def _plot2d(plotfunc):
"""
Decorator for common 2d plotting logic
Expand Down Expand Up @@ -449,6 +482,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
if imshow_rgb:
# Don't add a colorbar when showing an image with explicit colors
add_colorbar = False
# Matplotlib does not support normalising RGB data, so do it here.
# See eg. https://github.com/matplotlib/matplotlib/pull/10220
if robust or vmax is not None or vmin is not None:
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
vmin, vmax, robust = None, None, False

# Handle facetgrids first
if row or col:
Expand Down Expand Up @@ -625,6 +663,11 @@ def imshow(x, y, z, ax, **kwargs):
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.
Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA
data, by applying a single scaling factor and offset to all bands.
Passing ``robust=True`` infers ``vmin`` and ``vmax``
:ref:`in the usual way <robust-plotting>`.
.. note::
This function needs uniformly spaced coordinates to
properly label the axes. Call DataArray.plot() to check.
Expand Down
4 changes: 3 additions & 1 deletion xarray/plot/utils.py
Expand Up @@ -11,6 +11,9 @@
from ..core.utils import is_scalar


ROBUST_PERCENTILE = 2.0


def _load_default_cmap(fname='default_colormap.csv'):
"""
Returns viridis color map
Expand Down Expand Up @@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap_params : dict
Use depends on the type of the plotting function
"""
ROBUST_PERCENTILE = 2.0
import matplotlib as mpl

calc_data = np.ravel(plot_data[~pd.isnull(plot_data)])
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/test_plot.py
Expand Up @@ -1126,6 +1126,26 @@ def test_rgb_errors_bad_dim_sizes(self):
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_normalize_rgb_imshow(self):
for kwds in (
dict(vmin=-1), dict(vmax=2),
dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0),
dict(vmin=0, robust=True), dict(vmax=-1, robust=True),
):
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
arr = da.plot.imshow(**kwds).get_array()
assert 0 <= arr.min() <= arr.max() <= 1, kwds

def test_normalize_rgb_one_arg_error(self):
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
# If passed one bound that implies all out of range, error:
for kwds in [dict(vmax=-1), dict(vmin=2)]:
with pytest.raises(ValueError):
da.plot.imshow(**kwds)
# If passed two that's just moving the range, *not* an error:
for kwds in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]:
da.plot.imshow(**kwds)


class TestFacetGrid(PlotTestCase):
def setUp(self):
Expand Down

0 comments on commit 6aa225f

Please sign in to comment.