Skip to content

Commit

Permalink
lazy: binary expression can combine different aggregation states in g…
Browse files Browse the repository at this point in the history
…roupby context (#1875)

* add aggstate in aggcontext

* lazy: binary expression can combine different aggregation states in groupby context
  • Loading branch information
ritchie46 committed Nov 24, 2021
1 parent 1bdf9c7 commit a8d0dd2
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 38 deletions.
12 changes: 12 additions & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ abs = ["polars-core/abs"]
# no guarantees whatsoever
private = []

test = [
"rolling_window",
"rank",
"list",
"round_series",
"csv-file",
"dtype-categorical",
"cum_agg",
"regex",
"polars-core/plain_fmt",
]

[dependencies]
ahash = "0.7"
itertools = "0.10"
Expand Down
111 changes: 107 additions & 4 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::physical_plan::state::ExecutionState;
use crate::physical_plan::PhysicalAggregation;
use crate::prelude::*;
use polars_arrow::arrow::array::ArrayRef;
use polars_core::frame::groupby::GroupTuples;
use polars_core::{prelude::*, POOL};
use std::convert::TryFrom;
use std::sync::Arc;

pub struct BinaryExpr {
Expand Down Expand Up @@ -84,7 +86,7 @@ impl PhysicalExpr for BinaryExpr {
)
});
let mut ac_l = result_a?;
let ac_r = result_b?;
let mut ac_r = result_b?;

