Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG/Perf: Support ExtensionArrays in where #24114

Merged
merged 26 commits into from
Dec 10, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c4604df
API: Added ExtensionArray.where
TomAugspurger Dec 3, 2018
56470c3
Fixups:
TomAugspurger Dec 5, 2018
6f79282
32-bit compat
TomAugspurger Dec 5, 2018
a69dbb3
warn for categorical
TomAugspurger Dec 5, 2018
911a2da
debug 32-bit issue
TomAugspurger Dec 5, 2018
badb5be
compat, revert
TomAugspurger Dec 6, 2018
edff47e
32-bit compat
TomAugspurger Dec 6, 2018
4715ef6
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 6, 2018
d90f384
deprecation note for categorical
TomAugspurger Dec 6, 2018
5e14414
where versionadded
TomAugspurger Dec 6, 2018
e9665b8
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
033ac9c
Setitem-based where
TomAugspurger Dec 7, 2018
1271d3d
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
9e0d87d
update docs, cleanup
TomAugspurger Dec 7, 2018
e05a597
wip
TomAugspurger Dec 7, 2018
796332c
cleanup
TomAugspurger Dec 7, 2018
cad0c4c
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
6edd286
py2 compat
TomAugspurger Dec 7, 2018
30775f0
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 7, 2018
4de8bb5
Updated
TomAugspurger Dec 7, 2018
ce04a75
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 9, 2018
f98a82c
Clarify
TomAugspurger Dec 9, 2018
bcfb8f8
Merge remote-tracking branch 'upstream/master' into ea-where
TomAugspurger Dec 10, 2018
8d9b20b
Simplify error message
TomAugspurger Dec 10, 2018
c0351fd
sparse whatsnew
TomAugspurger Dec 10, 2018
539d3cb
updates
TomAugspurger Dec 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
- Slicing a single row of a ``DataFrame`` with multiple ExtensionArrays of the same type now preserves the dtype, rather than coercing to object (:issue:`22784`)
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
- Added :meth:`pandas.api.extensions.ExtensionArray.where` (:issue:`24077`)
- Bug when concatenating multiple ``Series`` with different extension dtypes not casting to object dtype (:issue:`22994`)
- Series backed by an ``ExtensionArray`` now work with :func:`util.hash_pandas_object` (:issue:`23066`)
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
Expand Down Expand Up @@ -1236,6 +1237,7 @@ Performance Improvements
- Improved performance of :meth:`DatetimeIndex.normalize` and :meth:`Timestamp.normalize` for timezone naive or UTC datetimes (:issue:`23634`)
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)

.. _whatsnew_0240.docs:

