Skip to content

Commit

Permalink
perf(rust, python): improve reducing window function performance ~33% (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 21, 2022
1 parent b62b95f commit a9d2528
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 47 deletions.
126 changes: 79 additions & 47 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::fmt::Write;
use std::sync::Arc;

use polars_arrow::bit_util::unset_bit_raw;
use polars_arrow::export::arrow::array::PrimitiveArray;
use polars_core::frame::groupby::{GroupBy, GroupsProxy};
use polars_core::frame::hash_join::{
default_join_ids, private_left_join_multiple_keys, JoinOptIds,
};
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::_split_offsets;
use polars_core::utils::arrow::bitmap::MutableBitmap;
use polars_core::{downcast_as_macro_arg_physical, POOL};
use polars_utils::sort::perfect_sort;
use polars_utils::sync::SyncPtr;
use rayon::prelude::*;

use super::*;
use crate::physical_plan::expression_err;
Expand Down Expand Up @@ -685,86 +689,114 @@ where
}
}
}
let mut values = Vec::with_capacity(len);
let ptr = values.as_mut_ptr() as *mut T::Native;
// safety:
// we will write from different threads but we will never alias.
let sync_ptr_values = unsafe { SyncPtr::new(ptr) };

if ca.null_count() == 0 {
let mut values = Vec::with_capacity(len);
let ptr = values.as_mut_ptr() as *mut T::Native;

match groups {
GroupsProxy::Idx(groups) => {
// this should always succeed as we don't expect any chunks after an aggregation
let agg_vals = ca.cont_slice().ok()?;
agg_vals.iter().zip(groups.all().iter()).for_each(|(v, g)| {
for idx in g {
debug_assert!((*idx as usize) < len);
unsafe { *ptr.add(*idx as usize) = *v }
}
POOL.install(|| {
agg_vals
.par_iter()
.zip(groups.all().par_iter())
.for_each(|(v, g)| {
let ptr = sync_ptr_values.get();
for idx in g {
debug_assert!((*idx as usize) < len);
unsafe { *ptr.add(*idx as usize) = *v }
}
})
})
}
GroupsProxy::Slice { groups, .. } => {
// this should always succeed as we don't expect any chunks after an aggregation
let agg_vals = ca.cont_slice().ok()?;
for (v, [start, g_len]) in agg_vals.iter().zip(groups.iter()) {
let start = *start as usize;
let end = start + *g_len as usize;
for idx in start..end {
debug_assert!(idx < len);
unsafe { *ptr.add(idx) = *v }
}
}
POOL.install(|| {
agg_vals
.par_iter()
.zip(groups.par_iter())
.for_each(|(v, [start, g_len])| {
let ptr = sync_ptr_values.get();
let start = *start as usize;
let end = start + *g_len as usize;
for idx in start..end {
debug_assert!(idx < len);
unsafe { *ptr.add(idx) = *v }
}
})
});
}
}

// safety: we have written all slots
unsafe { values.set_len(len) }
Some(ChunkedArray::new_vec(ca.name(), values).into_series())
} else {
let mut values = Vec::with_capacity(len);
let mut validity = MutableBitmap::with_capacity(len);
validity.extend_constant(len, true);
let ptr = values.as_mut_ptr() as *mut T::Native;
let validity_ptr = validity.as_slice_mut().as_mut_ptr();
let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };

let n_threads = POOL.current_num_threads();
let offsets = _split_offsets(ca.len(), n_threads);
match groups {
GroupsProxy::Idx(groups) => {
ca.into_iter()
.zip(groups.all().iter())
.for_each(|(opt_v, g)| {
for idx in g {
let idx = *idx as usize;
GroupsProxy::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
let offset = *offset;
let offset_len = *offset_len;
let ca = ca.slice(offset as i64, offset_len);
let groups = &groups.all()[offset..offset + offset_len];
let values_ptr = sync_ptr_values.get();

ca.into_iter().zip(groups.iter()).for_each(|(opt_v, g)| {
for idx in g {
let idx = *idx as usize;
debug_assert!(idx < len);
unsafe {
match opt_v {
Some(v) => {
*values_ptr.add(idx) = v;
}
None => {
*values_ptr.add(idx) = T::Native::default();
unset_bit_raw(sync_ptr_validity.get(), idx)
}
};
}
}
})
}),
GroupsProxy::Slice { groups, .. } => {
offsets.par_iter().for_each(|(offset, offset_len)| {
let offset = *offset;
let offset_len = *offset_len;
let ca = ca.slice(offset as i64, offset_len);
let groups = &groups[offset..offset + offset_len];
let values_ptr = sync_ptr_values.get();

for (opt_v, [start, g_len]) in ca.into_iter().zip(groups.iter()) {
let start = *start as usize;
let end = start + *g_len as usize;
for idx in start..end {
debug_assert!(idx < len);
unsafe {
match opt_v {
Some(v) => {
*ptr.add(idx) = v;
*values_ptr.add(idx) = v;
}
None => {
*ptr.add(idx) = T::Native::default();
validity.set_unchecked(idx, false);
*values_ptr.add(idx) = T::Native::default();
unset_bit_raw(sync_ptr_validity.get(), idx)
}
};
}
}
})
}
GroupsProxy::Slice { groups, .. } => {
for (opt_v, [start, g_len]) in ca.into_iter().zip(groups.iter()) {
let start = *start as usize;
let end = start + *g_len as usize;
for idx in start..end {
debug_assert!(idx < len);
unsafe {
match opt_v {
Some(v) => {
*ptr.add(idx) = v;
}
None => {
*ptr.add(idx) = T::Native::default();
validity.set_unchecked(idx, false);
}
};
}
}
}
})
}
}
// safety: we have written all slots
Expand Down
1 change: 1 addition & 0 deletions polars/polars-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod hash;
pub mod mem;
pub mod slice;
pub mod sort;
pub mod sync;
pub mod unwrap;

pub use functions::*;
Expand Down
22 changes: 22 additions & 0 deletions polars/polars-utils/src/sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/// Utility that allows use to send pointers to another thread.
/// This is better than going through `usize` as MIRI can follow these.
#[derive(Copy, Clone)]
pub struct SyncPtr<T>(*mut T);

impl<T> SyncPtr<T> {
/// # Safety
///
/// This will make a pointer sync and send.
/// Ensure that you don't break aliasing rules.
pub unsafe fn new(ptr: *mut T) -> Self {
Self(ptr)
}

#[inline(always)]
pub fn get(self) -> *mut T {
self.0
}
}

unsafe impl<T> Sync for SyncPtr<T> {}
unsafe impl<T> Send for SyncPtr<T> {}
29 changes: 29 additions & 0 deletions py-polars/tests/unit/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,32 @@ def test_window_5868() -> None:
"value": [None, 2],
"id": [None, 1],
}

df = pl.DataFrame({"a": [None, 1, 2, 3, 3, 3, 4, 4]})

assert df.select(pl.col("a").sum().over("a"))["a"].to_list() == [
None,
1,
2,
9,
9,
9,
8,
8,
]
assert df.with_column(pl.col("a").set_sorted()).select(pl.col("a").sum().over("a"))[
"a"
].to_list() == [None, 1, 2, 9, 9, 9, 8, 8]

assert df.drop_nulls().select(pl.col("a").sum().over("a"))["a"].to_list() == [
1,
2,
9,
9,
9,
8,
8,
]
assert df.drop_nulls().with_column(pl.col("a").set_sorted()).select(
pl.col("a").sum().over("a")
)["a"].to_list() == [1, 2, 9, 9, 9, 8, 8]

0 comments on commit a9d2528

Please sign in to comment.