Skip to content

Commit

Permalink
fix(python): Correctly use read_parquet for all binary inputs (#13218)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Dec 23, 2023
1 parent 9617dc5 commit bb00812
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 32 deletions.
9 changes: 5 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from operator import itemgetter
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
BinaryIO,
Expand Down Expand Up @@ -665,7 +666,7 @@ def _from_pandas(
@classmethod
def _read_csv(
cls,
source: str | Path | BinaryIO | bytes,
source: str | Path | IO[bytes] | bytes,
*,
has_header: bool = True,
columns: Sequence[int] | Sequence[str] | None = None,
Expand Down Expand Up @@ -816,7 +817,7 @@ def _read_csv(
@classmethod
def _read_parquet(
cls,
source: str | Path | BinaryIO | bytes,
source: str | Path | IO[bytes] | bytes,
*,
columns: Sequence[int] | Sequence[str] | None = None,
n_rows: int | None = None,
Expand Down Expand Up @@ -913,7 +914,7 @@ def _read_avro(
@classmethod
def _read_ipc(
cls,
source: str | Path | BinaryIO | bytes,
source: str | Path | IO[bytes] | bytes,
*,
columns: Sequence[int] | Sequence[str] | None = None,
n_rows: int | None = None,
Expand Down Expand Up @@ -995,7 +996,7 @@ def _read_ipc(
@classmethod
def _read_ipc_stream(
cls,
source: str | Path | BinaryIO | bytes,
source: str | Path | IO[bytes] | bytes,
*,
columns: Sequence[int] | Sequence[str] | None = None,
n_rows: int | None = None,
Expand Down
31 changes: 13 additions & 18 deletions py-polars/polars/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from typing import Any, BinaryIO, ContextManager, Iterator, TextIO, overload
from typing import IO, Any, ContextManager, Iterator, overload

from polars.dependencies import _FSSPEC_AVAILABLE, fsspec
from polars.exceptions import NoDataError
Expand All @@ -31,48 +31,48 @@ def _is_local_file(file: str) -> bool:

@overload
def _prepare_file_arg(
file: str | list[str] | Path | BinaryIO | bytes,
file: str | list[str] | Path | IO[bytes] | bytes,
encoding: str | None = ...,
*,
use_pyarrow: bool = ...,
raise_if_empty: bool = ...,
storage_options: dict[str, Any] | None = ...,
) -> ContextManager[str | BinaryIO]:
) -> ContextManager[str | BytesIO]:
...


@overload
def _prepare_file_arg(
file: str | TextIO | Path | BinaryIO | bytes,
file: str | Path | IO[str] | IO[bytes] | bytes,
encoding: str | None = ...,
*,
use_pyarrow: bool = ...,
raise_if_empty: bool = ...,
storage_options: dict[str, Any] | None = ...,
) -> ContextManager[str | BinaryIO]:
) -> ContextManager[str | BytesIO]:
...


@overload
def _prepare_file_arg(
file: str | list[str] | Path | TextIO | BinaryIO | bytes,
file: str | list[str] | Path | IO[str] | IO[bytes] | bytes,
encoding: str | None = ...,
*,
use_pyarrow: bool = ...,
raise_if_empty: bool = ...,
storage_options: dict[str, Any] | None = ...,
) -> ContextManager[str | list[str] | BinaryIO | list[BinaryIO]]:
) -> ContextManager[str | list[str] | BytesIO | list[BytesIO]]:
...


def _prepare_file_arg(
file: str | list[str] | Path | TextIO | BinaryIO | bytes,
file: str | list[str] | Path | IO[str] | IO[bytes] | bytes,
encoding: str | None = None,
*,
use_pyarrow: bool = False,
raise_if_empty: bool = True,
storage_options: dict[str, Any] | None = None,
) -> ContextManager[str | list[str] | BinaryIO | list[BinaryIO]]:
) -> ContextManager[str | list[str] | BytesIO | list[BytesIO]]:
"""
Prepare file argument.
Expand Down Expand Up @@ -116,15 +116,10 @@ def managed_file(file: Any) -> Iterator[Any]:

if isinstance(file, bytes):
if not has_utf8_utf8_lossy_encoding:
return _check_empty(
BytesIO(file.decode(encoding_str).encode("utf8")),
context="bytes",
raise_if_empty=raise_if_empty,
)
if use_pyarrow:
return _check_empty(
BytesIO(file), context="bytes", raise_if_empty=raise_if_empty
)
file = file.decode(encoding_str).encode("utf8")
return _check_empty(
BytesIO(file), context="bytes", raise_if_empty=raise_if_empty
)

if isinstance(file, StringIO):
return _check_empty(
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import contextlib
from pathlib import Path
from typing import TYPE_CHECKING, Any, BinaryIO
from typing import IO, TYPE_CHECKING, Any, BinaryIO

import polars._reexport as pl
from polars.dependencies import _PYARROW_AVAILABLE
Expand Down Expand Up @@ -186,7 +186,7 @@ def read_ipc_stream(
)


def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, DataType]:
def read_ipc_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, DataType]:
"""
Get the schema of an IPC file without reading data.
Expand Down
14 changes: 6 additions & 8 deletions py-polars/polars/io/parquet/functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import contextlib
from io import BytesIO
import io
from pathlib import Path
from typing import TYPE_CHECKING, Any, BinaryIO
from typing import IO, TYPE_CHECKING, Any

import polars._reexport as pl
from polars.convert import from_arrow
Expand All @@ -20,7 +20,7 @@


def read_parquet(
source: str | Path | list[str] | list[Path] | BinaryIO | BytesIO | bytes,
source: str | Path | list[str] | list[Path] | IO[bytes] | bytes,
*,
columns: list[int] | list[str] | None = None,
n_rows: int | None = None,
Expand Down Expand Up @@ -145,7 +145,7 @@ def read_parquet(
)

# Read binary types using `read_parquet`
elif isinstance(source, (BinaryIO, BytesIO, bytes)):
elif isinstance(source, (io.BufferedIOBase, io.RawIOBase, bytes)):
with _prepare_file_arg(source, use_pyarrow=False) as source_prep:
return pl.DataFrame._read_parquet(
source_prep,
Expand All @@ -161,7 +161,7 @@ def read_parquet(

# For other inputs, defer to `scan_parquet`
lf = scan_parquet(
source,
source, # type: ignore[arg-type]
n_rows=n_rows,
row_count_name=row_count_name,
row_count_offset=row_count_offset,
Expand All @@ -183,9 +183,7 @@ def read_parquet(
return lf.collect(no_optimization=True)


def read_parquet_schema(
source: str | BinaryIO | Path | bytes,
) -> dict[str, DataType]:
def read_parquet_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, DataType]:
"""
Get the schema of a Parquet file without reading data.
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,49 @@ def test_write_parquet_with_null_col(tmp_path: Path) -> None:
df.write_parquet(file_path, row_group_size=3)
out = pl.read_parquet(file_path)
assert_frame_equal(out, df)


@pytest.mark.write_disk()
def test_read_parquet_binary_buffered_reader(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

df = pl.DataFrame({"a": [1, 2, 3]})
file_path = tmp_path / "test.parquet"
df.write_parquet(file_path)

with file_path.open("rb") as f:
out = pl.read_parquet(f)
assert_frame_equal(out, df)


@pytest.mark.write_disk()
def test_read_parquet_binary_file_io(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

df = pl.DataFrame({"a": [1, 2, 3]})
file_path = tmp_path / "test.parquet"
df.write_parquet(file_path)

with file_path.open("rb", buffering=0) as f:
out = pl.read_parquet(f)
assert_frame_equal(out, df)


def test_read_parquet_binary_bytes_io() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
f = io.BytesIO()
df.write_parquet(f)
f.seek(0)

out = pl.read_parquet(f)
assert_frame_equal(out, df)


def test_read_parquet_binary_bytes() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
f = io.BytesIO()
df.write_parquet(f)
bytes = f.getvalue()

out = pl.read_parquet(bytes)
assert_frame_equal(out, df)

0 comments on commit bb00812

Please sign in to comment.