Skip to content

Commit

Permalink
fix invalid grouptuples in lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 16, 2021
1 parent 64d5416 commit fc56f25
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 16 deletions.
13 changes: 12 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ impl PhysicalExpr for ApplyExpr {
if self.inputs.len() == 1 {
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

// a unique or a sort
let mut update_group_tuples = false;

match self.collect_groups {
ApplyOptions::ApplyGroups => {
let mut container = [Default::default()];
Expand All @@ -68,13 +71,21 @@ impl PhysicalExpr for ApplyExpr {
.into_iter()
.map(|opt_s| {
opt_s.and_then(|s| {
let in_len = s.len();
container[0] = s;
self.function.call_udf(&mut container).ok()
self.function.call_udf(&mut container).ok().map(|s| {
if s.len() != in_len {
update_group_tuples = true;
}
s
})
})
})
.collect();

ca.rename(&name);
ac.with_series(ca.into_series(), true);
ac.with_update_groups(UpdateGroups::WithSeriesLen);
Ok(ac)
}
ApplyOptions::ApplyFlat => {
Expand Down
31 changes: 17 additions & 14 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub(crate) mod take;
pub(crate) mod ternary;
pub(crate) mod utils;
pub(crate) mod window;
// pub(crate) mod unique;

use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
Expand All @@ -27,6 +28,7 @@ use polars_io::PhysicalIoExpr;
use std::borrow::Cow;

#[cfg_attr(debug_assertions, derive(Debug))]
#[derive(Clone)]
pub(crate) enum AggState {
/// Already aggregated: `.agg_list(group_tuples` is called
/// and produced a `Series` of dtype `List`
Expand Down Expand Up @@ -231,32 +233,33 @@ impl<'a> AggregationContext<'a> {
// In case of new groups, a series always needs to be flattened
self.with_series(self.flat().into_owned(), false);
self.groups = Cow::Owned(groups);
// make sure that previous setting is not used
self.update_groups = UpdateGroups::No;
self
}

pub(crate) fn aggregated(&mut self) -> Cow<'_, Series> {
// we do this here instead of the pattern match because of mutable borrow overlaps.
//
// The groups are determined lazily and in case of a flat/non-aggregated
// series we use the groups to aggregate the list
// because this is lazy, we first must to update the groups
// by calling .groups()
self.groups();
match &self.series {
AggState::NotAggregated(s) => {
// we clone, because we only want to call `self.groups()` if needed.
// self groups may instantiate new groups and thus can be expensive.
match self.series.clone() {
AggState::NotAggregated(mut s) => {
// The groups are determined lazily and in case of a flat/non-aggregated
// series we use the groups to aggregate the list
// because this is lazy, we first must to update the groups
// by calling .groups()
self.groups();

// literal series
// the literal series needs to be expanded to the number of indices in the groups
let s = if s.len() == 1
if s.len() == 1
// or more then one group
&& (self.groups.len() > 1
// or single groups with more than on index
|| !self.groups.as_ref().is_empty()
&& self.groups[0].1.len() > 1)
{
// todo! optimize this, we don't have to call agg_list, create the list directly.
Cow::Owned(s.expand_at_index(0, self.groups.iter().map(|g| g.1.len()).sum()))
} else {
Cow::Borrowed(s)
s = s.expand_at_index(0, self.groups.iter().map(|g| g.1.len()).sum())
};

let out = Cow::Owned(
Expand All @@ -270,7 +273,7 @@ impl<'a> AggregationContext<'a> {
};
out
}
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => Cow::Borrowed(s),
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => Cow::Owned(s),
AggState::None => unreachable!(),
}
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl PhysicalExpr for SortExpr {
let series = ac.flat().into_owned();

let groups = ac
.groups
.groups()
.iter()
.map(|(_first, idx)| {
// Safety:
Expand Down
86 changes: 86 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/unique.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use std::sync::Arc;

pub struct UniqueExpr {
pub(crate) physical_expr: Arc<dyn PhysicalExpr>,
expr: Expr,
}

impl PhysicalExpr for UniqueExpr {
fn as_expression(&self) -> &Expr {
&self.expr
}

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let series = self.physical_expr.evaluate(df, state)?;
series.unique()
}

#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupTuples,
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let series = ac.flat().into_owned();

let groups = ac
.groups
.iter()
.map(|(_first, idx)| {
// Safety:
// Group tuples are always in bounds
let group =
unsafe { series.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) };

let unique_idx = group.arg_unique()?;

let new_idx: Vec<_> = unique_idx
.cont_slice()
.unwrap()
.iter()
.map(|&i| {
debug_assert!(idx.get(i as usize).is_some());
unsafe { *idx.get_unchecked(i as usize) }
})
.collect();
(new_idx[0], new_idx)
})
.collect();

ac.with_groups(groups);

Ok(ac)
}

fn to_field(&self, input_schema: &Schema) -> Result<Field> {
self.physical_expr.to_field(input_schema)
}

fn as_agg_expr(&self) -> Result<&dyn PhysicalAggregation> {
Ok(self)
}
}
impl PhysicalAggregation for UniqueExpr {
// As a final aggregation a Unique returns a list array.
fn aggregate(
&self,
df: &DataFrame,
groups: &GroupTuples,
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let agg_s = ac.aggregated();
let agg_s = agg_s
.list()
.unwrap()
.apply_amortized(|s| s.as_ref().unique())
.into_series();
Ok(Some(agg_s))
}
}
30 changes: 30 additions & 0 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2438,3 +2438,33 @@ fn test_apply_flatten() -> Result<()> {

Ok(())
}

#[test]
fn test_agg_unique_first() -> Result<()> {
let df = df![
"g"=> [1, 1, 2, 2, 3, 4, 1],
"v"=> [1, 2, 2, 2, 3, 4, 1],
]?;

let out = df
.lazy()
.groupby_stable([col("g")])
.agg([
col("v").unique().first(),
col("v").unique().sort(false).first().alias("true_first"),
col("v").unique().list(),
])
.collect()?;

let a = out.column("v_first").unwrap();
let a = a.sum::<i32>().unwrap();
// can be both because unique does not guarantee order
assert!(a == 10 || a == 11);

let a = out.column("true_first").unwrap();
let a = a.sum::<i32>().unwrap();
// can be both because unique does not guarantee order
assert_eq!(a, 10);

Ok(())
}

0 comments on commit fc56f25

Please sign in to comment.