Skip to content

Commit

Permalink
Small refactor of expr_to_lit et al (#2145)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Dec 24, 2021
1 parent 8c88270 commit 9fe155e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 30 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
The modules within `polars.internals` are interdependent. To prevent cyclical imports, they all import from each other
via this __init__ file using `import polars.internals as pli`. The imports below are being shared across this module.
"""
from .expr import Expr, _selection_to_pyexpr_list, expr_to_lit_or_expr, wrap_expr
from .expr import Expr, expr_to_lit_or_expr, selection_to_pyexpr_list, wrap_expr
from .frame import DataFrame, wrap_df
from .functions import concat, date_range # DataFrame.describe() & DataFrame.upsample()
from .lazy_frame import LazyFrame, wrap_ldf
Expand Down
30 changes: 12 additions & 18 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,13 @@
)


def _selection_to_pyexpr_list(
exprs: Union[str, "Expr", Sequence[Union[str, "Expr"]], "pli.Series"]
def selection_to_pyexpr_list(
exprs: Union[str, "Expr", Sequence[Union[str, "Expr", "pli.Series"]], "pli.Series"]
) -> List["PyExpr"]:
pyexpr_list: List[PyExpr]
if isinstance(exprs, Sequence) and not isinstance(exprs, str):
pyexpr_list = []
for expr in exprs:
pyexpr_list.append(expr_to_lit_or_expr(expr, str_to_lit=False)._pyexpr)
else:
pyexpr_list = [expr_to_lit_or_expr(exprs, str_to_lit=False)._pyexpr]
return pyexpr_list
if isinstance(exprs, (str, Expr, pli.Series)):
exprs = [exprs]

return [expr_to_lit_or_expr(e, str_to_lit=False)._pyexpr for e in exprs]


def wrap_expr(pyexpr: "PyExpr") -> "Expr":
Expand Down Expand Up @@ -825,7 +821,7 @@ def sort_by(
by = [by]
if not isinstance(reverse, list):
reverse = [reverse]
by = _selection_to_pyexpr_list(by)
by = selection_to_pyexpr_list(by)

return wrap_expr(self._pyexpr.sort_by(by, reverse))

Expand All @@ -844,8 +840,6 @@ def take(self, index: Union[List[int], "Expr", "pli.Series", np.ndarray]) -> "Ex
"""
if isinstance(index, (list, np.ndarray)):
index_lit = pli.lit(pli.Series("", index, dtype=UInt32))
elif isinstance(index, pli.Series):
index_lit = pli.lit(index)
else:
index_lit = pli.expr_to_lit_or_expr(index, str_to_lit=False)
return pli.wrap_expr(self._pyexpr.take(index_lit._pyexpr))
Expand Down Expand Up @@ -1069,7 +1063,7 @@ def over(self, expr: Union[str, "Expr", List[Union["Expr", str]]]) -> "Expr":
"""

pyexprs = _selection_to_pyexpr_list(expr)
pyexprs = selection_to_pyexpr_list(expr)

return wrap_expr(self._pyexpr.over(pyexprs))

Expand Down Expand Up @@ -2745,7 +2739,7 @@ def timestamp(self) -> Expr:


def expr_to_lit_or_expr(
expr: Union[Expr, bool, int, float, str, List[Expr], List[str], "pli.Series"],
expr: Union[Expr, bool, int, float, str, "pli.Series"],
str_to_lit: bool = True,
) -> Expr:
"""
Expand All @@ -2769,7 +2763,7 @@ def expr_to_lit_or_expr(
isinstance(expr, (int, float, str, pli.Series, datetime, date)) or expr is None
):
return pli.lit(expr)
elif isinstance(expr, list):
return [expr_to_lit_or_expr(e, str_to_lit=str_to_lit) for e in expr] # type: ignore[return-value]
else:
elif isinstance(expr, Expr):
return expr
else:
raise Exception
9 changes: 4 additions & 5 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ def sort(
if type(reverse) is bool:
reverse = [reverse]

by = pli.expr_to_lit_or_expr(by, str_to_lit=False)
by = pli._selection_to_pyexpr_list(by)
by = pli.selection_to_pyexpr_list(by)
return wrap_ldf(self._ldf.sort_by_exprs(by, reverse))

def collect(
Expand Down Expand Up @@ -478,7 +477,7 @@ def select(
exprs
Column or columns to select.
"""
exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return wrap_ldf(self._ldf.select(exprs))

def groupby(
Expand Down Expand Up @@ -1037,7 +1036,7 @@ def explode(
└─────────┴─────┘
"""
columns = pli._selection_to_pyexpr_list(columns)
columns = pli.selection_to_pyexpr_list(columns)
return wrap_ldf(self._ldf.explode(columns))

def drop_duplicates(
Expand Down Expand Up @@ -1217,7 +1216,7 @@ def agg(self, aggs: Union[List["pli.Expr"], "pli.Expr"]) -> "LazyFrame":
... ) # doctest: +SKIP
"""
aggs = pli._selection_to_pyexpr_list(aggs)
aggs = pli.selection_to_pyexpr_list(aggs)
return wrap_ldf(self.lgb.agg(aggs))

def head(self, n: int = 5) -> "LazyFrame":
Expand Down
12 changes: 6 additions & 6 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def map(
-------
Expr
"""
exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_map_mul(exprs, f, return_dtype, apply_groups=False))


Expand Down Expand Up @@ -666,7 +666,7 @@ def apply(
-------
Expr
"""
exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_map_mul(exprs, f, return_dtype, apply_groups=True))


Expand Down Expand Up @@ -724,7 +724,7 @@ def fold(
if isinstance(exprs, pli.Expr):
exprs = [exprs]

exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(pyfold(acc._pyexpr, f, exprs))


Expand Down Expand Up @@ -902,7 +902,7 @@ def argsort_by(
"""
if not isinstance(reverse, list):
reverse = [reverse] * len(exprs)
exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(pyargsort_by(exprs, reverse))


Expand Down Expand Up @@ -1000,7 +1000,7 @@ def concat_str(exprs: Sequence[Union["pli.Expr", str]], sep: str = "") -> "pli.E
sep
String value that will be used to separate the values.
"""
exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_concat_str(exprs, sep))


Expand Down Expand Up @@ -1109,7 +1109,7 @@ def concat_list(exprs: Sequence[Union[str, "pli.Expr"]]) -> "pli.Expr":
└─────────────────┘
"""
exprs = pli._selection_to_pyexpr_list(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_concat_lst(exprs))


Expand Down

0 comments on commit 9fe155e

Please sign in to comment.