Skip to content

Commit

Permalink
perf(rust, python): improve streaming primitve groupby (#5575)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 21, 2022
1 parent d3d58a3 commit 17249f2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use rayon::prelude::*;

use super::aggregates::AggregateFn;
use crate::executors::sinks::groupby::aggregates::AggregateFunction;
use crate::executors::sinks::groupby::string::{apply_aggregate, write_agg_idx};
use crate::executors::sinks::groupby::utils::compute_slices;
use crate::executors::sinks::utils::load_vec;
use crate::executors::sinks::HASHMAP_INIT_SIZE;
Expand Down Expand Up @@ -48,10 +49,9 @@ pub struct PrimitiveGroupbySink<K: PolarsNumericType> {
// the aggregations are all tightly packed
// the aggregation function of a group can be found
// by:
// first get the correct vec by the partition index
// * offset = (idx)
// * end = (offset + n_aggs)
aggregators: Vec<Vec<AggregateFunction>>,
aggregators: Vec<AggregateFunction>,
key: Arc<dyn PhysicalPipedExpr>,
// the columns that will be aggregated
aggregation_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
Expand Down Expand Up @@ -82,9 +82,8 @@ where
let partitions = _set_partition_size();

let pre_agg = load_vec(partitions, || PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE));
let aggregators = load_vec(partitions, || {
Vec::with_capacity(HASHMAP_INIT_SIZE * aggregation_columns.len())
});
let aggregators =
Vec::with_capacity(HASHMAP_INIT_SIZE * aggregation_columns.len() * partitions);

Self {
thread_no: 0,
Expand All @@ -107,20 +106,28 @@ where
}

fn pre_finalize(&mut self) -> PolarsResult<Vec<DataFrame>> {
let mut aggregators = std::mem::take(&mut self.aggregators);
// we create a pointer to the aggregation functions buffer
// we will deref *mut on every partition thread
// this will be safe, as the partitions guarantee that access don't alias.
let aggregators = self.aggregators.as_ptr() as usize;
let aggregators_len = self.aggregators.len();
let slices = compute_slices(&self.pre_agg_partitions, self.slice);

POOL.install(|| {
let dfs =
self.pre_agg_partitions
.par_iter()
.zip(aggregators.par_iter_mut())
.zip(slices.par_iter())
.filter_map(|((agg_map, agg_fns), slice)| {
.filter_map(|(agg_map, slice)| {
let (offset, slice_len) = (*slice)?;
if agg_map.is_empty() {
return None;
}
// safety:
// we will not alias.
let ptr = aggregators as *mut AggregateFunction;
let agg_fns =
unsafe { std::slice::from_raw_parts_mut(ptr, aggregators_len) };
let mut key_builder = PrimitiveChunkedBuilder::<K>::new(
self.output_schema.get_index(0).unwrap().0,
agg_map.len(),
Expand Down Expand Up @@ -191,26 +198,28 @@ where
// cow -> &series -> &dyn series_trait -> &chunkedarray
let ca: &ChunkedArray<K> = s.as_ref().as_ref();

// write the hashes to self.hashes buffer
// s.vec_hash(self.hb.clone(), &mut self.hashes).unwrap();
// now we have written hashes, we take the pointer to this buffer
// we will write the aggregation_function indexes in the same buffer
// this is unsafe and we must check that we only write the hashes that
// already read/taken. So we write on the slots we just read
let agg_idx_ptr = self.hashes.as_ptr() as *mut i64 as *mut IdxSize;

// todo! ammortize allocation
for phys_e in self.aggregation_columns.iter() {
let s = phys_e.evaluate(&chunk, context.execution_state.as_ref())?;
let s = s.to_physical_repr();
self.aggregation_series.push(s.rechunk());
}

let mut agg_iters = self
.aggregation_series
.iter()
.map(|s| s.phys_iter())
.collect::<Vec<_>>();

let arr = ca.downcast_iter().next().unwrap();
for (opt_v, &h) in arr.iter().zip(self.hashes.iter()) {
for (iteration_idx, (opt_v, &h)) in arr.iter().zip(self.hashes.iter()).enumerate() {
let opt_v = opt_v.copied();
let part = hash_to_partition(h, self.pre_agg_partitions.len());
let current_partition =
unsafe { self.pre_agg_partitions.get_unchecked_release_mut(part) };
let current_aggregators = unsafe { self.aggregators.get_unchecked_release_mut(part) };
let current_aggregators = &mut self.aggregators;

let entry = current_partition
.raw_entry_mut()
Expand All @@ -233,14 +242,27 @@ where
}
RawEntryMut::Occupied(entry) => *entry.get(),
};
for (i, agg_iter) in (0 as IdxSize..num_aggs as IdxSize).zip(agg_iters.iter_mut()) {
let i = (agg_idx + i) as usize;
let agg_fn = unsafe { current_aggregators.get_unchecked_release_mut(i) };
// # Safety
// we write to the hashes buffer we iterate over at the moment.
// this is sound because we writes are trailing from iteration
unsafe { write_agg_idx(agg_idx_ptr, iteration_idx, agg_idx) };
}

// note that this slice looks into the self.hashes buffer
let agg_idxs = unsafe { std::slice::from_raw_parts(agg_idx_ptr, ca.len()) };

agg_fn.pre_agg(chunk.chunk_index, agg_iter.as_mut())
}
for (agg_i, aggregation_s) in (0..num_aggs).zip(&self.aggregation_series) {
let has_physical_agg = self.agg_fns[agg_i].has_physical_agg();
apply_aggregate(
agg_i,
chunk.chunk_index,
agg_idxs,
aggregation_s,
has_physical_agg,
&mut self.aggregators,
);
}
drop(agg_iters);

self.aggregation_series.clear();
Ok(SinkResult::CanHaveMoreInput)
}
Expand All @@ -252,37 +274,36 @@ where
self.pre_agg_partitions
.iter_mut()
.zip(other.pre_agg_partitions.iter())
.zip(self.aggregators.iter_mut())
.zip(other.aggregators.iter())
.for_each(
|(((map_self, map_other), aggregators_self), aggregators_other)| {
for (key, &agg_idx_other) in map_other.iter() {
unsafe {
let entry = map_self.raw_entry_mut().from_key(key);

let agg_idx_self = match entry {
RawEntryMut::Vacant(entry) => {
let offset = NumCast::from(aggregators_self.len()).unwrap();
entry.insert(*key, offset);
// initialize the aggregators
for agg_fn in &self.agg_fns {
aggregators_self.push(agg_fn.split2())
}
offset
}
RawEntryMut::Occupied(entry) => *entry.get(),
};
for i in 0..self.aggregation_columns.len() {
let agg_fn_other = aggregators_other
.get_unchecked_release(agg_idx_other as usize + i);
let agg_fn_self = aggregators_self
.get_unchecked_release_mut(agg_idx_self as usize + i);
agg_fn_self.combine(agg_fn_other.as_any())
.for_each(|(map_self, map_other)| {
for (key, &agg_idx_other) in map_other.iter() {
let entry = map_self.raw_entry_mut().from_key(key);

let agg_idx_self = match entry {
RawEntryMut::Vacant(entry) => {
let offset = NumCast::from(self.aggregators.len()).unwrap();
entry.insert(*key, offset);
// initialize the aggregators
for agg_fn in &self.agg_fns {
self.aggregators.push(agg_fn.split2())
}
offset
}
RawEntryMut::Occupied(entry) => *entry.get(),
};
// combine the aggregation functions
for i in 0..self.aggregation_columns.len() {
unsafe {
let agg_fn_other = other
.aggregators
.get_unchecked_release(agg_idx_other as usize + i);
let agg_fn_self = self
.aggregators
.get_unchecked_release_mut(agg_idx_self as usize + i);
agg_fn_self.combine(agg_fn_other.as_any())
}
}
},
);
}
});
}

fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult<FinalizedSink> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub struct Utf8GroupbySink {
// the aggregations/keys are all tightly packed
// the aggregation function of a group can be found
// by:
// first get the correct vec by the partition index
// * offset = (idx)
// * end = (offset + 1)
keys: Vec<Option<smartstring::alias::String>>,
Expand Down Expand Up @@ -361,11 +360,11 @@ impl Sink for Utf8GroupbySink {
}

// write agg_idx to the hashes buffer.
unsafe fn write_agg_idx(h: *mut IdxSize, i: usize, agg_idx: IdxSize) {
pub(super) unsafe fn write_agg_idx(h: *mut IdxSize, i: usize, agg_idx: IdxSize) {
h.add(i).write(agg_idx)
}

fn apply_aggregate(
pub(super) fn apply_aggregate(
agg_i: usize,
chunk_idx: IdxSize,
agg_idxs: &[IdxSize],
Expand Down

0 comments on commit 17249f2

Please sign in to comment.