Expand All @@ -1262,6 +1264,7 @@ Categorical
- In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`)
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
- Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved

Datetimelike
^^^^^^^^^^^^
Expand Down
35 changes: 35 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ExtensionArray(object):
* unique
* factorize / _values_for_factorize
* argsort / _values_for_argsort
* where

The remaining methods implemented on this class should be performant,
as they only compose abstract methods. Still, a more efficient
Expand Down Expand Up @@ -661,6 +662,40 @@ def take(self, indices, allow_fill=False, fill_value=None):
# pandas.api.extensions.take
raise AbstractMethodError(self)

def where(self, cond, other):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other implementations of where (DataFrame.where, Index.where, etc.) have other default to NA. Do we want to maintain that convention here too?

"""
Replace values where the condition is False.

Parameters
----------
cond : ndarray or ExtensionArray
The mask indicating which values should be kept (True)
or replaced from `other` (False).

other : ndarray, ExtensionArray, or scalar
Entries where `cond` is False are replaced with
corresponding value from `other`.

Notes
-----
Note that `cond` and `other` *cannot* be a Series, Index, or callable.
When used from, e.g., :meth:`Series.where`, pandas will unbox
Series and Indexes, and will apply callables before they arrive here.

Returns
-------
ExtensionArray
Same dtype as the original.

See Also
--------
Series.where : Similar method for Series.
DataFrame.where : Similar method for DataFrame.
"""
return type(self)._from_sequence(np.where(cond, self, other),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this turns it into an array. we have much special handling for this (e.g. see .where for DTI). i think this needs to dispatch somehow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see you override things. ok then.

dtype=self.dtype,
copy=False)

def copy(self, deep=False):
# type: (bool) -> ExtensionArray
"""
Expand Down
28 changes: 28 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,34 @@ def take_nd(self, indexer, allow_fill=None, fill_value=None):

take = take_nd

def where(self, cond, other):
# n.b. this now preserves the type
codes = self._codes

if is_scalar(other) and isna(other):
other = -1
elif is_scalar(other):
item = self.categories.get_indexer([other]).item()

if item == -1:
raise ValueError("The value '{}' is not present in "
"this Categorical's categories".format(other))
other = item

elif is_categorical_dtype(other):
if not is_dtype_equal(self, other):
raise TypeError("The type of 'other' does not match.")
other = _get_codes_for_values(other, self.categories)
# get the codes from other that match our categories
pass
else:
other = np.where(isna(other), -1, other)

new_codes = np.where(cond, codes, other)
return type(self).from_codes(new_codes,
categories=self.categories,
ordered=self.ordered)

def _slice(self, slicer):
"""
Return a slice of myself.
Expand Down
12 changes: 12 additions & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,18 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None,

return self._shallow_copy(left_take, right_take)

def where(self, cond, other):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have IntervalIndex use this implementation instead of the naive object array based implementation that it currently uses. Can certainly leave that for a follow-up PR though, and I'd be happy to do it.

if is_scalar(other) and isna(other):
lother = other
rother = other
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this an elif that checks that other is interval-like (something like isinstance(other, Interval) or is_interval_dtype(other)), then have an else clause that raises a ValueError saying other must be interval-like?

As written I think this would raise a somewhat unclear AttributeError in self._check_closed_matches since it assumes other.closed exists.

self._check_closed_matches(other, name='other')
lother = other.left
rother = other.right
left = np.where(cond, self.left, lother)
right = np.where(cond, self.right, rother)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left/right should have a where method, so might be a bit safer to do something like:

left = self.left.where(cond, lother)
right = self.right.where(cond, rother)

np.where looks like it can cause some problems depending on what left/right are:

In [2]: left = pd.date_range('2018', periods=3); left
Out[2]: DatetimeIndex(['2018-01-01', '2018-01-02', '2018-01-03'], dtype='datetime64[ns]', freq='D')

In [3]: np.where([True, False, True], left, pd.NaT)
Out[3]: array([1514764800000000000, NaT, 1514937600000000000], dtype=object)

return self._shallow_copy(left, right)

def value_counts(self, dropna=True):
"""
Returns a Series containing counts of each interval.
Expand Down
22 changes: 22 additions & 0 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from pandas._libs import lib
from pandas._libs.tslibs import NaT, iNaT, period as libperiod
from pandas._libs.tslibs.fields import isleapyear_arr
from pandas._libs.tslibs.period import (
Expand Down Expand Up @@ -241,6 +242,11 @@ def _generate_range(cls, start, end, periods, freq, fields):

return subarr, freq

def _check_compatible_with(self, other):
if self.freqstr != other.freqstr:
msg = DIFFERENT_FREQ_INDEX.format(self.freqstr, other.freqstr)
raise IncompatibleFrequency(msg)

# --------------------------------------------------------------------
# Data / Attributes

Expand Down Expand Up @@ -341,6 +347,22 @@ def to_timestamp(self, freq=None, how='start'):
# --------------------------------------------------------------------
# Array-like / EA-Interface Methods

def where(self, cond, other):
# TODO(DatetimeArray): move to DatetimeLikeArrayMixin
# n.b. _ndarray_values candidate.
i8 = self.asi8
if lib.is_scalar(other):
if isna(other):
other = iNaT
elif isinstance(other, Period):
self._check_compatible_with(other)
other = other.ordinal
elif isinstance(other, type(self)):
self._check_compatible_with(other)
other = other.asi8
result = np.where(cond, i8, other)
return type(self)._simple_new(result, dtype=self.dtype)

def _formatter(self, boxed=False):
if boxed:
return str
Expand Down
14 changes: 14 additions & 0 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,20 @@ def take(self, indices, allow_fill=False, fill_value=None):
return type(self)(result, fill_value=self.fill_value, kind=self.kind,
**kwargs)

def where(self, cond, other):
if is_scalar(other):
result_dtype = np.result_type(self.dtype.subtype, other)
elif isinstance(other, type(self)):
result_dtype = np.result_type(self.dtype.subtype,
other.dtype.subtype)
else:
result_dtype = np.result_type(self.dtype.subtype, other.dtype)

dtype = self.dtype.update_dtype(result_dtype)
# TODO: avoid converting to dense.
values = np.where(cond, self, other)
return type(self)(values, dtype=dtype)

def _take_with_fill(self, indices, fill_value=None):
if fill_value is None:
fill_value = self.dtype.na_value
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class _DtypeOpsMixin(object):
na_value = np.nan
_metadata = ()

@property
def _ndarray_na_value(self):
"""Private method internal to pandas"""
raise AbstractMethodError(self)

def __eq__(self, other):
"""Check whether 'other' is equal to self.

Expand Down
6 changes: 1 addition & 5 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,7 @@ def _can_reindex(self, indexer):

@Appender(_index_shared_docs['where'])
def where(self, cond, other=None):
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)

cat = Categorical(values, dtype=self.dtype)
cat = self.values.where(cond, other=other)
return self._shallow_copy(cat, **self._get_attributes_dict())

def reindex(self, target, method=None, level=None, limit=None,
Expand Down
27 changes: 26 additions & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from pandas.core.dtypes.dtypes import (
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
from pandas.core.dtypes.generic import (
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
ABCSeries)
from pandas.core.dtypes.missing import (
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)

Expand Down Expand Up @@ -1967,6 +1968,30 @@ def shift(self, periods, axis=0):
placement=self.mgr_locs,
ndim=self.ndim)]

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
if isinstance(other, (ABCIndexClass, ABCSeries)):
other = other.array

if isinstance(cond, ABCDataFrame):
assert cond.shape[1] == 1
cond = cond.iloc[:, 0].array

if isinstance(other, ABCDataFrame):
assert other.shape[1] == 1
other = other.iloc[:, 0].array

TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(cond, (ABCIndexClass, ABCSeries)):
cond = cond.array

if lib.is_scalar(other) and isna(other):
# The default `other` for Series / Frame is np.nan
# we want to replace that with the correct NA value
# for the type
other = self.dtype.na_value
result = self.values.where(cond, other)
return self.make_block_same_class(result, placement=self.mgr_locs)

@property
def _ftype(self):
return getattr(self.values, '_pandas_ftype', Block._ftype)
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/arrays/categorical/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,32 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
tm.assert_numpy_array_equal(expected, result)
tm.assert_numpy_array_equal(exp_miss, res_miss)

def test_where_raises(self):
arr = Categorical(['a', 'b', 'c'])
with pytest.raises(ValueError, match="The value 'd'"):
arr.where([True, False, True], 'd')

def test_where_unobserved_categories(self):
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
result = arr.where([True, True, False], other='b')
expected = Categorical(['a', 'b', 'b'], categories=arr.categories)
tm.assert_categorical_equal(result, expected)

def test_where_other_categorical(self):
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
result = arr.where([True, False, True], other)
expected = Categorical(['a', 'c', 'c'], dtype=arr.dtype)
tm.assert_categorical_equal(result, expected)

def test_where_ordered_differs_rasies(self):
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
ordered=True)
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
ordered=True)
with pytest.raises(TypeError, match="The type of"):
arr.where([True, False, True], other)


@pytest.mark.parametrize("index", [True, False])
def test_mask_with_boolean(index):
Expand Down
12 changes: 11 additions & 1 deletion pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest

from pandas import Index, IntervalIndex, date_range, timedelta_range
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
from pandas.core.arrays import IntervalArray
import pandas.util.testing as tm

Expand Down Expand Up @@ -50,6 +50,16 @@ def test_set_closed(self, closed, new_closed):
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize('other', [
Interval(0, 1, closed='right'),
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
])
def test_where_raises(self, other):
arr = IntervalArray.from_breaks([1, 2, 3, 4], closed='left')
match = "'other.closed' is 'right', expected 'left'."
with pytest.raises(ValueError, match=match):
arr.where([True, False, True], other=other)


class TestSetitem(object):

Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/arrays/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,21 @@ def test_sub_period():
arr - other


# ----------------------------------------------------------------------------
# Methods

@pytest.mark.parametrize('other', [
pd.Period('2000', freq='H'),
period_array(['2000', '2001', '2000'], freq='H')
])
def test_where_different_freq_raises(other):
arr = period_array(['2000', '2001', '2002'], freq='D')
cond = np.array([True, False, True])
with pytest.raises(IncompatibleFrequency,
match="Input has different freq=H"):
arr.where(cond, other)


# ----------------------------------------------------------------------------
# Printing

Expand Down
37 changes: 37 additions & 0 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,40 @@ def test_hash_pandas_object_works(self, data, as_frame):
a = pd.util.hash_pandas_object(data)
b = pd.util.hash_pandas_object(data)
self.assert_equal(a, b)

@pytest.mark.parametrize("as_frame", [True, False])
def test_where_series(self, data, na_value, as_frame):
assert data[0] != data[1]
cls = type(data)
a, b = data[:2]

ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype))
cond = np.array([True, True, False, False])

if as_frame:
ser = ser.to_frame(name='a')
# TODO: alignment is broken for ndarray `cond`
cond = pd.DataFrame({"a": cond})

result = ser.where(cond)
expected = pd.Series(cls._from_sequence([a, a, na_value, na_value],
dtype=data.dtype))

if as_frame:
expected = expected.to_frame(name='a')
self.assert_equal(result, expected)

# array other
cond = np.array([True, False, True, True])
other = cls._from_sequence([a, b, a, b], dtype=data.dtype)
if as_frame:
# TODO: alignment is broken for ndarray `cond`
other = pd.DataFrame({"a": other})
# TODO: alignment is broken for array `other`
cond = pd.DataFrame({"a": cond})
result = ser.where(cond, other)
expected = pd.Series(cls._from_sequence([a, b, b, b],
dtype=data.dtype))
if as_frame:
expected = expected.to_frame(name='a')
self.assert_equal(result, expected)
7 changes: 7 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ def test_combine_add(self, data_repeated):
def test_hash_pandas_object_works(self, data, kind):
super().test_hash_pandas_object_works(data, kind)

@pytest.mark.skip(reason="broadcasting error")
def test_where_series(self, data, na_value):
# Fails with
# *** ValueError: operands could not be broadcast together
# with shapes (4,) (4,) (0,)
super().test_where_series(data, na_value)


class TestCasting(BaseJSON, base.BaseCastingTests):
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")
Expand Down
Loading