Skip to content

Commit

Permalink
Lazy: fix apply for expr arguments that are filtered
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 31, 2021
1 parent 82ad6f9 commit 1edaf66
Show file tree
Hide file tree
Showing 16 changed files with 87 additions and 42 deletions.
7 changes: 3 additions & 4 deletions polars/polars-arrow/src/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow::bitmap::MutableBitmap;
use arrow::buffer::Buffer;
use arrow::datatypes::{DataType, PhysicalType};
use arrow::types::NativeType;
use std::iter::FromIterator;
use std::sync::Arc;

/// # Safety
Expand Down Expand Up @@ -290,7 +291,6 @@ pub unsafe fn take_no_null_bool_opt_iter_unchecked<I: IntoIterator<Item = Option

/// # Safety
/// - no bounds checks
/// - iterator must be TrustedLen
#[inline]
pub unsafe fn take_no_null_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
arr: &LargeStringArray,
Expand All @@ -300,12 +300,11 @@ pub unsafe fn take_no_null_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
debug_assert!(idx < arr.len());
arr.value_unchecked(idx)
});
Arc::new(MutableUtf8Array::<i64>::from_trusted_len_values_iter_unchecked(iter).into())
Arc::new(MutableUtf8Array::<i64>::from_iter_values(iter).into())
}

/// # Safety
/// - no bounds checks
/// - iterator must be TrustedLen
#[inline]
pub unsafe fn take_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
arr: &LargeStringArray,
Expand All @@ -321,7 +320,7 @@ pub unsafe fn take_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
}
});

Arc::new(LargeStringArray::from_trusted_len_iter_unchecked(iter))
Arc::new(LargeStringArray::from_iter(iter))
}

/// # Safety
Expand Down
14 changes: 14 additions & 0 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod series_trait;
pub mod unstable;

use crate::chunked_array::ops::rolling_window::RollingOptions;
use crate::frame::groupby::GroupTuples;
#[cfg(feature = "rank")]
use crate::prelude::unique::rank::rank;
#[cfg(feature = "groupby_list")]
Expand Down Expand Up @@ -761,6 +762,19 @@ impl Series {
};
Ok(out)
}

/// Take the group tuples.
///
/// # Safety
/// Group tuples have to be in bounds.
pub unsafe fn take_group_values(&self, groups: &GroupTuples) -> Series {
self.take_iter_unchecked(
&mut groups
.iter()
.map(|g| g.1.iter().map(|idx| *idx as usize))
.flatten(),
)
}
}

