Skip to content

Commit

Permalink
fix[rust]: don't run parititioned groupby when computing mean of logi…
Browse files Browse the repository at this point in the history
…cal column (#4412)
  • Loading branch information
ritchie46 committed Aug 14, 2022
1 parent beb802a commit cada9de
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 10 deletions.
22 changes: 18 additions & 4 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ impl AExpr {
}
Mean(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
coerce_numeric_aggregation(&mut field);
Ok(field)
}
List(expr) => {
Expand All @@ -284,12 +284,12 @@ impl AExpr {
}
Std(expr, _) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
coerce_numeric_aggregation(&mut field);
Ok(field)
}
Var(expr, _) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
coerce_numeric_aggregation(&mut field);
Ok(field)
}
NUnique(expr) => {
Expand All @@ -309,7 +309,7 @@ impl AExpr {
}
Quantile { expr, .. } => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
coerce_numeric_aggregation(&mut field);
Ok(field)
}
}
Expand Down Expand Up @@ -359,3 +359,17 @@ impl AExpr {
}
}
}

fn coerce_numeric_aggregation(field: &mut Field) {
match field.dtype {
DataType::Duration(_) => {
// pass
}
DataType::Float32 => {
// pass
}
_ => {
field.coerce(DataType::Float64);
}
}
}
7 changes: 6 additions & 1 deletion polars/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,12 @@ impl DefaultPlanner {
match ae {
// struct is needed to keep both states
#[cfg(feature = "dtype-struct")]
Agg(AAggExpr::Mean(_)) => true,
Agg(AAggExpr::Mean(_)) => {
// only numeric means for now.
// logical types seem to break because of casts to float.
matches!(expr_arena.get(*agg).get_type(&input_schema, Context::Default, expr_arena).map(|dt| {
dt.is_numeric()}), Ok(true))
},
// only allowed expressions
Agg(agg_e) => {
matches!(
Expand Down
28 changes: 28 additions & 0 deletions polars/tests/it/lazy/groupby.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use polars_core::series::ops::NullBehavior;
use polars_core::SINGLE_LOCK;

use super::*;

Expand Down Expand Up @@ -153,3 +154,30 @@ fn test_groupby_agg_list_with_not_aggregated() -> Result<()> {
);
Ok(())
}

#[test]
#[cfg(all(feature = "dtype-duration", feature = "dtype-struct"))]
fn test_logical_mean_partitioned_groupby_block() -> Result<()> {
let guard = SINGLE_LOCK.lock();
let df = df![
"a" => [1, 1, 2],
"duration" => [1000, 2000, 3000]
]?;

let out = df
.lazy()
.with_column(col("duration").cast(DataType::Duration(TimeUnit::Microseconds)))
.groupby([col("a")])
.agg([col("duration").mean()])
.sort("duration", Default::default())
.collect()?;

let duration = out.column("duration")?;

assert_eq!(
duration.get(0),
AnyValue::Duration(1500, TimeUnit::Microseconds)
);

Ok(())
}
10 changes: 5 additions & 5 deletions py-polars/tests/db-benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,6 @@
print("total took:", total_time, "s")
assert out.shape == (9999995, 8)

if not ON_STRINGS:
if total_time > 11:
print("query took longer than 11s, may be noise")
exit(1)

# Additional tests
# the code below, does not belong to the db-benchmark
# but it triggers other code paths so the checksums assertion
Expand All @@ -315,3 +310,8 @@

assert out["id6"] == 2137755425
assert np.isclose(out["v3"], 4.7040828499563754e8)

if not ON_STRINGS:
if total_time > 12:
print("query took longer than 12s, may be noise")
exit(1)

0 comments on commit cada9de

Please sign in to comment.