Skip to content

Commit

Permalink
refactor(python): Organize utils for I/O functionality (#15529)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Apr 8, 2024
1 parent 49e4244 commit afe04a5
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 176 deletions.
10 changes: 0 additions & 10 deletions py-polars/polars/_utils/various.py
Expand Up @@ -158,16 +158,6 @@ def range_to_slice(rng: range) -> slice:
return slice(rng.start, rng.stop, rng.step)


def _prepare_row_index_args(
row_index_name: str | None = None,
row_index_offset: int = 0,
) -> tuple[str, int] | None:
if row_index_name is not None:
return (row_index_name, row_index_offset)
else:
return None


def _in_notebook() -> bool:
try:
from IPython import get_ipython
Expand Down
163 changes: 77 additions & 86 deletions py-polars/polars/io/_utils.py
Expand Up @@ -5,58 +5,73 @@
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import IO, Any, ContextManager, Iterator, Sequence, cast, overload
from typing import IO, Any, ContextManager, Iterator, Sequence, overload

from polars._utils.various import is_int_sequence, is_str_sequence, normalize_filepath
from polars.dependencies import _FSSPEC_AVAILABLE, fsspec
from polars.exceptions import NoDataError


def handle_projection_columns(
columns: Sequence[str] | Sequence[int] | str | None,
) -> tuple[list[int] | None, Sequence[str] | None]:
"""Disambiguates between columns specified as integers vs. strings."""
projection: list[int] | None = None
new_columns: Sequence[str] | None = None
if columns is not None:
if isinstance(columns, str):
new_columns = [columns]
elif is_int_sequence(columns):
projection = list(columns)
elif not is_str_sequence(columns):
msg = "`columns` arg should contain a list of all integers or all strings values"
raise TypeError(msg)
else:
new_columns = columns
if columns and len(set(columns)) != len(columns):
msg = f"`columns` arg should only have unique values, got {columns!r}"
raise ValueError(msg)
if projection and len(set(projection)) != len(projection):
msg = f"`columns` arg should only have unique values, got {projection!r}"
raise ValueError(msg)
return projection, new_columns
def parse_columns_arg(
columns: Sequence[str] | Sequence[int] | str | int | None,
) -> tuple[Sequence[int] | None, Sequence[str] | None]:
"""
Parse the `columns` argument of an I/O function.
Disambiguates between column names and column indices input.
def _is_glob_pattern(file: str) -> bool:
return any(char in file for char in ["*", "?", "["])
Returns
-------
tuple
A tuple containing the columns as a projection and a list of column names.
Only one will be specified, the other will be `None`.
"""
if columns is None:
return None, None

projection: Sequence[int] | None = None
column_names: Sequence[str] | None = None

if isinstance(columns, str):
column_names = [columns]
elif isinstance(columns, int):
projection = [columns]
elif is_str_sequence(columns):
_ensure_columns_are_unique(columns)
column_names = columns
elif is_int_sequence(columns):
_ensure_columns_are_unique(columns)
projection = columns
else:
msg = "the `columns` argument should contain a list of all integers or all string values"
raise TypeError(msg)

return projection, column_names

def _is_supported_cloud(file: str) -> bool:
return bool(re.match("^(s3a?|gs|gcs|file|abfss?|azure|az|adl|https?)://", file))

def _ensure_columns_are_unique(columns: Sequence[str] | Sequence[int]) -> None:
if len(columns) != len(set(columns)):
msg = f"`columns` arg should only have unique values, got {columns!r}"
raise ValueError(msg)

def _is_local_file(file: str) -> bool:
try:
next(glob.iglob(file, recursive=True)) # noqa: PTH207
except StopIteration:
return False

def parse_row_index_args(
row_index_name: str | None = None,
row_index_offset: int = 0,
) -> tuple[str, int] | None:
"""
Parse the `row_index_name` and `row_index_offset` arguments of an I/O function.
The Rust functions take a single tuple rather than two separate arguments.
"""
if row_index_name is None:
return None
else:
return True
return (row_index_name, row_index_offset)


@overload
def _prepare_file_arg(
def prepare_file_arg(
file: str | Path | list[str] | IO[bytes] | bytes,
encoding: str | None = ...,
*,
Expand All @@ -67,7 +82,7 @@ def _prepare_file_arg(


@overload
def _prepare_file_arg(
def prepare_file_arg(
file: str | Path | IO[str] | IO[bytes] | bytes,
encoding: str | None = ...,
*,
Expand All @@ -78,7 +93,7 @@ def _prepare_file_arg(


@overload
def _prepare_file_arg(
def prepare_file_arg(
file: str | Path | list[str] | IO[str] | IO[bytes] | bytes,
encoding: str | None = ...,
*,
Expand All @@ -88,7 +103,7 @@ def _prepare_file_arg(
) -> ContextManager[str | list[str] | BytesIO | list[BytesIO]]: ...


def _prepare_file_arg(
def prepare_file_arg(
file: str | Path | list[str] | IO[str] | IO[bytes] | bytes,
encoding: str | None = None,
*,
Expand All @@ -102,15 +117,15 @@ def _prepare_file_arg(
Utility for read_[csv, parquet]. (not to be used by scan_[csv, parquet]).
Returned value is always usable as a context.
A :class:`StringIO`, :class:`BytesIO` file is returned as a :class:`BytesIO`.
A `StringIO`, `BytesIO` file is returned as a `BytesIO`.
A local path is returned as a string.
An http URL is read into a buffer and returned as a :class:`BytesIO`.
An http URL is read into a buffer and returned as a `BytesIO`.
When `encoding` is not `utf8` or `utf8-lossy`, the whole file is
first read in python and decoded using the specified encoding and
returned as a :class:`BytesIO` (for usage with `read_csv`).
first read in Python and decoded using the specified encoding and
returned as a `BytesIO` (for usage with `read_csv`).
A `bytes` file is returned as a :class:`BytesIO` if `use_pyarrow=True`.
A `bytes` file is returned as a `BytesIO` if `use_pyarrow=True`.
When fsspec is installed, remote file(s) is (are) opened with
`fsspec.open(file, **kwargs)` or `fsspec.open_files(file, **kwargs)`.
Expand Down Expand Up @@ -181,8 +196,8 @@ def managed_file(file: Any) -> Iterator[Any]:
# make sure that this is before fsspec
# as fsspec needs requests to be installed
# to read from http
if _looks_like_url(file):
return _process_file_url(file, encoding_str)
if looks_like_url(file):
return process_file_url(file, encoding_str)
if _FSSPEC_AVAILABLE:
from fsspec.utils import infer_storage_options

Expand Down Expand Up @@ -234,7 +249,7 @@ def managed_file(file: Any) -> Iterator[Any]:
def _check_empty(
b: BytesIO, *, context: str, raise_if_empty: bool, read_position: int | None = None
) -> BytesIO:
if raise_if_empty and not b.getbuffer().nbytes:
if raise_if_empty and b.getbuffer().nbytes == 0:
hint = (
f" (buffer position = {read_position}; try seek(0) before reading?)"
if context in ("StringIO", "BytesIO") and read_position
Expand All @@ -245,11 +260,11 @@ def _check_empty(
return b


def _looks_like_url(path: str) -> bool:
def looks_like_url(path: str) -> bool:
return re.match("^(ht|f)tps?://", path, re.IGNORECASE) is not None


def _process_file_url(path: str, encoding: str | None = None) -> BytesIO:
def process_file_url(path: str, encoding: str | None = None) -> BytesIO:
from urllib.request import urlopen

with urlopen(path) as f:
Expand All @@ -259,42 +274,18 @@ def _process_file_url(path: str, encoding: str | None = None) -> BytesIO:
return BytesIO(f.read().decode(encoding).encode("utf8"))


@contextmanager
def PortableTemporaryFile(
mode: str = "w+b",
*,
buffering: int = -1,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
prefix: str | None = None,
dir: str | Path | None = None,
delete: bool = True,
errors: str | None = None,
) -> Iterator[Any]:
"""
Slightly more resilient version of the standard `NamedTemporaryFile`.
def is_glob_pattern(file: str) -> bool:
return any(char in file for char in ["*", "?", "["])

Plays better with Windows when using the 'delete' option.
"""
params = cast(
Any,
{
"mode": mode,
"buffering": buffering,
"encoding": encoding,
"newline": newline,
"suffix": suffix,
"prefix": prefix,
"dir": dir,
"delete": False,
"errors": errors,
},
)
tmp = NamedTemporaryFile(**params)

def is_supported_cloud(file: str) -> bool:
return bool(re.match("^(s3a?|gs|gcs|file|abfss?|azure|az|adl|https?)://", file))


def is_local_file(file: str) -> bool:
try:
yield tmp
finally:
tmp.close()
if delete:
Path(tmp.name).unlink(missing_ok=True)
next(glob.iglob(file, recursive=True)) # noqa: PTH207
except StopIteration:
return False
else:
return True
6 changes: 3 additions & 3 deletions py-polars/polars/io/avro.py
Expand Up @@ -6,7 +6,7 @@

from polars._utils.various import normalize_filepath
from polars._utils.wrap import wrap_df
from polars.io._utils import handle_projection_columns
from polars.io._utils import parse_columns_arg

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import PyDataFrame
Expand Down Expand Up @@ -42,7 +42,7 @@ def read_avro(
"""
if isinstance(source, (str, Path)):
source = normalize_filepath(source)
projection, parsed_columns = handle_projection_columns(columns)
projection, column_names = parse_columns_arg(columns)

pydf = PyDataFrame.read_avro(source, parsed_columns, projection, n_rows)
pydf = PyDataFrame.read_avro(source, column_names, projection, n_rows)
return wrap_df(pydf)
7 changes: 3 additions & 4 deletions py-polars/polars/io/csv/batched_reader.py
Expand Up @@ -4,13 +4,12 @@
from typing import TYPE_CHECKING, Sequence

from polars._utils.various import (
_prepare_row_index_args,
_process_null_values,
normalize_filepath,
)
from polars._utils.wrap import wrap_df
from polars.datatypes import N_INFER_DEFAULT, py_type_to_dtype
from polars.io._utils import handle_projection_columns
from polars.io._utils import parse_columns_arg, parse_row_index_args
from polars.io.csv._utils import _update_columns

with contextlib.suppress(ImportError): # Module not available when building docs
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(
raise TypeError(msg)

processed_null_values = _process_null_values(null_values)
projection, columns = handle_projection_columns(columns)
projection, columns = parse_columns_arg(columns)

self._reader = PyBatchedCsv.new(
infer_schema_length=infer_schema_length,
Expand All @@ -98,7 +97,7 @@ def __init__(
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
try_parse_dates=try_parse_dates,
skip_rows_after_header=skip_rows_after_header,
row_index=_prepare_row_index_args(row_index_name, row_index_offset),
row_index=parse_row_index_args(row_index_name, row_index_offset),
sample_size=sample_size,
eol_char=eol_char,
raise_if_empty=raise_if_empty,
Expand Down

0 comments on commit afe04a5

Please sign in to comment.