Skip to content

Commit

Permalink
Merge pull request #9808 from sinhrks/mask_values
Browse files Browse the repository at this point in the history
ENH: NDFrame.mask supports same kwds as where
  • Loading branch information
jreback committed Apr 4, 2015
2 parents 75fce78 + 2ccf1cb commit 30dd866
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 25 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.16.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Enhancements

- Added ``StringMethods.capitalize()`` and ``swapcase`` which behave as the same as standard ``str`` (:issue:`9766`)


- ``DataFrame.mask()`` and ``Series.mask()`` now support same keywords as ``where`` (:issue:`8801`)



Expand Down
35 changes: 14 additions & 21 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3250,16 +3250,14 @@ def _align_series(self, other, join='outer', axis=None, level=None,
return (left_result.__finalize__(self),
right_result.__finalize__(other))

def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
try_cast=False, raise_on_error=True):
"""
_shared_docs['where'] = ("""
Return an object of same shape as self and whose corresponding
entries are from self where cond is True and otherwise are from other.
entries are from self where cond is %(cond)s and otherwise are from other.
Parameters
----------
cond : boolean NDFrame or array
other : scalar or NDFrame
cond : boolean %(klass)s or array
other : scalar or %(klass)s
inplace : boolean, default False
Whether to perform the operation in place on the data
axis : alignment axis if needed, default None
Expand All @@ -3273,7 +3271,11 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
Returns
-------
wh : same type as caller
"""
""")
@Appender(_shared_docs['where'] % dict(_shared_doc_kwargs, cond="True"))
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
try_cast=False, raise_on_error=True):

if isinstance(cond, NDFrame):
cond = cond.reindex(**self._construct_axes_dict())
else:
Expand Down Expand Up @@ -3400,20 +3402,11 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,

return self._constructor(new_data).__finalize__(self)

def mask(self, cond):
"""
Returns copy whose values are replaced with nan if the
inverted condition is True
Parameters
----------
cond : boolean NDFrame or array
Returns
-------
wh: same as input
"""
return self.where(~cond, np.nan)
@Appender(_shared_docs['where'] % dict(_shared_doc_kwargs, cond="False"))
def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None,
try_cast=False, raise_on_error=True):
return self.where(~cond, other=other, inplace=inplace, axis=axis,
level=level, try_cast=try_cast, raise_on_error=raise_on_error)

def shift(self, periods=1, freq=None, axis=0, **kwargs):
"""
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9775,6 +9775,27 @@ def test_mask(self):
assert_frame_equal(rs, df.mask(df <= 0))
assert_frame_equal(rs, df.mask(~cond))

other = DataFrame(np.random.randn(5, 3))
rs = df.where(cond, other)
assert_frame_equal(rs, df.mask(df <= 0, other))
assert_frame_equal(rs, df.mask(~cond, other))

def test_mask_inplace(self):
# GH8801
df = DataFrame(np.random.randn(5, 3))
cond = df > 0

rdf = df.copy()

rdf.where(cond, inplace=True)
assert_frame_equal(rdf, df.where(cond))
assert_frame_equal(rdf, df.mask(~cond))

rdf = df.copy()
rdf.where(cond, -df, inplace=True)
assert_frame_equal(rdf, df.where(cond, -df))
assert_frame_equal(rdf, df.mask(~cond, -df))

def test_mask_edge_case_1xN_frame(self):
# GH4071
df = DataFrame([[1, 2]])
Expand Down
68 changes: 65 additions & 3 deletions pandas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,10 @@ def test_where_broadcast(self):
for i, use_item in enumerate(selection)])
assert_series_equal(s, expected)

s = Series(data)
result = s.where(~selection, arr)
assert_series_equal(result, expected)

def test_where_inplace(self):
s = Series(np.random.randn(5))
cond = s > 0
Expand Down Expand Up @@ -1856,11 +1860,69 @@ def test_where_dups(self):
assert_series_equal(comb, expected)

def test_mask(self):
# compare with tested results in test_where
s = Series(np.random.randn(5))
cond = s > 0

rs = s.where(~cond, np.nan)
assert_series_equal(rs, s.mask(cond))

rs = s.where(~cond)
rs2 = s.mask(cond)
assert_series_equal(rs, rs2)

rs = s.where(~cond, -s)
rs2 = s.mask(cond, -s)
assert_series_equal(rs, rs2)

cond = Series([True, False, False, True, False], index=s.index)
s2 = -(s.abs())
rs = s2.where(~cond[:3])
rs2 = s2.mask(cond[:3])
assert_series_equal(rs, rs2)

rs = s2.where(~cond[:3], -s2)
rs2 = s2.mask(cond[:3], -s2)
assert_series_equal(rs, rs2)

self.assertRaises(ValueError, s.mask, 1)
self.assertRaises(ValueError, s.mask, cond[:3].values, -s)

# dtype changes
s = Series([1,2,3,4])
result = s.mask(s>2, np.nan)
expected = Series([1, 2, np.nan, np.nan])
assert_series_equal(result, expected)

def test_mask_broadcast(self):
# GH 8801
# copied from test_where_broadcast
for size in range(2, 6):
for selection in [np.resize([True, False, False, False, False], size), # First element should be set
# Set alternating elements]
np.resize([True, False], size),
np.resize([False], size)]: # No element should be set
for item in [2.0, np.nan, np.finfo(np.float).max, np.finfo(np.float).min]:
for arr in [np.array([item]), [item], (item,)]:
data = np.arange(size, dtype=float)
s = Series(data)
result = s.mask(selection, arr)
expected = Series([item if use_item else data[i]
for i, use_item in enumerate(selection)])
assert_series_equal(result, expected)

def test_mask_inplace(self):
s = Series(np.random.randn(5))
cond = s > 0

rs = s.where(cond, np.nan)
assert_series_equal(rs, s.mask(~cond))
rs = s.copy()
rs.mask(cond, inplace=True)
assert_series_equal(rs.dropna(), s[~cond])
assert_series_equal(rs, s.mask(cond))

rs = s.copy()
rs.mask(cond, -s, inplace=True)
assert_series_equal(rs, s.mask(cond, -s))

def test_drop(self):

Expand Down Expand Up @@ -6845,7 +6907,7 @@ def test_repeat(self):
def test_unique_data_ownership(self):
# it works! #1807
Series(Series(["a", "c", "b"]).unique()).sort()

def test_datetime_timedelta_quantiles(self):
# covers #9694
self.assertTrue(pd.isnull(Series([],dtype='M8[ns]').quantile(.5)))
Expand Down

0 comments on commit 30dd866

Please sign in to comment.