Skip to content

Commit

Permalink
Allow passing axis kwargs to plot (#4020)
Browse files Browse the repository at this point in the history
* fix facecolor plot

* temp version

* finish fix facecolor + solves #3169

* black formatting

* add testing

* allow cartopy projection to be a kwarg

* fix PEP8 comment

* black formatting

* fix testing, plt not in parameterize

* fix testing, allows for no matplotlib

* black formating

* fix tests without matplotlib

* fix some mistakes

* isort, mypy

* fix mypy

* remove empty line

* correction from review

* correction from 2nd review

* updated tests

* updated tests

* black formatting

* follow up correction from review

* fix tests

* fix tests again

* fix bug in tests

* fix pb in tests

* remove useless line

* clean up tests

* fix

* Add whats-new

Co-authored-by: dcherian <deepak@cherian.net>
  • Loading branch information
raphaeldussin and dcherian committed Jul 2, 2020
1 parent 329cefb commit 834d4c4
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 14 deletions.
9 changes: 5 additions & 4 deletions doc/plotting.rst
Expand Up @@ -743,12 +743,13 @@ This script will plot the air temperature on a map.
air = xr.tutorial.open_dataset("air_temperature").air
ax = plt.axes(projection=ccrs.Orthographic(-80, 35))
air.isel(time=0).plot.contourf(ax=ax, transform=ccrs.PlateCarree())
ax.set_global()
p = air.isel(time=0).plot(
subplot_kws=dict(projection=ccrs.Orthographic(-80, 35), facecolor="gray"),
transform=ccrs.PlateCarree())
p.axes.set_global()
@savefig plotting_maps_cartopy.png width=100%
ax.coastlines()
p.axes.coastlines()
When faceting on maps, the projection can be transferred to the ``plot``
function using the ``subplot_kws`` keyword. The axes for the subplots created
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Expand Up @@ -54,6 +54,9 @@ Enhancements
By `Stephan Hoyer <https://github.com/shoyer>`_.
- :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep
coordinate attributes (:pull:`4103`). By `Oriol Abril <https://github.com/OriolAbril>`_.
- Axes kwargs such as ``facecolor`` can now be passed to :py:meth:`DataArray.plot` in ``subplot_kws``.
This works for both single axes plots and FacetGrid plots.
By `Raphael Dussin <https://github.com/raphaeldussin`_.

New Features
~~~~~~~~~~~~
Expand Down
15 changes: 9 additions & 6 deletions xarray/plot/plot.py
Expand Up @@ -155,8 +155,7 @@ def plot(
Relative tolerance used to determine if the indexes
are uniformly spaced. Usually a small positive number.
subplot_kws : dict, optional
Dictionary of keyword arguments for matplotlib subplots. Only applies
to FacetGrid plotting.
Dictionary of keyword arguments for matplotlib subplots.
**kwargs : optional
Additional keyword arguments to matplotlib
Expand All @@ -177,10 +176,10 @@ def plot(

if ndims in [1, 2]:
if row or col:
kwargs["subplot_kws"] = subplot_kws
kwargs["row"] = row
kwargs["col"] = col
kwargs["col_wrap"] = col_wrap
kwargs["subplot_kws"] = subplot_kws
if ndims == 1:
plotfunc = line
kwargs["hue"] = hue
Expand All @@ -190,6 +189,7 @@ def plot(
kwargs["hue"] = hue
else:
plotfunc = pcolormesh
kwargs["subplot_kws"] = subplot_kws
else:
if row or col or hue:
raise ValueError(error_msg)
Expand Down Expand Up @@ -553,8 +553,8 @@ def _plot2d(plotfunc):
always infer intervals, unless the mesh is irregular and plotted on
a map projection.
subplot_kws : dict, optional
Dictionary of keyword arguments for matplotlib subplots. Only applies
to FacetGrid plotting.
Dictionary of keyword arguments for matplotlib subplots. Only used
for 2D and FacetGrid plots.
cbar_ax : matplotlib Axes, optional
Axes in which to draw the colorbar.
cbar_kwargs : dict, optional
Expand Down Expand Up @@ -724,7 +724,10 @@ def newplotfunc(
"plt.imshow's `aspect` kwarg is not available " "in xarray"
)

ax = get_axis(figsize, size, aspect, ax)
if subplot_kws is None:
subplot_kws = dict()
ax = get_axis(figsize, size, aspect, ax, **subplot_kws)

primitive = plotfunc(
xplt,
yplt,
Expand Down
14 changes: 10 additions & 4 deletions xarray/plot/utils.py
Expand Up @@ -406,9 +406,12 @@ def _assert_valid_xy(darray, xy, name):
raise ValueError(f"{name} must be one of None, '{valid_xy_str}'")


def get_axis(figsize, size, aspect, ax):
import matplotlib as mpl
import matplotlib.pyplot as plt
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib is required for plot.utils.get_axis")

if figsize is not None:
if ax is not None:
Expand All @@ -427,8 +430,11 @@ def get_axis(figsize, size, aspect, ax):
elif aspect is not None:
raise ValueError("cannot provide `aspect` argument without `size`")

if kwargs and ax is not None:
raise ValueError("cannot use subplot_kws with existing ax")

if ax is None:
ax = plt.gca()
ax = plt.gca(**kwargs)

return ax

Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Expand Up @@ -77,6 +77,7 @@ def LooseVersion(vstring):
has_numbagg, requires_numbagg = _importorskip("numbagg")
has_seaborn, requires_seaborn = _importorskip("seaborn")
has_sparse, requires_sparse = _importorskip("sparse")
has_cartopy, requires_cartopy = _importorskip("cartopy")

# some special cases
has_scipy_or_netCDF4 = has_scipy or has_netCDF4
Expand Down
40 changes: 40 additions & 0 deletions xarray/tests/test_plot.py
Expand Up @@ -15,6 +15,7 @@
_build_discrete_cmap,
_color_palette,
_determine_cmap_params,
get_axis,
label_from_attrs,
)

Expand All @@ -23,6 +24,7 @@
assert_equal,
has_nc_time_axis,
raises_regex,
requires_cartopy,
requires_cftime,
requires_matplotlib,
requires_nc_time_axis,
Expand All @@ -36,6 +38,11 @@
except ImportError:
pass

try:
import cartopy as ctpy # type: ignore
except ImportError:
ctpy = None


@pytest.mark.flaky
@pytest.mark.skip(reason="maybe flaky")
Expand Down Expand Up @@ -2393,3 +2400,36 @@ def test_facetgrid_single_contour():
ds["time"] = [0, 1]

ds.plot.contour(col="time", levels=[4], colors=["k"])


@requires_matplotlib
def test_get_axis():
# test get_axis works with different args combinations
# and return the right type

# cannot provide both ax and figsize
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something")

# cannot provide both ax and size
with pytest.raises(ValueError, match="both `size` and `ax`"):
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something")

# cannot provide both size and figsize
with pytest.raises(ValueError, match="both `figsize` and `size`"):
get_axis(figsize=[4, 4], size=200, aspect=None, ax=None)

# cannot provide aspect and size
with pytest.raises(ValueError, match="`aspect` argument without `size`"):
get_axis(figsize=None, size=None, aspect=4 / 3, ax=None)

ax = get_axis()
assert isinstance(ax, mpl.axes.Axes)


@requires_cartopy
def test_get_axis_cartopy():

kwargs = {"projection": ctpy.crs.PlateCarree()}
ax = get_axis(**kwargs)
assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot)

0 comments on commit 834d4c4

Please sign in to comment.