-
Notifications
You must be signed in to change notification settings - Fork 153
Reorder agg kernel dispatch, and have Combined use inner accumulators #7889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,14 +5,14 @@ use vortex_error::VortexResult; | |
| use vortex_error::vortex_ensure; | ||
| use vortex_error::vortex_err; | ||
|
|
||
| use crate::AnyCanonical; | ||
| use crate::ArrayRef; | ||
| use crate::Columnar; | ||
| use crate::ExecutionCtx; | ||
| use crate::aggregate_fn::AggregateFn; | ||
| use crate::aggregate_fn::AggregateFnRef; | ||
| use crate::aggregate_fn::AggregateFnVTable; | ||
| use crate::aggregate_fn::session::AggregateFnSessionExt; | ||
| use crate::columnar::AnyColumnar; | ||
| use crate::dtype::DType; | ||
| use crate::executor::max_iterations; | ||
| use crate::scalar::Scalar; | ||
|
|
@@ -72,9 +72,26 @@ pub trait DynAccumulator: 'static + Send { | |
| /// Accumulate a new array into the accumulator's state. | ||
| fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; | ||
|
|
||
| /// Fold an external partial-state scalar into this accumulator's state. | ||
| /// | ||
| /// The scalar must have the dtype reported by the vtable's `partial_dtype` for the | ||
| /// options and input dtype used to construct this accumulator. | ||
| fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>; | ||
|
|
||
| /// Whether the accumulator's result is fully determined. | ||
| fn is_saturated(&self) -> bool; | ||
|
|
||
| /// Reset the accumulator's state to the empty group. | ||
| fn reset(&mut self); | ||
|
|
||
| /// Read the current partial state as a scalar without resetting it. | ||
| /// | ||
| /// The returned scalar has the dtype reported by the vtable's `partial_dtype`. | ||
| fn partial_scalar(&self) -> VortexResult<Scalar>; | ||
|
|
||
| /// Compute the final aggregate result as a scalar without resetting state. | ||
| fn final_scalar(&self) -> VortexResult<Scalar>; | ||
|
|
||
| /// Flush the accumulation state and return the partial aggregate result as a scalar. | ||
| /// | ||
| /// Resets the accumulator state back to the initial state. | ||
|
|
@@ -99,31 +116,75 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> { | |
| batch.dtype() | ||
| ); | ||
|
|
||
| // Allow the vtable to short-circuit on the raw array before decompression. | ||
| if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? { | ||
| // 0. Stats-driven shortcut: if the aggregate can be derived directly from the batch's | ||
| // cached statistics, use that and skip both kernel dispatch and decode. This is the | ||
| // only layer that consults `batch.statistics()`; encoding kernels must not. | ||
| if let Some(result) = self.vtable.try_partial_from_stats(batch)? { | ||
| vortex_ensure!( | ||
| result.dtype() == &self.partial_dtype, | ||
| "Aggregate try_partial_from_stats returned {}, expected {}", | ||
| result.dtype(), | ||
| self.partial_dtype, | ||
| ); | ||
| self.vtable.combine_partials(&mut self.partial, result)?; | ||
| return Ok(()); | ||
| } | ||
|
|
||
| let session = ctx.session().clone(); | ||
| let kernels = &session.aggregate_fns().kernels; | ||
|
|
||
| // 1. Kernel registry first: a registered `(encoding, aggregate_fn)` kernel is strictly | ||
| // more specific than the vtable's `try_accumulate` short-circuit. Checking the | ||
| // registry first gives kernels for `Combined<V>` aggregates a chance to fire — | ||
| // `Combined::try_accumulate` always returns true, so a later kernel check would be | ||
| // unreachable. | ||
| { | ||
| let kernels_r = kernels.read(); | ||
| let batch_id = batch.encoding_id(); | ||
| let kernel = kernels_r | ||
| .get(&(batch_id, Some(self.aggregate_fn.id()))) | ||
| .or_else(|| kernels_r.get(&(batch_id, None))) | ||
| .copied(); | ||
| drop(kernels_r); | ||
| if let Some(kernel) = kernel | ||
| && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)? | ||
| { | ||
| vortex_ensure!( | ||
| result.dtype() == &self.partial_dtype, | ||
| "Aggregate kernel returned {}, expected {}", | ||
| result.dtype(), | ||
| self.partial_dtype, | ||
| ); | ||
| self.vtable.combine_partials(&mut self.partial, result)?; | ||
| return Ok(()); | ||
| } | ||
| } | ||
|
|
||
| // 2. Allow the vtable to short-circuit on the raw array before decompression. | ||
| if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? { | ||
| return Ok(()); | ||
| } | ||
|
Comment on lines
135
to
+166
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this whole block different from the block below? |
||
|
|
||
| // 3. Iteratively check the registry against each intermediate encoding, executing one | ||
| // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`. | ||
| // Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of | ||
| // keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant) | ||
| // since the vtable's `accumulate(&Columnar)` handles both cases directly. | ||
| let mut batch = batch.clone(); | ||
| for _ in 0..max_iterations() { | ||
| if batch.is::<AnyCanonical>() { | ||
| if batch.is::<AnyColumnar>() { | ||
| break; | ||
| } | ||
|
|
||
| let kernels_r = kernels.read(); | ||
| let batch_id = batch.encoding_id(); | ||
| if let Some(result) = kernels_r | ||
| let kernel = kernels_r | ||
| .get(&(batch_id, Some(self.aggregate_fn.id()))) | ||
| .or_else(|| kernels_r.get(&(batch_id, None))) | ||
| .and_then(|kernel| { | ||
| kernel | ||
| .aggregate(&self.aggregate_fn, &batch, ctx) | ||
| .transpose() | ||
| }) | ||
| .transpose()? | ||
| .copied(); | ||
| drop(kernels_r); | ||
| if let Some(kernel) = kernel | ||
| && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)? | ||
| { | ||
| vortex_ensure!( | ||
| result.dtype() == &self.partial_dtype, | ||
|
|
@@ -135,29 +196,35 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> { | |
| return Ok(()); | ||
| } | ||
|
|
||
| // Execute one step and try again | ||
| batch = batch.execute(ctx)?; | ||
| } | ||
|
|
||
| // Otherwise, execute the batch until it is columnar and accumulate it into the state. | ||
| // 4. Otherwise, execute the batch until it is columnar and accumulate it into the state. | ||
| let columnar = batch.execute::<Columnar>(ctx)?; | ||
|
|
||
| self.vtable.accumulate(&mut self.partial, &columnar, ctx) | ||
| } | ||
|
|
||
| fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> { | ||
| self.vtable.combine_partials(&mut self.partial, other) | ||
| } | ||
|
|
||
| fn is_saturated(&self) -> bool { | ||
| self.vtable.is_saturated(&self.partial) | ||
| } | ||
|
|
||
| fn flush(&mut self) -> VortexResult<Scalar> { | ||
| let partial = self.vtable.to_scalar(&self.partial)?; | ||
| fn reset(&mut self) { | ||
| self.vtable.reset(&mut self.partial); | ||
| } | ||
|
|
||
| fn partial_scalar(&self) -> VortexResult<Scalar> { | ||
| let partial = self.vtable.to_scalar(&self.partial)?; | ||
|
|
||
| #[cfg(debug_assertions)] | ||
| { | ||
| vortex_ensure!( | ||
| partial.dtype() == &self.partial_dtype, | ||
| "Aggregate kernel returned incorrect DType on flush: expected {}, got {}", | ||
| "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}", | ||
| self.partial_dtype, | ||
| partial.dtype(), | ||
| ); | ||
|
|
@@ -166,17 +233,216 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> { | |
| Ok(partial) | ||
| } | ||
|
|
||
| fn finish(&mut self) -> VortexResult<Scalar> { | ||
| fn final_scalar(&self) -> VortexResult<Scalar> { | ||
| let result = self.vtable.finalize_scalar(&self.partial)?; | ||
| self.vtable.reset(&mut self.partial); | ||
|
|
||
| vortex_ensure!( | ||
| result.dtype() == &self.return_dtype, | ||
| "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}", | ||
| "Aggregate returned incorrect DType on final_scalar: expected {}, got {}", | ||
| self.return_dtype, | ||
| result.dtype(), | ||
| ); | ||
|
|
||
| Ok(result) | ||
| } | ||
|
|
||
| fn flush(&mut self) -> VortexResult<Scalar> { | ||
| let partial = self.partial_scalar()?; | ||
| self.reset(); | ||
| Ok(partial) | ||
| } | ||
|
|
||
| fn finish(&mut self) -> VortexResult<Scalar> { | ||
| let result = self.final_scalar()?; | ||
| self.reset(); | ||
| Ok(result) | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use vortex_buffer::buffer; | ||
| use vortex_error::VortexResult; | ||
| use vortex_session::SessionExt; | ||
| use vortex_session::VortexSession; | ||
|
|
||
| use crate::ArrayRef; | ||
| use crate::ExecutionCtx; | ||
| use crate::IntoArray; | ||
| use crate::VortexSessionExecute; | ||
| use crate::aggregate_fn::Accumulator; | ||
| use crate::aggregate_fn::AggregateFnRef; | ||
| use crate::aggregate_fn::AggregateFnVTable; | ||
| use crate::aggregate_fn::DynAccumulator; | ||
| use crate::aggregate_fn::EmptyOptions; | ||
| use crate::aggregate_fn::combined::Combined; | ||
| use crate::aggregate_fn::combined::PairOptions; | ||
| use crate::aggregate_fn::fns::mean::Mean; | ||
| use crate::aggregate_fn::fns::sum::Sum; | ||
| use crate::aggregate_fn::kernels::DynAggregateKernel; | ||
| use crate::aggregate_fn::session::AggregateFnSession; | ||
| use crate::array::VTable; | ||
| use crate::arrays::Dict; | ||
| use crate::arrays::DictArray; | ||
| use crate::dtype::DType; | ||
| use crate::dtype::Nullability; | ||
| use crate::dtype::PType; | ||
| use crate::scalar::Scalar; | ||
| use crate::session::ArraySession; | ||
|
|
||
| /// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the | ||
| /// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate` | ||
| /// would produce for `dict_of_seven()`. | ||
| #[derive(Debug)] | ||
| struct SentinelMeanPartialKernel; | ||
| impl DynAggregateKernel for SentinelMeanPartialKernel { | ||
| fn aggregate( | ||
| &self, | ||
| _aggregate_fn: &AggregateFnRef, | ||
| _batch: &ArrayRef, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<Option<Scalar>> { | ||
| Ok(Some(sentinel_partial())) | ||
| } | ||
| } | ||
|
|
||
| /// Returns `Ok(None)` => kernel declined, dispatch falls through. | ||
| #[derive(Debug)] | ||
| struct DeclineKernel; | ||
| impl DynAggregateKernel for DeclineKernel { | ||
| fn aggregate( | ||
| &self, | ||
| _aggregate_fn: &AggregateFnRef, | ||
| _batch: &ArrayRef, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<Option<Scalar>> { | ||
| Ok(None) | ||
| } | ||
| } | ||
|
|
||
| /// Sum partial sentinel `42.0` — distinguishable from the natural Sum of | ||
| /// `dict_of_seven()` which is `7.0`. | ||
| #[derive(Debug)] | ||
| struct SentinelSumPartialKernel; | ||
| impl DynAggregateKernel for SentinelSumPartialKernel { | ||
| fn aggregate( | ||
| &self, | ||
| _aggregate_fn: &AggregateFnRef, | ||
| _batch: &ArrayRef, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<Option<Scalar>> { | ||
| Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable))) | ||
| } | ||
| } | ||
|
|
||
| fn fresh_session() -> VortexSession { | ||
| VortexSession::empty().with::<ArraySession>() | ||
| } | ||
|
|
||
| fn dict_of_seven() -> ArrayRef { | ||
| DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array()) | ||
| .expect("valid dictionary") | ||
| .into_array() | ||
| } | ||
|
|
||
| fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> { | ||
| let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); | ||
| Accumulator::try_new( | ||
| Mean::combined(), | ||
| PairOptions(EmptyOptions, EmptyOptions), | ||
| dtype, | ||
| ) | ||
| } | ||
|
|
||
| fn sentinel_partial() -> Scalar { | ||
| let acc = mean_f64_accumulator().expect("build accumulator"); | ||
| let sum = Scalar::primitive(42.0f64, Nullability::Nullable); | ||
| let count = Scalar::primitive(1u64, Nullability::NonNullable); | ||
| Scalar::struct_(acc.partial_dtype, vec![sum, count]) | ||
| } | ||
|
|
||
| /// Kernel registered for `(Dict, Combined<Mean>)` fires in preference to | ||
| /// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder. | ||
| #[test] | ||
| fn combined_kernel_fires() -> VortexResult<()> { | ||
| static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel; | ||
| let session = fresh_session(); | ||
| session | ||
| .get::<AggregateFnSession>() | ||
| .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL); | ||
| let mut ctx = session.create_execution_ctx(); | ||
|
|
||
| let mut acc = mean_f64_accumulator()?; | ||
| acc.accumulate(&dict_of_seven(), &mut ctx)?; | ||
| let partial = acc.flush()?; | ||
|
|
||
| let s = partial.as_struct(); | ||
| assert_eq!( | ||
| s.field("sum").unwrap().as_primitive().as_::<f64>(), | ||
| Some(42.0) | ||
| ); | ||
| assert_eq!( | ||
| s.field("count").unwrap().as_primitive().as_::<u64>(), | ||
| Some(1) | ||
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// Kernel returns `Ok(None)` => dispatch falls through to `Combined::try_accumulate`'s | ||
| /// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`. | ||
| #[test] | ||
| fn fallback_when_kernel_declines() -> VortexResult<()> { | ||
| static KERNEL: DeclineKernel = DeclineKernel; | ||
| let session = fresh_session(); | ||
| session | ||
| .get::<AggregateFnSession>() | ||
| .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL); | ||
| let mut ctx = session.create_execution_ctx(); | ||
|
|
||
| let mut acc = mean_f64_accumulator()?; | ||
| acc.accumulate(&dict_of_seven(), &mut ctx)?; | ||
| let partial = acc.flush()?; | ||
|
|
||
| let s = partial.as_struct(); | ||
| assert_eq!( | ||
| s.field("sum").unwrap().as_primitive().as_::<f64>(), | ||
| Some(7.0) | ||
| ); | ||
| assert_eq!( | ||
| s.field("count").unwrap().as_primitive().as_::<u64>(), | ||
| Some(1) | ||
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// A kernel registered for the inner `(Dict, Sum)` child fires when accumulating a | ||
| /// Dict batch through `Combined<Mean>`. This is the reusable-primitive case the | ||
| /// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed. | ||
| #[test] | ||
| fn child_kernel_fires_through_combined() -> VortexResult<()> { | ||
| static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel; | ||
| let session = fresh_session(); | ||
| session | ||
| .get::<AggregateFnSession>() | ||
| .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL); | ||
| let mut ctx = session.create_execution_ctx(); | ||
|
|
||
| let mut acc = mean_f64_accumulator()?; | ||
| acc.accumulate(&dict_of_seven(), &mut ctx)?; | ||
| let partial = acc.flush()?; | ||
|
|
||
| let s = partial.as_struct(); | ||
| // `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired | ||
| // via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the | ||
| // batch's valid_count, so count is the real 1. | ||
| assert_eq!( | ||
| s.field("sum").unwrap().as_primitive().as_::<f64>(), | ||
| Some(42.0) | ||
| ); | ||
| assert_eq!( | ||
| s.field("count").unwrap().as_primitive().as_::<u64>(), | ||
| Some(1) | ||
| ); | ||
| Ok(()) | ||
| } | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we later want to move this also to the session?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in, we want even canonical dispatch to be via the registry?