Skip to content
156 changes: 95 additions & 61 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from .dtypes import DatetimeTZDtype, ExtensionDtype, PeriodDtype
from .generic import (
ABCDataFrame,
ABCDatetimeArray,
ABCDatetimeIndex,
ABCPeriodArray,
Expand Down Expand Up @@ -95,12 +96,13 @@ def maybe_downcast_to_dtype(result, dtype):
""" try to cast to the specified dtype (e.g. convert back to bool/int
or could be an astype of float64->float32
"""
do_round = False

if is_scalar(result):
return result

def trans(x):
return x
elif isinstance(result, ABCDataFrame):
# occurs in pivot_table doctest
return result

if isinstance(dtype, str):
if dtype == "infer":
Expand All @@ -118,83 +120,115 @@ def trans(x):
elif inferred_type == "floating":
dtype = "int64"
if issubclass(result.dtype.type, np.number):

def trans(x): # noqa
return x.round()
do_round = True

else:
dtype = "object"

if isinstance(dtype, str):
dtype = np.dtype(dtype)

try:
converted = maybe_downcast_numeric(result, dtype, do_round)
if converted is not result:
return converted

# a datetimelike
# GH12821, iNaT is casted to float
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
try:
result = result.astype(dtype)
except Exception:
if dtype.tz:
# convert to datetime and change timezone
from pandas import to_datetime

result = to_datetime(result).tz_localize("utc")
result = result.tz_convert(dtype.tz)

elif dtype.type is Period:
# TODO(DatetimeArray): merge with previous elif
from pandas.core.arrays import PeriodArray

try:
return PeriodArray(result, freq=dtype.freq)
except TypeError:
# e.g. TypeError: int() argument must be a string, a
# bytes-like object or a number, not 'Period
pass

return result


def maybe_downcast_numeric(result, dtype, do_round: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very similar to to_numeric; would plan as a followup to move to_numeric logic here and call this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

"""
Subset of maybe_downcast_to_dtype restricted to numeric dtypes.

Parameters
----------
result : ndarray or ExtensionArray
dtype : np.dtype or ExtensionDtype
do_round : bool

Returns
-------
ndarray or ExtensionArray
"""
if not isinstance(dtype, np.dtype):
# e.g. SparseDtype has no itemsize attr
return result

if isinstance(result, list):
# reached via groupoby.agg _ohlc; really this should be handled
# earlier
result = np.array(result)

def trans(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than doing this, I would pass in a callable directly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that gives a lot more degrees of freedom to the caller, I'd rather it just be a bool kwarg until/unless we need something more

if do_round:
return x.round()
return x

if dtype.kind == result.dtype.kind:
# don't allow upcasts here (except if empty)
if dtype.kind == result.dtype.kind:
if result.dtype.itemsize <= dtype.itemsize and np.prod(result.shape):
return result
if result.dtype.itemsize <= dtype.itemsize and result.size:
return result

if is_bool_dtype(dtype) or is_integer_dtype(dtype):
if is_bool_dtype(dtype) or is_integer_dtype(dtype):

if not result.size:
# if we don't have any elements, just astype it
if not np.prod(result.shape):
return trans(result).astype(dtype)
return trans(result).astype(dtype)

# do a test on the first element, if it fails then we are done
r = result.ravel()
arr = np.array([r[0]])
# do a test on the first element, if it fails then we are done
r = result.ravel()
arr = np.array([r[0]])

if isna(arr).any() or not np.allclose(arr, trans(arr).astype(dtype), rtol=0):
# if we have any nulls, then we are done
if isna(arr).any() or not np.allclose(
arr, trans(arr).astype(dtype), rtol=0
):
return result
return result

elif not isinstance(r[0], (np.integer, np.floating, np.bool, int, float, bool)):
# a comparable, e.g. a Decimal may slip in here
elif not isinstance(
r[0], (np.integer, np.floating, np.bool, int, float, bool)
):
return result
return result

if (
issubclass(result.dtype.type, (np.object_, np.number))
and notna(result).all()
):
new_result = trans(result).astype(dtype)
try:
if np.allclose(new_result, result, rtol=0):
return new_result
except Exception:

# comparison of an object dtype with a number type could
# hit here
if (new_result == result).all():
return new_result
elif issubclass(dtype.type, np.floating) and not is_bool_dtype(result.dtype):
return result.astype(dtype)

# a datetimelike
# GH12821, iNaT is casted to float
elif dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
if (
issubclass(result.dtype.type, (np.object_, np.number))
and notna(result).all()
):
new_result = trans(result).astype(dtype)
try:
result = result.astype(dtype)
if np.allclose(new_result, result, rtol=0):
return new_result
except Exception:
if dtype.tz:
# convert to datetime and change timezone
from pandas import to_datetime

result = to_datetime(result).tz_localize("utc")
result = result.tz_convert(dtype.tz)

elif dtype.type == Period:
# TODO(DatetimeArray): merge with previous elif
from pandas.core.arrays import PeriodArray

return PeriodArray(result, freq=dtype.freq)

except Exception:
pass
# comparison of an object dtype with a number type could
# hit here
if (new_result == result).all():
return new_result

elif (
issubclass(dtype.type, np.floating)
and not is_bool_dtype(result.dtype)
and not is_string_dtype(result.dtype)
):
return result.astype(dtype)

return result

Expand Down