From 9e818b01b238c974d0e7c72907c734b10bd8494c Mon Sep 17 00:00:00 2001 From: Jeff Reback Date: Tue, 29 May 2018 06:34:04 -0400 Subject: [PATCH] make construct_array_type arg optional --- pandas/core/dtypes/base.py | 8 +++++--- pandas/core/dtypes/dtypes.py | 4 ++-- pandas/tests/extension/base/dtype.py | 8 ++++++++ pandas/tests/extension/category/test_categorical.py | 4 +++- pandas/tests/extension/decimal/array.py | 4 ++-- pandas/tests/extension/decimal/test_decimal.py | 4 +++- pandas/tests/extension/json/array.py | 4 ++-- pandas/tests/extension/json/test_json.py | 4 +++- 8 files changed, 28 insertions(+), 12 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index f09151f88c2c92..c0c9a8d22ce4f4 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -162,18 +162,20 @@ def name(self): raise AbstractMethodError(self) @classmethod - def construct_array_type(cls, array): + def construct_array_type(cls, array=None): """Return the array type associated with this dtype Parameters ---------- - string : str + array : array-like, optional Returns ------- type """ - return type(array) + if array is None: + return cls + raise NotImplementedError @classmethod def construct_from_string(cls, string): diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 7b5095e6278380..7d147da661e34d 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -318,12 +318,12 @@ def _hash_categories(categories, ordered=True): return np.bitwise_xor.reduce(hashed) @classmethod - def construct_array_type(cls, array): + def construct_array_type(cls, array=None): """Return the array type associated with this dtype Parameters ---------- - array : value array + array : array-like, optional Returns ------- diff --git a/pandas/tests/extension/base/dtype.py b/pandas/tests/extension/base/dtype.py index 63d3d807c270c8..52a12816c8722c 100644 --- a/pandas/tests/extension/base/dtype.py +++ b/pandas/tests/extension/base/dtype.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import pandas as pd @@ -46,3 +47,10 @@ def test_eq_with_str(self, dtype): def test_eq_with_numpy_object(self, dtype): assert dtype != np.dtype('object') + + def test_array_type(self, data, dtype): + assert dtype.construct_array_type() is type(data) + + def test_array_type_with_arg(self, data, dtype): + with pytest.raises(NotImplementedError): + dtype.construct_array_type('foo') diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 530a4e7a22a7a3..7fc31a02ef3220 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -52,7 +52,9 @@ def data_for_grouping(): class TestDtype(base.BaseDtypeTests): - pass + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type('foo') is Categorical class TestInterface(base.BaseInterfaceTests): diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 8c8acf58ec7578..7bdbbf77cf4d64 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -16,12 +16,12 @@ class DecimalDtype(ExtensionDtype): na_value = decimal.Decimal('NaN') @classmethod - def construct_array_type(cls, array): + def construct_array_type(cls, array=None): """Return the array type associated with this dtype Parameters ---------- - string : str + array : array-like, optional Returns ------- diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index b932bd8cb50f21..65f87942fb1919 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -92,7 +92,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseDecimal, base.BaseDtypeTests): - pass + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type('foo') is DecimalArray class TestInterface(BaseDecimal, base.BaseInterfaceTests): diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 9dee52c043750d..f5d7d58277cc54 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -33,12 +33,12 @@ class JSONDtype(ExtensionDtype): na_value = {} @classmethod - def construct_array_type(cls, array): + def construct_array_type(cls, array=None): """Return the array type associated with this dtype Parameters ---------- - string : str + array : array-like, optional Returns ------- diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index ba0936a5f14a12..d41fea587131ce 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -107,7 +107,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseJSON, base.BaseDtypeTests): - pass + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type('foo') is JSONArray class TestInterface(BaseJSON, base.BaseInterfaceTests):