Skip to content

Commit

Permalink
improve performance of asof_join by > 2 keys (#3055)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 3, 2022
1 parent 5d9d9bd commit 7337b83
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 192 deletions.
211 changes: 110 additions & 101 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::*;
use crate::frame::groupby::hashing::HASHMAP_INIT_SIZE;
use crate::utils::{split_ca, split_df};
use crate::vector_hasher::{df_rows_to_hashes_threaded, AsU64};
use crate::POOL;
use ahash::RandomState;
use arrow::types::NativeType;
use num::Zero;
use rayon::prelude::*;
use std::fmt::Debug;
Expand All @@ -15,22 +17,6 @@ use crate::frame::hash_join::{
create_probe_table, get_hash_tbl_threaded_join_partitioned, multiple_keys as mk, prepare_strs,
};

fn find_latest_leq<T>(left_val: T, right_asof: &[T], subset_idx: &[IdxSize]) -> Option<IdxSize>
where
T: Copy + PartialOrd,
{
subset_idx
.iter()
.rev()
.find(|&&i| {
debug_assert!((i as usize) < right_asof.len());
// Safety:
// idx are in bounds
unsafe { *right_asof.get_unchecked(i as usize) <= left_val }
})
.copied()
}

pub(super) unsafe fn join_asof_backward_with_indirection_and_tolerance<
T: PartialOrd + Copy + Sub<Output = T> + Debug,
>(
Expand Down Expand Up @@ -89,6 +75,57 @@ pub(super) unsafe fn join_asof_backward_with_indirection<T: PartialOrd + Copy +
}
}

// process the group taken by the `by` operation and keep track of the offset.
// we don't process a group at once but per `index_left` we find the `right_index` and keep track
// of the offsets we have already processed in a seperate hashmap. Then on a next iteration we can
// continue from that offsets location.
#[allow(clippy::too_many_arguments)]
#[allow(clippy::type_complexity)]
fn process_group<K, T>(
k: K,
idx_left: IdxSize,
tolerance: T,
indexes_right: &[IdxSize],
right_tbl_offsets: &mut PlHashMap<K, (usize, Option<IdxSize>)>,
join_asof_fn: unsafe fn(T, &[T], &[IdxSize], T) -> (Option<IdxSize>, usize),
left_asof: &[T],
right_asof: &[T],
results: &mut Vec<Option<IdxSize>>,
) where
K: Hash + PartialEq + Eq,
T: NativeType + Sub<Output = T> + PartialOrd + num::Zero,
{
let (offset_slice, mut previous_join_idx) =
*right_tbl_offsets.get(&k).unwrap_or(&(0usize, None));
debug_assert!((idx_left as usize) < left_asof.len());
let val_l = unsafe { *left_asof.get_unchecked(idx_left as usize) };
// Safety;
// elide bound checks
let (join_idx, offset_slice_add) =
unsafe { join_asof_fn(val_l, right_asof, &indexes_right[offset_slice..], tolerance) };
let offset_slice = offset_slice + offset_slice_add;

match join_idx {
Some(_) => {
results.push(join_idx);
right_tbl_offsets.insert(k, (offset_slice, join_idx));
}
None => {
if tolerance > num::zero() {
if let Some(idx) = previous_join_idx {
debug_assert!((idx as usize) < right_asof.len());
let val_r = unsafe { *right_asof.get_unchecked(idx as usize) };
let dist = val_l - val_r;
if dist > tolerance {
previous_join_idx = None;
}
}
}
results.push(previous_join_idx)
}
}
}

