From aed83b3abf41615afeece431864bc9d36d97125b Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 2 Jan 2024 12:28:29 +0100 Subject: [PATCH 1/2] feat: support negative 'gather' in 'group_by' context --- .../src/physical_plan/expressions/mod.rs | 8 + .../src/physical_plan/expressions/take.rs | 312 ++++++++++-------- crates/polars-ops/src/series/ops/index.rs | 59 +++- .../tests/unit/operations/test_gather.py | 9 +- 4 files changed, 250 insertions(+), 138 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/mod.rs b/crates/polars-lazy/src/physical_plan/expressions/mod.rs index fcfbbe424c8f..e13c81d01727 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/mod.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/mod.rs @@ -133,6 +133,14 @@ pub struct AggregationContext<'a> { } impl<'a> AggregationContext<'a> { + pub(crate) fn dtype(&self) -> DataType { + match &self.state { + AggState::Literal(s) => s.dtype().clone(), + AggState::AggregatedList(s) => s.list().unwrap().inner_dtype(), + AggState::AggregatedScalar(s) => s.dtype().clone(), + AggState::NotAggregated(s) => s.dtype().clone(), + } + } pub(crate) fn groups(&mut self) -> &Cow<'a, GroupsProxy> { match self.update_groups { UpdateGroups::No => {}, diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index 9507195ab7fe..b6b20ff5830e 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -1,10 +1,11 @@ use std::sync::Arc; use arrow::legacy::utils::CustomIterTools; +use polars_core::chunked_array::builder::get_list_builder; use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::utils::NoNull; -use polars_ops::prelude::convert_to_unsigned_index; +use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain}; use crate::physical_plan::state::ExecutionState; use crate::prelude::*; @@ -16,23 +17,6 @@ pub struct TakeExpr { pub(crate) returns_scalar: bool, } -impl TakeExpr { - fn finish( - &self, - df: &DataFrame, - state: &ExecutionState, - series: Series, - ) -> PolarsResult { - let idx = self.idx.evaluate(df, state)?; - let idx = convert_to_unsigned_index(&idx, series.len())?; - series.take(&idx) - } - - fn oob_err(&self) -> PolarsResult<()> { - polars_bail!(expr = self.expr, OutOfBounds: "index out of bounds"); - } -} - impl PhysicalExpr for TakeExpr { fn as_expression(&self) -> Option<&Expr> { Some(&self.expr) @@ -52,63 +36,24 @@ impl PhysicalExpr for TakeExpr { let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?; let mut idx = self.idx.evaluate_on_groups(df, groups, state)?; + let s_idx = idx.series(); + match s_idx.dtype() { + DataType::List(inner) => { + polars_ensure!(inner.is_integer(), InvalidOperation: "expected numeric dtype as index, got {:?}", inner) + }, + dt if dt.is_integer() => { + // Unsigned integers will fall through and will use faster paths. + if !is_positive_idx_uncertain(s_idx) { + return self.process_negative_indices_agg(ac, idx, groups); + } + }, + dt => polars_bail!(InvalidOperation: "expected numeric dtype as index, got {:?}", dt), + } + let idx = match idx.state { AggState::AggregatedScalar(s) => { let idx = s.cast(&IDX_DTYPE)?; - if s.null_count() != idx.null_count() { - polars_warn!("negative indexing not yet supported in group-by context") - } - let idx = idx.idx().unwrap(); - - // The indexes are AggregatedScalar, meaning they are a single values pointing into - // a group. If we zip this with the first of each group -> `idx + firs` then we can - // simply use a take operation on the whole array instead of per group. - - // The groups maybe scattered all over the place, so we sort by group. - ac.sort_by_groups(); - - // A previous aggregation may have updated the groups. - let groups = ac.groups(); - - // Determine the gather indices. - let idx: IdxCa = match groups.as_ref() { - GroupsProxy::Idx(groups) => { - if groups.all().iter().zip(idx).any(|(g, idx)| match idx { - None => true, - Some(idx) => idx >= g.len() as IdxSize, - }) { - self.oob_err()?; - } - - idx.into_iter() - .zip(groups.first().iter()) - .map(|(idx, first)| idx.map(|idx| idx + first)) - .collect_trusted() - }, - GroupsProxy::Slice { groups, .. } => { - if groups.iter().zip(idx).any(|(g, idx)| match idx { - None => true, - Some(idx) => idx >= g[1], - }) { - self.oob_err()?; - } - - idx.into_iter() - .zip(groups.iter()) - .map(|(idx, g)| idx.map(|idx| idx + g[0])) - .collect_trusted() - }, - }; - let taken = ac.flat_naive().take(&idx)?; - - let taken = if self.returns_scalar { - taken - } else { - taken.as_list().into_series() - }; - - ac.with_series(taken, true, Some(&self.expr))?; - return Ok(ac); + return self.process_positive_indices_agg_scalar(ac, idx.idx().unwrap()); }, AggState::AggregatedList(s) => { polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); @@ -122,64 +67,7 @@ impl PhysicalExpr for TakeExpr { }, AggState::Literal(s) => { let idx = s.cast(&IDX_DTYPE)?; - if s.null_count() != idx.null_count() { - polars_warn!("negative indexing not yet supported in group-by context") - } - let idx = idx.idx().unwrap(); - - return if idx.len() == 1 { - match idx.get(0) { - None => polars_bail!(ComputeError: "cannot take by a null"), - Some(idx) => { - if idx != 0 { - // We must make sure that the column we take from is sorted by - // groups otherwise we might point into the wrong group. - ac.sort_by_groups() - } - // Make sure that we look at the updated groups. - let groups = ac.groups(); - - // We offset the groups first by idx. - let idx: NoNull = match groups.as_ref() { - GroupsProxy::Idx(groups) => { - if groups.all().iter().any(|g| idx >= g.len() as IdxSize) { - self.oob_err()?; - } - - groups.first().iter().map(|f| *f + idx).collect_trusted() - }, - GroupsProxy::Slice { groups, .. } => { - if groups.iter().any(|g| idx >= g[1]) { - self.oob_err()?; - } - - groups.iter().map(|g| g[0] + idx).collect_trusted() - }, - }; - let taken = ac.flat_naive().take(&idx.into_inner())?; - - let taken = if self.returns_scalar { - taken - } else { - taken.as_list().into_series() - }; - - ac.with_series(taken, true, Some(&self.expr))?; - ac.with_update_groups(UpdateGroups::WithGroupsLen); - Ok(ac) - }, - } - } else { - let out = ac - .aggregated() - .list() - .unwrap() - .try_apply_amortized(|s| s.as_ref().take(idx))?; - - ac.with_series(out.into_series(), true, Some(&self.expr))?; - ac.with_update_groups(UpdateGroups::WithGroupsLen); - Ok(ac) - }; + return self.process_positive_indices_agg_literal(ac, idx.idx().unwrap()); }, }; @@ -207,3 +95,167 @@ impl PhysicalExpr for TakeExpr { self.phys_expr.to_field(input_schema) } } + +impl TakeExpr { + fn finish( + &self, + df: &DataFrame, + state: &ExecutionState, + series: Series, + ) -> PolarsResult { + let idx = self.idx.evaluate(df, state)?; + let idx = convert_to_unsigned_index(&idx, series.len())?; + series.take(&idx) + } + + fn oob_err(&self) -> PolarsResult<()> { + polars_bail!(expr = self.expr, OutOfBounds: "index out of bounds"); + } + + fn process_positive_indices_agg_scalar<'b>( + &self, + mut ac: AggregationContext<'b>, + idx: &IdxCa, + ) -> PolarsResult> { + // The indexes are AggregatedScalar, meaning they are a single values pointing into + // a group. If we zip this with the first of each group -> `idx + first` then we can + // simply use a take operation on the whole array instead of per group. + + // The groups maybe scattered all over the place, so we sort by group. + ac.sort_by_groups(); + + // A previous aggregation may have updated the groups. + let groups = ac.groups(); + + // Determine the gather indices. + let idx: IdxCa = match groups.as_ref() { + GroupsProxy::Idx(groups) => { + if groups.all().iter().zip(idx).any(|(g, idx)| match idx { + None => true, + Some(idx) => idx >= g.len() as IdxSize, + }) { + self.oob_err()?; + } + + idx.into_iter() + .zip(groups.first().iter()) + .map(|(idx, first)| idx.map(|idx| idx + first)) + .collect_trusted() + }, + GroupsProxy::Slice { groups, .. } => { + if groups.iter().zip(idx).any(|(g, idx)| match idx { + None => true, + Some(idx) => idx >= g[1], + }) { + self.oob_err()?; + } + + idx.into_iter() + .zip(groups.iter()) + .map(|(idx, g)| idx.map(|idx| idx + g[0])) + .collect_trusted() + }, + }; + + let taken = ac.flat_naive().take(&idx)?; + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; + + ac.with_series(taken, true, Some(&self.expr))?; + Ok(ac) + } + + fn process_positive_indices_agg_literal<'b>( + &self, + mut ac: AggregationContext<'b>, + idx: &IdxCa, + ) -> PolarsResult> { + if idx.len() == 1 { + match idx.get(0) { + None => polars_bail!(ComputeError: "cannot take by a null"), + Some(idx) => { + if idx != 0 { + // We must make sure that the column we take from is sorted by + // groups otherwise we might point into the wrong group. + ac.sort_by_groups() + } + // Make sure that we look at the updated groups. + let groups = ac.groups(); + + // We offset the groups first by idx. + let idx: NoNull = match groups.as_ref() { + GroupsProxy::Idx(groups) => { + if groups.all().iter().any(|g| idx >= g.len() as IdxSize) { + self.oob_err()?; + } + + groups.first().iter().map(|f| *f + idx).collect_trusted() + }, + GroupsProxy::Slice { groups, .. } => { + if groups.iter().any(|g| idx >= g[1]) { + self.oob_err()?; + } + + groups.iter().map(|g| g[0] + idx).collect_trusted() + }, + }; + let taken = ac.flat_naive().take(&idx.into_inner())?; + + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; + + ac.with_series(taken, true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithGroupsLen); + Ok(ac) + }, + } + } else { + let out = ac + .aggregated() + .list() + .unwrap() + .try_apply_amortized(|s| s.as_ref().take(idx))?; + + ac.with_series(out.into_series(), true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithGroupsLen); + Ok(ac) + } + } + + fn process_negative_indices_agg<'b>( + &self, + mut ac: AggregationContext<'b>, + mut idx: AggregationContext<'b>, + groups: &'b GroupsProxy, + ) -> PolarsResult> { + let mut builder = get_list_builder( + &ac.dtype(), + idx.series().len(), + groups.len(), + ac.series().name(), + )?; + + unsafe { + let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); + for (s, idx) in iter { + match (s, idx) { + (Some(s), Some(idx)) => { + let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?; + let out = s.as_ref().take(&idx)?; + builder.append_series(&out)?; + }, + _ => builder.append_null(), + }; + } + let out = builder.finish().into_series(); + ac.with_agg_state(AggState::AggregatedList(out)); + } + Ok(ac) + } +} diff --git a/crates/polars-ops/src/series/ops/index.rs b/crates/polars-ops/src/series/ops/index.rs index ba0b41a98907..b35444c1e4c8 100644 --- a/crates/polars-ops/src/series/ops/index.rs +++ b/crates/polars-ops/src/series/ops/index.rs @@ -1,15 +1,11 @@ -use std::fmt::Debug; - +use num_traits::Signed; use polars_core::error::{polars_bail, polars_ensure, PolarsResult}; use polars_core::prelude::{ChunkedArray, DataType, IdxCa, PolarsIntegerType, Series, IDX_DTYPE}; use polars_utils::index::ToIdx; -use polars_utils::IdxSize; fn convert(ca: &ChunkedArray, target_len: usize) -> PolarsResult where T: PolarsIntegerType, - IdxSize: TryFrom, - >::Error: Debug, T::Native: ToIdx, { let target_len = target_len as u64; @@ -47,3 +43,56 @@ pub fn convert_to_unsigned_index(s: &Series, target_len: usize) -> PolarsResult< _ => unreachable!(), } } + +/// May give false negatives because it ignores the null values. +fn is_positive_idx_uncertain_impl(ca: &ChunkedArray) -> bool +where + T: PolarsIntegerType, + T::Native: Signed, +{ + ca.downcast_iter().all(|v| { + let values = v.values(); + let mut all_positive = true; + + // process chunks to autovec but still have early return + for chunk in values.chunks(1024) { + for v in chunk.iter() { + all_positive &= v.is_positive() + } + if !all_positive { + return all_positive; + } + } + all_positive + }) +} + +/// May give false negatives because it ignores the null values. +pub fn is_positive_idx_uncertain(s: &Series) -> bool { + let dtype = s.dtype(); + debug_assert!(dtype.is_integer(), "expected integers as index"); + if dtype.is_unsigned_integer() { + return true; + } + match dtype { + DataType::Int64 => { + let ca = s.i64().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + DataType::Int32 => { + let ca = s.i32().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + let ca = s.i16().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + #[cfg(feature = "dtype-i8")] + DataType::Int8 => { + let ca = s.i8().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + _ => unreachable!(), + } +} diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py index c62f3746a5ff..eec986c5cbe6 100644 --- a/py-polars/tests/unit/operations/test_gather.py +++ b/py-polars/tests/unit/operations/test_gather.py @@ -1,8 +1,11 @@ import polars as pl -def test_negative_index_select() -> None: - df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]}) +def test_negative_index() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]}) assert df.select(pl.col("a").gather([0, -1])).to_dict(as_series=False) == { - "a": [[1, 2, 3], [4, 5, 6]] + "a": [1, 6] } + assert df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])).sort( + "a" + ).to_dict(as_series=False) == {"a": [0, 1], "b": [[2, 6], [1, 5]]} From 68a55ba11c248e62e1f6e34569cd28c8f1ce336d Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 2 Jan 2024 13:07:18 +0100 Subject: [PATCH 2/2] include zero --- crates/polars-ops/src/series/ops/index.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-ops/src/series/ops/index.rs b/crates/polars-ops/src/series/ops/index.rs index b35444c1e4c8..fc2823e81fc0 100644 --- a/crates/polars-ops/src/series/ops/index.rs +++ b/crates/polars-ops/src/series/ops/index.rs @@ -1,4 +1,4 @@ -use num_traits::Signed; +use num_traits::{Signed, Zero}; use polars_core::error::{polars_bail, polars_ensure, PolarsResult}; use polars_core::prelude::{ChunkedArray, DataType, IdxCa, PolarsIntegerType, Series, IDX_DTYPE}; use polars_utils::index::ToIdx; @@ -57,7 +57,7 @@ where // process chunks to autovec but still have early return for chunk in values.chunks(1024) { for v in chunk.iter() { - all_positive &= v.is_positive() + all_positive &= v.is_positive() | v.is_zero() } if !all_positive { return all_positive;