Skip to content

Commit

Permalink
Add type annotation for projection variable (#2100)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Dec 21, 2021
1 parent 7e036f8 commit 7021980
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 29 deletions.
53 changes: 25 additions & 28 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

from polars._html import NotebookFormatter
from polars.datatypes import Boolean, DataType, Datetime, UInt32, py_type_to_dtype
from polars.utils import _process_null_values
from polars.utils import _process_null_values, is_int_sequence, is_str_sequence

try:
import pandas as pd
Expand Down Expand Up @@ -463,16 +463,15 @@ def read_csv(
if isinstance(file, StringIO):
file = file.getvalue().encode()

projection: Optional[List[int]] = None
projection: Optional[Sequence[int]] = None
if columns:
if isinstance(columns, list):
if all(isinstance(i, int) for i in columns):
projection = columns # type: ignore
columns = None
elif not all(isinstance(i, str) for i in columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)
if is_int_sequence(columns):
projection = columns
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)

dtype_list: Optional[List[Tuple[str, Type[DataType]]]] = None
dtype_slice: Optional[List[Type[DataType]]] = None
Expand Down Expand Up @@ -531,16 +530,15 @@ def read_parquet(
n_rows
Stop reading from parquet file after reading ``n_rows``.
"""
projection: Optional[List[int]] = None
projection: Optional[Sequence[int]] = None
if columns:
if isinstance(columns, list):
if all(isinstance(i, int) for i in columns):
projection = columns # type: ignore
columns = None
elif not all(isinstance(i, str) for i in columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)
if is_int_sequence(columns):
projection = columns
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)

self = DataFrame.__new__(DataFrame)
self._df = PyDataFrame.read_parquet(file, columns, projection, n_rows)
Expand Down Expand Up @@ -568,16 +566,15 @@ def read_ipc(
-------
DataFrame
"""
projection: Optional[List[int]] = None
projection: Optional[Sequence[int]] = None
if columns:
if isinstance(columns, list):
if all(isinstance(i, int) for i in columns):
projection = columns # type: ignore
columns = None
elif not all(isinstance(i, str) for i in columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)
if is_int_sequence(columns):
projection = columns
columns = None
elif not is_str_sequence(columns):
raise ValueError(
"columns arg should contain a list of all integers or all strings values."
)

self = DataFrame.__new__(DataFrame)
self._df = PyDataFrame.read_ipc(file, columns, projection, n_rows)
Expand Down
28 changes: 27 additions & 1 deletion py-polars/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import ctypes
import sys
from datetime import date, datetime, timedelta, timezone
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, Iterable, List, Sequence, Tuple, Type, Union

import numpy as np

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard


def _process_null_values(
null_values: Union[None, str, List[str], Dict[str, str]] = None,
Expand Down Expand Up @@ -51,3 +57,23 @@ def _datetime_to_pl_timestamp(dt: datetime) -> int:
def _date_to_pl_date(d: date) -> int:
dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc)
return int(dt.timestamp()) // (3600 * 24)


def is_str_sequence(
val: Sequence[object], allow_str: bool = False
) -> TypeGuard[Sequence[str]]:
"""
Checks that `val` is a sequence of strings. Note that a single string is a sequence of strings
by definition, use `allow_str=False` to return False on a single string
"""
if (not allow_str) and isinstance(val, str):
return False
return _is_iterable_of(val, Sequence, str)


def is_int_sequence(val: Sequence[object]) -> TypeGuard[Sequence[int]]:
return _is_iterable_of(val, Sequence, int)


def _is_iterable_of(val: Iterable, itertype: Type, eltype: Type) -> bool:
return isinstance(val, itertype) and all(isinstance(x, eltype) for x in val)

0 comments on commit 7021980

Please sign in to comment.