Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG pd.NA not treated correctly in where and mask operations #53124

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
89c0f3d
make NA propagate where and mask operations
Charlie-XIAO May 6, 2023
321147f
changelog added
Charlie-XIAO May 6, 2023
36bbe16
fix when using boolean arrays
Charlie-XIAO May 7, 2023
e2216cb
added tests, reword NA propagates -> if cond=NA then element propagates
Charlie-XIAO May 8, 2023
5a45a29
Merge branch 'main' into na-masked-unexp
Charlie-XIAO May 8, 2023
9875669
avoid multiple fillna when unnecessary
Charlie-XIAO May 8, 2023
8381aba
Merge branch 'main' into na-masked-unexp
Charlie-XIAO May 19, 2023
5a41560
Merge branch 'main' into na-masked-unexp
Charlie-XIAO Jun 4, 2023
8af09df
Merge branch 'main' into na-masked-unexp
Charlie-XIAO Jun 11, 2023
3859bff
Merge branch 'main' into na-masked-unexp
Charlie-XIAO Jun 12, 2023
c1d43c8
Merge remote-tracking branch 'upstream/main' into na-masked-unexp
Charlie-XIAO Jul 14, 2023
c542727
Merge branch 'na-masked-unexp' of https://github.com/Charlie-XIAO/pan…
Charlie-XIAO Jul 14, 2023
8140c5b
Merge branch 'main' into na-masked-unexp
Charlie-XIAO Jul 16, 2023
a2151be
Merge branch 'main' into na-masked-unexp
Charlie-XIAO Aug 1, 2023
6f90c1c
Merge remote-tracking branch 'upstream/main' into na-masked-unexp
Charlie-XIAO Aug 29, 2023
394d4bb
Merge branch 'na-masked-unexp' of https://github.com/Charlie-XIAO/pan…
Charlie-XIAO Aug 29, 2023
1cc6208
Merge remote-tracking branch 'upstream/main' into na-masked-unexp
Charlie-XIAO Aug 29, 2023
09f62bc
raise in where and mask if cond is nullable bool with NAs
Charlie-XIAO Aug 29, 2023
b55f411
Merge remote-tracking branch 'upstream/main' into na-masked-unexp
Charlie-XIAO Aug 29, 2023
cbbd866
remove conflicting (?) test and improve message
Charlie-XIAO Aug 30, 2023
3a34a85
Merge remote-tracking branch 'upstream/main' into na-masked-unexp
Charlie-XIAO Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ Interval

Indexing
^^^^^^^^
-
- Bug in :meth:`DataFrame.where`, :meth:`DataFrame.mask`, :meth:`Series.where`, and :meth:`Series.mask`, when ``cond`` for an element is ``pd.NA``; the corresponding element now propagates through (:issue:`52955`)
-

