Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -54,35 +54,33 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::fmt(&self, f: &mut core::fmt::

impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum

pub type vortex_array::aggregate_fn::fns::sum::Sum::GroupState = vortex_array::aggregate_fn::fns::sum::SumGroupState

pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions

pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial

pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, state: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_accumulate(&self, state: &mut Self::GroupState, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_flush(&self, state: &mut Self::GroupState) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_is_saturated(&self, state: &Self::GroupState) -> bool
pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_merge(&self, state: &mut Self::GroupState, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_new(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::GroupState>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub struct vortex_array::aggregate_fn::fns::sum::SumGroupState
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub struct vortex_array::aggregate_fn::fns::sum::SumOptions

Expand Down Expand Up @@ -110,6 +108,8 @@ pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::hash<__H: core::hash::H

impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::fns::sum::SumOptions

pub struct vortex_array::aggregate_fn::fns::sum::SumPartial

pub mod vortex_array::aggregate_fn::session

pub struct vortex_array::aggregate_fn::session::AggregateFnSession
Expand Down Expand Up @@ -294,64 +294,64 @@ pub fn V::id(&self) -> arcref::ArcRef<str>

pub trait vortex_array::aggregate_fn::AggregateFnVTable: 'static + core::marker::Sized + core::clone::Clone + core::marker::Send + core::marker::Sync

pub type vortex_array::aggregate_fn::AggregateFnVTable::GroupState: 'static + core::marker::Send

pub type vortex_array::aggregate_fn::AggregateFnVTable::Options: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + core::fmt::Debug + core::fmt::Display + core::cmp::PartialEq + core::cmp::Eq + core::hash::Hash

pub type vortex_array::aggregate_fn::AggregateFnVTable::Partial: 'static + core::marker::Send

pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Canonical, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::finalize_scalar(&self, state: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::id(&self) -> vortex_array::aggregate_fn::AggregateFnId

pub fn vortex_array::aggregate_fn::AggregateFnVTable::is_saturated(&self, state: &Self::Partial) -> bool

pub fn vortex_array::aggregate_fn::AggregateFnVTable::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_accumulate(&self, state: &mut Self::GroupState, batch: &vortex_array::Canonical, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum

pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions

pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_flush(&self, state: &mut Self::GroupState) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial

pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_is_saturated(&self, state: &Self::GroupState) -> bool
pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_merge(&self, state: &mut Self::GroupState, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_new(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::GroupState>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>

pub type vortex_array::aggregate_fn::fns::sum::Sum::GroupState = vortex_array::aggregate_fn::fns::sum::SumGroupState
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>

pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId

pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, state: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool

pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_accumulate(&self, state: &mut Self::GroupState, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_flush(&self, state: &mut Self::GroupState) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_is_saturated(&self, state: &Self::GroupState) -> bool

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_merge(&self, state: &mut Self::GroupState, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_new(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::GroupState>

pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable

pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef
Expand Down
30 changes: 15 additions & 15 deletions vortex-array/src/aggregate_fn/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ pub struct Accumulator<V: AggregateFnVTable> {
/// The DType of the aggregate.
return_dtype: DType,
/// The DType of the accumulator state.
state_dtype: DType,
/// The current state of the accumulator, updated after each accumulate/merge call.
current_state: V::GroupState,
partial_dtype: DType,
/// The partial state of the accumulator, updated after each accumulate/merge call.
partial: V::Partial,
/// A session used to lookup custom aggregate kernels.
session: VortexSession,
}
Expand All @@ -47,17 +47,17 @@ impl<V: AggregateFnVTable> Accumulator<V> {
session: VortexSession,
) -> VortexResult<Self> {
let return_dtype = vtable.return_dtype(&options, &dtype)?;
let state_dtype = vtable.state_dtype(&options, &dtype)?;
let current_state = vtable.state_new(&options, &dtype)?;
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
let partial = vtable.empty_partial(&options, &dtype)?;
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();

Ok(Self {
vtable,
aggregate_fn,
dtype,
return_dtype,
state_dtype,
current_state,
partial_dtype,
partial,
session,
})
}
Expand Down Expand Up @@ -110,12 +110,12 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch)?
{
vortex_ensure!(
result.dtype() == &self.state_dtype,
result.dtype() == &self.partial_dtype,
"Aggregate kernel returned {}, expected {}",
result.dtype(),
self.state_dtype,
self.partial_dtype,
);
self.vtable.state_merge(&mut self.current_state, result)?;
self.vtable.combine_partials(&mut self.partial, result)?;
return Ok(());
}

Expand All @@ -127,22 +127,22 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
let canonical = batch.execute::<Canonical>(&mut ctx)?;

self.vtable
.state_accumulate(&mut self.current_state, &canonical, &mut ctx)
.accumulate(&mut self.partial, &canonical, &mut ctx)
}

fn is_saturated(&self) -> bool {
self.vtable.state_is_saturated(&self.current_state)
self.vtable.is_saturated(&self.partial)
}

fn flush(&mut self) -> VortexResult<Scalar> {
let partial = self.vtable.state_flush(&mut self.current_state)?;
let partial = self.vtable.flush(&mut self.partial)?;

#[cfg(debug_assertions)]
{
vortex_ensure!(
partial.dtype() == &self.state_dtype,
partial.dtype() == &self.partial_dtype,
"Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
self.state_dtype,
self.partial_dtype,
partial.dtype(),
);
}
Expand Down
26 changes: 13 additions & 13 deletions vortex-array/src/aggregate_fn/accumulator_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ pub struct GroupedAccumulator<V: AggregateFnVTable> {
dtype: DType,
/// The DType of the aggregate.
return_dtype: DType,
/// The DType of the accumulator state.
state_dtype: DType,
/// The DType of the partial accumulator state.
partial_dtype: DType,
/// The accumulated state for prior batches of groups.
states: Vec<ArrayRef>,
partials: Vec<ArrayRef>,
/// A session used to lookup custom aggregate kernels.
session: VortexSession,
}
Expand All @@ -70,16 +70,16 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
) -> VortexResult<Self> {
let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
let return_dtype = vtable.return_dtype(&options, &dtype)?;
let state_dtype = vtable.state_dtype(&options, &dtype)?;
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;

Ok(Self {
vtable,
options,
aggregate_fn,
dtype,
return_dtype,
state_dtype,
states: vec![],
partial_dtype,
partials: vec![],
session,
})
}
Expand Down Expand Up @@ -129,8 +129,8 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
}

fn flush(&mut self) -> VortexResult<ArrayRef> {
let states = std::mem::take(&mut self.states);
Ok(ChunkedArray::try_new(states, self.state_dtype.clone())?.into_array())
let states = std::mem::take(&mut self.partials);
Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
}

fn finish(&mut self) -> VortexResult<ArrayRef> {
Expand Down Expand Up @@ -211,7 +211,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
self.dtype.clone(),
self.session.clone(),
)?;
let mut states = builder_with_capacity(&self.state_dtype, offsets.len());
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());

for (offset, size) in offsets.iter().zip(sizes.iter()) {
let offset = offset.to_usize().vortex_expect("Offset value is not usize");
Expand Down Expand Up @@ -277,7 +277,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
self.dtype.clone(),
self.session.clone(),
)?;
let mut states = builder_with_capacity(&self.state_dtype, groups.len());
let mut states = builder_with_capacity(&self.partial_dtype, groups.len());

let mut offset = 0;
let size = groups
Expand All @@ -301,12 +301,12 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {

fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
vortex_ensure!(
state.dtype() == &self.state_dtype,
state.dtype() == &self.partial_dtype,
"State DType mismatch: expected {}, got {}",
self.state_dtype,
self.partial_dtype,
state.dtype()
);
self.states.push(state);
self.partials.push(state);
Ok(())
}
}
Loading
Loading