Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support negative indices in gather in group_by context #13373

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ pub struct AggregationContext<'a> {
}

impl<'a> AggregationContext<'a> {
pub(crate) fn dtype(&self) -> DataType {
match &self.state {
AggState::Literal(s) => s.dtype().clone(),
AggState::AggregatedList(s) => s.list().unwrap().inner_dtype(),
AggState::AggregatedScalar(s) => s.dtype().clone(),
AggState::NotAggregated(s) => s.dtype().clone(),
}
}
pub(crate) fn groups(&mut self) -> &Cow<'a, GroupsProxy> {
match self.update_groups {
UpdateGroups::No => {},
Expand Down
312 changes: 182 additions & 130 deletions crates/polars-lazy/src/physical_plan/expressions/take.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::sync::Arc;

use arrow::legacy::utils::CustomIterTools;
use polars_core::chunked_array::builder::get_list_builder;
use polars_core::frame::group_by::GroupsProxy;
use polars_core::prelude::*;
use polars_core::utils::NoNull;
use polars_ops::prelude::convert_to_unsigned_index;
use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain};

use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
Expand All @@ -16,23 +17,6 @@ pub struct TakeExpr {
pub(crate) returns_scalar: bool,
}

impl TakeExpr {
fn finish(
&self,
df: &DataFrame,
state: &ExecutionState,
series: Series,
) -> PolarsResult<Series> {
let idx = self.idx.evaluate(df, state)?;
let idx = convert_to_unsigned_index(&idx, series.len())?;
series.take(&idx)
}

fn oob_err(&self) -> PolarsResult<()> {
polars_bail!(expr = self.expr, OutOfBounds: "index out of bounds");
}
}

impl PhysicalExpr for TakeExpr {
fn as_expression(&self) -> Option<&Expr> {
Some(&self.expr)
Expand All @@ -52,63 +36,24 @@ impl PhysicalExpr for TakeExpr {
let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;
let mut idx = self.idx.evaluate_on_groups(df, groups, state)?;

let s_idx = idx.series();
match s_idx.dtype() {
DataType::List(inner) => {
polars_ensure!(inner.is_integer(), InvalidOperation: "expected numeric dtype as index, got {:?}", inner)
},
dt if dt.is_integer() => {
// Unsigned integers will fall through and will use faster paths.
if !is_positive_idx_uncertain(s_idx) {
return self.process_negative_indices_agg(ac, idx, groups);
}
},
dt => polars_bail!(InvalidOperation: "expected numeric dtype as index, got {:?}", dt),
}

let idx = match idx.state {
AggState::AggregatedScalar(s) => {
let idx = s.cast(&IDX_DTYPE)?;
if s.null_count() != idx.null_count() {
polars_warn!("negative indexing not yet supported in group-by context")
}
let idx = idx.idx().unwrap();

// The indexes are AggregatedScalar, meaning they are a single values pointing into
// a group. If we zip this with the first of each group -> `idx + firs` then we can
// simply use a take operation on the whole array instead of per group.

// The groups maybe scattered all over the place, so we sort by group.
ac.sort_by_groups();

// A previous aggregation may have updated the groups.
let groups = ac.groups();

// Determine the gather indices.
let idx: IdxCa = match groups.as_ref() {
GroupsProxy::Idx(groups) => {
if groups.all().iter().zip(idx).any(|(g, idx)| match idx {
None => true,
Some(idx) => idx >= g.len() as IdxSize,
}) {
self.oob_err()?;
}

idx.into_iter()
.zip(groups.first().iter())
.map(|(idx, first)| idx.map(|idx| idx + first))
.collect_trusted()
},
GroupsProxy::Slice { groups, .. } => {
if groups.iter().zip(idx).any(|(g, idx)| match idx {
None => true,
Some(idx) => idx >= g[1],
}) {
self.oob_err()?;
}

idx.into_iter()
.zip(groups.iter())
.map(|(idx, g)| idx.map(|idx| idx + g[0]))
.collect_trusted()
},
};
let taken = ac.flat_naive().take(&idx)?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
return Ok(ac);
return self.process_positive_indices_agg_scalar(ac, idx.idx().unwrap());
},
AggState::AggregatedList(s) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
Expand All @@ -122,64 +67,7 @@ impl PhysicalExpr for TakeExpr {
},
AggState::Literal(s) => {
let idx = s.cast(&IDX_DTYPE)?;
if s.null_count() != idx.null_count() {
polars_warn!("negative indexing not yet supported in group-by context")
}
let idx = idx.idx().unwrap();

return if idx.len() == 1 {
match idx.get(0) {
None => polars_bail!(ComputeError: "cannot take by a null"),
Some(idx) => {
if idx != 0 {
// We must make sure that the column we take from is sorted by
// groups otherwise we might point into the wrong group.
ac.sort_by_groups()
}
// Make sure that we look at the updated groups.
let groups = ac.groups();

// We offset the groups first by idx.
let idx: NoNull<IdxCa> = match groups.as_ref() {
GroupsProxy::Idx(groups) => {
if groups.all().iter().any(|g| idx >= g.len() as IdxSize) {
self.oob_err()?;
}

groups.first().iter().map(|f| *f + idx).collect_trusted()
},
GroupsProxy::Slice { groups, .. } => {
if groups.iter().any(|g| idx >= g[1]) {
self.oob_err()?;
}

groups.iter().map(|g| g[0] + idx).collect_trusted()
},
};
let taken = ac.flat_naive().take(&idx.into_inner())?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
},
}
} else {
let out = ac
.aggregated()
.list()
.unwrap()
.try_apply_amortized(|s| s.as_ref().take(idx))?;

ac.with_series(out.into_series(), true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
};
return self.process_positive_indices_agg_literal(ac, idx.idx().unwrap());
},
};

Expand Down Expand Up @@ -207,3 +95,167 @@ impl PhysicalExpr for TakeExpr {
self.phys_expr.to_field(input_schema)
}
}

impl TakeExpr {
fn finish(
&self,
df: &DataFrame,
state: &ExecutionState,
series: Series,
) -> PolarsResult<Series> {
let idx = self.idx.evaluate(df, state)?;
let idx = convert_to_unsigned_index(&idx, series.len())?;
series.take(&idx)
}

fn oob_err(&self) -> PolarsResult<()> {
polars_bail!(expr = self.expr, OutOfBounds: "index out of bounds");
}

fn process_positive_indices_agg_scalar<'b>(
&self,
mut ac: AggregationContext<'b>,
idx: &IdxCa,
) -> PolarsResult<AggregationContext<'b>> {
// The indexes are AggregatedScalar, meaning they are a single values pointing into
// a group. If we zip this with the first of each group -> `idx + first` then we can
// simply use a take operation on the whole array instead of per group.

// The groups maybe scattered all over the place, so we sort by group.
ac.sort_by_groups();

// A previous aggregation may have updated the groups.
let groups = ac.groups();

// Determine the gather indices.
let idx: IdxCa = match groups.as_ref() {
GroupsProxy::Idx(groups) => {
if groups.all().iter().zip(idx).any(|(g, idx)| match idx {
None => true,
Some(idx) => idx >= g.len() as IdxSize,
}) {
self.oob_err()?;
}

idx.into_iter()
.zip(groups.first().iter())
.map(|(idx, first)| idx.map(|idx| idx + first))
.collect_trusted()
},
GroupsProxy::Slice { groups, .. } => {
if groups.iter().zip(idx).any(|(g, idx)| match idx {
None => true,
Some(idx) => idx >= g[1],
}) {
self.oob_err()?;
}

idx.into_iter()
.zip(groups.iter())
.map(|(idx, g)| idx.map(|idx| idx + g[0]))
.collect_trusted()
},
};

let taken = ac.flat_naive().take(&idx)?;
let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
Ok(ac)
}

fn process_positive_indices_agg_literal<'b>(
&self,
mut ac: AggregationContext<'b>,
idx: &IdxCa,
) -> PolarsResult<AggregationContext<'b>> {
if idx.len() == 1 {
match idx.get(0) {
None => polars_bail!(ComputeError: "cannot take by a null"),
Some(idx) => {
if idx != 0 {
// We must make sure that the column we take from is sorted by
// groups otherwise we might point into the wrong group.
ac.sort_by_groups()
}
// Make sure that we look at the updated groups.
let groups = ac.groups();

// We offset the groups first by idx.
let idx: NoNull<IdxCa> = match groups.as_ref() {
GroupsProxy::Idx(groups) => {
if groups.all().iter().any(|g| idx >= g.len() as IdxSize) {
self.oob_err()?;
}

groups.first().iter().map(|f| *f + idx).collect_trusted()
},
GroupsProxy::Slice { groups, .. } => {
if groups.iter().any(|g| idx >= g[1]) {
self.oob_err()?;
}

groups.iter().map(|g| g[0] + idx).collect_trusted()
},
};
let taken = ac.flat_naive().take(&idx.into_inner())?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
},
}
} else {
let out = ac
.aggregated()
.list()
.unwrap()
.try_apply_amortized(|s| s.as_ref().take(idx))?;

ac.with_series(out.into_series(), true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
}
}

fn process_negative_indices_agg<'b>(
&self,
mut ac: AggregationContext<'b>,
mut idx: AggregationContext<'b>,
groups: &'b GroupsProxy,
) -> PolarsResult<AggregationContext<'b>> {
let mut builder = get_list_builder(
&ac.dtype(),
idx.series().len(),
groups.len(),
ac.series().name(),
)?;

unsafe {
let iter = ac.iter_groups(false).zip(idx.iter_groups(false));
for (s, idx) in iter {
match (s, idx) {
(Some(s), Some(idx)) => {
let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?;
let out = s.as_ref().take(&idx)?;
builder.append_series(&out)?;
},
_ => builder.append_null(),
};
}
let out = builder.finish().into_series();
ac.with_agg_state(AggState::AggregatedList(out));
}
Ok(ac)
}
}