Skip to content

Commit

Permalink
fix: Fix invalid paritionable query (#14966)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 10, 2024
1 parent 419b891 commit e7be629
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 28 deletions.
9 changes: 6 additions & 3 deletions crates/polars-arrow/src/array/binview/mod.rs
Expand Up @@ -26,6 +26,7 @@ mod private {
}
pub use iterator::BinaryViewValueIter;
pub use mutable::MutableBinaryViewArray;
use polars_utils::slice::GetSaferUnchecked;
use private::Sealed;

use crate::array::binview::view::{validate_binary_view, validate_utf8_only, validate_utf8_view};
Expand Down Expand Up @@ -273,7 +274,7 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
/// Assumes that the `i < self.len`.
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> &T {
let v = *self.views.get_unchecked(i);
let v = *self.views.get_unchecked_release(i);
let len = v.length;

// view layout:
Expand All @@ -290,10 +291,12 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
let ptr = self.views.as_ptr() as *const u8;
std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize)
} else {
let (data_ptr, data_len) = *self.raw_buffers.get_unchecked(v.buffer_idx as usize);
let (data_ptr, data_len) = *self
.raw_buffers
.get_unchecked_release(v.buffer_idx as usize);
let data = std::slice::from_raw_parts(data_ptr, data_len);
let offset = v.offset as usize;
data.get_unchecked(offset..offset + len as usize)
data.get_unchecked_release(offset..offset + len as usize)
};
T::from_bytes_unchecked(bytes)
}
Expand Down
@@ -1,3 +1,4 @@
use polars_core::series::IsSorted;
use polars_core::utils::{accumulate_dataframes_vertical, split_df};
use rayon::prelude::*;

Expand Down Expand Up @@ -155,6 +156,12 @@ fn estimate_unique_count(keys: &[Series], mut sample_size: usize) -> PolarsResul
}
}

// Lower this at debug builds so that we hit this in the test suite.
#[cfg(debug_assertions)]
const PARTITION_LIMIT: usize = 15;
#[cfg(not(debug_assertions))]
const PARTITION_LIMIT: usize = 1000;

// Checks if we should run normal or default aggregation
// by sampling data.
fn can_run_partitioned(
Expand All @@ -163,7 +170,16 @@ fn can_run_partitioned(
state: &ExecutionState,
from_partitioned_ds: bool,
) -> PolarsResult<bool> {
if std::env::var("POLARS_NO_PARTITION").is_ok() {
if !keys
.iter()
.take(1)
.all(|s| matches!(s.is_sorted_flag(), IsSorted::Not))
{
if state.verbose() {
eprintln!("FOUND SORTED KEY: running default HASH AGGREGATION")
}
Ok(false)
} else if std::env::var("POLARS_NO_PARTITION").is_ok() {
if state.verbose() {
eprintln!("POLARS_NO_PARTITION set: running default HASH AGGREGATION")
}
Expand All @@ -173,9 +189,9 @@ fn can_run_partitioned(
eprintln!("POLARS_FORCE_PARTITION set: running partitioned HASH AGGREGATION")
}
Ok(true)
} else if original_df.height() < 1000 && !cfg!(test) {
} else if original_df.height() < PARTITION_LIMIT && !cfg!(test) {
if state.verbose() {
eprintln!("DATAFRAME < 1000 rows: running default HASH AGGREGATION")
eprintln!("DATAFRAME < {PARTITION_LIMIT} rows: running default HASH AGGREGATION")
}
Ok(false)
} else {
Expand Down Expand Up @@ -297,7 +313,7 @@ impl PartitionGroupByExec {
}

#[cfg(feature = "streaming")]
if !self.maintain_order {
if !self.maintain_order && std::env::var("POLARS_NO_STREAMING_GROUPBY").is_err() {
if let Some(out) = self.run_streaming(state, original_df.clone()) {
return out;
}
Expand Down
33 changes: 12 additions & 21 deletions crates/polars-lazy/src/physical_plan/planner/lp.rs
Expand Up @@ -34,8 +34,15 @@ fn partitionable_gb(

if partitionable {
for agg in aggs {
let aexpr = expr_arena.get(*agg);
let depth = (expr_arena).iter(*agg).count();
let mut agg = *agg;
let mut aexpr = expr_arena.get(agg);
// It should end with an aggregation
if let AExpr::Alias(input, _) = aexpr {
agg = *input;
aexpr = expr_arena.get(agg);
}

let depth = (expr_arena).iter(agg).count();

// These single expressions are partitionable
if matches!(aexpr, AExpr::Len) {
Expand All @@ -48,37 +55,21 @@ fn partitionable_gb(
break;
}

// it should end with an aggregation
if let AExpr::Alias(input, _) = aexpr {
// col().agg().alias() is allowed: count of 3
// col().alias() is not allowed: count of 2
// count().alias() is allowed: count of 2
if depth <= 2 {
match expr_arena.get(*input) {
AExpr::Len => {},
_ => {
partitionable = false;
break;
},
}
}
}

let has_aggregation =
|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));

// check if the aggregation type is partitionable
// only simple aggregation like col().sum
// that can be divided in to the aggregation of their partitions are allowed
if !((expr_arena).iter(*agg).all(|(_, ae)| {
if !((expr_arena).iter(agg).all(|(_, ae)| {
use AExpr::*;
match ae {
// struct is needed to keep both states
#[cfg(feature = "dtype-struct")]
Agg(AAggExpr::Mean(_)) => {
// only numeric means for now.
// logical types seem to break because of casts to float.
matches!(expr_arena.get(*agg).get_type(_input_schema, Context::Default, expr_arena).map(|dt| {
matches!(expr_arena.get(agg).get_type(_input_schema, Context::Default, expr_arena).map(|dt| {
dt.is_numeric()}), Ok(true))
},
// only allowed expressions
Expand Down Expand Up @@ -120,7 +111,7 @@ fn partitionable_gb(

#[cfg(feature = "object")]
{
for name in aexpr_to_leaf_names(*agg, expr_arena) {
for name in aexpr_to_leaf_names(agg, expr_arena) {
let dtype = _input_schema.get(&name).unwrap();

if let DataType::Object(_, _) = dtype {
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Expand Up @@ -949,3 +949,21 @@ def test_group_by_with_null() -> None:
)
output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c"))
assert_frame_equal(expected, output)


def test_partitioned_group_by_14954(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_FORCE_PARTITION", "1")
assert (
pl.DataFrame({"a": range(20)})
.select(pl.col("a") % 2)
.group_by("a")
.agg(
(pl.col("a") > 1000).alias("a > 1000"),
)
).sort("a").to_dict(as_series=False) == {
"a": [0, 1],
"a > 1000": [
[False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False],
],
}

0 comments on commit e7be629

Please sign in to comment.