Skip to content

Commit

Permalink
add in extension dtype registry
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed May 21, 2018
1 parent e4f7536 commit ec1c081
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 35 deletions.
39 changes: 5 additions & 34 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DatetimeTZDtype, DatetimeTZDtypeType,
PeriodDtype, PeriodDtypeType,
IntervalDtype, IntervalDtypeType,
ExtensionDtype, PandasExtensionDtype)
ExtensionDtype, registry)
from .generic import (ABCCategorical, ABCPeriodIndex,
ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex,
Expand Down Expand Up @@ -1975,39 +1975,10 @@ def pandas_dtype(dtype):
np.dtype or a pandas dtype
"""

if isinstance(dtype, DatetimeTZDtype):
return dtype
elif isinstance(dtype, PeriodDtype):
return dtype
elif isinstance(dtype, CategoricalDtype):
return dtype
elif isinstance(dtype, IntervalDtype):
return dtype
elif isinstance(dtype, string_types):
try:
return DatetimeTZDtype.construct_from_string(dtype)
except TypeError:
pass

if dtype.startswith('period[') or dtype.startswith('Period['):
# do not parse string like U as period[U]
try:
return PeriodDtype.construct_from_string(dtype)
except TypeError:
pass

elif dtype.startswith('interval') or dtype.startswith('Interval'):
try:
return IntervalDtype.construct_from_string(dtype)
except TypeError:
pass

try:
return CategoricalDtype.construct_from_string(dtype)
except TypeError:
pass
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
return dtype
# registered extension types
result = registry.find(dtype)
if result is not None:
return result

try:
npdtype = np.dtype(dtype)
Expand Down
80 changes: 80 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,64 @@

import re
import numpy as np
from collections import OrderedDict
from pandas import compat
from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex

from .base import ExtensionDtype, _DtypeOpsMixin


class Registry:
""" class to register our dtypes for inference
We can directly construct dtypes in pandas_dtypes if they are
a type; the registry allows us to register an extension dtype
to try inference from a string or a dtype class
These are tried in order for inference.
"""
dtypes = OrderedDict()

@classmethod
def register(self, dtype, constructor=None):
"""
Parameters
----------
dtype : PandasExtension Dtype
"""
if not issubclass(dtype, PandasExtensionDtype):
raise ValueError("can only register pandas extension dtypes")

if constructor is None:
constructor = dtype.construct_from_string

self.dtypes[dtype] = constructor

def find(self, dtype):
"""
Parameters
----------
dtype : PandasExtensionDtype or string
Returns
-------
return the first matching dtype, otherwise return None
"""
for dtype_type, constructor in self.dtypes.items():
if isinstance(dtype, dtype_type):
return dtype
if isinstance(dtype, compat.string_types):
try:
return constructor(dtype)
except TypeError:
pass

return None


registry = Registry()


class PandasExtensionDtype(_DtypeOpsMixin):
"""
A np.dtype duck-typed class, suitable for holding a custom dtype.
Expand Down Expand Up @@ -564,6 +616,17 @@ def construct_from_string(cls, string):
pass
raise TypeError("could not construct PeriodDtype")

@classmethod
def construct_from_string_strict(cls, string):
"""
Strict construction from a string, raise a TypeError if not
possible
"""
if string.startswith('period[') or string.startswith('Period['):
# do not parse string like U as period[U]
return PeriodDtype.construct_from_string(string)
raise TypeError("could not construct PeriodDtype")

def __unicode__(self):
return "period[{freq}]".format(freq=self.freq.freqstr)

Expand Down Expand Up @@ -683,6 +746,16 @@ def construct_from_string(cls, string):
msg = "a string needs to be passed, got type {typ}"
raise TypeError(msg.format(typ=type(string)))

@classmethod
def construct_from_string_strict(cls, string):
"""
Strict construction from a string, raise a TypeError if not
possible
"""
if string.startswith('interval') or string.startswith('Interval'):
return IntervalDtype.construct_from_string(string)
raise TypeError("cannot construct IntervalDtype")

def __unicode__(self):
if self.subtype is None:
return "interval"
Expand Down Expand Up @@ -723,3 +796,10 @@ def is_dtype(cls, dtype):
else:
return False
return super(IntervalDtype, cls).is_dtype(dtype)


# register the dtypes in search order
registry.register(DatetimeTZDtype)
registry.register(PeriodDtype, PeriodDtype.construct_from_string_strict)
registry.register(IntervalDtype, IntervalDtype.construct_from_string_strict)
registry.register(CategoricalDtype)
1 change: 1 addition & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4061,6 +4061,7 @@ def _try_cast(arr, take_fast_path):
"Pass the extension array directly.".format(dtype))
raise ValueError(msg)


elif dtype is not None and raise_cast_failure:
raise
else:
Expand Down
23 changes: 22 additions & 1 deletion pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pandas.compat import string_types
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype, PeriodDtype,
IntervalDtype, CategoricalDtype)
IntervalDtype, CategoricalDtype, registry)
from pandas.core.dtypes.common import (
is_categorical_dtype, is_categorical,
is_datetime64tz_dtype, is_datetimetz,
Expand Down Expand Up @@ -767,3 +767,24 @@ def test_update_dtype_errors(self, bad_dtype):
msg = 'a CategoricalDtype must be passed to perform an update, '
with tm.assert_raises_regex(ValueError, msg):
dtype.update_dtype(bad_dtype)


@pytest.mark.parametrize(
'dtype',
[DatetimeTZDtype, CategoricalDtype,
PeriodDtype, IntervalDtype])
def test_registry(dtype):
assert dtype in registry.dtypes


@pytest.mark.parametrize(
'dtype, expected',
[('int64', None),
('interval', IntervalDtype()),
('interval[int64]', IntervalDtype()),
('category', CategoricalDtype()),
('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
def test_registry_find(dtype, expected):

assert registry.find(dtype) == expected
12 changes: 12 additions & 0 deletions pandas/tests/extension/base/constructors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import numpy as np
import pandas as pd
import pandas.util.testing as tm
from pandas.core.internals import ExtensionBlock
Expand Down Expand Up @@ -45,3 +46,14 @@ def test_series_given_mismatched_index_raises(self, data):
msg = 'Length of passed values is 3, index implies 5'
with tm.assert_raises_regex(ValueError, msg):
pd.Series(data[:3], index=[0, 1, 2, 3, 4])

def test_from_dtype(self, data):
# construct from our dtype & string dtype
dtype = data.dtype

expected = pd.Series(data)
result = pd.Series(np.array(data), dtype=dtype)
self.assert_series_equal(result, expected)

result = pd.Series(np.array(data), dtype=str(dtype))
self.assert_series_equal(result, expected)

0 comments on commit ec1c081

Please sign in to comment.