Skip to content

Commit

Permalink
API: register_extension_dtype class decorator (#22666)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Sep 13, 2018
1 parent c040353 commit 857515f
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 37 deletions.
1 change: 1 addition & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2559,6 +2559,7 @@ objects.
.. autosummary::
:toctree: generated/

api.extensions.register_extension_dtype
api.extensions.register_dataframe_accessor
api.extensions.register_series_accessor
api.extensions.register_index_accessor
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ ExtensionType Changes
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)

.. _whatsnew_0240.api.incompatibilities:

Expand Down
4 changes: 3 additions & 1 deletion pandas/api/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
from pandas.core.algorithms import take # noqa
from pandas.core.arrays.base import (ExtensionArray, # noqa
ExtensionScalarOpsMixin)
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa
from pandas.core.dtypes.dtypes import ( # noqa
ExtensionDtype, register_extension_dtype
)
8 changes: 4 additions & 4 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
is_list_like)
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.dtypes import registry
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.missing import isna, notna

from pandas.io.formats.printing import (
Expand Down Expand Up @@ -614,9 +614,9 @@ def integer_arithmetic_method(self, other):
classname = "{}Dtype".format(name)
attributes_dict = {'type': getattr(np, dtype),
'name': name}
dtype_type = type(classname, (_IntegerDtype, ), attributes_dict)
dtype_type = register_extension_dtype(
type(classname, (_IntegerDtype, ), attributes_dict)
)
setattr(module, classname, dtype_type)

# register
registry.register(dtype_type)
_dtypes[dtype] = dtype_type()
8 changes: 7 additions & 1 deletion pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class ExtensionDtype(_DtypeOpsMixin):
* _is_numeric
Optionally one can override construct_array_type for construction
with the name of this dtype via the Registry
with the name of this dtype via the Registry. See
:meth:`pandas.api.extensions.register_extension_dtype`.
* construct_array_type
Expand All @@ -138,6 +139,11 @@ class ExtensionDtype(_DtypeOpsMixin):
Methods and properties required by the interface raise
``pandas.errors.AbstractMethodError`` and no ``register`` method is
provided for registering virtual subclasses.
See Also
--------
pandas.api.extensions.register_extension_dtype
pandas.api.extensions.ExtensionArray
"""

def __str__(self):
Expand Down
36 changes: 26 additions & 10 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
from .base import ExtensionDtype, _DtypeOpsMixin


def register_extension_dtype(cls):
"""Class decorator to register an ExtensionType with pandas.
.. versionadded:: 0.24.0
This enables operations like ``.astype(name)`` for the name
of the ExtensionDtype.
Examples
--------
>>> from pandas.api.extensions import register_extension_dtype
>>> from pandas.api.extensions import ExtensionDtype
>>> @register_extension_dtype
... class MyExtensionDtype(ExtensionDtype):
... pass
"""
registry.register(cls)
return cls


class Registry(object):
"""
Registry for dtype inference
Expand All @@ -17,10 +37,6 @@ class Registry(object):
Multiple extension types can be registered.
These are tried in order.
Examples
--------
registry.register(MyExtensionDtype)
"""
def __init__(self):
self.dtypes = []
Expand Down Expand Up @@ -65,9 +81,6 @@ def find(self, dtype):


registry = Registry()
# TODO(Extension): remove the second registry once all internal extension
# dtypes are real extension dtypes.
_pandas_registry = Registry()


class PandasExtensionDtype(_DtypeOpsMixin):
Expand Down Expand Up @@ -145,6 +158,7 @@ class CategoricalDtypeType(type):
pass


@register_extension_dtype
class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
"""
Type for categorical data with the categories and orderedness
Expand Down Expand Up @@ -692,6 +706,7 @@ class IntervalDtypeType(type):
pass


@register_extension_dtype
class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
"""
A Interval duck-typed class, suitable for holding an interval
Expand Down Expand Up @@ -824,8 +839,9 @@ def is_dtype(cls, dtype):
return super(IntervalDtype, cls).is_dtype(dtype)


# register the dtypes in search order
registry.register(IntervalDtype)
registry.register(CategoricalDtype)
# TODO(Extension): remove the second registry once all internal extension
# dtypes are real extension dtypes.
_pandas_registry = Registry()

_pandas_registry.register(DatetimeTZDtype)
_pandas_registry.register(PeriodDtype)
5 changes: 0 additions & 5 deletions pandas/tests/extension/base/dtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -51,10 +50,6 @@ def test_eq_with_numpy_object(self, dtype):
def test_array_type(self, data, dtype):
assert dtype.construct_array_type() is type(data)

def test_array_type_with_arg(self, data, dtype):
with pytest.raises(NotImplementedError):
dtype.construct_array_type('foo')

def test_check_dtype(self, data):
dtype = data.dtype

Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):


class TestDtype(BaseDecimal, base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is DecimalArray
pass


class TestInterface(BaseDecimal, base.BaseInterfaceTests):
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def assert_frame_equal(self, left, right, *args, **kwargs):


class TestDtype(BaseJSON, base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is JSONArray
pass


class TestInterface(BaseJSON, base.BaseInterfaceTests):
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def data_for_grouping():


class TestDtype(base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is Categorical
pass


class TestInterface(base.BaseInterfaceTests):
Expand Down
5 changes: 1 addition & 4 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pandas.tests.extension import base
from pandas.core.dtypes.common import is_extension_array_dtype

from pandas.core.arrays import IntegerArray, integer_array
from pandas.core.arrays import integer_array
from pandas.core.arrays.integer import (
Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype)
Expand Down Expand Up @@ -92,9 +92,6 @@ def test_is_dtype_unboxes_dtype(self):
# we have multiple dtypes, so skip
pass

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is IntegerArray


class TestArithmeticOps(base.BaseArithmeticOpsTests):

Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ class BaseInterval(object):


class TestDtype(BaseInterval, base.BaseDtypeTests):

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is IntervalArray
pass


class TestCasting(BaseInterval, base.BaseCastingTests):
Expand Down

0 comments on commit 857515f

Please sign in to comment.