Skip to content

Commit

Permalink
Home directory support (#2940)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjermain committed Mar 28, 2022
1 parent 5fabae0 commit 7faa3dc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 27 deletions.
41 changes: 26 additions & 15 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from polars.utils import (
_prepare_row_count_args,
_process_null_values,
format_path,
handle_projection_columns,
is_int_sequence,
is_str_sequence,
Expand Down Expand Up @@ -459,7 +460,7 @@ def _from_pandas(
@classmethod
def _read_csv(
cls: Type[DF],
file: Union[str, BinaryIO, bytes],
file: Union[str, Path, BinaryIO, bytes],
has_header: bool = True,
columns: Optional[Union[List[int], List[str]]] = None,
sep: str = ",",
Expand Down Expand Up @@ -489,8 +490,8 @@ def _read_csv(
self = cls.__new__(cls)

path: Optional[str]
if isinstance(file, str):
path = file
if isinstance(file, (str, Path)):
path = format_path(file)
else:
path = None
if isinstance(file, BytesIO):
Expand Down Expand Up @@ -581,7 +582,7 @@ def _read_csv(
@classmethod
def _read_parquet(
cls: Type[DF],
file: Union[str, BinaryIO],
file: Union[str, Path, BinaryIO],
columns: Optional[Union[List[int], List[str]]] = None,
n_rows: Optional[int] = None,
parallel: bool = True,
Expand All @@ -602,6 +603,8 @@ def _read_parquet(
parallel
Read the parquet file in parallel. The single threaded reader consumes less memory.
"""
if isinstance(file, (str, Path)):
file = format_path(file)
if isinstance(file, str) and "*" in file:
from polars import scan_parquet

Expand Down Expand Up @@ -638,7 +641,7 @@ def _read_parquet(
@classmethod
def _read_avro(
cls: Type[DF],
file: Union[str, BinaryIO],
file: Union[str, Path, BinaryIO],
columns: Optional[Union[List[int], List[str]]] = None,
n_rows: Optional[int] = None,
) -> DF:
Expand All @@ -656,6 +659,8 @@ def _read_avro(
-------
DataFrame
"""
if isinstance(file, (str, Path)):
file = format_path(file)
projection, columns = handle_projection_columns(columns)
self = cls.__new__(cls)
self._df = PyDataFrame.read_avro(file, columns, projection, n_rows)
Expand All @@ -664,7 +669,7 @@ def _read_avro(
@classmethod
def _read_ipc(
cls: Type[DF],
file: Union[str, BinaryIO],
file: Union[str, Path, BinaryIO],
columns: Optional[Union[List[int], List[str]]] = None,
n_rows: Optional[int] = None,
row_count_name: Optional[str] = None,
Expand All @@ -687,6 +692,8 @@ def _read_ipc(
DataFrame
"""

if isinstance(file, (str, Path)):
file = format_path(file)
if isinstance(file, str) and "*" in file:
from polars import scan_ipc

Expand Down Expand Up @@ -720,14 +727,16 @@ def _read_ipc(
@classmethod
def _read_json(
cls: Type[DF],
file: Union[str, IOBase],
file: Union[str, Path, IOBase],
json_lines: bool = False,
) -> DF:
"""
See Also pl.read_json
"""
if isinstance(file, StringIO):
file = BytesIO(file.getvalue().encode())
elif isinstance(file, (str, Path)):
file = format_path(file)

self = cls.__new__(cls)
self._df = PyDataFrame.read_json(file, json_lines)
Expand Down Expand Up @@ -973,6 +982,8 @@ def write_json(
to_string
Ignore file argument and return a string.
"""
if isinstance(file, (str, Path)):
file = format_path(file)
to_string_io = (file is not None) and isinstance(file, StringIO)
if to_string or file is None or to_string_io:
with BytesIO() as buf:
Expand Down Expand Up @@ -1067,8 +1078,8 @@ def write_csv(
self._df.to_csv(buffer, has_header, ord(sep), ord(quote))
return str(buffer.getvalue(), encoding="utf-8")

if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

self._df.to_csv(file, has_header, ord(sep), ord(quote))
return None
Expand Down Expand Up @@ -1104,8 +1115,8 @@ def write_avro(
- "snappy"
- "deflate"
"""
if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

self._df.to_avro(file, compression)

Expand Down Expand Up @@ -1141,8 +1152,8 @@ def write_ipc(
"""
if compression is None:
compression = "uncompressed"
if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

self._df.to_ipc(file, compression)

Expand Down Expand Up @@ -1316,8 +1327,8 @@ def write_parquet(
"""
if compression is None:
compression = "uncompressed"
if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

if use_pyarrow:
if not _PYARROW_AVAILABLE:
Expand Down
26 changes: 14 additions & 12 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from urllib.request import urlopen

from polars.utils import handle_projection_columns
from polars.utils import format_path, handle_projection_columns

try:
import pyarrow as pa
Expand Down Expand Up @@ -110,19 +110,21 @@ def managed_file(file: Any) -> Iterator[Any]:
if isinstance(file, BytesIO):
return managed_file(file)
if isinstance(file, Path):
return managed_file(str(file))
return managed_file(format_path(file))
if isinstance(file, str):
if _WITH_FSSPEC:
if infer_storage_options(file)["protocol"] == "file":
return managed_file(file)
return managed_file(format_path(file))
return fsspec.open(file, **kwargs)
if file.startswith("http"):
return _process_http_file(file)
if isinstance(file, list) and bool(file) and all(isinstance(f, str) for f in file):
if _WITH_FSSPEC:
if all(infer_storage_options(f)["protocol"] == "file" for f in file):
return managed_file(file)
return managed_file([format_path(f) for f in file])
return fsspec.open_files(file, **kwargs)
if isinstance(file, str):
file = format_path(file)
return managed_file(file)


Expand Down Expand Up @@ -558,8 +560,8 @@ def scan_csv(
dtypes = kwargs.pop("dtype", dtypes)
n_rows = kwargs.pop("stop_after_n_rows", n_rows)

if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

return LazyFrame.scan_csv(
file=file,
Expand Down Expand Up @@ -619,8 +621,8 @@ def scan_ipc(
# Map legacy arguments to current ones and remove them from kwargs.
n_rows = kwargs.pop("stop_after_n_rows", n_rows)

if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

return LazyFrame.scan_ipc(
file=file,
Expand Down Expand Up @@ -669,8 +671,8 @@ def scan_parquet(
# Map legacy arguments to current ones and remove them from kwargs.
n_rows = kwargs.pop("stop_after_n_rows", n_rows)

if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)

return LazyFrame.scan_parquet(
file=file,
Expand Down Expand Up @@ -723,8 +725,8 @@ def read_avro(
-------
DataFrame
"""
if isinstance(file, Path):
file = str(file)
if isinstance(file, (str, Path)):
file = format_path(file)
if columns is None:
columns = kwargs.pop("projection", None)

Expand Down
9 changes: 9 additions & 0 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ctypes
import os
import sys
from datetime import date, datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
Expand Down Expand Up @@ -201,3 +203,10 @@ def _in_notebook() -> bool:
except AttributeError:
return False
return True


def format_path(path: Union[str, Path]) -> str:
"""
Returnsa string path, expanding the home directory if present.
"""
return os.path.expanduser(path)
10 changes: 10 additions & 0 deletions py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# may contain many things that seemed to go wrong at scale

import os
import time

import numpy as np
Expand Down Expand Up @@ -52,3 +53,12 @@
)
assert computed[0, "min"] == minimum
assert computed[0, "max"] == maximum

# test home directory support
# https://github.com/pola-rs/polars/pull/2940
filename = "~/test.parquet"

df.to_parquet(filename)
df = pl.read_parquet(filename)

os.remove(filename)

0 comments on commit 7faa3dc

Please sign in to comment.