Skip to content

Commit

Permalink
Backport PR #53391: BUG: read_csv with dtype=bool[pyarrow] (#53472)
Browse files Browse the repository at this point in the history
* Backport PR #53391: BUG: read_csv with dtype=bool[pyarrow]

* Add xfail
  • Loading branch information
mroeschke committed May 31, 2023
1 parent 8bc5245 commit 54e7fe9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.0.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Fixed regressions

Bug fixes
~~~~~~~~~
-
- Bug in :func:`read_csv` when defining ``dtype`` with ``bool[pyarrow]`` for the ``"c"`` and ``"python"`` engines (:issue:`53390`)

.. ---------------------------------------------------------------------------
.. _whatsnew_203.other:
Expand Down
6 changes: 5 additions & 1 deletion pandas/_libs/parsers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ from pandas.core.dtypes.common import (
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.inference import is_dict_like

from pandas.core.arrays.boolean import BooleanDtype

cdef:
float64_t INF = <float64_t>np.inf
float64_t NEGINF = -INF
Expand Down Expand Up @@ -1167,7 +1169,9 @@ cdef class TextReader:
array_type = dtype.construct_array_type()
try:
# use _from_sequence_of_strings if the class defines it
if is_bool_dtype(dtype):
if isinstance(dtype, BooleanDtype):
# xref GH 47534: BooleanArray._from_sequence_of_strings has extra
# kwargs
true_values = [x.decode() for x in self.true_values]
false_values = [x.decode() for x in self.false_values]
result = array_type._from_sequence_of_strings(
Expand Down
3 changes: 2 additions & 1 deletion pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
FloatingArray,
IntegerArray,
)
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.indexes.api import (
Index,
MultiIndex,
Expand Down Expand Up @@ -800,7 +801,7 @@ def _cast_types(self, values: ArrayLike, cast_type: DtypeObj, column) -> ArrayLi
elif isinstance(cast_type, ExtensionDtype):
array_type = cast_type.construct_array_type()
try:
if is_bool_dtype(cast_type):
if isinstance(cast_type, BooleanDtype):
# error: Unexpected keyword argument "true_values" for
# "_from_sequence_of_strings" of "ExtensionArray"
return array_type._from_sequence_of_strings( # type: ignore[call-arg] # noqa:E501
Expand Down
26 changes: 19 additions & 7 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import pandas as pd
import pandas._testing as tm
from pandas.api.extensions import no_default
from pandas.api.types import (
is_bool_dtype,
is_float_dtype,
Expand Down Expand Up @@ -738,14 +739,11 @@ def test_setitem_preserves_views(self, data):


class TestBaseParsing(base.BaseParsingTests):
@pytest.mark.parametrize("dtype_backend", ["pyarrow", no_default])
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data, request):
def test_EA_types(self, engine, data, dtype_backend, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
)
elif pa.types.is_decimal(pa_dtype):
if pa.types.is_decimal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
Expand All @@ -763,14 +761,28 @@ def test_EA_types(self, engine, data, request):
request.node.add_marker(
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
)
elif (
pa.types.is_duration(pa_dtype)
and dtype_backend == "pyarrow"
and engine == "python"
):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="Invalid type for timedelta scalar: NAType",
)
)
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
csv_output = df.to_csv(index=False, na_rep=np.nan)
if pa.types.is_binary(pa_dtype):
csv_output = BytesIO(csv_output)
else:
csv_output = StringIO(csv_output)
result = pd.read_csv(
csv_output, dtype={"with_dtype": str(data.dtype)}, engine=engine
csv_output,
dtype={"with_dtype": str(data.dtype)},
engine=engine,
dtype_backend=dtype_backend,
)
expected = df
self.assert_frame_equal(result, expected)
Expand Down

0 comments on commit 54e7fe9

Please sign in to comment.