Skip to content

Commit

Permalink
add join_asof tolerances (#2758)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 24, 2022
1 parent bd424be commit 4044d75
Show file tree
Hide file tree
Showing 19 changed files with 402 additions and 91 deletions.
29 changes: 29 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,35 @@ pub enum AnyValue<'a> {
Object(&'a dyn PolarsObjectSafe),
}

impl<'a> AnyValue<'a> {
/// Extract a numerical value from the AnyValue
#[doc(hidden)]
#[cfg(feature = "private")]
pub fn extract<T: NumCast>(&self) -> Option<T> {
use AnyValue::*;
match self {
Null => None,
Int8(v) => NumCast::from(*v),
Int16(v) => NumCast::from(*v),
Int32(v) => NumCast::from(*v),
Int64(v) => NumCast::from(*v),
UInt8(v) => NumCast::from(*v),
UInt16(v) => NumCast::from(*v),
UInt32(v) => NumCast::from(*v),
UInt64(v) => NumCast::from(*v),
Float32(v) => NumCast::from(*v),
Float64(v) => NumCast::from(*v),
#[cfg(feature = "dtype-date")]
Date(v) => NumCast::from(*v),
#[cfg(feature = "dtype-datetime")]
Datetime(v, _, _) => NumCast::from(*v),
#[cfg(feature = "dtype-duration")]
Duration(v, _) => NumCast::from(*v),
_ => unimplemented!(),
}
}
}

impl<'a> Hash for AnyValue<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
use AnyValue::*;
Expand Down
155 changes: 152 additions & 3 deletions polars/polars-core/src/frame/asof_join/asof.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,126 @@
use polars_arrow::index::IdxSize;
use std::fmt::Debug;
use std::ops::Sub;

pub(super) fn join_asof_forward_with_tolerance<T: PartialOrd + Copy + Debug + Sub<Output = T>>(
left: &[T],
right: &[T],
tolerance: T,
) -> Vec<Option<IdxSize>> {
if right.is_empty() {
return vec![None; left.len()];
}
if left.is_empty() {
return vec![];
}

let mut out = Vec::with_capacity(left.len());
let mut offset = 0 as IdxSize;

for &val_l in left {
loop {
match right.get(offset as usize) {
Some(&val_r) => {
if val_r >= val_l {
let dist = val_r - val_l;
let value = if dist > tolerance { None } else { Some(offset) };

out.push(value);
break;
}
offset += 1;
}
None => {
out.extend(std::iter::repeat(None).take(left.len() - out.len()));
return out;
}
}
}
}
out
}

pub(super) fn join_asof_backward_with_tolerance<T>(
left: &[T],
right: &[T],
tolerance: T,
) -> Vec<Option<IdxSize>>
where
T: PartialOrd + Copy + Debug + Sub<Output = T>,
{
if right.is_empty() {
return vec![None; left.len()];
}
if left.is_empty() {
return vec![];
}
let mut out = Vec::with_capacity(left.len());

let mut offset = 0 as IdxSize;
// left array could start lower than right;
// left: [-1, 0, 1, 2],
// right: [1, 2, 3]
// first values should be None, until left has catched up
let mut left_catched_up = false;

// init with left so that the distance starts at 0
let mut previous_right = left[0];
let mut dist;

for &val_l in left {
loop {
dist = val_l - previous_right;

match right.get(offset as usize) {
Some(&val_r) => {
// we fill nulls until left value is larger than right
if !left_catched_up {
if val_l < val_r {
out.push(None);
break;
} else {
left_catched_up = true;
}
}

// right is larger than left.
// we take the last value before that
if val_r > val_l {
let value = if dist > tolerance {
None
} else {
Some(offset - 1)
};

out.push(value);
break;
}
// right still smaller or equal to left
// continue looping the right side
else {
previous_right = val_r;
offset += 1;
}
}
// we depleted the right array
// we cannot fill the remainder of the value, because we need to check tolerances
None => {
// if we have previous value, continue with that one
let val = if left_catched_up && dist <= tolerance {
Some(offset - 1)
}
// else null
else {
None
};
out.push(val);
break;
}
}
}
}
out
}

pub(super) fn join_asof_backward<T: PartialOrd + Copy + Debug>(
left: &[T],
Expand All @@ -18,6 +139,7 @@ pub(super) fn join_asof_backward<T: PartialOrd + Copy + Debug>(
loop {
match right.get(offset as usize) {
Some(&val_r) => {
// we fill nulls until left value is larger than right
if !left_catched_up {
if val_l < val_r {
out.push(None);
Expand All @@ -27,18 +149,26 @@ pub(super) fn join_asof_backward<T: PartialOrd + Copy + Debug>(
}
}

// the branch where
// right is larger than left.
// we take the last value before that
if val_r > val_l {
out.push(Some(offset - 1));
break;
} else {
}
// right still smaller or equal to left
// continue looping the right side
else {
offset += 1;
}
}
// we depleted the right array
None => {
// if we have previous value, continue with that one
let val = if left_catched_up {
Some(offset - 1)
} else {
}
// else all null
else {
None
};
out.extend(std::iter::repeat(val).take(left.len() - out.len()));
Expand Down Expand Up @@ -100,6 +230,25 @@ mod test {
assert_eq!(tuples, &[Some(1), Some(3), Some(3), Some(3)]);
}

#[test]
fn test_asof_backward_tolerance() {
let a = [-1, 20, 25, 30, 30, 40];
let b = [10, 20, 30, 30];
let tuples = join_asof_backward_with_tolerance(&a, &b, 4);
assert_eq!(tuples, &[None, Some(1), None, Some(3), Some(3), None]);
}

#[test]
fn test_asof_forward_tolerance() {
let a = [-1, 20, 25, 30, 30, 40, 52];
let b = [10, 20, 33, 55];
let tuples = join_asof_forward_with_tolerance(&a, &b, 4);
assert_eq!(
tuples,
&[None, Some(1), None, Some(2), Some(2), None, Some(3)]
);
}

#[test]
fn test_asof_forward() {
let a = [-1, 1, 2, 4, 6];
Expand Down
38 changes: 16 additions & 22 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,16 @@ impl DataFrame {
let right_asof = other.column(right_on)?;
let right_asof_name = right_asof.name();

check_asof_columns(left_asof, right_asof)?;

let left_by = self.select(left_by)?;
let right_by = other.select(right_by)?;

let left_by_s = &left_by.get_columns()[0];
let right_by_s = &right_by.get_columns()[0];

let right_join_tuples = if left_asof.bit_repr_is_large() {
let left_asof = left_asof.bit_repr_large();
let right_asof = right_asof.bit_repr_large();
// we cannot use bit repr as that loses ordering
let left_asof = left_asof.cast(&DataType::Int64)?;
let right_asof = right_asof.cast(&DataType::Int64)?;
let left_asof = left_asof.i64().unwrap();
Expand All @@ -394,24 +395,21 @@ impl DataFrame {
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.cast(&DataType::Int64).unwrap();
let left_by = left_by.i64().unwrap();
let right_by = right_by_s.cast(&DataType::Int64).unwrap();
let right_by = right_by.i64().unwrap();
asof_join_by_numeric(left_by, right_by, left_asof, right_asof)
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
} else {
let left_by = left_by_s.cast(&DataType::Int32).unwrap();
let left_by = left_by.i32().unwrap();
let right_by = right_by_s.cast(&DataType::Int32).unwrap();
let right_by = right_by.i32().unwrap();
asof_join_by_numeric(left_by, right_by, left_asof, right_asof)
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
}
}
}
} else {
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof)
}
} else {
// we cannot use bit repr as that loses ordering
let left_asof = left_asof.cast(&DataType::Int32)?;
let right_asof = right_asof.cast(&DataType::Int32)?;
let left_asof = left_asof.i32().unwrap();
Expand All @@ -427,17 +425,13 @@ impl DataFrame {
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.cast(&DataType::Int64).unwrap();
let left_by = left_by.i64().unwrap();
let right_by = right_by_s.cast(&DataType::Int64).unwrap();
let right_by = right_by.i64().unwrap();
asof_join_by_numeric(left_by, right_by, left_asof, right_asof)
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
} else {
let left_by = left_by_s.cast(&DataType::Int32).unwrap();
let left_by = left_by.i32().unwrap();
let right_by = right_by_s.cast(&DataType::Int32).unwrap();
let right_by = right_by.i32().unwrap();
asof_join_by_numeric(left_by, right_by, left_asof, right_asof)
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
}
}
}
Expand Down

0 comments on commit 4044d75

Please sign in to comment.