Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEPR: try_cast kwarg in mask, where #38836

Merged
merged 4 commits into from
Dec 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Deprecations
- Deprecating allowing scalars passed to the :class:`Categorical` constructor (:issue:`38433`)
- Deprecated allowing subclass-specific keyword arguments in the :class:`Index` constructor, use the specific subclass directly instead (:issue:`14093`,:issue:`21311`,:issue:`22315`,:issue:`26974`)
- Deprecated ``astype`` of datetimelike (``timedelta64[ns]``, ``datetime64[ns]``, ``Datetime64TZDtype``, ``PeriodDtype``) to integer dtypes, use ``values.view(...)`` instead (:issue:`38544`)
-
- Deprecated keyword ``try_cast`` in :meth:`Series.where`, :meth:`Series.mask`, :meth:`DataFrame.where`, :meth:`DataFrame.mask`; cast results manually if desired (:issue:`38836`)
-

.. ---------------------------------------------------------------------------
Expand Down
33 changes: 24 additions & 9 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8781,7 +8781,6 @@ def _where(
axis=None,
level=None,
errors="raise",
try_cast=False,
):
"""
Equivalent to public method `where`, except that `other` is not
Expand Down Expand Up @@ -8932,7 +8931,6 @@ def _where(
cond=cond,
align=align,
errors=errors,
try_cast=try_cast,
axis=block_axis,
)
result = self._constructor(new_data)
Expand All @@ -8954,7 +8952,7 @@ def where(
axis=None,
level=None,
errors="raise",
try_cast=False,
try_cast=lib.no_default,
):
"""
Replace values where the condition is {cond_rev}.
Expand Down Expand Up @@ -8986,9 +8984,12 @@ def where(
- 'raise' : allow exceptions to be raised.
- 'ignore' : suppress exceptions. On error return original object.

try_cast : bool, default False
try_cast : bool, default None
Try to cast the result back to the input type (if possible).

.. deprecated:: 1.3.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default should be None now

Manually cast back if necessary.

Returns
-------
Same type as caller or None if ``inplace=True``.
Expand Down Expand Up @@ -9077,9 +9078,16 @@ def where(
4 True True
"""
other = com.apply_if_callable(other, self)
return self._where(
cond, other, inplace, axis, level, errors=errors, try_cast=try_cast
)

if try_cast is not lib.no_default:
warnings.warn(
"try_cast keyword is deprecated and will be removed in a "
"future version",
FutureWarning,
stacklevel=2,
)

return self._where(cond, other, inplace, axis, level, errors=errors)

@final
@doc(
Expand All @@ -9098,12 +9106,20 @@ def mask(
axis=None,
level=None,
errors="raise",
try_cast=False,
try_cast=lib.no_default,
):

inplace = validate_bool_kwarg(inplace, "inplace")
cond = com.apply_if_callable(cond, self)

if try_cast is not lib.no_default:
warnings.warn(
"try_cast keyword is deprecated and will be removed in a "
"future version",
FutureWarning,
stacklevel=2,
)

# see gh-21891
if not hasattr(cond, "__invert__"):
cond = np.array(cond)
Expand All @@ -9114,7 +9130,6 @@ def mask(
inplace=inplace,
axis=axis,
level=level,
try_cast=try_cast,
errors=errors,
)

Expand Down
21 changes: 5 additions & 16 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,9 +1290,7 @@ def _maybe_reshape_where_args(self, values, other, cond, axis):

return other, cond

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
"""
evaluate the block; return result block(s) from the result

Expand All @@ -1303,7 +1301,6 @@ def where(
errors : str, {'raise', 'ignore'}, default 'raise'
- ``raise`` : allow exceptions to be raised
- ``ignore`` : suppress exceptions. On error return original object
try_cast: bool, default False
axis : int, default 0

Returns
Expand Down Expand Up @@ -1342,9 +1339,7 @@ def where(
# we cannot coerce, return a compat dtype
# we are explicitly ignoring errors
block = self.coerce_to_target_dtype(other)
blocks = block.where(
orig_other, cond, errors=errors, try_cast=try_cast, axis=axis
)
blocks = block.where(orig_other, cond, errors=errors, axis=axis)
return self._maybe_downcast(blocks, "infer")

if not (
Expand Down Expand Up @@ -1825,9 +1820,7 @@ def shift(
)
]

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:

cond = _extract_bool_array(cond)
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))
Expand Down Expand Up @@ -2075,9 +2068,7 @@ def to_native_types(self, na_rep="NaT", **kwargs):
result = arr._format_native_types(na_rep=na_rep, **kwargs)
return self.make_block(result)

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
def where(self, other, cond, errors="raise", axis: int = 0) -> List["Block"]:
# TODO(EA2D): reshape unnecessary with 2D EAs
arr = self.array_values().reshape(self.shape)

Expand All @@ -2086,9 +2077,7 @@ def where(
try:
res_values = arr.T.where(cond, other).T
except (ValueError, TypeError):
return super().where(
other, cond, errors=errors, try_cast=try_cast, axis=axis
)
return super().where(other, cond, errors=errors, axis=axis)

# TODO(EA2D): reshape not needed with 2D EAs
res_values = res_values.reshape(self.values.shape)
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,7 @@ def get_axe(block, qs, axes):
def isna(self, func) -> "BlockManager":
return self.apply("apply", func=func)

def where(
self, other, cond, align: bool, errors: str, try_cast: bool, axis: int
) -> "BlockManager":
def where(self, other, cond, align: bool, errors: str, axis: int) -> "BlockManager":
if align:
align_keys = ["other", "cond"]
else:
Expand All @@ -557,7 +555,6 @@ def where(
other=other,
cond=cond,
errors=errors,
try_cast=try_cast,
axis=axis,
)

Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/frame/indexing/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,16 @@ def test_mask_dtype_conversion(self):
expected = bools.astype(float).mask(mask)
result = bools.mask(mask)
tm.assert_frame_equal(result, expected)


def test_mask_try_cast_deprecated(frame_or_series):

obj = DataFrame(np.random.randn(4, 3))
if frame_or_series is not DataFrame:
obj = obj[0]

mask = obj > 0

with tm.assert_produces_warning(FutureWarning):
# try_cast keyword deprecated
obj.mask(mask, -1, try_cast=True)
12 changes: 12 additions & 0 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,15 @@ def test_where_ea_other(self):
expected["B"] = expected["B"].astype(object)
result = df.where(mask, ser2, axis=1)
tm.assert_frame_equal(result, expected)


def test_where_try_cast_deprecated(frame_or_series):
obj = DataFrame(np.random.randn(4, 3))
if frame_or_series is not DataFrame:
obj = obj[0]

mask = obj > 0

with tm.assert_produces_warning(FutureWarning):
# try_cast keyword deprecated
obj.where(mask, -1, try_cast=False)