Skip to content

Commit

Permalink
ensure we always work on groups in the groupby context
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 31, 2021
1 parent b845c6f commit 4848b99
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 91 deletions.
16 changes: 0 additions & 16 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ mod series_trait;
pub mod unstable;

use crate::chunked_array::ops::rolling_window::RollingOptions;
use crate::frame::groupby::GroupTuples;
#[cfg(feature = "rank")]
use crate::prelude::unique::rank::rank;
#[cfg(feature = "groupby_list")]
Expand Down Expand Up @@ -762,21 +761,6 @@ impl Series {
};
Ok(out)
}

/// Take the group tuples.
///
/// # Safety
/// - Group tuples have to be in bounds.
pub unsafe fn take_group_values(&self, groups: &GroupTuples) -> Series {
let len = groups.iter().map(|g| g.1.len()).sum::<usize>();
self.take_iter_unchecked(
&mut groups
.iter()
.map(|g| g.1.iter().map(|idx| *idx as usize))
.flatten()
.trust_my_length(len),
)
}
}

impl Deref for Series {
Expand Down
49 changes: 8 additions & 41 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use rayon::prelude::*;
use std::convert::TryFrom;
use std::sync::Arc;

pub struct ApplyExpr {
Expand Down Expand Up @@ -239,51 +238,19 @@ impl PhysicalAggregation for ApplyExpr {
ApplyOptions::ApplyGroups => {
let mut container = vec![Default::default(); acs.len()];
let name = acs[0].series().name().to_string();
let first_len = acs[0].len();

// the arguments of an apply can be a group, but can also be the result of a separate aggregation
// in the last case we may not aggregate, but see that Series as its own input.
// this part we make sure that we get owned series in the proper state (aggregated or not aggregated)
// so that we can make iterators from them next.
let owned_series = acs
.iter_mut()
.map(|ac| {
let not_aggregated_len = ac.len();
let original_len = ac.is_original_len();

// this branch we see the argument per group, so we must aggregate
// every group will have a different argument
let s = if not_aggregated_len == first_len && original_len {
ac.aggregated()
// this branch we see the argument as a constant, that will be applied per group
} else {
ac.flat_corrected()
};
(s, not_aggregated_len, original_len)
})
.collect::<Vec<_>>();
// Don't ever try to be smart here.
// Every argument needs to be aggregated; period.
// We only work on groups in the groupby context.
// If the argument is a literal use `map`
let owned_series = acs.iter_mut().map(|ac| ac.aggregated()).collect::<Vec<_>>();

// now we make the iterators
let mut iters = owned_series
.iter()
.map(|(s, not_aggregated_len, original_len)| {
// this branch we see the arguments per group. every group has a different argument
if *not_aggregated_len == first_len && *original_len {
let ca = s.list().unwrap();
Box::new(
ca.downcast_iter()
.map(|arr| arr.iter())
.flatten()
.map(|arr| {
arr.map(|arr| Series::try_from(("", arr)).unwrap())
}),
)
as Box<dyn Iterator<Item = Option<Series>>>
// this branch we repeat the argument per group
} else {
let s = s.clone().into_owned();
Box::new(std::iter::repeat(Some(s)))
}
.map(|s| {
let ca = s.list().unwrap();
ca.into_iter()
})
.collect::<Vec<_>>();

Expand Down
38 changes: 4 additions & 34 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub(crate) mod window;

use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
use polars_arrow::array::ValueSize;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use polars_io::PhysicalIoExpr;
Expand Down Expand Up @@ -131,7 +130,7 @@ impl<'a> AggregationContext<'a> {
match (
&self.groups,
self.sorted,
self.original_len,
self.is_original_len(),
&other.groups,
other.sorted,
other.original_len,
Expand Down Expand Up @@ -286,6 +285,9 @@ impl<'a> AggregationContext<'a> {
}

/// Get the not-aggregated version of the series.
/// Note that we call it naive, because if a previous expr
/// has filtered or sorted this, this information is in the
/// group tuples not the flattened series.
pub(crate) fn flat_naive(&self) -> Cow<'_, Series> {
match &self.series {
AggState::NotAggregated(s) => Cow::Borrowed(s),
Expand All @@ -295,38 +297,6 @@ impl<'a> AggregationContext<'a> {
}
}

/// Get the not-aggregated version of the series.
/// This corrects for filtered data. Note that the group tuples
/// can not be used on this Series.
pub(crate) fn flat_corrected(&mut self) -> Cow<'_, Series> {
match (&self.series, self.is_original_len()) {
(AggState::NotAggregated(s), false) => {
let s = s.clone();
let groups = self.groups();
Cow::Owned(unsafe { s.take_group_values(groups) })
}
(AggState::AggregatedList(s), false) => {
let s = s.explode().unwrap();
let groups = self.groups();
Cow::Owned(unsafe { s.take_group_values(groups) })
}
_ => self.flat_naive(),
}
}

/// Get the length of the Series when it is not aggregated
pub(crate) fn len(&self) -> usize {
match &self.series {
AggState::NotAggregated(s) => s.len(),
AggState::AggregatedFlat(s) => s.len(),
AggState::AggregatedList(s) => {
let list = s.list().unwrap();
list.get_values_size()
}
AggState::None => unreachable!(),
}
}

/// Take the series.
pub(crate) fn take(&mut self) -> Series {
match std::mem::take(&mut self.series) {
Expand Down

0 comments on commit 4848b99

Please sign in to comment.