Skip to content

Commit e57d238

Browse files
committed
Fix bugs with pint quantity input, column iteration
1 parent f297017 commit e57d238

1 file changed

Lines changed: 20 additions & 10 deletions

File tree

proplot/axes/plot.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,16 @@ def _load_objects():
11531153

11541154

11551155
# Standardization utilities
1156+
def _is_array(data):
1157+
"""
1158+
Test whether input is numpy array or pint quantity.
1159+
"""
1160+
# NOTE: This is used in _iter_columns to identify 2D matrices that
1161+
# should be iterated over and omit e.g. scalar marker size or marker color.
1162+
_load_objects()
1163+
return isinstance(data, ndarray) or ndarray is not Quantity and isinstance(data, Quantity) # noqa: E501
1164+
1165+
11561166
def _is_numeric(data):
11571167
"""
11581168
Test whether input is numeric array rather than datetime or strings.
@@ -1222,6 +1232,7 @@ def _safe_mask(mask, *args):
12221232
Safely apply the mask to the input arrays, accounting for existing masked
12231233
or invalid values. Values matching ``False`` are set to `np.nan`.
12241234
"""
1235+
_load_objects()
12251236
invalid = ~mask # True if invalid
12261237
args_masked = []
12271238
for arg in args:
@@ -1248,6 +1259,7 @@ def _safe_range(data, lo=0, hi=100, automin=True, automax=True):
12481259
for masked values. Use min and max functions when possible for speed. Return
12491260
``None`` if we faile to get a valid range.
12501261
"""
1262+
_load_objects()
12511263
units = 1
12521264
if ndarray is not Quantity and isinstance(data, Quantity):
12531265
data, units = data.magnitude, data.units
@@ -1369,10 +1381,10 @@ def _get_labels(data, axis=0, always=True):
13691381
# data values metadata but that is incorrect. The paradigm for 1D plots
13701382
# is we have row coordinates representing x, data values representing y,
13711383
# and column coordinates representing individual series.
1372-
if axis not in (0, 1, 2):
1373-
raise ValueError(f'Invalid axis {axis}.')
13741384
labels = None
13751385
_load_objects()
1386+
if axis not in (0, 1, 2):
1387+
raise ValueError(f'Invalid axis {axis}.')
13761388
if isinstance(data, (ndarray, Quantity)):
13771389
if not always:
13781390
pass
@@ -1446,6 +1458,7 @@ def _get_units(data):
14461458
Get the unit string from the `xarray.DataArray` attributes or the
14471459
`pint.Quantity`. Format the latter with :rcraw:`unitformat`.
14481460
"""
1461+
_load_objects()
14491462
# Get units from the attributes
14501463
if ndarray is not DataArray and isinstance(data, DataArray):
14511464
units = data.attrs.get('units', None)
@@ -1859,6 +1872,7 @@ def _redirect_or_standardize(self, *args, **kwargs):
18591872
kwargs[key] = _get_data(data, kwargs[key])
18601873

18611874
# Auto-setup matplotlib with the input unit registry
1875+
_load_objects()
18621876
for arg in args:
18631877
if ndarray is not DataArray and isinstance(arg, DataArray):
18641878
arg = arg.data
@@ -2939,10 +2953,9 @@ def _iter_columns(self, *args, label=None, labels=None, values=None, **kwargs):
29392953
keyword arguments using the input label-list ``'labels'``.
29402954
"""
29412955
# Handle cycle args and label lists
2942-
# WARNING: Must convert to ndarray or can get singleton DataArrays
2943-
# WARNING: We do not handle color cycling here because we want to allow
2944-
# iterating over columns of scatter() color arrays. Handle in _parse_cycle().
2945-
n = max(1 if a.ndim < 2 else a.shape[1] for a in args if isinstance(a, ndarray))
2956+
# NOTE: Arrays here should have had metadata stripped by _standardize_1d
2957+
# but could still be pint quantities that get processed by axis converter.
2958+
n = max(1 if not _is_array(a) or a.ndim < 2 else a.shape[-1] for a in args)
29462959
labels = _not_none(label=label, values=values, labels=labels)
29472960
if not np.iterable(labels) or isinstance(labels, str):
29482961
labels = n * [labels]
@@ -2957,10 +2970,7 @@ def _iter_columns(self, *args, label=None, labels=None, values=None, **kwargs):
29572970
for i in range(n):
29582971
kw = kwargs.copy()
29592972
kw['label'] = labels[i] or None
2960-
a = tuple(
2961-
a if not isinstance(a, ndarray) or a.ndim == 1 else a[:, i]
2962-
for a in args
2963-
)
2973+
a = tuple(a if not _is_array(a) or a.ndim < 2 else a[..., i] for a in args)
29642974
yield (i, n, *a, kw)
29652975

29662976
def _parse_cycle(

0 commit comments

Comments
 (0)