Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 57 additions & 38 deletions proplot/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def default_crs(self, func, *args, crs=None, **kwargs):
except TypeError as err: # duplicate keyword args, i.e. crs is positional
if not args:
raise err
result = func(self, *args[:-1], crs=args[-1], **kwargs)
args, crs = args[:-1], args[-1]
result = func(self, *args, crs=crs, **kwargs)
# Fix extent, so axes tight bounding box gets correct box!
# From this issue:
# https://github.com/SciTools/cartopy/issues/1207#issuecomment-439975083
Expand All @@ -217,15 +218,18 @@ def default_crs(self, func, *args, crs=None, **kwargs):
return result


def _standard_label(data, axis=None, units=True):
def _axis_labels_title(data, axis=None, units=True):
"""
Get data and label for pandas or xarray objects or their coordinates.
Get data and label for pandas or xarray objects or their coordinates
along axis `axis`. If `units` is ``True`` also look for units on xarray
data arrays.
"""
label = ''
_load_objects()
if isinstance(data, ndarray):
if axis is not None and data.ndim > axis:
data = np.arange(data.shape[axis])

# Xarray with common NetCDF attribute names
elif isinstance(data, DataArray):
if axis is not None and data.ndim > axis:
Expand All @@ -239,6 +243,7 @@ def _standard_label(data, axis=None, units=True):
label = f'{label} ({units})'
elif units:
label = units

# Pandas object with name attribute
# if not label and isinstance(data, DataFrame) and data.columns.size == 1:
elif isinstance(data, (DataFrame, Series, Index)):
Expand All @@ -251,6 +256,7 @@ def _standard_label(data, axis=None, units=True):
# DataFrame has no native name attribute but user can add one:
# https://github.com/pandas-dev/pandas/issues/447
label = getattr(data, 'name', '') or ''

return data, str(label).strip()


