Skip to content

Commit

Permalink
fix(rust, python): fix group order in binary aggregation (#5744)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 8, 2022
1 parent f05e4da commit 9c4203c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 32 deletions.
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,13 @@ impl PhysicalExpr for BinaryExpr {
// when groups overlap, step 2 creates more values than rows
// and the original group lengths will be incorrect
(
AggState::AggregatedList(_),
AggState::AggregatedList(_) | AggState::AggregatedFlat(_),
AggState::NotAggregated(_) | AggState::Literal(_),
false,
)
| (
AggState::NotAggregated(_) | AggState::Literal(_),
AggState::AggregatedList(_),
AggState::AggregatedList(_) | AggState::AggregatedFlat(_),
false,
) => {
ac_l.sort_by_groups();
Expand Down
30 changes: 30 additions & 0 deletions polars/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,33 @@ fn test_cache_with_partial_projection() -> PolarsResult<()> {

Ok(())
}

#[test]
#[cfg(feature = "cross_join")]
fn test_cse_columns_projections() -> PolarsResult<()> {
let right = df![
"A" => [1, 2],
"B" => [3, 4],
"D" => [5, 6]
]?
.lazy();

let left = df![
"C" => [3, 4],
]?
.lazy();

let left = left.cross_join(right.clone().select([col("A")]));
let q = left.join(
right.rename(["B"], ["C"]),
[col("A"), col("C")],
[col("A"), col("C")],
JoinType::Left,
);

let out = q.collect()?;

assert_eq!(out.get_column_names(), &["C", "A", "D"]);

Ok(())
}
30 changes: 0 additions & 30 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2047,33 +2047,3 @@ fn test_partitioned_gb_ternary() -> PolarsResult<()> {

Ok(())
}

#[test]
#[cfg(feature = "cross_join")]
fn test_cse_columns_projections() -> PolarsResult<()> {
let right = df![
"A" => [1, 2],
"B" => [3, 4],
"D" => [5, 6]
]?
.lazy();

let left = df![
"C" => [3, 4],
]?
.lazy();

let left = left.cross_join(right.clone().select([col("A")]));
let q = left.join(
right.rename(["B"], ["C"]),
[col("A"), col("C")],
[col("A"), col("C")],
JoinType::Left,
);

let out = q.collect()?;

assert_eq!(out.get_column_names(), &["C", "A", "D"]);

Ok(())
}
28 changes: 28 additions & 0 deletions polars/tests/it/lazy/expressions/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,31 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {

Ok(())
}

#[test]
fn test_binary_group_consistency() -> PolarsResult<()> {
let lf = df![
"name" => ["a", "b", "c", "d"],
"category" => [1, 2, 3, 4],
"score" => [3, 5, 1, 2],
]?
.lazy();

let out = lf
.groupby([col("category")])
.agg([col("name").filter(col("score").eq(col("score").max()))])
.sort("category", Default::default())
.collect()?;
let out = out.column("name")?;

assert_eq!(out.dtype(), &DataType::List(Box::new(DataType::Utf8)));
assert_eq!(
out.explode()?
.utf8()?
.into_no_null_iter()
.collect::<Vec<_>>(),
&["a", "b", "c", "d"]
);

Ok(())
}

0 comments on commit 9c4203c

Please sign in to comment.