impl Deref for Series {
Expand Down
20 changes: 10 additions & 10 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,23 @@ impl PhysicalAggregation for AggregationExpr {

match self.agg_type {
GroupByMethod::Min => {
let agg_s = ac.flat().into_owned().agg_min(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_min(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Max => {
let agg_s = ac.flat().into_owned().agg_max(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_max(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Median => {
let agg_s = ac.flat().into_owned().agg_median(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_median(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Mean => {
let agg_s = ac.flat().into_owned().agg_mean(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_mean(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Sum => {
let agg_s = ac.flat().into_owned().agg_sum(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_sum(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Count => {
Expand All @@ -96,17 +96,17 @@ impl PhysicalAggregation for AggregationExpr {
Ok(Some(ca.into_inner().into_series()))
}
GroupByMethod::First => {
let mut agg_s = ac.flat().into_owned().agg_first(ac.groups());
let mut agg_s = ac.flat_naive().into_owned().agg_first(ac.groups());
agg_s.rename(&new_name);
Ok(Some(agg_s))
}
GroupByMethod::Last => {
let mut agg_s = ac.flat().into_owned().agg_last(ac.groups());
let mut agg_s = ac.flat_naive().into_owned().agg_last(ac.groups());
agg_s.rename(&new_name);
Ok(Some(agg_s))
}
GroupByMethod::NUnique => {
let opt_agg = ac.flat().into_owned().agg_n_unique(ac.groups());
let opt_agg = ac.flat_naive().into_owned().agg_n_unique(ac.groups());
let opt_agg = opt_agg.map(|mut agg| {
agg.rename(&new_name);
agg.into_series()
Expand All @@ -131,11 +131,11 @@ impl PhysicalAggregation for AggregationExpr {
Ok(Some(column.into_series()))
}
GroupByMethod::Std => {
let agg_s = ac.flat().into_owned().agg_std(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_std(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Var => {
let agg_s = ac.flat().into_owned().agg_var(ac.groups());
let agg_s = ac.flat_naive().into_owned().agg_var(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
}
GroupByMethod::Quantile(_, _) => {
Expand Down
14 changes: 7 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use rayon::prelude::*;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::sync::Arc;

Expand Down Expand Up @@ -91,7 +90,9 @@ impl PhysicalExpr for ApplyExpr {
Ok(ac)
}
ApplyOptions::ApplyFlat => {
let s = self.function.call_udf(&mut [ac.flat().into_owned()])?;
let s = self
.function
.call_udf(&mut [ac.flat_naive().into_owned()])?;
if ac.is_aggregated() {
ac.with_update_groups(UpdateGroups::WithGroupsLen);
}
Expand Down Expand Up @@ -148,7 +149,7 @@ impl PhysicalExpr for ApplyExpr {
ApplyOptions::ApplyFlat => {
let mut s = acs
.iter()
.map(|ac| ac.flat().into_owned())
.map(|ac| ac.flat_naive().into_owned())
.collect::<Vec<_>>();

let s = self.function.call_udf(&mut s)?;
Expand Down Expand Up @@ -216,7 +217,7 @@ impl PhysicalAggregation for ApplyExpr {
// if its flat, we just apply and return
// if not flat, the flattening sorts by group, so we must create new group tuples
// and again aggregate.
let out = self.function.call_udf(&mut [ac.flat().into_owned()]);
let out = self.function.call_udf(&mut [ac.flat_naive().into_owned()]);

if ac.is_not_aggregated() || !matches!(ac.series().dtype(), DataType::List(_)) {
out.map(Some)
Expand Down Expand Up @@ -256,7 +257,7 @@ impl PhysicalAggregation for ApplyExpr {
ac.aggregated()
// this branch we see the argument as a constant, that will be applied per group
} else {
Cow::Borrowed(ac.series())
ac.flat_corrected()
};
(s, not_aggregated_len, original_len)
})
Expand All @@ -280,7 +281,6 @@ impl PhysicalAggregation for ApplyExpr {
as Box<dyn Iterator<Item = Option<Series>>>
// this branch we repeat the argument per group
} else {
dbg!("here");
let s = s.clone().into_owned();
Box::new(std::iter::repeat(Some(s)))
}
Expand Down Expand Up @@ -323,7 +323,7 @@ impl PhysicalAggregation for ApplyExpr {
"flat apply on any expression that is already \
in aggregated state is not yet suported"
);
ac.flat().into_owned()
ac.flat_naive().into_owned()
})
.collect::<Vec<_>>();

Expand Down
6 changes: 5 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ impl PhysicalExpr for BinaryExpr {
// Both are or a flat series or aggregated into a list
// so we can flatten the Series and apply the operators
_ => {
let out = apply_operator(ac_l.flat().as_ref(), ac_r.flat().as_ref(), self.op)?;
let out = apply_operator(
ac_l.flat_naive().as_ref(),
ac_r.flat_naive().as_ref(),
self.op,
)?;
ac_l.combine_groups(ac_r).with_series(out, false);
Ok(ac_l)
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl PhysicalExpr for CastExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
let s = ac.flat();
let s = ac.flat_naive();
let s = self.finish(s.as_ref())?;
ac.with_series(s, false);
Ok(ac)
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PhysicalExpr for FilterExpr {
let mut ac_s = self.input.evaluate_on_groups(df, groups, state)?;
let ac_predicate = self.by.evaluate_on_groups(df, groups, state)?;
let groups = ac_s.groups();
let predicate_s = ac_predicate.flat();
let predicate_s = ac_predicate.flat_naive();
let predicate = predicate_s.bool()?;

let groups = POOL.install(|| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl PhysicalExpr for IsNotNullExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let s = ac.flat();
let s = ac.flat_naive();
let s = s.is_not_null().into_series();
ac.with_series(s, false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl PhysicalExpr for IsNullExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let s = ac.flat();
let s = ac.flat_naive();
let s = s.is_null().into_series();
ac.with_series(s, false);

Expand Down
23 changes: 21 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl<'a> AggregationContext<'a> {
/// Update the group tuples
pub(crate) fn with_groups(&mut self, groups: GroupTuples) -> &mut Self {
// In case of new groups, a series always needs to be flattened
self.with_series(self.flat().into_owned(), false);
self.with_series(self.flat_naive().into_owned(), false);
self.groups = Cow::Owned(groups);
// make sure that previous setting is not used
self.update_groups = UpdateGroups::No;
Expand Down Expand Up @@ -286,7 +286,7 @@ impl<'a> AggregationContext<'a> {
}

/// Get the not-aggregated version of the series.
pub(crate) fn flat(&self) -> Cow<'_, Series> {
pub(crate) fn flat_naive(&self) -> Cow<'_, Series> {
match &self.series {
AggState::NotAggregated(s) => Cow::Borrowed(s),
AggState::AggregatedList(s) => Cow::Owned(s.explode().unwrap()),
Expand All @@ -295,6 +295,25 @@ impl<'a> AggregationContext<'a> {
}
}

/// Get the not-aggregated version of the series.
/// This corrects for filtered data. Note that the group tuples
/// can not be used on this Series.
pub(crate) fn flat_corrected(&mut self) -> Cow<'_, Series> {
match (&self.series, self.is_original_len()) {
(AggState::NotAggregated(s), false) => {
let s = s.clone();
let groups = self.groups();
Cow::Owned(unsafe { s.take_group_values(groups) })
}
(AggState::AggregatedList(s), false) => {
let s = s.explode().unwrap();
let groups = self.groups();
Cow::Owned(unsafe { s.take_group_values(groups) })
}
_ => self.flat_naive(),
}
}

/// Get the length of the Series when it is not aggregated
pub(crate) fn len(&self) -> usize {
match &self.series {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl PhysicalExpr for NotExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.0.evaluate_on_groups(df, groups, state)?;
let s = ac.flat().into_owned();
let s = ac.flat_naive().into_owned();
ac.with_series(self.finish(s)?, false);

Ok(ac)
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl PhysicalExpr for SortExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let series = ac.flat().into_owned();
let series = ac.flat_naive().into_owned();

let groups = ac
.groups()
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl PhysicalExpr for SortByExpr {

let groups = if self.by.len() == 1 {
let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?;
let sort_by_s = ac_sort_by.flat().into_owned();
let sort_by_s = ac_sort_by.flat_naive().into_owned();
let groups = ac_sort_by.groups();

groups
Expand Down Expand Up @@ -112,7 +112,7 @@ impl PhysicalExpr for SortByExpr {
.collect::<Result<Vec<_>>>()?;
let sort_by_s = ac_sort_by
.iter()
.map(|s| s.flat().into_owned())
.map(|s| s.flat_naive().into_owned())
.collect::<Vec<_>>();
let groups = ac_sort_by[0].groups();

Expand Down
6 changes: 4 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl PhysicalExpr for TernaryExpr {
let mut ac_truthy = ac_truthy?;
let mut ac_falsy = ac_falsy?;

let mask_s = ac_mask.flat();
let mask_s = ac_mask.flat_naive();

assert!(
!(mask_s.len() != required_height),
Expand Down Expand Up @@ -182,7 +182,9 @@ The predicate produced {} values. Where the original DataFrame has {} values",
// so we can flatten the Series an apply the operators
_ => {
let mask = mask_s.bool()?;
let out = ac_truthy.flat().zip_with(mask, ac_falsy.flat().as_ref())?;
let out = ac_truthy
.flat_naive()
.zip_with(mask, ac_falsy.flat_naive().as_ref())?;

assert!(!(out.len() != required_height), "The output of the `when -> then -> otherwise-expr` is of a different length than the groups.\
The expr produced {} values. Where the original DataFrame has {} values",
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/unique.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl PhysicalExpr for UniqueExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let series = ac.flat().into_owned();
let series = ac.flat_naive().into_owned();

let groups = ac
.groups
Expand Down
21 changes: 14 additions & 7 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2507,13 +2507,20 @@ fn test_parquet_exec() -> Result<()> {
fn test_is_in() -> Result<()> {
let df = fruits_cars();

// TODO! fix this
// // this will be executed by apply (still incorrect)
// let out = df
// .lazy()
// .groupby_stable([col("fruits")])
// .agg([col("cars").is_in(col("cars").filter(col("cars").eq(lit("beetle"))))])
// .collect()?;
// // this will be executed by apply
let out = df
.clone()
.lazy()
.groupby_stable([col("fruits")])
.agg([col("cars").is_in(col("cars").filter(col("cars").eq(lit("beetle"))))])
.collect()?;
let out = out.column("cars").unwrap();
let out = out.explode()?;
let out = out.bool().unwrap();
assert_eq!(
Vec::from(out),
&[Some(true), Some(false), Some(true), Some(true), Some(true)]
);

// this will be executed by map
let out = df
Expand Down

0 comments on commit 1edaf66

Please sign in to comment.