Skip to content

Commit

Permalink
fix[rust]: fix filter(is_in) in groupby context (#4577)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 26, 2022
1 parent 067045c commit 4ac12e4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 49 deletions.
6 changes: 6 additions & 0 deletions polars/polars-core/src/frame/groupby/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ impl<'a> GroupsIndicator<'a> {
GroupsIndicator::Slice([_, len]) => *len as usize,
}
}
pub fn first(&self) -> IdxSize {
match self {
GroupsIndicator::Idx(g) => g.0,
GroupsIndicator::Slice([first, _]) => *first,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
Expand Down
110 changes: 63 additions & 47 deletions polars/polars-lazy/src/physical_plan/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,54 +49,70 @@ impl PhysicalExpr for FilterExpr {
let groups = ac_s.groups();
let predicate_s = ac_predicate.flat_naive();
let predicate = predicate_s.bool()?.rechunk();
let predicate = predicate.downcast_iter().next().unwrap();

let groups = POOL.install(|| {
match groups.as_ref() {
GroupsProxy::Idx(groups) => {
let groups = groups
.par_iter()
.map(|(first, idx)| unsafe {
let idx: Vec<IdxSize> = idx
.iter()
// Safety:
// just checked bounds in short circuited lhs
.filter_map(|i| {
match predicate.value(*i as usize)
&& predicate.is_valid_unchecked(*i as usize)
{
true => Some(*i),
_ => None,
}
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();

GroupsProxy::Idx(groups)
}
GroupsProxy::Slice { groups, .. } => {
let groups = groups
.par_iter()
.map(|&[first, len]| unsafe {
let idx: Vec<IdxSize> = (first..first + len)
// Safety:
// just checked bounds in short circuited lhs
.filter(|&i| {
predicate.value(i as usize)
&& predicate.is_valid_unchecked(i as usize)
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();
GroupsProxy::Idx(groups)
}

// all values true don't do anything
if predicate.all() {
return Ok(ac_s);
}
// all values false
// create empty groups
let groups = if !predicate.any() {
let groups = groups.iter().map(|gi| [gi.first(), 0]).collect::<Vec<_>>();
GroupsProxy::Slice {
groups,
rolling: false,
}
});
}
// filter the indexes that are true
else {
let predicate = predicate.downcast_iter().next().unwrap();
POOL.install(|| {
match groups.as_ref() {
GroupsProxy::Idx(groups) => {
let groups = groups
.par_iter()
.map(|(first, idx)| unsafe {
let idx: Vec<IdxSize> = idx
.iter()
// Safety:
// just checked bounds in short circuited lhs
.filter_map(|i| {
match predicate.value(*i as usize)
&& predicate.is_valid_unchecked(*i as usize)
{
true => Some(*i),
_ => None,
}
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();

GroupsProxy::Idx(groups)
}
GroupsProxy::Slice { groups, .. } => {
let groups = groups
.par_iter()
.map(|&[first, len]| unsafe {
let idx: Vec<IdxSize> = (first..first + len)
// Safety:
// just checked bounds in short circuited lhs
.filter(|&i| {
predicate.value(i as usize)
&& predicate.is_valid_unchecked(i as usize)
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();
GroupsProxy::Idx(groups)
}
}
})
};

ac_s.with_groups(groups).set_original_len(false);
Ok(ac_s)
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ pub trait PhysicalExpr: Send + Sync {
None
}

//
fn is_valid_aggregation(&self) -> bool;

fn is_literal(&self) -> bool {
Expand Down
9 changes: 7 additions & 2 deletions polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,13 @@ where
/// Check if root expression is a literal
#[cfg(feature = "is_in")]
pub(crate) fn has_root_literal_expr(e: &Expr) -> bool {
let roots = expr_to_root_column_exprs(e);
roots.iter().any(|e| matches!(e, Expr::Literal(_)))
match e {
Expr::Literal(_) => true,
_ => {
let roots = expr_to_root_column_exprs(e);
roots.iter().any(|e| matches!(e, Expr::Literal(_)))
}
}
}

// this one is used so much that it has its own function, to reduce inlining
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,22 @@ def test_melt_values_predicate_pushdown() -> None:
.filter(pl.col("value") == pl.lit("123"))
.collect()
).to_dict(False) == {"id": [1], "variable": ["asset_key_1"], "value": ["123"]}


def test_filter_is_in_4572() -> None:
df = pl.DataFrame({"id": [1, 2, 1, 2], "k": ["a"] * 2 + ["b"] * 2})
expected = (
df.groupby("id").agg(pl.col("k").filter(pl.col("k") == "a").list()).sort("id")
)
assert (
df.groupby("id")
.agg(pl.col("k").filter(pl.col("k").is_in(["a"])).list())
.sort("id")
.frame_equal(expected)
)
assert (
df.sort("id")
.groupby("id")
.agg(pl.col("k").filter(pl.col("k").is_in(["a"])).list())
.frame_equal(expected)
)

0 comments on commit 4ac12e4

Please sign in to comment.