Skip to content

Commit

Permalink
Backport PR #52058 on branch 2.0.x (BUG: to_sql with ArrowExtesionArr…
Browse files Browse the repository at this point in the history
…ay) (#52124)

* Backport PR #52058: BUG: to_sql with ArrowExtesionArray

* _data
  • Loading branch information
mroeschke committed Mar 22, 2023
1 parent 22e7c08 commit 7d5d123
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
5 changes: 4 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,7 +2075,10 @@ def _dt_round(
return self._round_temporally("round", freq, ambiguous, nonexistent)

def _dt_to_pydatetime(self):
return np.array(self._data.to_pylist(), dtype=object)
data = self._data.to_pylist()
if self._dtype.pyarrow_dtype.unit == "ns":
data = [ts.to_pydatetime(warn=False) for ts in data]
return np.array(data, dtype=object)

def _dt_tz_localize(
self,
Expand Down
12 changes: 7 additions & 5 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,14 +961,16 @@ def insert_data(self) -> tuple[list[str], list[np.ndarray]]:
data_list: list[np.ndarray] = [None] * ncols # type: ignore[list-item]

for i, (_, ser) in enumerate(temp.items()):
vals = ser._values
if vals.dtype.kind == "M":
d = vals.to_pydatetime()
elif vals.dtype.kind == "m":
if ser.dtype.kind == "M":
d = ser.dt.to_pydatetime()
elif ser.dtype.kind == "m":
vals = ser._values
if isinstance(vals, ArrowExtensionArray):
vals = vals.to_numpy(dtype=np.dtype("m8[ns]"))
# store as integers, see GH#6921, GH#7076
d = vals.view("i8").astype(object)
else:
d = vals.astype(object)
d = ser._values.astype(object)

assert isinstance(d, np.ndarray), type(d)

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,7 @@ def test_dt_to_pydatetime():
result = ser.dt.to_pydatetime()
expected = np.array(data, dtype=object)
tm.assert_numpy_array_equal(result, expected)
assert all(type(res) is datetime for res in result)

expected = ser.astype("datetime64[ns]").dt.to_pydatetime()
tm.assert_numpy_array_equal(result, expected)
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
date,
datetime,
time,
timedelta,
)
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -549,6 +550,26 @@ def test_dataframe_to_sql(conn, test_frame1, request):
test_frame1.to_sql("test", conn, if_exists="append", index=False)


@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable)
def test_dataframe_to_sql_arrow_dtypes(conn, request):
# GH 52046
pytest.importorskip("pyarrow")
df = DataFrame(
{
"int": pd.array([1], dtype="int8[pyarrow]"),
"datetime": pd.array(
[datetime(2023, 1, 1)], dtype="timestamp[ns][pyarrow]"
),
"timedelta": pd.array([timedelta(1)], dtype="duration[ns][pyarrow]"),
"string": pd.array(["a"], dtype="string[pyarrow]"),
}
)
conn = request.getfixturevalue(conn)
with tm.assert_produces_warning(UserWarning, match="the 'timedelta'"):
df.to_sql("test_arrow", conn, if_exists="replace", index=False)


@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable)
@pytest.mark.parametrize("method", [None, "multi"])
Expand Down

0 comments on commit 7d5d123

Please sign in to comment.