Skip to content

Commit

Permalink
make construct_array_type arg optional
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed Jun 3, 2018
1 parent 4c4262e commit 9e818b0
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 12 deletions.
8 changes: 5 additions & 3 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,20 @@ def name(self):
raise AbstractMethodError(self)

@classmethod
def construct_array_type(cls, array):
def construct_array_type(cls, array=None):
"""Return the array type associated with this dtype
Parameters
----------
string : str
array : array-like, optional
Returns
-------
type
"""
return type(array)
if array is None:
return cls
raise NotImplementedError

@classmethod
def construct_from_string(cls, string):
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,12 @@ def _hash_categories(categories, ordered=True):
return np.bitwise_xor.reduce(hashed)

@classmethod
def construct_array_type(cls, array):
def construct_array_type(cls, array=None):
"""Return the array type associated with this dtype
Parameters
----------
array : value array
array : array-like, optional
Returns
-------
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/base/dtype.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -46,3 +47,10 @@ def test_eq_with_str(self, dtype):

def test_eq_with_numpy_object(self, dtype):
assert dtype != np.dtype('object')

def test_array_type(self, data, dtype):
assert dtype.construct_array_type() is type(data)

def test_array_type_with_arg(self, data, dtype):
with pytest.raises(NotImplementedError):
dtype.construct_array_type('foo')
4 changes: 3 additions & 1 deletion pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def data_for_grouping():


class TestDtype(base.BaseDtypeTests):
pass

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type('foo') is Categorical


class TestInterface(base.BaseInterfaceTests):
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ class DecimalDtype(ExtensionDtype):
na_value = decimal.Decimal('NaN')

@classmethod
def construct_array_type(cls, array):
def construct_array_type(cls, array=None):
"""Return the array type associated with this dtype
Parameters
----------
string : str
array : array-like, optional
Returns
-------
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs):


class TestDtype(BaseDecimal, base.BaseDtypeTests):
pass

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type('foo') is DecimalArray


class TestInterface(BaseDecimal, base.BaseInterfaceTests):
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/json/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class JSONDtype(ExtensionDtype):
na_value = {}

@classmethod
def construct_array_type(cls, array):
def construct_array_type(cls, array=None):
"""Return the array type associated with this dtype
Parameters
----------
string : str
array : array-like, optional
Returns
-------
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs):


class TestDtype(BaseJSON, base.BaseDtypeTests):
pass

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type('foo') is JSONArray


class TestInterface(BaseJSON, base.BaseInterfaceTests):
Expand Down

0 comments on commit 9e818b0

Please sign in to comment.