Skip to content

Commit

Permalink
PERF: dtype checks (#52506)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Apr 7, 2023
1 parent 92f837f commit 8e19396
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 54 deletions.
29 changes: 13 additions & 16 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@
TD64NS_DTYPE,
ensure_object,
is_bool_dtype,
is_complex_dtype,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_integer_dtype,
is_object_dtype,
is_scalar,
is_string_or_object_np_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
Expand Down Expand Up @@ -291,7 +288,7 @@ def _isna_array(values: ArrayLike, inf_as_na: bool = False):
result = values.isna() # type: ignore[assignment]
elif is_string_or_object_np_dtype(values.dtype):
result = _isna_string_dtype(values, inf_as_na=inf_as_na)
elif needs_i8_conversion(dtype):
elif dtype.kind in "mM":
# this is the NaT pattern
result = values.view("i8") == iNaT
else:
Expand Down Expand Up @@ -502,7 +499,7 @@ def array_equivalent(
# fastpath when we require that the dtypes match (Block.equals)
if left.dtype.kind in "fc":
return _array_equivalent_float(left, right)
elif needs_i8_conversion(left.dtype):
elif left.dtype.kind in "mM":
return _array_equivalent_datetimelike(left, right)
elif is_string_or_object_np_dtype(left.dtype):
# TODO: fastpath for pandas' StringDtype
Expand All @@ -519,14 +516,14 @@ def array_equivalent(
return _array_equivalent_object(left, right, strict_nan)

# NaNs can occur in float and complex arrays.
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):
if left.dtype.kind in "fc":
if not (left.size and right.size):
return True
return ((left == right) | (isna(left) & isna(right))).all()

elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype):
elif left.dtype.kind in "mM" or right.dtype.kind in "mM":
# datetime64, timedelta64, Period
if not is_dtype_equal(left.dtype, right.dtype):
if left.dtype != right.dtype:
return False

left = left.view("i8")
Expand All @@ -541,11 +538,11 @@ def array_equivalent(
return np.array_equal(left, right)


def _array_equivalent_float(left, right) -> bool:
def _array_equivalent_float(left: np.ndarray, right: np.ndarray) -> bool:
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())


def _array_equivalent_datetimelike(left, right):
def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray):
return np.array_equal(left.view("i8"), right.view("i8"))


Expand Down Expand Up @@ -601,7 +598,7 @@ def infer_fill_value(val):
if not is_list_like(val):
val = [val]
val = np.array(val, copy=False)
if needs_i8_conversion(val.dtype):
if val.dtype.kind in "mM":
return np.array("NaT", dtype=val.dtype)
elif is_object_dtype(val.dtype):
dtype = lib.infer_dtype(ensure_object(val), skipna=False)
Expand All @@ -616,7 +613,7 @@ def maybe_fill(arr: np.ndarray) -> np.ndarray:
"""
Fill numpy.ndarray with NaN, unless we have a integer or boolean dtype.
"""
if arr.dtype.kind not in ("u", "i", "b"):
if arr.dtype.kind not in "iub":
arr.fill(np.nan)
return arr

Expand Down Expand Up @@ -650,15 +647,15 @@ def na_value_for_dtype(dtype: DtypeObj, compat: bool = True):

if isinstance(dtype, ExtensionDtype):
return dtype.na_value
elif needs_i8_conversion(dtype):
elif dtype.kind in "mM":
return dtype.type("NaT", "ns")
elif is_float_dtype(dtype):
elif dtype.kind == "f":
return np.nan
elif is_integer_dtype(dtype):
elif dtype.kind in "iu":
if compat:
return 0
return np.nan
elif is_bool_dtype(dtype):
elif dtype.kind == "b":
if compat:
return False
return np.nan
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
is_dataclass,
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
is_float,
is_float_dtype,
is_hashable,
Expand Down Expand Up @@ -3597,7 +3596,9 @@ def transpose(self, *args, copy: bool = False) -> DataFrame:
result._mgr.add_references(self._mgr) # type: ignore[arg-type]

