Skip to content

Commit

Permalink
Overload pl.from_arrow type hints (#4236)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteosantama committed Aug 3, 2022
1 parent c4fc26c commit 20032d1
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 18 deletions.
13 changes: 12 additions & 1 deletion py-polars/polars/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,18 @@ def from_numpy(
return DataFrame._from_numpy(data, columns=columns, orient=orient)


# Note that we cannot overload because pyarrow has no stubs :(
@overload
def from_arrow(a: pa.Table, rechunk: bool = True) -> DataFrame:
...


@overload
def from_arrow( # type: ignore[misc]
a: pa.Array | pa.ChunkedArray, rechunk: bool = True
) -> Series:
...


def from_arrow(
a: pa.Table | pa.Array | pa.ChunkedArray, rechunk: bool = True
) -> DataFrame | Series:
Expand Down
4 changes: 1 addition & 3 deletions py-polars/polars/internals/anonymous_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def _scan_ds_impl(
"""
if not _PYARROW_AVAILABLE: # pragma: no cover
raise ImportError("'pyarrow' is required for scanning from pyarrow datasets.")
return pl.from_arrow( # type: ignore[return-value]
ds.to_table(columns=with_columns)
)
return pl.from_arrow(ds.to_table(columns=with_columns))


def _scan_ds(ds: pa.dataset.dataset) -> pli.LazyFrame:
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from io import BytesIO, IOBase, StringIO
from pathlib import Path
from typing import Any, BinaryIO, Callable, Mapping, TextIO, cast
from typing import Any, BinaryIO, Callable, Mapping, TextIO

from polars.utils import format_path, handle_projection_columns

Expand Down Expand Up @@ -271,7 +271,7 @@ def read_csv(
[f"column_{int(column[1:]) + 1}" for column in tbl.column_names]
)

df = cast(DataFrame, from_arrow(tbl, rechunk))
df = from_arrow(tbl, rechunk)
if new_columns:
return _update_columns(df, new_columns)
return df
Expand Down Expand Up @@ -909,7 +909,7 @@ def read_parquet(
" 'read_parquet(..., use_pyarrow=True)'."
)

return from_arrow( # type: ignore[return-value]
return from_arrow(
pa.parquet.read_table(
source_prep,
memory_map=memory_map,
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def read_sql(
partition_num=partition_num,
protocol=protocol,
)
return cast(DataFrame, from_arrow(tbl))
return from_arrow(tbl)
else:
raise ImportError(
"connectorx is not installed. Please run `pip install connectorx>=0.2.2`."
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/io/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_categorical_round_trip() -> None:
tbl = df.to_arrow()
assert "dictionary" in str(tbl["cat"].type)

df2: pl.DataFrame = pl.from_arrow(tbl) # type: ignore[assignment]
df2 = pl.from_arrow(tbl)
assert df2.dtypes == [pl.Int64, pl.Categorical]


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def test_microseconds_accuracy() -> None:
),
)

assert pl.from_arrow(a)["timestamp"].to_list() == timestamps # type: ignore[index]
assert pl.from_arrow(a)["timestamp"].to_list() == timestamps


def test_cast_time_units() -> None:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def test_from_arrow_table() -> None:
data = {"a": [1, 2], "b": [1, 2]}
tbl = pa.table(data)

df: pl.DataFrame = pl.from_arrow(tbl) # type: ignore[assignment]
df = pl.from_arrow(tbl)
df.frame_equal(pl.DataFrame(data))


Expand Down Expand Up @@ -800,7 +800,7 @@ def test_column_names() -> None:
"b": pa.array([1, 2, 3, 4, 5], pa.int64()),
}
)
df: pl.DataFrame = pl.from_arrow(tbl) # type: ignore[assignment]
df = pl.from_arrow(tbl)
assert df.columns == ["a", "b"]


Expand Down
12 changes: 6 additions & 6 deletions py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,19 +363,19 @@ def test_from_empty_pandas_strings() -> None:

def test_from_empty_arrow() -> None:
df = pl.from_arrow(pa.table(pd.DataFrame({"a": [], "b": []})))
assert df.columns == ["a", "b"] # type: ignore[union-attr]
assert df.dtypes == [pl.Float64, pl.Float64] # type: ignore[union-attr]
assert df.columns == ["a", "b"]
assert df.dtypes == [pl.Float64, pl.Float64]

# 2705
df1 = pd.DataFrame(columns=["b"], dtype=float)
tbl = pa.Table.from_pandas(df1)
out = pl.from_arrow(tbl)
assert out.columns == ["b", "__index_level_0__"] # type: ignore[union-attr]
assert out.dtypes == [pl.Float64, pl.Utf8] # type: ignore[union-attr]
assert out.columns == ["b", "__index_level_0__"]
assert out.dtypes == [pl.Float64, pl.Utf8]
tbl = pa.Table.from_pandas(df1, preserve_index=False)
out = pl.from_arrow(tbl)
assert out.columns == ["b"] # type: ignore[union-attr]
assert out.dtypes == [pl.Float64] # type: ignore[union-attr]
assert out.columns == ["b"]
assert out.dtypes == [pl.Float64]


def test_from_null_column() -> None:
Expand Down

0 comments on commit 20032d1

Please sign in to comment.