Skip to content

Commit

Permalink
PERF: dtype checks (#52279)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
  • Loading branch information
jbrockmendel and mroeschke committed Mar 29, 2023
1 parent d95fc0b commit 5a65a73
Show file tree
Hide file tree
Showing 28 changed files with 147 additions and 134 deletions.
15 changes: 12 additions & 3 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Literal,
cast,
)
Expand All @@ -21,6 +22,7 @@
)
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
ExtensionDtype,
PandasDtype,
)
from pandas.core.dtypes.missing import array_equivalent
Expand Down Expand Up @@ -53,6 +55,9 @@

from pandas.io.formats.printing import pprint_thing

if TYPE_CHECKING:
from pandas._typing import DtypeObj


def assert_almost_equal(
left,
Expand Down Expand Up @@ -965,7 +970,9 @@ def assert_series_equal(
obj=str(obj),
index_values=np.asarray(left.index),
)
elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype):
elif isinstance(left.dtype, ExtensionDtype) and isinstance(
right.dtype, ExtensionDtype
):
assert_extension_array_equal(
left._values,
right._values,
Expand Down Expand Up @@ -1320,7 +1327,9 @@ def assert_copy(iter1, iter2, **eql_kwargs) -> None:
assert elem1 is not elem2, msg


def is_extension_array_dtype_and_needs_i8_conversion(left_dtype, right_dtype) -> bool:
def is_extension_array_dtype_and_needs_i8_conversion(
left_dtype: DtypeObj, right_dtype: DtypeObj
) -> bool:
"""
Checks that we have the combination of an ExtensionArraydtype and
a dtype that should be converted to int64
Expand All @@ -1331,7 +1340,7 @@ def is_extension_array_dtype_and_needs_i8_conversion(left_dtype, right_dtype) ->
Related to issue #37609
"""
return is_extension_array_dtype(left_dtype) and needs_i8_conversion(right_dtype)
return isinstance(left_dtype, ExtensionDtype) and needs_i8_conversion(right_dtype)


def assert_indexing_slices_equivalent(ser: Series, l_slc: slice, i_slc: slice) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
from pandas.core.dtypes.cast import is_nested_object
from pandas.core.dtypes.common import (
is_dict_like,
is_extension_array_dtype,
is_list_like,
is_sequence,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCNDFrame,
Expand Down Expand Up @@ -940,7 +940,7 @@ def series_generator(self):
ser = self.obj._ixs(0, axis=0)
mgr = ser._mgr

if is_extension_array_dtype(ser.dtype):
if isinstance(ser.dtype, ExtensionDtype):
# values will be incorrect for this block
# TODO(EA2D): special case would be unnecessary with 2D EAs
obj = self.obj
Expand Down
8 changes: 3 additions & 5 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def __init__(
# we're inferring from values
dtype = CategoricalDtype(categories, dtype.ordered)

elif is_categorical_dtype(values.dtype):
elif isinstance(values.dtype, CategoricalDtype):
old_codes = extract_array(values)._codes
codes = recode_for_categories(
old_codes, values.dtype.categories, dtype.categories, copy=copy
Expand Down Expand Up @@ -504,9 +504,7 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
if self.dtype is dtype:
result = self.copy() if copy else self

elif is_categorical_dtype(dtype):
dtype = cast(CategoricalDtype, dtype)

elif isinstance(dtype, CategoricalDtype):
# GH 10696/18593/18630
dtype = self.dtype.update_dtype(dtype)
self = self.copy() if copy else self
Expand Down Expand Up @@ -2497,7 +2495,7 @@ def __init__(self, data) -> None:

@staticmethod
def _validate(data):
if not is_categorical_dtype(data.dtype):
if not isinstance(data.dtype, CategoricalDtype):
raise AttributeError("Can only use .cat accessor with a 'category' dtype")

# error: Signature of "_delegate_property_get" incompatible with supertype
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
is_integer_dtype,
is_list_like,
is_object_dtype,
is_period_dtype,
is_string_dtype,
is_timedelta64_dtype,
pandas_dtype,
Expand Down Expand Up @@ -1405,7 +1404,7 @@ def __sub__(self, other):
):
# DatetimeIndex, ndarray[datetime64]
result = self._sub_datetime_arraylike(other)
elif is_period_dtype(other_dtype):
elif isinstance(other_dtype, PeriodDtype):
# PeriodIndex
result = self._sub_periodlike(other)
elif is_integer_dtype(other_dtype):
Expand Down
22 changes: 11 additions & 11 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Literal,
Sequence,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -55,15 +54,17 @@
is_dtype_equal,
is_float_dtype,
is_integer_dtype,
is_interval_dtype,
is_list_like,
is_object_dtype,
is_scalar,
is_string_dtype,
needs_i8_conversion,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
IntervalDtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCDatetimeIndex,
Expand Down Expand Up @@ -317,8 +318,7 @@ def _ensure_simple_new_inputs(
if dtype is not None:
# GH 19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if is_interval_dtype(dtype):
dtype = cast(IntervalDtype, dtype)
if isinstance(dtype, IntervalDtype):
if dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)
Expand All @@ -344,7 +344,7 @@ def _ensure_simple_new_inputs(
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
if is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
if isinstance(left.dtype, CategoricalDtype) or is_string_dtype(left.dtype):
# GH 19016
msg = (
"category, object, and string subtypes are not supported "
Expand Down Expand Up @@ -752,14 +752,14 @@ def _cmp_method(self, other, op):
# determine the dtype of the elements we want to compare
if isinstance(other, Interval):
other_dtype = pandas_dtype("interval")
elif not is_categorical_dtype(other.dtype):
elif not isinstance(other.dtype, CategoricalDtype):
other_dtype = other.dtype
else:
# for categorical defer to categories for dtype
other_dtype = other.categories.dtype

# extract intervals if we have interval categories with matching closed
if is_interval_dtype(other_dtype):
if isinstance(other_dtype, IntervalDtype):
if self.closed != other.categories.closed:
return invalid_comparison(self, other, op)

Expand All @@ -768,7 +768,7 @@ def _cmp_method(self, other, op):
)

# interval-like -> need same closed and matching endpoints
if is_interval_dtype(other_dtype):
if isinstance(other_dtype, IntervalDtype):
if self.closed != other.closed:
return invalid_comparison(self, other, op)
elif not isinstance(other, Interval):
Expand Down Expand Up @@ -951,7 +951,7 @@ def astype(self, dtype, copy: bool = True):
if dtype is not None:
dtype = pandas_dtype(dtype)

if is_interval_dtype(dtype):
if isinstance(dtype, IntervalDtype):
if dtype == self.dtype:
return self.copy() if copy else self

Expand Down Expand Up @@ -1683,7 +1683,7 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
values = np.array(values)
values = extract_array(values, extract_numpy=True)

if is_interval_dtype(values.dtype):
if isinstance(values.dtype, IntervalDtype):
if self.closed != values.closed:
# not comparable -> no overlap
return np.zeros(self.shape, dtype=bool)
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def astype(self, dtype, copy: bool = True):
return self
else:
return self.copy()
if is_period_dtype(dtype):
if isinstance(dtype, PeriodDtype):
return self.asfreq(dtype.freq)

if is_datetime64_any_dtype(dtype):
Expand Down Expand Up @@ -897,7 +897,7 @@ def period_array(

if is_datetime64_dtype(data_dtype):
return PeriodArray._from_datetime64(data, freq)
if is_period_dtype(data_dtype):
if isinstance(data_dtype, PeriodDtype):
return PeriodArray(data, freq=freq)

# other iterable of some kind
Expand Down Expand Up @@ -966,7 +966,7 @@ def validate_dtype_freq(

if dtype is not None:
dtype = pandas_dtype(dtype)
if not is_period_dtype(dtype):
if not isinstance(dtype, PeriodDtype):
raise ValueError("dtype must be PeriodDtype")
if freq is None:
freq = dtype.freq
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
is_array_like,
is_bool_dtype,
is_datetime64_any_dtype,
is_datetime64tz_dtype,
is_dtype_equal,
is_integer,
is_list_like,
Expand All @@ -54,6 +53,7 @@
is_string_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.generic import (
ABCIndex,
ABCSeries,
Expand Down Expand Up @@ -458,7 +458,7 @@ def __init__(
data = extract_array(data, extract_numpy=True)
if not isinstance(data, np.ndarray):
# EA
if is_datetime64tz_dtype(data.dtype):
if isinstance(data.dtype, DatetimeTZDtype):
warnings.warn(
f"Creating SparseArray from {data.dtype} data "
"loses timezone information. Cast to object before "
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.common import (
is_extension_array_dtype,
is_object_dtype,
is_scalar,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCIndex,
Expand Down Expand Up @@ -565,7 +565,7 @@ def to_numpy(
array(['1999-12-31T23:00:00.000000000', '2000-01-01T23:00:00...'],
dtype='datetime64[ns]')
"""
if is_extension_array_dtype(self.dtype):
if isinstance(self.dtype, ExtensionDtype):
return self.array.to_numpy(dtype, copy=copy, na_value=na_value, **kwargs)
elif kwargs:
bad_keys = list(kwargs.keys())[0]
Expand Down Expand Up @@ -1132,7 +1132,7 @@ def _memory_usage(self, deep: bool = False) -> int:
)

v = self.array.nbytes
if deep and is_object_dtype(self) and not PYPY:
if deep and is_object_dtype(self.dtype) and not PYPY:
values = cast(np.ndarray, self._values)
v += lib.memory_usage_of_objects(values)
return v
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from pandas.core.dtypes.common import (
is_array_like,
is_bool_dtype,
is_extension_array_dtype,
is_integer,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import (
ABCExtensionArray,
ABCIndex,
Expand Down Expand Up @@ -122,7 +122,7 @@ def is_bool_indexer(key: Any) -> bool:
and convert to an ndarray.
"""
if isinstance(key, (ABCSeries, np.ndarray, ABCIndex)) or (
is_array_like(key) and is_extension_array_dtype(key.dtype)
is_array_like(key) and isinstance(key.dtype, ExtensionDtype)
):
if key.dtype == np.object_:
key_array = np.asarray(key)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
is_bool,
is_bool_dtype,
is_datetime64_any_dtype,
is_datetime64tz_dtype,
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
Expand All @@ -123,6 +122,7 @@
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -9623,7 +9623,7 @@ def align(
if self.ndim == 1 or axis == 0:
# If we are aligning timezone-aware DatetimeIndexes and the timezones
# do not match, convert both to UTC.
if is_datetime64tz_dtype(left.index.dtype):
if isinstance(left.index.dtype, DatetimeTZDtype):
if left.index.tz != right.index.tz:
if join_index is not None:
# GH#33671 copy to ensure we don't change the index on
Expand Down
10 changes: 6 additions & 4 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@
from pandas.core.dtypes.common import (
ensure_int64,
is_bool,
is_categorical_dtype,
is_dict_like,
is_integer_dtype,
is_interval_dtype,
is_numeric_dtype,
is_scalar,
)
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
IntervalDtype,
)
from pandas.core.dtypes.missing import (
isna,
notna,
Expand Down Expand Up @@ -681,7 +683,7 @@ def value_counts(

index_names = self.grouper.names + [self.obj.name]

if is_categorical_dtype(val.dtype) or (
if isinstance(val.dtype, CategoricalDtype) or (
bins is not None and not np.iterable(bins)
):
# scalar bins cannot be done at top level
Expand Down Expand Up @@ -717,7 +719,7 @@ def value_counts(
)
llab = lambda lab, inc: lab[inc]._multiindex.codes[-1]

if is_interval_dtype(lab.dtype):
if isinstance(lab.dtype, IntervalDtype):
# TODO: should we do this inside II?
lab_interval = cast(Interval, lab)

Expand Down

0 comments on commit 5a65a73

Please sign in to comment.