Skip to content

Commit

Permalink
fix any and python dispatch to proper any impl (#2680)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 17, 2022
1 parent 4faa427 commit 6d16f64
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 10 deletions.
16 changes: 10 additions & 6 deletions polars/polars-lazy/src/physical_plan/expressions/alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ impl AliasExpr {
expr,
}
}
fn finish(&self, mut input: Series) -> Result<Series> {
fn finish(&self, mut input: Series) -> Series {
input.rename(&self.name);
Ok(input)
input
}
}

Expand All @@ -32,7 +32,7 @@ impl PhysicalExpr for AliasExpr {

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let series = self.physical_expr.evaluate(df, state)?;
self.finish(series)
Ok(self.finish(series))
}

#[allow(clippy::ptr_arg)]
Expand All @@ -43,10 +43,14 @@ impl PhysicalExpr for AliasExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let mut s = ac.take();
s.rename(&self.name);
let s = ac.take();
let s = self.finish(s);

ac.with_series(self.finish(s)?, ac.is_aggregated());
if ac.is_literal() {
ac.with_literal(s);
} else {
ac.with_series(s, ac.is_aggregated());
}
Ok(ac)
}

Expand Down
8 changes: 7 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ impl PhysicalExpr for CastExpr {
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
let s = ac.flat_naive();
let s = self.finish(s.as_ref())?;
ac.with_series(s, false);

if ac.is_literal() {
ac.with_literal(s);
} else {
ac.with_series(s, false);
}

Ok(ac)
}

Expand Down
12 changes: 12 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ impl<'a> AggregationContext<'a> {
pub(crate) fn is_aggregated(&self) -> bool {
!self.is_not_aggregated()
}
pub(crate) fn is_literal(&self) -> bool {
matches!(self.state, AggState::Literal(_))
}

pub(crate) fn combine_groups(&mut self, other: AggregationContext) -> &mut Self {
if let (Cow::Borrowed(_), Cow::Owned(a)) = (&self.groups, other.groups) {
Expand Down Expand Up @@ -332,6 +335,11 @@ impl<'a> AggregationContext<'a> {
self
}

pub(crate) fn with_literal(&mut self, series: Series) -> &mut Self {
self.state = AggState::Literal(series);
self
}

/// Update the group tuples
pub(crate) fn with_groups(&mut self, groups: GroupsProxy) -> &mut Self {
// In case of new groups, a series always needs to be flattened
Expand All @@ -353,6 +361,10 @@ impl<'a> AggregationContext<'a> {
// because this is lazy, we first must to update the groups
// by calling .groups()
self.groups();
assert!(
self.groups.len() <= s.len(),
"implementation error groups are out of bounds; please open an issue"
);

let out = s
.agg_list(&self.groups)
Expand Down
23 changes: 23 additions & 0 deletions polars/tests/it/lazy/window_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,26 @@ fn test_window_exprs_any_all() -> Result<()> {
assert!(df.frame_equal(&expected));
Ok(())
}

#[test]
fn test_window_naive_any() -> Result<()> {
let df = df![
"row_id" => [0, 0, 1, 1, 1],
"boolvar" => [true, false, true, false, false]
]?;

let df = df
.lazy()
.with_column(
col("boolvar")
.sum()
.gt(lit(0))
.over([col("row_id")])
.alias("res"),
)
.collect()?;

let res = df.column("res")?;
assert_eq!(res.sum::<usize>(), Some(5));
Ok(())
}
4 changes: 2 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def any(name: Union[str, List["pli.Expr"]]) -> "pli.Expr":
"""
if isinstance(name, list):
return fold(lit(0), lambda a, b: a | b, name).alias("any")
return col(name).sum() > 0
return col(name).any()


def exclude(columns: Union[str, List[str]]) -> "pli.Expr":
Expand Down Expand Up @@ -906,7 +906,7 @@ def all(name: Optional[Union[str, List["pli.Expr"]]] = None) -> "pli.Expr":
return col("*")
if isinstance(name, list):
return fold(lit(0), lambda a, b: a & b, name).alias("all")
return col(name).cast(bool).sum() == col(name).count()
return col(name).all()


def groups(column: str) -> "pli.Expr":
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def test_all_expr() -> None:


def test_any_expr(fruits_cars: pl.DataFrame) -> None:
assert fruits_cars.select(pl.any("A"))[0, 0]
assert fruits_cars.with_column(pl.col("A").cast(bool)).select(pl.any("A"))[0, 0]
assert fruits_cars.select(pl.any([pl.col("A"), pl.col("B")]))[0, 0]


Expand Down

0 comments on commit 6d16f64

Please sign in to comment.