Skip to content

Commit

Permalink
ENH: add __from_pyarrow__ support to DatetimeTZDtype
Browse files Browse the repository at this point in the history
  • Loading branch information
tswast committed Mar 25, 2023
1 parent 6c50f70 commit 4d46462
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
39 changes: 39 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,45 @@ def __eq__(self, other: Any) -> bool:
and tz_compare(self.tz, other.tz)
)

def __from_arrow__(
self, array: pyarrow.Array | pyarrow.ChunkedArray
) -> DatetimeArray:
"""
Construct DatetimeArray from pyarrow Array/ChunkedArray.
"""
import pyarrow
import pyarrow.types

from pandas.core.arrays import DatetimeArray
from pandas.core.arrays.arrow._arrow_utils import (
pyarrow_array_to_numpy_and_mask,
)

pa_type = array.type
pa_unit = "ns"
if pyarrow.types.is_timestamp(pa_type):
pa_unit = pa_type.unit

if isinstance(array, pyarrow.Array):
chunks = [array]
else:
chunks = array.chunks

results = []
for arr in chunks:
data, mask = pyarrow_array_to_numpy_and_mask(
arr, dtype=np.dtype(f"datetime64[{pa_unit}]")
)
data = data.astype(f"datetime64[{self._unit}]")
darr = DatetimeArray(data.copy(), copy=False)
darr[~mask] = NaT
darr = darr.tz_localize(self._tz)
results.append(darr)

if not results:
return DatetimeArray(np.array([], dtype="int64"), copy=False)
return DatetimeArray._concat_same_type(results)

def __setstate__(self, state) -> None:
# for pickle compat. __get_state__ is defined in the
# PandasExtensionDtype superclass and uses the public properties to
Expand Down
64 changes: 64 additions & 0 deletions pandas/tests/arrays/datetimes/test_constructors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas.core.dtypes.dtypes import DatetimeTZDtype

import pandas as pd
Expand Down Expand Up @@ -161,3 +163,65 @@ def test_2d(self, order):
res = DatetimeArray._from_sequence(arr)
expected = DatetimeArray._from_sequence(arr.ravel()).reshape(arr.shape)
tm.assert_datetime_array_equal(res, expected)


# ----------------------------------------------------------------------------
# Arrow interaction


pyarrow_skip = td.skip_if_no("pyarrow")


@pytest.mark.parametrize(
("pa_unit", "pd_unit", "pa_tz", "pd_tz"),
[
("s", "s", "UTC", "UTC"),
("ms", "ms", "UTC", "Europe/Berlin"),
("us", "us", "US/Eastern", "UTC"),
("ns", "ns", "US/Central", "Asia/Kolkata"),
("ns", "s", "UTC", "UTC"),
("us", "ms", "UTC", "Europe/Berlin"),
("ms", "us", "US/Eastern", "UTC"),
("s", "ns", "US/Central", "Asia/Kolkata"),
],
)
@pyarrow_skip
def test_from_arrow_with_different_units_and_timezones(pa_unit, pd_unit, pa_tz, pd_tz):
# in case pyarrow lost the Interval extension type (eg on parquet roundtrip
# with datetime64[ns] subtype, see GH-45881), still allow conversion
# from arrow to IntervalArray
import pyarrow as pa

data = [0, 123456789, None, 2**63 - 1, -123456789]
pa_type = pa.timestamp(pa_unit, tz=pa_tz)
arr = pa.array(data, type=pa_type)
dtype = DatetimeTZDtype(unit=pd_unit, tz=pd_tz)

result = dtype.__from_arrow__(arr)
expected = DatetimeArray(
np.array(data, dtype=f"datetime64[{pa_unit}]").astype(f"datetime64[{pd_unit}]")
)
expected = expected.tz_localize(pd_tz)
tm.assert_extension_array_equal(result, expected)

result = dtype.__from_arrow__(pa.chunked_array([arr]))
tm.assert_extension_array_equal(result, expected)


@pyarrow_skip
def test_from_arrow_from_integers():
# in case pyarrow lost the Interval extension type (eg on parquet roundtrip
# with datetime64[ns] subtype, see GH-45881), still allow conversion
# from arrow to IntervalArray
import pyarrow as pa

data = [0, 123456789, None, 2**63 - 1, -123456789]
arr = pa.array(data)
dtype = DatetimeTZDtype(unit="ns", tz="UTC")

result = dtype.__from_arrow__(arr)
expected = DatetimeArray(np.array(data, dtype="datetime64[ns]"))
tm.assert_extension_array_equal(result, expected)

result = dtype.__from_arrow__(pa.chunked_array([arr]))
tm.assert_extension_array_equal(result, expected)

0 comments on commit 4d46462

Please sign in to comment.