Skip to content

Commit

Permalink
sorted_merge_join (#3505)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 4, 2022
1 parent 1f10a2d commit 9460e46
Show file tree
Hide file tree
Showing 9 changed files with 515 additions and 16 deletions.
42 changes: 30 additions & 12 deletions polars/polars-arrow/src/data_types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub trait IsFloat {
/// # Safety
/// unsafe code downstream relies on the correct is_float call
pub unsafe trait IsFloat: private::Sealed {
fn is_float() -> bool {
false
}
Expand All @@ -12,20 +14,36 @@ pub trait IsFloat {
}
}

impl IsFloat for i8 {}
impl IsFloat for i16 {}
impl IsFloat for i32 {}
impl IsFloat for i64 {}
impl IsFloat for u8 {}
impl IsFloat for u16 {}
impl IsFloat for u32 {}
impl IsFloat for u64 {}
impl IsFloat for &str {}
impl<T: IsFloat> IsFloat for Option<T> {}
unsafe impl IsFloat for i8 {}
unsafe impl IsFloat for i16 {}
unsafe impl IsFloat for i32 {}
unsafe impl IsFloat for i64 {}
unsafe impl IsFloat for u8 {}
unsafe impl IsFloat for u16 {}
unsafe impl IsFloat for u32 {}
unsafe impl IsFloat for u64 {}
unsafe impl IsFloat for &str {}
unsafe impl<T: IsFloat> IsFloat for Option<T> {}

mod private {
pub trait Sealed {}
impl Sealed for i8 {}
impl Sealed for i16 {}
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for &str {}
impl<T: Sealed> Sealed for Option<T> {}
}

macro_rules! impl_is_float {
($tp:ty) => {
impl IsFloat for $tp {
unsafe impl IsFloat for $tp {
fn is_float() -> bool {
true
}
Expand Down
1 change: 1 addition & 0 deletions polars/polars-arrow/src/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod list;
pub mod rolling;
pub mod set;
pub mod sort_partition;
pub mod sorted_join;
#[cfg(feature = "strings")]
pub mod string;
pub mod take;
Expand Down
141 changes: 141 additions & 0 deletions polars/polars-arrow/src/kernels/sorted_join/inner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use super::*;

pub fn join<T: PartialOrd + Copy + Debug>(
left: &[T],
right: &[T],
left_offset: IdxSize,
) -> InnerJoinIds {
if left.is_empty() || right.is_empty() {
return (vec![], vec![]);
}

// * 1.5 because of possible duplicates
let cap = (std::cmp::min(left.len(), right.len()) as f32 * 1.5) as usize;
let mut out_rhs = Vec::with_capacity(cap);
let mut out_lhs = Vec::with_capacity(cap);

let mut right_idx = 0 as IdxSize;
// left array could start lower than right;
// left: [-1, 0, 1, 2],
// right: [1, 2, 3]
let first_right = right[0];
let mut left_idx = left.partition_point(|v| v < &first_right) as IdxSize;

for &val_l in &left[left_idx as usize..] {
while let Some(&val_r) = right.get(right_idx as usize) {
// matching join key
if val_l == val_r {
out_lhs.push(left_idx + left_offset);
out_rhs.push(right_idx);
let current_idx = right_idx;

loop {
right_idx += 1;
match right.get(right_idx as usize) {
// rhs depleted
None => {
// reset right index because the next lhs value can be the same
right_idx = current_idx;
break;
}
Some(&val_r) => {
if val_l == val_r {
out_lhs.push(left_idx + left_offset);
out_rhs.push(right_idx);
} else {
// reset right index because the next lhs value can be the same
right_idx = current_idx;
break;
}
}
}
}
break;
}

// right is larger than left.
if val_r > val_l {
break;
}
// continue looping the right side
right_idx += 1;
}
// loop {
// match right.get(right_idx as usize) {
// Some(&val_r) => {
// // matching join key
// if val_l == val_r {
// out_lhs.push(left_idx + left_offset);
// out_rhs.push(right_idx);
// let current_idx = right_idx;
//
// loop {
// right_idx += 1;
// match right.get(right_idx as usize) {
// // rhs depleted
// None => {
// // reset right index because the next lhs value can be the same
// right_idx = current_idx;
// break;
// }
// Some(&val_r) => {
// if val_l == val_r {
// out_lhs.push(left_idx + left_offset);
// out_rhs.push(right_idx);
// } else {
// // reset right index because the next lhs value can be the same
// right_idx = current_idx;
// break;
// }
// }
// }
// }
// break;
// }
//
// // right is larger than left.
// if val_r > val_l {
// break;
// }
// // continue looping the right side
// right_idx += 1;
// }
// // we depleted the right array
// None => {
// break;
// }
// }
// }
left_idx += 1;
}
(out_lhs, out_rhs)
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_inner_join() {
let lhs = &[0, 1, 1, 2, 3, 5];
let rhs = &[0, 1, 1, 3, 4];

let (l_idx, r_idx) = join(lhs, rhs, 0);

assert_eq!(&l_idx, &[0, 1, 1, 2, 2, 4]);
assert_eq!(&r_idx, &[0, 1, 2, 1, 2, 3]);

let lhs = &[4, 4, 4, 4, 5, 6, 6, 7, 7, 7];
let rhs = &[0, 1, 2, 3, 4, 4, 4, 6, 7, 7];
let (l_idx, r_idx) = join(lhs, rhs, 0);

assert_eq!(
&l_idx,
&[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 5, 6, 7, 7, 8, 8, 9, 9]
);
assert_eq!(
&r_idx,
&[4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 7, 8, 9, 8, 9, 8, 9]
);
}
}

0 comments on commit 9460e46

Please sign in to comment.