Expand Down Expand Up @@ -307,7 +313,7 @@ def standardize_1d(self, func, *args, **kwargs):
if x is None:
axis = 1 if (name in ('hist', 'boxplot', 'violinplot') or any(
kwargs.get(s, None) for s in ('means', 'medians'))) else 0
x, _ = _standard_label(y, axis=axis)
x, _ = _axis_labels_title(y, axis=axis)
x = _to_array(x)
if x.ndim != 1:
raise ValueError(
Expand All @@ -332,20 +338,21 @@ def standardize_1d(self, func, *args, **kwargs):
kwargs['positions'] = xi
if name in ('boxplot', 'violinplot'):
kwargs['positions'] = xi

# Next handle labels if 'autoformat' is on
if self.figure._auto_format:
# Ylabel
y, label = _standard_label(y)
if label:
# for histogram, this indicates x coordinate
y, label = _axis_labels_title(y)
if label: # for histogram, this label is used for *x* coordinates
iaxis = xax if name in ('hist',) else yax
kw[iaxis + 'label'] = label
# Xlabel
x, label = _standard_label(x)
x, label = _axis_labels_title(x)
if label and name not in ('hist',):
kw[xax + 'label'] = label
if name != 'scatter' and len(x) > 1 and xi is None and x[1] < x[0]:
kw[xax + 'reverse'] = True

# Appply
if kw:
self.format(**kw)
Expand Down Expand Up @@ -550,7 +557,7 @@ def standardize_2d(self, func, *args, order='C', globe=False, **kwargs):
# Handle labels if 'autoformat' is on
if self.figure._auto_format:
for key, xy in zip(('xlabel', 'ylabel'), (x, y)):
_, label = _standard_label(xy)
_, label = _axis_labels_title(xy)
if label:
kw[key] = label
if len(xy) > 1 and all(isinstance(xy, Number)
Expand All @@ -562,8 +569,8 @@ def standardize_2d(self, func, *args, order='C', globe=False, **kwargs):
y = yi
# Handle figure titles
if self.figure._auto_format:
_, colorbar_label = _standard_label(Zs[0], units=True)
_, title = _standard_label(Zs[0], units=False)
_, colorbar_label = _axis_labels_title(Zs[0], units=True)
_, title = _axis_labels_title(Zs[0], units=False)
if title:
kw['title'] = title
if kw:
Expand Down Expand Up @@ -1789,7 +1796,7 @@ def cycle_changer(

# Plot susccessive columns
objs = []
label_leg = None # for colorbar or legend
label_leg_cbar = None # for colorbar or legend
for i in range(ncols):
# Prop cycle properties
kw = kwargs.copy()
Expand All @@ -1806,59 +1813,70 @@ def cycle_changer(
kw[key] = value

# Get x coordinates
ix, iy = x, ys[0] # samples
x_col, y_first = x, ys[0] # samples
if name in ('pie',):
kw['labels'] = _not_none(labels, ix) # TODO: move to pie wrapper?
kw['labels'] = _not_none(labels, x_col) # TODO: move to pie wrapper?
if name in ('bar',): # adjust
if not stacked:
ix = x + (i - ncols / 2 + 0.5) * width / ncols
elif stacked and iy.ndim > 1:
x_col = x + (i - ncols / 2 + 0.5) * width / ncols
elif stacked and y_first.ndim > 1:
key = 'x' if barh else 'bottom'
kw[key] = _to_indexer(iy)[:, :i].sum(axis=1)
kw[key] = _to_indexer(y_first)[:, :i].sum(axis=1)

# Get y coordinates and labels
if name in ('pie', 'boxplot', 'violinplot'):
iys = (iy,) # only ever have one y value, cannot have legend labs
# Only ever have one y value, cannot have legend labs
ys_col = (y_first,)

else:
# The coordinates
# WARNING: If stacked=True then we always *ignore* second
# argument passed to fill_between. Warning should be issued
# by fill_between_wrapper in this case.
if stacked and 'fill_between' in name:
iys = tuple(
iy if iy.ndim == 1 else _to_indexer(iy)[:, :ii].sum(axis=1)
ys_col = tuple(
y_first if y_first.ndim == 1
else _to_indexer(y_first)[:, :ii].sum(axis=1)
for ii in (i, i + 1)
)
else:
iys = tuple(
iy if iy.ndim == 1 else _to_indexer(iy)[:, i]
for iy in ys
ys_col = tuple(
y_i if y_i.ndim == 1 else _to_indexer(y_i)[:, i]
for y_i in ys
)

# Possible legend labels
# Several scenarios:
# 1. Always prefer input labels
# 2. Always add labels if this is a *named* dimension.
# 3. Even if not *named* dimension add labels if labels are string
if len(labels) != ncols:
raise ValueError(
f'Got {ncols} columns in data array, '
f'but {len(labels)} labels.'
)
label = labels[i]
values, label_leg = _standard_label(iy, axis=1)
if label_leg and label is None:
label = _to_ndarray(values)[i]
label = labels[i] # input labels
labels_cols, label_leg_cbar = _axis_labels_title(y_first, axis=1)
labels_cols = _to_ndarray(labels_cols)
if label is None and (
label_leg_cbar or labels_cols.size and isinstance(labels_cols[i], str)
):
label = labels_cols[i]
if label is not None:
kw['label'] = label

# Build coordinate arguments
xy = ()
x_ys_col = ()
if barh: # special, use kwargs only!
kw.update({'bottom': ix, 'width': iys[0]})
kw.update({'bottom': x_col, 'width': ys_col[0]})
kw.setdefault('x', kwargs.get('bottom', 0)) # required
elif name in ('pie', 'hist', 'boxplot', 'violinplot'):
xy = (*iys,)
x_ys_col = ys_col
else: # has x-coordinates, and maybe more than one y
xy = (ix, *iys)
x_ys_col = (x_col, *ys_col)

# Call plotting function
obj = func(self, *xy, *args, **kw)
obj = func(self, *x_ys_col, *args, **kw)
if isinstance(obj, (list, tuple)) and len(obj) == 1:
obj = obj[0]
objs.append(obj)
Expand All @@ -1873,8 +1891,8 @@ def cycle_changer(
# Add keywords
if loc != 'fill':
colorbar_kw.setdefault('loc', loc)
if label_leg:
colorbar_kw.setdefault('label', label_leg)
if label_leg_cbar:
colorbar_kw.setdefault('label', label_leg_cbar)
self._auto_colorbar[loc][1].update(colorbar_kw)

# Add legend
Expand All @@ -1887,8 +1905,8 @@ def cycle_changer(
# Add keywords
if loc != 'fill':
legend_kw.setdefault('loc', loc)
if label_leg:
legend_kw.setdefault('label', label_leg)
if label_leg_cbar:
legend_kw.setdefault('label', label_leg_cbar)
self._auto_legend[loc][1].update(legend_kw)

# Return
Expand Down Expand Up @@ -2270,6 +2288,7 @@ def cmap_changer(
colorbar_kw = colorbar_kw or {}

# Flexible user input
Z_sample = args[-1]
vmin = _not_none(vmin=vmin, norm_kw_vmin=norm_kw.pop('vmin', None))
vmax = _not_none(vmax=vmax, norm_kw_vmax=norm_kw.pop('vmax', None))
values = _not_none(values=values, centers=centers)
Expand Down Expand Up @@ -2349,7 +2368,7 @@ def cmap_changer(
ticks = None
if cmap is not None and name not in ('hexbin',):
norm, cmap, levels, ticks = _build_discrete_norm(
args[-1], # sample data for getting suitable levels
Z_sample, # sample data for getting suitable levels
levels=levels, values=values,
norm=norm, norm_kw=norm_kw,
locator=locator, locator_kw=locator_kw,
Expand Down Expand Up @@ -2471,7 +2490,7 @@ def cmap_changer(
if colorbar:
loc = self._loc_translate(colorbar, 'colorbar', allow_manual=False)
if 'label' not in colorbar_kw and self.figure._auto_format:
_, label = _standard_label(args[-1]) # last one is data, we assume
_, label = _axis_labels_title(Z_sample) # last one is data, we assume
if label:
colorbar_kw.setdefault('label', label)
if name in ('parametric',) and values is not None:
Expand Down