Skip to content

Commit

Permalink
test and run lazy functions Series
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 3, 2021
1 parent 4eb15e0 commit e3790c9
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 104 deletions.
147 changes: 106 additions & 41 deletions py-polars/polars/lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Callable, Optional, Dict
from typing import Union, List, Callable, Optional, Dict, Any

from polars import Series
from polars.frame import DataFrame, wrap_df
Expand Down Expand Up @@ -790,6 +790,7 @@ def agg_groups(self) -> "Expr":
return wrap_expr(self._pyexpr.agg_groups())

def count(self) -> "Expr":
"""Count the number of values in this expression"""
return wrap_expr(self._pyexpr.count())

def slice(self, offset: int, length: int):
Expand All @@ -805,7 +806,7 @@ def slice(self, offset: int, length: int):
"""
return wrap_expr(self._pyexpr.slice(offset, length))

def cum_sum(self, reverse: bool):
def cum_sum(self, reverse: bool = False):
"""
Get an array with the cumulative sum computed at every element
Expand All @@ -816,7 +817,7 @@ def cum_sum(self, reverse: bool):
"""
return wrap_expr(self._pyexpr.cum_sum(reverse))

def cum_min(self, reverse: bool):
def cum_min(self, reverse: bool = False):
"""
Get an array with the cumulative min computed at every element
Expand All @@ -827,7 +828,7 @@ def cum_min(self, reverse: bool):
"""
return wrap_expr(self._pyexpr.cum_min(reverse))

def cum_max(self, reverse: bool):
def cum_max(self, reverse: bool = False):
"""
Get an array with the cumulative max computed at every element
Expand Down Expand Up @@ -1517,11 +1518,13 @@ def except_(name: str) -> "Expr":
return wrap_expr(pyexcept(name))


def count(name: str = "") -> "Expr":
def count(column: "Union[str, Series]" = "") -> "Union[Expr, int]":
"""
Count the number of values in this column
"""
return col(name).count()
if type(column) is Series:
return column.len()
return col(column).count()


def to_list(name: str) -> "Expr":
Expand All @@ -1531,112 +1534,164 @@ def to_list(name: str) -> "Expr":
return col(name).list()


def std(name: str) -> "Expr":
def std(column: "Union[str, Series]") -> "Union[Expr, float]":
"""
Get standard deviation
"""
return col(name).std()
if type(column) is Series:
return column.std()
return col(column).std()


def var(name: str) -> "Expr":
def var(column: "Union[str, Series]") -> "Union[Expr, float]":
"""
Get variance
"""
return col(name).var()
if type(column) is Series:
return column.var()
return col(column).var()


def max(name: "Union[str, List[Expr]]") -> "Expr":
def max(column: "Union[str, List[Expr], Series]") -> "Union[Expr, Any]":
"""
Get maximum value
Get maximum value. Can be used horizontally or vertically.
Parameters
----------
column
Column(s) to be used in aggregation. Will lead to different behavior based on the input.
input:
- Union[str, Series] -> aggregate the maximum value of that column
- List[Expr] -> aggregate the maximum value horizontally.
"""
if isinstance(name, list):
if type(column) is Series:
return column.max()
if isinstance(column, list):

def max_(acc: Series, val: Series) -> Series:
mask = acc < val
return acc.zip_with(mask, val)

return fold(lit(0), max_, name).alias("max")
return col(name).max()
return fold(lit(0), max_, column).alias("max")
return col(column).max()


def min(name: "Union[str, List[Expr]]") -> "Expr":
def min(column: "Union[str, List[Expr], Series]") -> "Union[Expr, Any]":
"""
Get minimum value
column
Column(s) to be used in aggregation. Will lead to different behavior based on the input.
input:
- Union[str, Series] -> aggregate the sum value of that column
- List[Expr] -> aggregate the sum value horizontally.
"""
if isinstance(name, list):
if type(column) is Series:
return column.min()
if isinstance(column, list):