elif (
self._is_homogeneous_type and dtypes and is_extension_array_dtype(dtypes[0])
self._is_homogeneous_type
and dtypes
and isinstance(dtypes[0], ExtensionDtype)
):
# We have EAs with the same dtype. We can preserve that dtype in transpose.
dtype = dtypes[0]
Expand Down Expand Up @@ -4178,7 +4179,7 @@ def _set_item(self, key, value) -> None:
if (
key in self.columns
and value.ndim == 1
and not is_extension_array_dtype(value)
and not isinstance(value.dtype, ExtensionDtype)
):
# broadcast across multiple columns if necessary
if not self.columns.is_unique or isinstance(self.columns, MultiIndex):
Expand Down
9 changes: 6 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -4670,7 +4673,7 @@ def _drop_axis(
if errors == "raise" and labels_missing:
raise KeyError(f"{labels} not found in axis")

if is_extension_array_dtype(mask.dtype):
if isinstance(mask.dtype, ExtensionDtype):
# GH#45860
mask = mask.to_numpy(dtype=bool)

Expand Down Expand Up @@ -5458,7 +5461,7 @@ def _needs_reindex_multi(self, axes, method, level: Level | None) -> bool_t:
and not (
self.ndim == 2
and len(self.dtypes) == 1
and is_extension_array_dtype(self.dtypes.iloc[0])
and isinstance(self.dtypes.iloc[0], ExtensionDtype)
)
)

