Skip to content

Commit

Permalink
Backport PR #51854 on branch 2.0.x (ENH: Add misc pyarrow types to Ar…
Browse files Browse the repository at this point in the history
…rowDtype.type) (#51887)

Backport PR #51854: ENH: Add misc pyarrow types to ArrowDtype.type

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and mroeschke committed Mar 10, 2023
1 parent 5f2b051 commit 9a2dc47
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
11 changes: 10 additions & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
StorageExtensionDtype,
register_extension_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtypeType

if not pa_version_under7p0:
import pyarrow as pa
Expand Down Expand Up @@ -106,7 +107,7 @@ def type(self):
return int
elif pa.types.is_floating(pa_type):
return float
elif pa.types.is_string(pa_type):
elif pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
return str
elif (
pa.types.is_binary(pa_type)
Expand All @@ -132,6 +133,14 @@ def type(self):
return time
elif pa.types.is_decimal(pa_type):
return Decimal
elif pa.types.is_dictionary(pa_type):
# TODO: Potentially change this & CategoricalDtype.type to
# something more representative of the scalar
return CategoricalDtypeType
elif pa.types.is_list(pa_type) or pa.types.is_large_list(pa_type):
return list
elif pa.types.is_map(pa_type):
return dict
elif pa.types.is_null(pa_type):
# TODO: None? pd.NA? pa.null?
return type(pa_type)
Expand Down
23 changes: 19 additions & 4 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.common import is_any_int_dtype
from pandas.core.dtypes.dtypes import CategoricalDtypeType

import pandas as pd
import pandas._testing as tm
Expand Down Expand Up @@ -1543,9 +1544,23 @@ def test_mode_dropna_false_mode_na(data):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("arrow_dtype", [pa.binary(), pa.binary(16), pa.large_binary()])
def test_arrow_dtype_type(arrow_dtype):
assert ArrowDtype(arrow_dtype).type == bytes
@pytest.mark.parametrize(
"arrow_dtype, expected_type",
[
[pa.binary(), bytes],
[pa.binary(16), bytes],
[pa.large_binary(), bytes],
[pa.large_string(), str],
[pa.list_(pa.int64()), list],
[pa.large_list(pa.int64()), list],
[pa.map_(pa.string(), pa.int64()), dict],
[pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType],
],
)
def test_arrow_dtype_type(arrow_dtype, expected_type):
# GH 51845
# TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture
assert ArrowDtype(arrow_dtype).type == expected_type


def test_is_bool_dtype():
Expand Down Expand Up @@ -1938,7 +1953,7 @@ def test_str_get(i, exp):

@pytest.mark.xfail(
reason="TODO: StringMethods._validate should support Arrow list types",
raises=NotImplementedError,
raises=AttributeError,
)
def test_str_join():
ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None])))
Expand Down

0 comments on commit 9a2dc47

Please sign in to comment.