def min_(acc: Series, val: Series) -> Series:
mask = acc > val
return acc.zip_with(mask, val)

return fold(lit(0), min_, name).alias("min")
return col(name).min()
return fold(lit(0), min_, column).alias("min")
return col(column).min()


def sum(name: "Union[str, List[Expr]]") -> "Expr":
def sum(column: "Union[str, List[Expr], Series]") -> "Union[Expr, Any]":
"""
Get sum value
column
Column(s) to be used in aggregation. Will lead to different behavior based on the input.
input:
- Union[str, Series] -> aggregate the sum value of that column
- List[Expr] -> aggregate the sum value horizontally.
"""
if isinstance(name, list):
return fold(lit(0), lambda a, b: a + b, name).alias("sum")
return col(name).sum()
if type(column) is Series:
return column.sum()
if isinstance(column, list):
return fold(lit(0), lambda a, b: a + b, column).alias("sum")
return col(column).sum()


def mean(name: str) -> "Expr":
def mean(column: "Union[str, Series]") -> "Union[Expr, float]":
"""
Get mean value
"""
return col(name).mean()
if type(column) is Series:
return column.mean()
return col(column).mean()


def avg(name: str) -> "Expr":
def avg(column: "Union[str, Series]") -> "Union[Expr, float]":
"""
Alias for mean
"""
return col(name).mean()
return mean(column)


def median(name: str) -> "Expr":
def median(column: "Union[str, Series]") -> "Union[Expr, float, int]":
"""
Get median value
"""
return col(name).median()
if type(column) is Series:
return column.median()
return col(column).median()


def n_unique(name: str) -> "Expr":
def n_unique(column: "Union[str, Series]") -> "Union[Expr, int]":
"""Count unique values"""
return col(name).n_unique()
if type(column) is Series:
return column.n_unique()
return col(column).n_unique()


def first(name: str) -> "Expr":
def first(column: "Union[str, Series]") -> "Union[Expr, Any]":
"""
Get first value
"""
return col(name).first()
if type(column) is Series:
if column.len() > 0:
return column[0]
else:
raise IndexError("Series empty so no first value can be returned")
return col(column).first()


def last(name: str) -> "Expr":
def last(column: str) -> "Expr":
"""
Get last value
"""
return col(name).last()
if type(column) is Series:
if column.len() > 0:
return column[-1]
else:
raise IndexError("Series empty so no last value can be returned")
return col(column).last()


def head(name: str, n: "Optional[int]" = None) -> "Expr":
def head(
column: "Union[str, Series]", n: "Optional[int]" = None
) -> "Union[Expr, Series]":
"""
Get the first n rows of an Expression
Parameters
----------
name
column name
column
column name or Series
n
number of rows to take
"""
return col(name).head(n)
if type(column) is Series:
return column.head(n)
return col(column).head(n)


def tail(name: str, n: "Optional[int]" = None) -> "Expr":
def tail(
column: "Union[str, Series]", n: "Optional[int]" = None
) -> "Union[Expr, Series]":
"""
Get the last n rows of an Expression
Expand All @@ -1647,7 +1702,9 @@ def tail(name: str, n: "Optional[int]" = None) -> "Expr":
n
number of rows to take
"""
return col(name).tail(n)
if type(column) is Series:
return column.tail(n)
return col(column).tail(n)


def lit_date(dt: "datetime") -> Expr:
Expand Down Expand Up @@ -1802,6 +1859,14 @@ def all(name: "Union[str, List[Expr]]") -> "Expr":
return col(name).cast(bool).sum() == col(name).count()


def groups(column: str) -> "Expr":
return col(column).groups()


def quantile(column: str, quantile: float) -> "Expr":
return col(column).quantile(quantile)


class UDF:
def __init__(self, f: Callable[[Series], Series], output_type: "DataType"):
self.f = f
Expand Down
51 changes: 0 additions & 51 deletions py-polars/polars/lazy/agg.py

This file was deleted.

0 comments on commit e3790c9

Please sign in to comment.