Expand Down
20 changes: 7 additions & 13 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@
ensure_platform_int,
ensure_uint64,
is_1d_only_ea_dtype,
is_bool_dtype,
is_complex_dtype,
is_float_dtype,
is_integer_dtype,
is_numeric_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.missing import (
isna,
Expand Down Expand Up @@ -248,7 +242,7 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
if how == "rank":
out_dtype = "float64"
else:
if is_numeric_dtype(dtype):
if dtype.kind in "iufcb":
out_dtype = f"{dtype.kind}{dtype.itemsize}"
else:
out_dtype = "object"
Expand All @@ -274,9 +268,9 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif how in ["mean", "median", "var", "std", "sem"]:
if is_float_dtype(dtype) or is_complex_dtype(dtype):
if dtype.kind in "fc":
return dtype
elif is_numeric_dtype(dtype):
elif dtype.kind in "iub":
return np.dtype(np.float64)
return dtype

Expand Down Expand Up @@ -339,14 +333,14 @@ def _call_cython_op(
orig_values = values

dtype = values.dtype
is_numeric = is_numeric_dtype(dtype)
is_numeric = dtype.kind in "iufcb"

is_datetimelike = needs_i8_conversion(dtype)
is_datetimelike = dtype.kind in "mM"

if is_datetimelike:
values = values.view("int64")
is_numeric = True
elif is_bool_dtype(dtype):
elif dtype.kind == "b":
values = values.view("uint8")
if values.dtype == "float16":
values = values.astype(np.float32)
Expand Down Expand Up @@ -446,7 +440,7 @@ def _call_cython_op(
# i.e. counts is defined. Locations where count<min_count
# need to have the result set to np.nan, which may require casting,
# see GH#40767
if is_integer_dtype(result.dtype) and not is_datetimelike:
if result.dtype.kind in "iu" and not is_datetimelike:
# if the op keeps the int dtypes, we have to use 0
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
empty_groups = counts < cutoff
Expand Down
8 changes: 4 additions & 4 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from pandas.core.dtypes.common import (
is_array_like,
is_bool_dtype,
is_extension_array_dtype,
is_hashable,
is_integer,
is_iterator,
Expand All @@ -46,6 +45,7 @@
is_sequence,
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -1128,10 +1128,10 @@ def _validate_key(self, key, axis: Axis):
# boolean not in slice and with boolean index
ax = self.obj._get_axis(axis)
if isinstance(key, bool) and not (
is_bool_dtype(ax)
is_bool_dtype(ax.dtype)
or ax.dtype.name == "boolean"
or isinstance(ax, MultiIndex)
and is_bool_dtype(ax.get_level_values(0))
and is_bool_dtype(ax.get_level_values(0).dtype)
):
raise KeyError(
f"{key}: boolean label can not be used without a boolean index"
Expand Down Expand Up @@ -2490,7 +2490,7 @@ def check_bool_indexer(index: Index, key) -> np.ndarray:
result = result.take(indexer)

# fall through for boolean
if not is_extension_array_dtype(result.dtype):
if not isinstance(result.dtype, ExtensionDtype):
return result.astype(bool)._values

if is_object_dtype(key):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _interpolate_1d(
# sort preserve_nans and convert to list
preserve_nans = sorted(preserve_nans)

is_datetimelike = needs_i8_conversion(yvalues.dtype)
is_datetimelike = yvalues.dtype.kind in "mM"

if is_datetimelike:
yvalues = yvalues.view("i8")
Expand Down
17 changes: 8 additions & 9 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import (
is_any_int_dtype,
is_complex,
is_float,
is_float_dtype,
Expand Down Expand Up @@ -247,7 +246,7 @@ def _maybe_get_mask(
# Boolean data cannot contain nulls, so signal via mask being None
return None

if skipna or needs_i8_conversion(values.dtype):
if skipna or values.dtype.kind in "mM":
mask = isna(values)

return mask
Expand Down Expand Up @@ -300,7 +299,7 @@ def _get_values(
dtype = values.dtype

datetimelike = False
if needs_i8_conversion(values.dtype):
if values.dtype.kind in "mM":
# changing timedelta64/datetime64 to int64 needs to happen after
# finding `mask` above
values = np.asarray(values.view("i8"))
Expand Down Expand Up @@ -433,7 +432,7 @@ def _na_for_min_count(values: np.ndarray, axis: AxisInt | None) -> Scalar | np.n
For 2-D values, returns a 1-D array where each element is missing.
"""
# we either return np.nan or pd.NaT
if is_numeric_dtype(values.dtype):
if values.dtype.kind in "iufcb":
values = values.astype("float64")
fill_value = na_value_for_dtype(values.dtype)

Expand Down Expand Up @@ -521,7 +520,7 @@ def nanany(
# expected "bool")
return values.any(axis) # type: ignore[return-value]

if needs_i8_conversion(values.dtype) and values.dtype.kind != "m":
if values.dtype.kind == "M":
# GH#34479
warnings.warn(
"'any' with datetime64 dtypes is deprecated and will raise in a "
Expand Down Expand Up @@ -582,7 +581,7 @@ def nanall(
# expected "bool")
return values.all(axis) # type: ignore[return-value]

if needs_i8_conversion(values.dtype) and values.dtype.kind != "m":
if values.dtype.kind == "M":
# GH#34479
warnings.warn(
"'all' with datetime64 dtypes is deprecated and will raise in a "
Expand Down Expand Up @@ -976,12 +975,12 @@ def nanvar(
"""
dtype = values.dtype
mask = _maybe_get_mask(values, skipna, mask)
if is_any_int_dtype(dtype):
if dtype.kind in "iu":
values = values.astype("f8")
if mask is not None:
values[mask] = np.nan

if is_float_dtype(values.dtype):
if values.dtype.kind == "f":
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype)
else:
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof)
Expand All @@ -1007,7 +1006,7 @@ def nanvar(
# Return variance as np.float64 (the datatype used in the accumulator),
# unless we were dealing with a float array, in which case use the same
# precision as the original values array.
if is_float_dtype(dtype):
if dtype.kind == "f":
result = result.astype(dtype, copy=False)
return result

Expand Down
8 changes: 4 additions & 4 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
)
from pandas.core.dtypes.common import (
is_dict_like,
is_extension_array_dtype,
is_integer,
is_iterator,
is_list_like,
Expand All @@ -73,6 +72,7 @@
pandas_dtype,
validate_all_hashable,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def to_dict(self, into: type[dict] = dict) -> dict:
# GH16122
into_c = com.standardize_mapping(into)

if is_object_dtype(self) or is_extension_array_dtype(self):
if is_object_dtype(self.dtype) or isinstance(self.dtype, ExtensionDtype):
return into_c((k, maybe_box_native(v)) for k, v in self.items())
else:
# Not an object dtype => all types will be the same so let the default
Expand Down Expand Up @@ -4164,7 +4164,7 @@ def explode(self, ignore_index: bool = False) -> Series:
3 4
dtype: object
"""
if not len(self) or not is_object_dtype(self):
if not len(self) or not is_object_dtype(self.dtype):
result = self.copy()
return result.reset_index(drop=True) if ignore_index else result

Expand Down Expand Up @@ -5220,7 +5220,7 @@ def _convert_dtypes(
input_series = self
if infer_objects:
input_series = input_series.infer_objects()
if is_object_dtype(input_series):
if is_object_dtype(input_series.dtype):
input_series = input_series.copy(deep=None)

if convert_string or convert_integer or convert_boolean or convert_floating:
Expand Down

0 comments on commit 8e19396

Please sign in to comment.