Skip to content

Commit

Permalink
refactor[python]: prefer "Sequence" to "list" in various type signatu…
Browse files Browse the repository at this point in the history
…res (#4746)
  • Loading branch information
alexander-beedie committed Sep 6, 2022
1 parent 66b125a commit 096cd78
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 64 deletions.
9 changes: 4 additions & 5 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,10 @@ def sequence_to_pyseries(

value = _get_first_non_none(values)
if value is not None:
# this branch is for dtypes set with python types.
# eg. 'datetime.date/datetime.datetime'
# and values that are integers
# if this holds we take the physical branch
# if the values are also python types we take the temporal branch
# for temporal dtypes:
# * if the values are integer, we take the physical branch.
# * if the values are python types, take the temporal branch.
# * if the values are ISO-8601 strings, init then convert via strptime.
if dtype in py_temporal_types and isinstance(value, int):
dtype = py_type_to_dtype(dtype) # construct from integer
elif (
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def exclude(
self,
columns: (
str
| list[str]
| Sequence[str]
| DataType
| type[DataType]
| DataType
Expand Down
68 changes: 36 additions & 32 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@


def col(
name: str | list[str] | Sequence[PolarsDataType] | pli.Series | PolarsDataType,
name: str | Sequence[str] | Sequence[PolarsDataType] | pli.Series | PolarsDataType,
) -> pli.Expr:
"""
Return an expression representing a column in a DataFrame.
Expand Down Expand Up @@ -169,7 +169,7 @@ def col(
if isinstance(name, DataType):
return pli.wrap_expr(_dtype_cols([name]))

if isinstance(name, list):
elif not isinstance(name, str) and isinstance(name, Sequence):
if len(name) == 0 or isinstance(name[0], str):
return pli.wrap_expr(pycols(name))
elif is_polars_dtype(name[0]):
Expand Down Expand Up @@ -309,7 +309,7 @@ def var(column: str | pli.Series, ddof: int = 1) -> pli.Expr | float | None:


@overload
def max(column: str | list[pli.Expr | str]) -> pli.Expr:
def max(column: str | Sequence[pli.Expr | str]) -> pli.Expr:
...


Expand All @@ -318,7 +318,7 @@ def max(column: pli.Series) -> int | float:
...


def max(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
def max(column: str | Sequence[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
"""
Get the maximum value. Can be used horizontally or vertically.
Expand All @@ -333,15 +333,15 @@ def max(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
"""
if isinstance(column, pli.Series):
return column.max()
elif isinstance(column, list):
elif isinstance(column, str):
return col(column).max()
else:
exprs = pli.selection_to_pyexpr_list(column)
return pli.wrap_expr(_max_exprs(exprs))
else:
return col(column).max()


@overload
def min(column: str | list[pli.Expr | str]) -> pli.Expr:
def min(column: str | Sequence[pli.Expr | str]) -> pli.Expr:
...


Expand All @@ -350,7 +350,7 @@ def min(column: pli.Series) -> int | float:
...


def min(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
def min(column: str | Sequence[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
"""
Get the minimum value.
Expand All @@ -363,15 +363,15 @@ def min(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
"""
if isinstance(column, pli.Series):
return column.min()
elif isinstance(column, list):
elif isinstance(column, str):
return col(column).min()
else:
exprs = pli.selection_to_pyexpr_list(column)
return pli.wrap_expr(_min_exprs(exprs))
else:
return col(column).min()


@overload
def sum(column: str | list[pli.Expr | str] | pli.Expr) -> pli.Expr:
def sum(column: str | Sequence[pli.Expr | str] | pli.Expr) -> pli.Expr:
...


Expand All @@ -380,7 +380,9 @@ def sum(column: pli.Series) -> int | float:
...


def sum(column: str | list[pli.Expr | str] | pli.Series | pli.Expr) -> pli.Expr | Any:
def sum(
column: str | Sequence[pli.Expr | str] | pli.Series | pli.Expr,
) -> pli.Expr | Any:
"""
Sum values in a column/Series, or horizontally across list of columns/expressions.
Expand Down Expand Up @@ -469,14 +471,14 @@ def sum(column: str | list[pli.Expr | str] | pli.Series | pli.Expr) -> pli.Expr
"""
if isinstance(column, pli.Series):
return column.sum()
elif isinstance(column, list):
elif isinstance(column, str):
return col(column).sum()
elif isinstance(column, Sequence):
exprs = pli.selection_to_pyexpr_list(column)
return pli.wrap_expr(_sum_exprs(exprs))
elif isinstance(column, pli.Expr):
# use u32 as that is not cast to float as eagerly
return fold(lit(0).cast(UInt32), lambda a, b: a + b, column).alias("sum")
else:
return col(column).sum()
# (Expr): use u32 as that will not cast to float as eagerly
return fold(lit(0).cast(UInt32), lambda a, b: a + b, column).alias("sum")


@overload
Expand Down Expand Up @@ -830,8 +832,8 @@ def cov(


def map(
exprs: list[str] | list[pli.Expr],
f: Callable[[list[pli.Series]], pli.Series],
exprs: Sequence[str] | Sequence[pli.Expr],
f: Callable[[Sequence[pli.Series]], pli.Series],
return_dtype: type[DataType] | None = None,
) -> pli.Expr:
"""
Expand All @@ -858,8 +860,8 @@ def map(


def apply(
exprs: list[str | pli.Expr],
f: Callable[[list[pli.Series]], pli.Series | Any],
exprs: Sequence[str | pli.Expr],
f: Callable[[Sequence[pli.Series]], pli.Series | Any],
return_dtype: type[DataType] | None = None,
) -> pli.Expr:
"""
Expand Down Expand Up @@ -920,19 +922,20 @@ def fold(
return pli.wrap_expr(pyfold(acc._pyexpr, f, exprs))


def any(name: str | list[str] | list[pli.Expr] | pli.Expr) -> pli.Expr:
def any(name: str | Sequence[str] | Sequence[pli.Expr] | pli.Expr) -> pli.Expr:
"""Evaluate columnwise or elementwise with a bitwise OR operation."""
if isinstance(name, (list, pli.Expr)):
if isinstance(name, str):
return col(name).any()
else:
return fold(lit(False), lambda a, b: a.cast(bool) | b.cast(bool), name).alias(
"any"
)
return col(name).any()


def exclude(
columns: (
str
| list[str]
| Sequence[str]
| DataType
| type[DataType]
| DataType
Expand Down Expand Up @@ -1031,7 +1034,7 @@ def exclude(
return col("*").exclude(columns)


def all(name: str | list[pli.Expr] | pli.Expr | None = None) -> pli.Expr:
def all(name: str | Sequence[pli.Expr] | pli.Expr | None = None) -> pli.Expr:
"""
Do one of two things.
Expand Down Expand Up @@ -1063,11 +1066,12 @@ def all(name: str | list[pli.Expr] | pli.Expr | None = None) -> pli.Expr:
"""
if name is None:
return col("*")
if isinstance(name, (list, pli.Expr)):
elif isinstance(name, str):
return col(name).all()
else:
return fold(lit(True), lambda a, b: a.cast(bool) & b.cast(bool), name).alias(
"all"
)
return col(name).all()


def groups(column: str) -> pli.Expr:
Expand Down Expand Up @@ -1169,7 +1173,7 @@ def arange(

def argsort_by(
exprs: pli.Expr | str | Sequence[pli.Expr | str],
reverse: list[bool] | bool = False,
reverse: Sequence[bool] | bool = False,
) -> pli.Expr:
"""
Find the indexes that would sort the columns.
Expand Down Expand Up @@ -1507,7 +1511,7 @@ def concat_list(exprs: Sequence[str | pli.Expr | pli.Series] | pli.Expr) -> pli.


def collect_all(
lazy_frames: list[pli.LazyFrame],
lazy_frames: Sequence[pli.LazyFrame],
type_coercion: bool = True,
predicate_pushdown: bool = True,
projection_pushdown: bool = True,
Expand Down
36 changes: 13 additions & 23 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Callable, Iterable, Sequence, TypeVar

import polars.internals as pli
from polars.datatypes import DataType, Date, Datetime
from polars.datatypes import DataType, Date, Datetime, PolarsDataType, is_polars_dtype

try:
from polars.polars import PyExpr
Expand Down Expand Up @@ -76,41 +76,34 @@ def _date_to_pl_date(d: date) -> int:
return int(dt.timestamp()) // (3600 * 24)


def _is_iterable_of(val: Iterable[object], eltype: type) -> bool:
"""Check whether the given iterable is of a certain type."""
def _is_iterable_of(val: Iterable[object], eltype: type | tuple[type, ...]) -> bool:
"""Check whether the given iterable is of the given type(s)."""
return all(isinstance(x, eltype) for x in val)


def is_bool_sequence(val: object) -> TypeGuard[Sequence[bool]]:
"""Check whether the given sequence is a sequence of booleans."""
if isinstance(val, Sequence):
return _is_iterable_of(val, bool)
else:
return False
return isinstance(val, Sequence) and _is_iterable_of(val, bool)


def is_dtype_sequence(val: object) -> TypeGuard[Sequence[PolarsDataType]]:
"""Check whether the given object is a sequence of polars DataTypes."""
return isinstance(val, Sequence) and all(is_polars_dtype(x) for x in val)


def is_int_sequence(val: object) -> TypeGuard[Sequence[int]]:
"""Check whether the given sequence is a sequence of integers."""
if isinstance(val, Sequence):
return _is_iterable_of(val, int)
else:
return False
return isinstance(val, Sequence) and _is_iterable_of(val, int)


def is_expr_sequence(val: object) -> TypeGuard[Sequence[pli.Expr]]:
"""Check whether the given object is a sequence of Exprs."""
if isinstance(val, Sequence):
return _is_iterable_of(val, pli.Expr)
else:
return False
return isinstance(val, Sequence) and _is_iterable_of(val, pli.Expr)


def is_pyexpr_sequence(val: object) -> TypeGuard[Sequence[PyExpr]]:
"""Check whether the given object is a sequence of PyExprs."""
if isinstance(val, Sequence):
return _is_iterable_of(val, PyExpr)
else:
return False
return isinstance(val, Sequence) and _is_iterable_of(val, PyExpr)


def is_str_sequence(
Expand All @@ -124,10 +117,7 @@ def is_str_sequence(
"""
if allow_str is False and isinstance(val, str):
return False
if isinstance(val, Sequence):
return _is_iterable_of(val, str)
else:
return False
return isinstance(val, Sequence) and _is_iterable_of(val, str)


def range_to_slice(rng: range) -> slice:
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/test_apply.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import typing
from datetime import date, datetime, timedelta
from functools import reduce
from typing import Sequence, no_type_check

import numpy as np

Expand Down Expand Up @@ -33,7 +33,7 @@ def test_apply_none() -> None:
assert out_df["a"].to_list() == (df["a"] * df["b"]).to_list()

# check if we can return None
def func(s: list[pl.Series]) -> pl.Series | None:
def func(s: Sequence[pl.Series]) -> pl.Series | None:
if s[0][0] == 190:
return None
else:
Expand All @@ -57,7 +57,7 @@ def test_apply_return_py_object() -> None:
assert out.shape == (1, 2)


@typing.no_type_check
@no_type_check
def test_agg_objects() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 096cd78

Please sign in to comment.