Skip to content

Commit

Permalink
fix(python): remove overloads for from_arrow (#5065)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteosantama committed Oct 2, 2022
1 parent c93e212 commit 74fb89e
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 38 deletions.
12 changes: 0 additions & 12 deletions py-polars/polars/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,6 @@ def from_numpy(
return DataFrame._from_numpy(data, columns=columns, orient=orient)


@overload
def from_arrow(a: pa.Table, rechunk: bool = True) -> DataFrame:
...


@overload
def from_arrow( # type: ignore[misc]
a: pa.Array | pa.ChunkedArray, rechunk: bool = True
) -> Series:
...


def from_arrow(
a: pa.Table | pa.Array | pa.ChunkedArray, rechunk: bool = True
) -> DataFrame | Series:
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/internals/anonymous_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pickle
from functools import partial
from typing import cast

import polars as pl
from polars import internals as pli
Expand Down Expand Up @@ -53,7 +54,7 @@ def _scan_ds_impl(
"""
if not _PYARROW_AVAILABLE: # pragma: no cover
raise ImportError("'pyarrow' is required for scanning from pyarrow datasets.")
return pl.from_arrow(ds.to_table(columns=with_columns))
return cast(pli.DataFrame, pl.from_arrow(ds.to_table(columns=with_columns)))


def _scan_ds(ds: pa.dataset.dataset) -> pli.LazyFrame:
Expand Down
32 changes: 22 additions & 10 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
import sys
from io import BytesIO, IOBase, StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Mapping, TextIO, overload
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Mapping,
TextIO,
cast,
overload,
)
from warnings import warn

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -289,7 +298,7 @@ def read_csv(
[f"column_{int(column[1:]) + 1}" for column in tbl.column_names]
)

df = from_arrow(tbl, rechunk)
df = cast(DataFrame, from_arrow(tbl, rechunk))
if new_columns:
return _update_columns(df, new_columns)
return df
Expand Down Expand Up @@ -949,13 +958,16 @@ def read_parquet(
" 'read_parquet(..., use_pyarrow=True)'."
)

return from_arrow(
pa.parquet.read_table(
source_prep,
memory_map=memory_map,
columns=columns,
**pyarrow_options,
)
return cast(
DataFrame,
from_arrow(
pa.parquet.read_table(
source_prep,
memory_map=memory_map,
columns=columns,
**pyarrow_options,
)
),
)

return DataFrame._read_parquet(
Expand Down Expand Up @@ -1101,7 +1113,7 @@ def read_sql(
partition_num=partition_num,
protocol=protocol,
)
return from_arrow(tbl)
return cast(DataFrame, from_arrow(tbl))
else:
raise ImportError(
"connectorx is not installed. Please run `pip install connectorx>=0.2.2`."
Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/unit/io/test_other.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from typing import cast

import polars as pl

Expand All @@ -22,7 +23,7 @@ def test_categorical_round_trip() -> None:
tbl = df.to_arrow()
assert "dictionary" in str(tbl["cat"].type)

df2 = pl.from_arrow(tbl)
df2 = cast(pl.DataFrame, pl.from_arrow(tbl))
assert df2.dtypes == [pl.Int64, pl.Categorical]


Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import sys
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, no_type_check
from typing import TYPE_CHECKING, cast, no_type_check

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -264,13 +264,13 @@ def test_datetime_consistency() -> None:
def test_timezone() -> None:
ts = pa.timestamp("s")
data = pa.array([1000, 2000], type=ts)
s: pl.Series = pl.from_arrow(data) # type: ignore[assignment]
s = cast(pl.Series, pl.from_arrow(data))

# with timezone; we do expect a warning here
tz_ts = pa.timestamp("s", tz="America/New_York")
tz_data = pa.array([1000, 2000], type=tz_ts)
# with pytest.warns(Warning):
tz_s: pl.Series = pl.from_arrow(tz_data) # type: ignore[assignment]
tz_s = cast(pl.Series, pl.from_arrow(tz_data))

# different timezones are not considered equal
# we check both `null_equal=True` and `null_equal=False`
Expand Down Expand Up @@ -648,8 +648,8 @@ def test_microseconds_accuracy() -> None:
]
),
)

assert pl.from_arrow(a)["timestamp"].to_list() == timestamps
df = cast(pl.DataFrame, pl.from_arrow(a))
assert df["timestamp"].to_list() == timestamps


def test_cast_time_units() -> None:
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import date, datetime, timedelta
from decimal import Decimal
from io import BytesIO
from typing import TYPE_CHECKING, Any, Iterator
from typing import TYPE_CHECKING, Any, Iterator, cast

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -209,12 +209,12 @@ def test_from_arrow() -> None:
),
]

df = pl.from_arrow(tbl)
df = cast(pl.DataFrame, pl.from_arrow(tbl))
assert df.schema == expected_schema
assert df.rows() == expected_data

empty_tbl = tbl[:0] # no rows
df = pl.from_arrow(empty_tbl)
df = cast(pl.DataFrame, pl.from_arrow(empty_tbl))
assert df.schema == expected_schema
assert df.rows() == []

Expand Down Expand Up @@ -849,7 +849,7 @@ def test_from_arrow_table() -> None:
data = {"a": [1, 2], "b": [1, 2]}
tbl = pa.table(data)

df = pl.from_arrow(tbl)
df = cast(pl.DataFrame, pl.from_arrow(tbl))
df.frame_equal(pl.DataFrame(data))


Expand Down Expand Up @@ -937,7 +937,7 @@ def test_column_names() -> None:
}
)
for a in (tbl, tbl[:0]):
df = pl.from_arrow(a)
df = cast(pl.DataFrame, pl.from_arrow(a))
assert df.columns == ["a", "b"]


Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/unit/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,25 +352,25 @@ def test_from_empty_pandas_strings() -> None:


def test_from_empty_arrow() -> None:
df = pl.from_arrow(pa.table(pd.DataFrame({"a": [], "b": []})))
df = cast(pl.DataFrame, pl.from_arrow(pa.table(pd.DataFrame({"a": [], "b": []}))))
assert df.columns == ["a", "b"]
assert df.dtypes == [pl.Float64, pl.Float64]

# 2705
df1 = pd.DataFrame(columns=["b"], dtype=float)
tbl = pa.Table.from_pandas(df1)
out = pl.from_arrow(tbl)
out = cast(pl.DataFrame, pl.from_arrow(tbl))
assert out.columns == ["b", "__index_level_0__"]
assert out.dtypes == [pl.Float64, pl.Int8]
tbl = pa.Table.from_pandas(df1, preserve_index=False)
out = pl.from_arrow(tbl)
out = cast(pl.DataFrame, pl.from_arrow(tbl))
assert out.columns == ["b"]
assert out.dtypes == [pl.Float64]

# 4568
tbl = pa.table({"l": []}, schema=pa.schema([("l", pa.large_list(pa.uint8()))]))

df = pl.from_arrow(tbl)
df = cast(pl.DataFrame, pl.from_arrow(tbl))
assert df.schema["l"] == pl.List(pl.UInt8)


Expand Down

0 comments on commit 74fb89e

Please sign in to comment.