Missing
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9869,6 +9869,8 @@ def _where(
# align the cond to same shape as myself
cond = common.apply_if_callable(cond, self)
if isinstance(cond, NDFrame):
# GH #52955: if cond is NA, element propagates in mask and where
cond = cond.fillna(True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has the option of just raising on NAs been discussed? seems ambiguous and a general PITA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are saying raising in where and mask, no we haven't discussed yet. If you are saying raising in _where, I think this is not desired since then, the following will not work:

>>> df = pd.DataFrame(np.random.random((3, 3)), dtype=pd.Float64Dtype())
>>> df[0][0] = pd.NA
>>> df
          0         1         2
0      <NA>  0.609241  0.419094
1  0.274784  0.342904  0.026101
2  0.670259  0.218889  0.177126
>>> df[df >= 0.5] = 0  # This will raise an error, which I assume is undesired
>>> df
          0         1         2
0      <NA>       0.0  0.419094
1  0.274784  0.342904  0.026101
2       0.0  0.218889  0.177126

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would just have that raise too, yes.

Copy link
Contributor Author

@Charlie-XIAO Charlie-XIAO May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel I think the above code snippet actually works for versions v2.0.x, do we really want to change its behavior? @topper-123 I think we may need further discussion about the desired behavior of _where, i.e., propagate or raise. I will postpone the rewording mentioned in #53124 (comment) until maintainers reach an agreement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should accept BooleanArrays (and Series/DataFrame containing BooleanArrays/ArrowArray[bool]) as conditional here. I think it will be surprising if those data structure work in loc and not here.

Do similar functionality raise in any other methods? I don't recall any.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jbrockmendel any updates on this?

# CoW: Make sure reference is not kept alive
if cond.ndim == 1 and self.ndim == 2:
cond = cond._constructor_expanddim(
Expand All @@ -9883,6 +9885,8 @@ def _where(
if cond.shape != self.shape:
raise ValueError("Array conditional must be same shape as self")
cond = self._constructor(cond, **self._construct_axes_dict(), copy=False)
# GH #52955: if cond is NA, element propagates in mask and where
cond = cond.fillna(True)

Charlie-XIAO marked this conversation as resolved.
Show resolved Hide resolved
# make sure we are boolean
fill_value = bool(inplace)
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/frame/indexing/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NA,
DataFrame,
Float64Dtype,
Int64Dtype,
Series,
StringDtype,
Timedelta,
Expand Down Expand Up @@ -150,3 +151,16 @@ def test_mask_inplace_no_other():
df.mask(cond, inplace=True)
expected = DataFrame({"a": [np.nan, 2], "b": ["x", np.nan]})
tm.assert_frame_equal(df, expected)


def test_mask_with_na():
# See GH #52955, if cond is NA, propagate in mask
df = DataFrame([[1, NA], [NA, 2]], dtype=Int64Dtype())

result1 = df.mask(df % 2 == 1, 0)
expected1 = DataFrame([[0, NA], [NA, 2]], dtype=Int64Dtype())
tm.assert_frame_equal(result1, expected1)

result2 = df.mask(df[0] % 2 == 1, 0)
expected2 = DataFrame([[0, 0], [NA, 2]], dtype=Int64Dtype())
tm.assert_frame_equal(result2, expected2)
14 changes: 14 additions & 0 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DataFrame,
DatetimeIndex,
Index,
Int64Dtype,
Series,
StringDtype,
Timestamp,
Expand Down Expand Up @@ -1032,3 +1033,16 @@ def test_where_inplace_no_other():
df.where(cond, inplace=True)
expected = DataFrame({"a": [1, np.nan], "b": [np.nan, "y"]})
tm.assert_frame_equal(df, expected)


def test_where_with_na():
# See GH #52955, if cond is NA, propagate in where
df = DataFrame([[1, pd.NA], [pd.NA, 2]], dtype=Int64Dtype())

result1 = df.where(df % 2 == 1, 0)
expected1 = DataFrame([[1, pd.NA], [pd.NA, 0]], dtype=Int64Dtype())
tm.assert_frame_equal(result1, expected1)

result2 = df.where(df[0] % 2 == 1, 0)
expected2 = DataFrame([[1, pd.NA], [pd.NA, 2]], dtype=Int64Dtype())
tm.assert_frame_equal(result2, expected2)
17 changes: 16 additions & 1 deletion pandas/tests/series/indexing/test_mask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import pytest

from pandas import Series
from pandas import (
NA,
Int64Dtype,
Series,
)
import pandas._testing as tm


Expand Down Expand Up @@ -67,3 +71,14 @@ def test_mask_inplace():
rs = s.copy()
rs.mask(cond, -s, inplace=True)
tm.assert_series_equal(rs, s.mask(cond, -s))


def test_mask_with_na():
# See GH #52955, if cond is NA, propagate in mask
s = Series([1, 2, NA], dtype=Int64Dtype())
res1 = s.mask(s % 2 == 1, 0)
res2 = s.mask(s.array % 2 == 1, 0)

exp = Series([0, 2, NA], dtype=Int64Dtype())
tm.assert_series_equal(res1, exp)
tm.assert_series_equal(res2, exp)
13 changes: 13 additions & 0 deletions pandas/tests/series/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pandas as pd
from pandas import (
NA,
Int64Dtype,
Series,
Timestamp,
date_range,
Expand Down Expand Up @@ -464,3 +466,14 @@ def test_where_datetimelike_categorical(tz_naive_fixture):
res = pd.DataFrame(lvals).where(mask[:, None], pd.DataFrame(rvals))

tm.assert_frame_equal(res, pd.DataFrame(dr))


def test_where_with_na():
# See GH #52955, if cond is NA, propagate in where
s = Series([1, 2, NA], dtype=Int64Dtype())
res1 = s.where(s % 2 == 1, 0)
res2 = s.where(s.array % 2 == 1, 0)

exp = Series([1, 0, NA], dtype=Int64Dtype())
tm.assert_series_equal(res1, exp)
tm.assert_series_equal(res2, exp)