Skip to content

Commit

Permalink
fix take
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 2, 2022
1 parent 0f8d526 commit 78b991b
Show file tree
Hide file tree
Showing 15 changed files with 353 additions and 193 deletions.
9 changes: 9 additions & 0 deletions polars/polars-core/src/frame/groupby/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,18 @@ impl GroupsIdx {
self.into_iter()
}

pub fn all(&self) -> &[Vec<u32>] {
&self.all
}

pub fn first(&self) -> &[u32] {
&self.first
}

pub(crate) fn len(&self) -> usize {
self.first.len()
}

pub(crate) unsafe fn get_unchecked(&self, index: usize) -> BorrowIdxItem {
let first = *self.first.get_unchecked(index);
let all = self.all.get_unchecked(index);
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ impl Expr {
pub fn arg_sort(self, reverse: bool) -> Self {
assert!(
!has_expr(&self, |e| matches!(e, Expr::Wildcard)),
"wildcard not supported in unique expr"
"wildcard not supported in argsort expr"
);
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl PhysicalAggregation for AggregationExpr {
Ok(opt_agg)
}
GroupByMethod::List => {
let agg = ac.aggregated().into_owned();
let agg = ac.aggregated();
Ok(rename_option_series(Some(agg), &new_name))
}
GroupByMethod::Groups => {
Expand Down
11 changes: 3 additions & 8 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ impl PhysicalExpr for ApplyExpr {
Ok(ac)
}
ApplyOptions::ApplyList => {
let s = self
.function
.call_udf(&mut [ac.aggregated().into_owned()])?;
let s = self.function.call_udf(&mut [ac.aggregated()])?;
ac.with_series(s, true);
Ok(ac)
}
Expand Down Expand Up @@ -188,10 +186,7 @@ impl PhysicalExpr for ApplyExpr {
Ok(ac)
}
ApplyOptions::ApplyList => {
let mut s = acs
.iter_mut()
.map(|ac| ac.aggregated().into_owned())
.collect::<Vec<_>>();
let mut s = acs.iter_mut().map(|ac| ac.aggregated()).collect::<Vec<_>>();
let s = self.function.call_udf(&mut s)?;
let mut ac = acs.pop().unwrap();
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Expand All @@ -218,7 +213,7 @@ impl PhysicalAggregation for ApplyExpr {
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.evaluate_on_groups(df, groups, state)?;
let s = ac.aggregated().into_owned();
let s = ac.aggregated();
Ok(Some(s))
}
}
12 changes: 8 additions & 4 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ impl PhysicalExpr for BinaryExpr {
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`

// if the groups_len == df.len we can just apply all flat.
(AggState::AggregatedFlat(s), AggState::NotAggregated(_)) if s.len() != df.height() => {
(AggState::AggregatedFlat(s), AggState::NotAggregated(_) | AggState::Literal(_))
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
let l = ac_l.aggregated();
let l = l.as_ref();
Expand Down Expand Up @@ -146,7 +148,9 @@ impl PhysicalExpr for BinaryExpr {
Ok(ac_l)
}
// if the groups_len == df.len we can just apply all flat.
(AggState::NotAggregated(_), AggState::AggregatedFlat(s)) if s.len() != df.height() => {
(AggState::NotAggregated(_) | AggState::Literal(_), AggState::AggregatedFlat(s))
if s.len() != df.height() =>
{
// this is now a list
let l = ac_l.aggregated();
let l = l.list().unwrap();
Expand Down Expand Up @@ -189,8 +193,8 @@ impl PhysicalExpr for BinaryExpr {
ac_l.with_series(ca.into_series(), true);
Ok(ac_l)
}
(AggState::AggregatedList(_), AggState::NotAggregated(_))
| (AggState::NotAggregated(_), AggState::AggregatedList(_)) => {
(AggState::AggregatedList(_), AggState::NotAggregated(_) | AggState::Literal(_))
| (AggState::NotAggregated(_) | AggState::Literal(_), AggState::AggregatedList(_)) => {
ac_l.sort_by_groups();
ac_r.sort_by_groups();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl PhysicalExpr for LiteralExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let s = self.evaluate(df, state)?;
Ok(AggregationContext::new(s, Cow::Borrowed(groups), false))
Ok(AggregationContext::from_literal(s, Cow::Borrowed(groups)))
}

fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
Expand Down
84 changes: 43 additions & 41 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub(crate) enum AggState {
AggregatedFlat(Series),
/// Not yet aggregated: `agg_list` still has to be called.
NotAggregated(Series),
None,
Literal(Series),
}

impl AggState {
Expand Down Expand Up @@ -73,7 +73,7 @@ pub(crate) enum UpdateGroups {

impl Default for AggState {
fn default() -> Self {
AggState::None
AggState::Literal(Series::default())
}
}

Expand All @@ -82,7 +82,7 @@ pub struct AggregationContext<'a> {
/// Can be in one of two states
/// 1. already aggregated as list
/// 2. flat (still needs the grouptuples to aggregate)
series: AggState,
state: AggState,
/// group tuples for AggState
groups: Cow<'a, GroupsProxy>,
/// if the group tuples are already used in a level above
Expand Down Expand Up @@ -210,20 +210,20 @@ impl<'a> AggregationContext<'a> {
}

pub(crate) fn series(&self) -> &Series {
match &self.series {
match &self.state {
AggState::NotAggregated(s)
| AggState::AggregatedFlat(s)
| AggState::AggregatedList(s) => s,
AggState::None => unreachable!(),
AggState::Literal(s) => s,
}
}

pub(crate) fn agg_state(&self) -> &AggState {
&self.series
&self.state
}

pub(crate) fn is_not_aggregated(&self) -> bool {
matches!(&self.series, AggState::NotAggregated(_))
matches!(&self.state, AggState::NotAggregated(_))
}

pub(crate) fn is_aggregated(&self) -> bool {
Expand All @@ -240,7 +240,7 @@ impl<'a> AggregationContext<'a> {
/// # Arguments
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
/// the columns dtype)
pub(crate) fn new(
fn new(
series: Series,
groups: Cow<'a, GroupsProxy>,
aggregated: bool,
Expand All @@ -258,7 +258,18 @@ impl<'a> AggregationContext<'a> {
};

Self {
series,
state: series,
groups,
sorted: false,
update_groups: UpdateGroups::No,
original_len: true,
all_unit_len: false,
}
}

fn from_literal(lit: Series, groups: Cow<'a, GroupsProxy>) -> AggregationContext<'a> {
Self {
state: AggState::Literal(lit),
groups,
sorted: false,
update_groups: UpdateGroups::No,
Expand Down Expand Up @@ -286,26 +297,26 @@ impl<'a> AggregationContext<'a> {
/// Calling this function will ensure both are sortened. This will be a no-op
/// if already aggregated.
pub(crate) fn sort_by_groups(&mut self) {
match &self.series {
match &self.state {
AggState::NotAggregated(s) => {
// We should not aggregate literals!!
if self.series.safe_to_agg(&self.groups) {
if self.state.safe_to_agg(&self.groups) {
let agg = s.agg_list(&self.groups).unwrap();
self.update_groups = UpdateGroups::WithGroupsLen;
self.series = AggState::AggregatedList(agg);
self.state = AggState::AggregatedList(agg);
}
}
AggState::AggregatedFlat(_) => {}
AggState::AggregatedList(_) => {}
AggState::None => {}
AggState::Literal(_) => {}
}
}

/// # Arguments
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
/// the columns dtype)
pub(crate) fn with_series(&mut self, series: Series, aggregated: bool) -> &mut Self {
self.series = match (aggregated, series.dtype()) {
self.state = match (aggregated, series.dtype()) {
(true, &DataType::List(_)) => {
assert_eq!(series.len(), self.groups.len());
AggState::AggregatedList(series)
Expand All @@ -314,7 +325,7 @@ impl<'a> AggregationContext<'a> {
_ => {
// already aggregated to sum, min even this series was flattened it never could
// retrieve the length before grouping, so it stays in this state.
if let AggState::AggregatedFlat(_) = self.series {
if let AggState::AggregatedFlat(_) = self.state {
AggState::AggregatedFlat(series)
} else {
AggState::NotAggregated(series)
Expand All @@ -335,43 +346,34 @@ impl<'a> AggregationContext<'a> {
}

/// Get the aggregated version of the series.
pub(crate) fn aggregated(&mut self) -> Cow<'_, Series> {
pub(crate) fn aggregated(&mut self) -> Series {
// we clone, because we only want to call `self.groups()` if needed.
// self groups may instantiate new groups and thus can be expensive.
match self.series.clone() {
AggState::NotAggregated(mut s) => {
match self.state.clone() {
AggState::NotAggregated(s) => {
// The groups are determined lazily and in case of a flat/non-aggregated
// series we use the groups to aggregate the list
// because this is lazy, we first must to update the groups
// by calling .groups()
self.groups();

// literal series
// the literal series needs to be expanded to the number of indices in the groups
if s.len() == 1
// or more then one group
&& (self.groups.len() > 1
// or single groups with more than one index
|| !self.groups.as_ref().is_empty()
&& self.groups.get(0).len() > 1)
{
// todo! optimize this, we don't have to call agg_list, create the list directly.
s = s.expand_at_index(0, self.groups.iter().map(|g| g.len()).sum())
};

let out = Cow::Owned(
s.agg_list(&self.groups)
.expect("should be able to aggregate this to list"),
);
let out = s
.agg_list(&self.groups)
.expect("should be able to aggregate this to list");

if !self.sorted {
self.sorted = true;
self.update_groups = UpdateGroups::WithGroupsLen;
};
out
}
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => Cow::Owned(s),
AggState::None => unreachable!(),
AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => s,
AggState::Literal(s) => {
self.groups();
// todo! optimize this, we don't have to call agg_list, create the list directly.
let s = s.expand_at_index(0, self.groups.iter().map(|g| g.len()).sum());
s.agg_list(&self.groups).unwrap()
}
}
}

Expand All @@ -380,21 +382,21 @@ impl<'a> AggregationContext<'a> {
/// has filtered or sorted this, this information is in the
/// group tuples not the flattened series.
pub(crate) fn flat_naive(&self) -> Cow<'_, Series> {
match &self.series {
match &self.state {
AggState::NotAggregated(s) => Cow::Borrowed(s),
AggState::AggregatedList(s) => Cow::Owned(s.explode().unwrap()),
AggState::AggregatedFlat(s) => Cow::Borrowed(s),
AggState::None => unreachable!(),
AggState::Literal(s) => Cow::Borrowed(s),
}
}

/// Take the series.
pub(crate) fn take(&mut self) -> Series {
match std::mem::take(&mut self.series) {
match std::mem::take(&mut self.state) {
AggState::NotAggregated(s)
| AggState::AggregatedFlat(s)
| AggState::AggregatedList(s) => s,
AggState::None => panic!("implementation error"),
AggState::Literal(s) => s,
}
}
}
Expand Down
15 changes: 8 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,20 @@ impl PhysicalExpr for SliceExpr {
GroupsProxy::Idx(groups) => {
let groups = groups
.iter()
.map(|(_, idx)| {
let (offset, len) = slice_offsets(self.offset, self.len, idx.len());
(offset as u32, idx[offset..offset + len].to_vec())
.map(|(first, idx)| {
let (offset, len) = slice_offsets(self.offset as i64, self.len, idx.len());
(first + offset as u32, idx[offset..offset + len].to_vec())
})
.collect();
GroupsProxy::Idx(groups)
}
GroupsProxy::Slice(groups) => {
let groups = groups
.iter()
.map(|&[_first, len]| {
let (offset, len) = slice_offsets(self.offset, self.len, len as usize);
[offset as u32, len as u32]
.map(|&[first, len]| {
let (offset, len) =
slice_offsets(self.offset as i64, self.len, len as usize);
[first + offset as u32, len as u32]
})
.collect_trusted();
GroupsProxy::Slice(groups)
Expand Down Expand Up @@ -76,7 +77,7 @@ impl PhysicalAggregation for SliceExpr {
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.evaluate_on_groups(df, groups, state)?;
let s = ac.aggregated().into_owned();
let s = ac.aggregated();
Ok(Some(s))
}
}

0 comments on commit 78b991b

Please sign in to comment.