Skip to content
21 changes: 7 additions & 14 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pandas.core.dtypes.dtypes import (
registry, CategoricalDtype, CategoricalDtypeType, DatetimeTZDtype,
DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype,
IntervalDtypeType, ExtensionDtype)
IntervalDtypeType, PandasExtensionDtype, ExtensionDtype,
_pandas_registry)
from pandas.core.dtypes.generic import (
ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, ABCIndexClass,
Expand Down Expand Up @@ -1709,17 +1710,9 @@ def is_extension_array_dtype(arr_or_dtype):
Third-party libraries may implement arrays or types satisfying
this interface as well.
"""
from pandas.core.arrays import ExtensionArray

if isinstance(arr_or_dtype, (ABCIndexClass, ABCSeries)):
arr_or_dtype = arr_or_dtype._values

try:
arr_or_dtype = pandas_dtype(arr_or_dtype)
except TypeError:
pass

return isinstance(arr_or_dtype, (ExtensionDtype, ExtensionArray))
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
return (isinstance(dtype, ExtensionDtype) or
registry.find(dtype) is not None)


def is_complex_dtype(arr_or_dtype):
Expand Down Expand Up @@ -1999,12 +1992,12 @@ def pandas_dtype(dtype):
return dtype

# registered extension types
result = registry.find(dtype)
result = _pandas_registry.find(dtype) or registry.find(dtype)
if result is not None:
return result

# un-registered extension types
elif isinstance(dtype, ExtensionDtype):
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
return dtype

# try a numpy dtype
Expand Down
13 changes: 8 additions & 5 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ class Registry(object):
--------
registry.register(MyExtensionDtype)
"""
dtypes = []
def __init__(self):
self.dtypes = []

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't actually a classmethod.

def register(self, dtype):
"""
Parameters
Expand All @@ -50,7 +50,7 @@ def find(self, dtype):
dtype_type = dtype
if not isinstance(dtype, type):
dtype_type = type(dtype)
if issubclass(dtype_type, (PandasExtensionDtype, ExtensionDtype)):
if issubclass(dtype_type, ExtensionDtype):
return dtype

return None
Expand All @@ -65,6 +65,9 @@ def find(self, dtype):


registry = Registry()
# TODO(Extension): remove the second registry once all internal extension
# dtypes are real extension dtypes.
_pandas_registry = Registry()


class PandasExtensionDtype(_DtypeOpsMixin):
Expand Down Expand Up @@ -822,7 +825,7 @@ def is_dtype(cls, dtype):


# register the dtypes in search order
registry.register(DatetimeTZDtype)
registry.register(PeriodDtype)
registry.register(IntervalDtype)
registry.register(CategoricalDtype)
_pandas_registry.register(DatetimeTZDtype)
_pandas_registry.register(PeriodDtype)
24 changes: 17 additions & 7 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pandas.core.dtypes.dtypes import (
DatetimeTZDtype, PeriodDtype,
IntervalDtype, CategoricalDtype, registry)
IntervalDtype, CategoricalDtype, registry, _pandas_registry)
from pandas.core.dtypes.common import (
is_categorical_dtype, is_categorical,
is_datetime64tz_dtype, is_datetimetz,
Expand Down Expand Up @@ -775,21 +775,31 @@ def test_update_dtype_errors(self, bad_dtype):

@pytest.mark.parametrize(
'dtype',
[DatetimeTZDtype, CategoricalDtype,
PeriodDtype, IntervalDtype])
[CategoricalDtype, IntervalDtype])
def test_registry(dtype):
assert dtype in registry.dtypes


@pytest.mark.parametrize('dtype', [DatetimeTZDtype, PeriodDtype])
def test_pandas_registry(dtype):
assert dtype not in registry.dtypes
assert dtype in _pandas_registry.dtypes


@pytest.mark.parametrize(
'dtype, expected',
[('int64', None),
('interval', IntervalDtype()),
('interval[int64]', IntervalDtype()),
('interval[datetime64[ns]]', IntervalDtype('datetime64[ns]')),
('category', CategoricalDtype()),
('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
('category', CategoricalDtype())])
def test_registry_find(dtype, expected):

assert registry.find(dtype) == expected


@pytest.mark.parametrize(
'dtype, expected',
[('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
def test_pandas_registry_find(dtype, expected):
assert _pandas_registry.find(dtype) == expected