if !ac_l.can_combine(&ac_r) {
return Err(PolarsError::InvalidOperation(
Expand All @@ -93,9 +95,110 @@ impl PhysicalExpr for BinaryExpr {
.into(),
));
}
let out = apply_operator(ac_l.flat().as_ref(), ac_r.flat().as_ref(), self.op)?;
ac_l.combine_groups(ac_r).with_series(out, false);
Ok(ac_l)

match (ac_l.agg_state(), ac_r.agg_state()) {
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`
(AggState::AggregatedFlat(_), AggState::NotAggregated(_)) => {
// this is a flat series of len eq to group tuples
let l = ac_l.aggregated();
let l = l.as_ref();
let arr_l = &l.chunks()[0];
assert_eq!(l.len(), groups.len());

// we create a dummy Series that is not cloned nor moved
// so we can swap the ArrayRef during the hot loop
// this prevents a series Arc alloc and a vec alloc per iteration
let dummy = Series::try_from(("dummy", vec![arr_l.clone()])).unwrap();
let chunks = unsafe {
let chunks = dummy.chunks();
let ptr = chunks.as_ptr() as *mut ArrayRef;
let len = chunks.len();
std::slice::from_raw_parts_mut(ptr, len)
};

// this is now a list
let r = ac_r.aggregated();
let r = r.list().unwrap();

let mut ca: ListChunked = r
.amortized_iter()
.enumerate()
.map(|(idx, opt_s)| {
opt_s
.map(|s| {
let r = s.as_ref();
// TODO: optimize this? Its slow and unsafe.

// Safety:
// we are in bounds
let mut arr = unsafe { Arc::from(arr_l.slice_unchecked(idx, 1)) };
std::mem::swap(&mut chunks[0], &mut arr);
let l = &dummy;

apply_operator(l, r, self.op)
})
.transpose()
})
.collect::<Result<_>>()?;
ca.rename(l.name());

ac_l.with_series(ca.into_series(), true);
Ok(ac_l)
}
(AggState::NotAggregated(_), AggState::AggregatedFlat(_)) => {
// this is now a list
let l = ac_l.aggregated();
let l = l.list().unwrap();

// this is a flat series of len eq to group tuples
let r = ac_r.aggregated();
assert_eq!(l.len(), groups.len());
let r = r.as_ref();
let arr_r = &r.chunks()[0];

// we create a dummy Series that is not cloned nor moved
// so we can swap the ArrayRef during the hot loop
// this prevents a series Arc alloc and a vec alloc per iteration
let dummy = Series::try_from(("dummy", vec![arr_r.clone()])).unwrap();
let chunks = unsafe {
let chunks = dummy.chunks();
let ptr = chunks.as_ptr() as *mut ArrayRef;
let len = chunks.len();
std::slice::from_raw_parts_mut(ptr, len)
};

let mut ca: ListChunked = l
.amortized_iter()
.enumerate()
.map(|(idx, opt_s)| {
opt_s
.map(|s| {
let l = s.as_ref();
// TODO: optimize this? Its slow.
// Safety:
// we are in bounds
let mut arr = unsafe { Arc::from(arr_r.slice_unchecked(idx, 1)) };
std::mem::swap(&mut chunks[0], &mut arr);
let r = &dummy;

apply_operator(l, r, self.op)
})
.transpose()
})
.collect::<Result<_>>()?;
ca.rename(l.name());

ac_l.with_series(ca.into_series(), true);
Ok(ac_l)
}
// Both are or a flat series or aggreagated into a list
// so we can flatten the Series an apply the operators
_ => {
let out = apply_operator(ac_l.flat().as_ref(), ac_r.flat().as_ref(), self.op)?;
ac_l.combine_groups(ac_r).with_series(out, false);
Ok(ac_l)
}
}
}

fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
Expand Down
68 changes: 41 additions & 27 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ use polars_core::prelude::*;
use polars_io::PhysicalIoExpr;
use std::borrow::Cow;

#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) enum AggState {
Aggregated(Series),
/// Already aggregated: `.agg_list(group_tuples` is called
/// and produced a `Series` of dtype `List`
AggregatedList(Series),
/// Already aggregated: `.agg_list(group_tuples` is called
/// and produced a `Series` of any dtype that is not nested.
AggregatedFlat(Series),
/// Not yet aggregated: `agg_list` still has to be called.
NotAggregated(Series),
None,
}
Expand Down Expand Up @@ -141,17 +148,22 @@ impl<'a> AggregationContext<'a> {

pub(crate) fn series(&self) -> &Series {
match &self.series {
AggState::Aggregated(s) => s,
AggState::NotAggregated(s) => s,
_ => unreachable!(),
AggState::NotAggregated(s)
| AggState::AggregatedFlat(s)
| AggState::AggregatedList(s) => s,
AggState::None => unreachable!(),
}
}

pub(crate) fn agg_state(&self) -> &AggState {
&self.series
}

pub(crate) fn is_not_aggregated(&self) -> bool {
matches!(&self.series, AggState::NotAggregated(_))
}

pub fn is_aggregated(&self) -> bool {
pub(crate) fn is_aggregated(&self) -> bool {
!self.is_not_aggregated()
}

Expand All @@ -170,10 +182,16 @@ impl<'a> AggregationContext<'a> {
groups: Cow<'a, GroupTuples>,
aggregated: bool,
) -> AggregationContext<'a> {
let series = if aggregated {
AggState::Aggregated(series)
} else {
AggState::NotAggregated(series)
let series = match (aggregated, series.dtype()) {
(true, &DataType::List(_)) => {
assert_eq!(series.len(), groups.len());
AggState::AggregatedList(series)
}
(true, _) => {
assert_eq!(series.len(), groups.len());
AggState::AggregatedFlat(series)
}
_ => AggState::NotAggregated(series),
};

Self {
Expand All @@ -200,7 +218,10 @@ impl<'a> AggregationContext<'a> {
/// the columns dtype)
pub(crate) fn with_series(&mut self, series: Series, aggregated: bool) -> &mut Self {
self.series = match (aggregated, series.dtype()) {
(true, &DataType::List(_)) => AggState::Aggregated(series),
(true, &DataType::List(_)) => {
assert_eq!(series.len(), self.groups.len());
AggState::AggregatedList(series)
}
_ => AggState::NotAggregated(series),
};
self
Expand All @@ -214,14 +235,13 @@ impl<'a> AggregationContext<'a> {
}

pub(crate) fn aggregated(&mut self) -> Cow<'_, Series> {
// we do this here because of mutable borrow overlaps.
// The groups are determined lazily and in case of a flat
// we do this here instead of the patter match because of mutable borrow overlaps.
//
// The groups are determined lazily and in case of a flat/non-aggregated
// series we use the groups to aggregate the list
// because this is lazy, we first must to update the groups
// by calling .groups()
if let AggState::NotAggregated(_) = self.series {
self.groups();
}
self.groups();
match &self.series {
AggState::NotAggregated(s) => {
// literal series
Expand Down Expand Up @@ -250,31 +270,25 @@ impl<'a> AggregationContext<'a> {
};
out
}
AggState::Aggregated(s) => Cow::Borrowed(s),
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => Cow::Borrowed(s),
AggState::None => unreachable!(),
}
}

pub(crate) fn flat(&self) -> Cow<'_, Series> {
match &self.series {
AggState::NotAggregated(s) => Cow::Borrowed(s),
AggState::Aggregated(s) => {
// it is not always aggregated as list
// could for instance also be f64 by mean aggregation
if let DataType::List(_) = s.dtype() {
Cow::Owned(s.explode().unwrap())
} else {
Cow::Borrowed(s)
}
}
AggState::AggregatedList(s) => Cow::Owned(s.explode().unwrap()),
AggState::AggregatedFlat(s) => Cow::Borrowed(s),
AggState::None => unreachable!(),
}
}

pub(crate) fn take(&mut self) -> Series {
match std::mem::take(&mut self.series) {
AggState::NotAggregated(s) => s,
AggState::Aggregated(s) => s,
AggState::NotAggregated(s)
| AggState::AggregatedFlat(s)
| AggState::AggregatedList(s) => s,
AggState::None => panic!("implementation error"),
}
}
Expand Down
83 changes: 76 additions & 7 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,7 @@ fn test_filter_count() -> Result<()> {
}

#[test]
#[cfg(feature = "dtype-i16")]
fn test_groupby_small_ints() -> Result<()> {
let df = df![
"id_32" => [1i32, 2],
Expand Down Expand Up @@ -2244,22 +2245,41 @@ fn test_literal_window_fn() -> Result<()> {
}

#[test]
#[should_panic]
fn test_invalid_ternary_in_agg1() {
fn test_binary_agg_context_1() -> Result<()> {
let df = df![
"groups" => [1, 1, 2, 2, 3, 3],
"vals" => [1, 2, 3, 4, 5, 6]
]
.unwrap();

df.lazy()
.groupby([col("groups")])
let out = df
.lazy()
.stable_groupby([col("groups")])
.agg([when(col("vals").first().neq(lit(1)))
.then(lit("a"))
.otherwise(lit("b"))])
.collect();
.otherwise(lit("b"))
.alias("foo")])
.collect()
.unwrap();

let out = out.column("foo")?;
let out = out.explode()?;
let out = out.utf8()?;
assert_eq!(
Vec::from(out),
&[
Some("b"),
Some("b"),
Some("a"),
Some("a"),
Some("a"),
Some("a")
]
);
Ok(())
}

// just like binary expression, this must ben changed. This can work
#[test]
#[should_panic]
fn test_invalid_ternary_in_agg2() {
Expand All @@ -2274,5 +2294,54 @@ fn test_invalid_ternary_in_agg2() {
.agg([when(col("vals").neq(lit(1)))
.then(col("vals").first())
.otherwise(lit("b"))])
.collect();
.collect()
.unwrap();
}

#[test]
fn test_binary_agg_context_2() -> Result<()> {
let df = df![
"groups" => [1, 1, 2, 2, 3, 3],
"vals" => [1, 2, 3, 4, 5, 6]
]?;

// this is complex because we first aggregate one expression of the binary operation.

let out = df
.clone()
.lazy()
.stable_groupby([col("groups")])
.agg([((col("vals").first() - col("vals")).list()).alias("vals")])
.collect()?;

// 0 - [1, 2] = [0, -1]
// 3 - [3, 4] = [0, -1]
// 5 - [5, 6] = [0, -1]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(0), Some(-1), Some(0), Some(-1), Some(0), Some(-1)]
);

// Same, but now we reverse the lhs / rhs.
let out = df
.lazy()
.stable_groupby([col("groups")])
.agg([((col("vals")) - col("vals").first()).list().alias("vals")])
.collect()?;

// [1, 2] - 1 = [0, 1]
// [3, 4] - 3 = [0, 1]
// [5, 6] - 5 = [0, 1]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(0), Some(1), Some(0), Some(1), Some(0), Some(1)]
);

Ok(())
}

0 comments on commit a8d0dd2

Please sign in to comment.