fn asof_join_by_numeric<T, S>(
by_left: &ChunkedArray<S>,
by_right: &ChunkedArray<S>,
Expand Down Expand Up @@ -161,7 +198,7 @@ where
// assume the result tuples equal length of the no. of hashes processed by this thread.
let mut results = Vec::with_capacity(vals_left.len());

let mut right_tbl_offsets = PlHashMap::with_capacity(64);
let mut right_tbl_offsets = PlHashMap::with_capacity(HASHMAP_INIT_SIZE);

vals_left.iter().enumerate().for_each(|(idx_a, k)| {
let idx_a = (idx_a + offset) as IdxSize;
Expand All @@ -176,42 +213,17 @@ where
match value {
// left and right matches
Some(indexes_b) => {
let (offset_slice, mut previous_join_idx) =
*right_tbl_offsets.get(k).unwrap_or(&(0usize, None));
debug_assert!((idx_a as usize) < left_asof.len());
let val_l = unsafe { *left_asof.get_unchecked(idx_a as usize) };
// Safety;
// elide bound checks
let (join_idx, offset_slice_add) = unsafe {
join_asof_fn(
val_l,
right_asof,
&indexes_b[offset_slice..],
tolerance,
)
};
let offset_slice = offset_slice + offset_slice_add;

match join_idx {
Some(_) => {
results.push(join_idx);
right_tbl_offsets.insert(k, (offset_slice, join_idx));
}
None => {
if tolerance > num::zero() {
if let Some(idx) = previous_join_idx {
debug_assert!((idx as usize) < right_asof.len());
let val_r =
unsafe { *right_asof.get_unchecked(idx as usize) };
let dist = val_l - val_r;
if dist > tolerance {
previous_join_idx = None;
}
}
}
results.push(previous_join_idx)
}
}
process_group(
*k,
idx_a,
tolerance,
indexes_b,
&mut right_tbl_offsets,
join_asof_fn,
left_asof,
right_asof,
&mut results,
);
}
// only left values, right = null
None => results.push(None),
Expand Down Expand Up @@ -287,10 +299,10 @@ where
// local reference
let hash_tbls = &hash_tbls;

// assume the result tuples equal lenght of the no. of hashes processed by this thread.
// assume the result tuples equal length of the no. of hashes processed by this thread.
let mut results = Vec::with_capacity(vals_left.len());

let mut right_tbl_offsets = PlHashMap::with_capacity(64);
let mut right_tbl_offsets = PlHashMap::with_capacity(HASHMAP_INIT_SIZE);

vals_left.iter().enumerate().for_each(|(idx_a, k)| {
let idx_a = (idx_a + offset) as IdxSize;
Expand All @@ -305,42 +317,17 @@ where
match value {
// left and right matches
Some(indexes_b) => {
let (offset_slice, mut previous_join_idx) =
*right_tbl_offsets.get(k).unwrap_or(&(0usize, None));
debug_assert!((idx_a as usize) < left_asof.len());
let val_l = unsafe { *left_asof.get_unchecked(idx_a as usize) };
// Safety;
// elide bound checks
let (join_idx, offset_slice_add) = unsafe {
join_asof_fn(
val_l,
right_asof,
&indexes_b[offset_slice..],
tolerance,
)
};
let offset_slice = offset_slice + offset_slice_add;

match join_idx {
Some(_) => {
results.push(join_idx);
right_tbl_offsets.insert(k, (offset_slice, join_idx));
}
None => {
if tolerance > num::zero() {
if let Some(idx) = previous_join_idx {
debug_assert!((idx as usize) < right_asof.len());
let val_r =
unsafe { *right_asof.get_unchecked(idx as usize) };
let dist = val_l - val_r;
if dist > tolerance {
previous_join_idx = None;
}
}
}
results.push(previous_join_idx)
}
}
process_group(
*k,
idx_a,
tolerance,
indexes_b,
&mut right_tbl_offsets,
join_asof_fn,
left_asof,
right_asof,
&mut results,
);
}
// only left values, right = null
None => results.push(None),
Expand All @@ -360,10 +347,22 @@ fn asof_join_by_multiple<T>(
b: &DataFrame,
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
) -> Vec<Option<IdxSize>>
where
T: PolarsNumericType,
{
#[allow(clippy::type_complexity)]
let (join_asof_fn, tolerance): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
) = match tolerance {
Some(tolerance) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
}
None => (join_asof_backward_with_indirection, T::Native::zero()),
};
let left_asof = left_asof.rechunk();
let left_asof = left_asof.cont_slice().unwrap();

Expand Down Expand Up @@ -393,18 +392,17 @@ where
.map(|(probe_hashes, offset)| {
// local reference
let hash_tbls = &hash_tbls;
let mut results =
Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads());

// assume the result tuples equal length of the no. of hashes processed by this thread.
let mut results = Vec::with_capacity(probe_hashes.len());
let mut right_tbl_offsets = PlHashMap::with_capacity(HASHMAP_INIT_SIZE);

let local_offset = offset;

let mut idx_a = local_offset as IdxSize;
for probe_hashes in probe_hashes.data_views() {
for (idx, &h) in probe_hashes.iter().enumerate() {
debug_assert!(idx + offset < left_asof.len());
// Safety:
// idx are in bounds
let left_val = unsafe { *left_asof.get_unchecked(idx + offset) };

// probe table that contains the hashed value
let current_probe_table = unsafe {
get_hash_tbl_threaded_join_partitioned(h, hash_tbls, n_tables)
Expand All @@ -419,8 +417,19 @@ where

match entry {
// left and right matches
Some((_, indexes_b)) => {
results.push(find_latest_leq(left_val, right_asof, indexes_b))
Some((k, indexes_b)) => {
process_group(
// take the first idx as unique identifier of that group.
k.idx,
idx_a,
tolerance,
indexes_b,
&mut right_tbl_offsets,
join_asof_fn,
left_asof,
right_asof,
&mut results,
);
}
// only left values, right = null
None => results.push(None),
Expand Down Expand Up @@ -511,7 +520,7 @@ impl DataFrame {
#[cfg(feature = "dtype-categorical")]
check_categorical_src(lhs.dtype(), rhs.dtype())?;
}
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof)
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof, tolerance)
}
} else {
// we cannot use bit repr as that loses ordering
Expand Down Expand Up @@ -546,7 +555,7 @@ impl DataFrame {
}
}
} else {
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof)
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof, tolerance)
}
};

Expand Down

0 comments on commit 7337b83

Please sign in to comment.