Skip to content

Commit

Permalink
fix(rust, python): block streaming on literal series/range (#6058)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 5, 2023
1 parent b3a2c9c commit 09ed10a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
34 changes: 27 additions & 7 deletions polars/polars-lazy/src/physical_plan/streaming/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,37 @@ fn to_physical_piped_expr(
}

fn is_streamable(node: Node, expr_arena: &Arena<AExpr>) -> bool {
expr_arena.iter(node).all(|(_, ae)| match ae {
// check weather leaf colum is Col or Lit
let mut seen_column = false;
let mut seen_lit_range = false;
let all = expr_arena.iter(node).all(|(_, ae)| match ae {
AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => {
matches!(options.collect_groups, ApplyOptions::ApplyFlat)
}
AExpr::Column(_)
| AExpr::Literal(_)
| AExpr::BinaryExpr { .. }
| AExpr::Alias(_, _)
| AExpr::Cast { .. } => true,
AExpr::Column(_) => {
seen_column = true;
true
}
AExpr::BinaryExpr { .. } | AExpr::Alias(_, _) | AExpr::Cast { .. } => true,
AExpr::Literal(lv) => match lv {
LiteralValue::Series(_) | LiteralValue::Range { .. } => {
seen_lit_range = true;
true
}
_ => true,
},
_ => false,
})
});

if all {
// adding a range or literal series to chunks will fail because sizes don't match
// if column is a leaf column then it is ok
// - so we want to block `with_column(lit(Series))`
// - but we want to allow `with_column(col("foo").is_in(Series))`
// that means that IFF we seen a lit_range, we only allow if we also seen a `column`.
return if seen_lit_range { seen_column } else { true };
}
false
}

fn all_streamable(exprs: &[Node], expr_arena: &Arena<AExpr>) -> bool {
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,12 @@ def test_streaming_categoricals_5921() -> None:
for out in [out_eager, out_lazy]:
assert out.dtypes == [pl.Categorical, pl.Int64]
assert out.to_dict(False) == {"X": ["a", "b"], "Y": [2, 1]}


def test_streaming_block_on_literals_6054() -> None:
df = pl.DataFrame({"col_1": [0] * 5 + [1] * 5})
s = pl.Series("col_2", list(range(10)))

assert df.lazy().with_column(s).groupby("col_1").agg(pl.all().first()).collect(
streaming=True
).sort("col_1").to_dict(False) == {"col_1": [0, 1], "col_2": [0, 5]}

0 comments on commit 09ed10a

Please sign in to comment.