Skip to content

Commit

Permalink
Refactor IO projection & columns argument handling (#2171)
Browse files Browse the repository at this point in the history
+ add test
  • Loading branch information
zundertj committed Dec 26, 2021
1 parent 84ce73b commit cc3bb1d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 41 deletions.
34 changes: 4 additions & 30 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
from polars.datatypes import Boolean, DataType, Datetime, UInt32, py_type_to_dtype
from polars.utils import (
_process_null_values,
handle_projection_columns,
is_int_sequence,
is_str_sequence,
range_to_slice,
)

Expand Down Expand Up @@ -471,15 +471,7 @@ def read_csv(
if isinstance(file, StringIO):
file = file.getvalue().encode()

projection: Optional[Sequence[int]] = None
if columns:
if is_int_sequence(columns):
projection = columns
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)
projection, columns = handle_projection_columns(columns)

dtype_list: Optional[List[Tuple[str, Type[DataType]]]] = None
dtype_slice: Optional[List[Type[DataType]]] = None
Expand Down Expand Up @@ -541,16 +533,7 @@ def read_parquet(
parallel
Read the parquet file in parallel. The single threaded reader consumes less memory.
"""
projection: Optional[Sequence[int]] = None
if columns:
if is_int_sequence(columns):
projection = columns
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)

projection, columns = handle_projection_columns(columns)
self = DataFrame.__new__(DataFrame)
self._df = PyDataFrame.read_parquet(file, columns, projection, n_rows, parallel)
return self
Expand All @@ -577,16 +560,7 @@ def read_ipc(
-------
DataFrame
"""
projection: Optional[Sequence[int]] = None
if columns:
if is_int_sequence(columns):
projection = columns
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)

projection, columns = handle_projection_columns(columns)
self = DataFrame.__new__(DataFrame)
self._df = PyDataFrame.read_ipc(file, columns, projection, n_rows)
return self
Expand Down
13 changes: 3 additions & 10 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from urllib.request import urlopen

from polars.utils import handle_projection_columns

try:
import pyarrow as pa
import pyarrow.csv
Expand Down Expand Up @@ -254,16 +256,7 @@ def read_csv(
if columns is None:
columns = kwargs.pop("projection", None)

projection: Optional[List[int]] = None
if columns:
if isinstance(columns, list):
if all(isinstance(i, int) for i in columns):
projection = columns # type: ignore
columns = None
elif not all(isinstance(i, str) for i in columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)
projection, columns = handle_projection_columns(columns)

if isinstance(file, bytes) and len(file) == 0:
raise ValueError("no date in bytes")
Expand Down
15 changes: 15 additions & 0 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,18 @@ def range_to_slice(rng: range) -> slice:
else:
step = None
return slice(rng.start, rng.stop, step)


def handle_projection_columns(
columns: Optional[Union[List[str], List[int]]]
) -> Tuple[Optional[List[int]], Optional[List[str]]]:
projection: Optional[List[int]] = None
if columns:
if is_int_sequence(columns):
projection = columns # type: ignore
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)
return projection, columns # type: ignore
18 changes: 17 additions & 1 deletion py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import date
from functools import partial
from pathlib import Path
from typing import Dict, Type
from typing import Dict, List, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -206,6 +206,22 @@ def test_partial_column_rename() -> None:
assert df.columns == ["foo", "b", "c"]


@pytest.mark.parametrize(
"col_input, col_out", [([0, 1], ["a", "b"]), ([0, 2], ["a", "c"]), (["b"], ["b"])]
)
def test_read_csv_columns_argument(
col_input: Union[List[int], List[str]], col_out: List[str]
) -> None:
csv = """a,b,c
1,2,3
1,2,3
"""
f = io.StringIO(csv)
df = pl.read_csv(f, columns=col_input)
assert df.shape[0] == 2
assert df.columns == col_out


def test_column_rename_and_dtype_overwrite() -> None:
csv = """
a,b,c
Expand Down

0 comments on commit cc3bb1d

Please sign in to comment.