Skip to content

Commit

Permalink
add tolerance to asof + by (#2937)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 20, 2022
1 parent ffe5162 commit df27f82
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 16 deletions.
133 changes: 119 additions & 14 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use crate::utils::{split_ca, split_df};
use crate::vector_hasher::{df_rows_to_hashes_threaded, AsU64};
use crate::POOL;
use ahash::RandomState;
use num::Zero;
use rayon::prelude::*;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Sub;

use crate::frame::hash_join::{
create_probe_table, get_hash_tbl_threaded_join_partitioned, multiple_keys as mk, prepare_strs,
Expand All @@ -27,10 +29,44 @@ where
.copied()
}

pub(super) unsafe fn join_asof_backward_with_indirection_and_tolerance<
T: PartialOrd + Copy + Sub<Output = T> + Debug,
>(
val_l: T,
right: &[T],
offsets: &[IdxSize],
tolerance: T,
) -> (Option<IdxSize>, usize) {
if offsets.is_empty() {
return (None, 0);
}
let mut previous = *offsets.get_unchecked(0);
let first = *right.get_unchecked(previous as usize);
if val_l < first {
(None, 0)
} else {
for (idx, &offset) in offsets.iter().enumerate() {
let val_r = *right.get_unchecked(offset as usize);
if val_r > val_l {
let dist = val_l - val_r;
return if dist > tolerance {
(None, idx)
} else {
(Some(previous), idx)
};
}
previous = offset
}
(None, offsets.len())
}
}

pub(super) unsafe fn join_asof_backward_with_indirection<T: PartialOrd + Copy + Debug>(
val_l: T,
right: &[T],
offsets: &[IdxSize],
// only there to have the same function signature
_: T,
) -> (Option<IdxSize>, usize) {
if offsets.is_empty() {
return (None, 0);
Expand All @@ -56,12 +92,24 @@ fn asof_join_by_numeric<T, S>(
by_right: &ChunkedArray<S>,
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
) -> Vec<Option<IdxSize>>
where
T: PolarsNumericType,
S: PolarsNumericType,
S::Native: Hash + Eq + AsU64,
{
#[allow(clippy::type_complexity)]
let (join_asof_fn, tolerance): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
) = match tolerance {
Some(tolerance) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
}
None => (join_asof_backward_with_indirection, T::Native::zero()),
};
let left_asof = left_asof.rechunk();
let left_asof = left_asof.cont_slice().unwrap();

Expand Down Expand Up @@ -108,7 +156,7 @@ where
// local reference
let hash_tbls = &hash_tbls;

// assume the result tuples equal lenght of the no. of hashes processed by this thread.
// assume the result tuples equal length of the no. of hashes processed by this thread.
let mut results = Vec::with_capacity(vals_left.len());

let mut right_tbl_offsets = PlHashMap::with_capacity(64);
Expand All @@ -126,16 +174,18 @@ where
match value {
// left and right matches
Some(indexes_b) => {
let (offset_slice, previous_join_idx) =
let (offset_slice, mut previous_join_idx) =
*right_tbl_offsets.get(k).unwrap_or(&(0usize, None));
let val_l = left_asof[idx_a as usize];
debug_assert!((idx_a as usize) < left_asof.len());
let val_l = unsafe { *left_asof.get_unchecked(idx_a as usize) };
// Safety;
// elide bound checks
let (join_idx, offset_slice_add) = unsafe {
join_asof_backward_with_indirection(
join_asof_fn(
val_l,
right_asof,
&indexes_b[offset_slice..],
tolerance,
)
};
let offset_slice = offset_slice + offset_slice_add;
Expand All @@ -145,7 +195,20 @@ where
results.push(join_idx);
right_tbl_offsets.insert(k, (offset_slice, join_idx));
}
None => results.push(previous_join_idx),
None => {
if tolerance > num::zero() {
if let Some(idx) = previous_join_idx {
debug_assert!((idx as usize) < right_asof.len());
let val_r =
unsafe { *right_asof.get_unchecked(idx as usize) };
let dist = val_l - val_r;
if dist > tolerance {
previous_join_idx = None;
}
}
}
results.push(previous_join_idx)
}
}
}
// only left values, right = null
Expand All @@ -164,10 +227,23 @@ fn asof_join_by_utf8<T>(
by_right: &Utf8Chunked,
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
) -> Vec<Option<IdxSize>>
where
T: PolarsNumericType,
{
#[allow(clippy::type_complexity)]
let (join_asof_fn, tolerance): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
) = match tolerance {
Some(tolerance) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
}
None => (join_asof_backward_with_indirection, T::Native::zero()),
};

let left_asof = left_asof.rechunk();
let left_asof = left_asof.cont_slice().unwrap();

Expand Down Expand Up @@ -227,16 +303,18 @@ where
match value {
// left and right matches
Some(indexes_b) => {
let (offset_slice, previous_join_idx) =
let (offset_slice, mut previous_join_idx) =
*right_tbl_offsets.get(k).unwrap_or(&(0usize, None));
let val_l = left_asof[idx_a as usize];
debug_assert!((idx_a as usize) < left_asof.len());
let val_l = unsafe { *left_asof.get_unchecked(idx_a as usize) };
// Safety;
// elide bound checks
let (join_idx, offset_slice_add) = unsafe {
join_asof_backward_with_indirection(
join_asof_fn(
val_l,
right_asof,
&indexes_b[offset_slice..],
tolerance,
)
};
let offset_slice = offset_slice + offset_slice_add;
Expand All @@ -246,7 +324,20 @@ where
results.push(join_idx);
right_tbl_offsets.insert(k, (offset_slice, join_idx));
}
None => results.push(previous_join_idx),
None => {
if tolerance > num::zero() {
if let Some(idx) = previous_join_idx {
debug_assert!((idx as usize) < right_asof.len());
let val_r =
unsafe { *right_asof.get_unchecked(idx as usize) };
let dist = val_l - val_r;
if dist > tolerance {
previous_join_idx = None;
}
}
}
results.push(previous_join_idx)
}
}
}
// only left values, right = null
Expand Down Expand Up @@ -348,6 +439,7 @@ impl DataFrame {
/// The keys must be sorted to perform an asof join. This is a special implementation of an asof join
/// that searches for the nearest keys within a subgroup set by `by`.
#[cfg_attr(docsrs, doc(cfg(feature = "asof_join")))]
#[allow(clippy::too_many_arguments)]
pub fn join_asof_by<I, S>(
&self,
other: &DataFrame,
Expand All @@ -356,6 +448,7 @@ impl DataFrame {
left_by: I,
right_by: I,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
) -> Result<DataFrame>
where
I: IntoIterator<Item = S>,
Expand Down Expand Up @@ -392,16 +485,21 @@ impl DataFrame {
right_by_s.utf8().unwrap(),
left_asof,
right_asof,
tolerance,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
)
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
)
}
}
}
Expand All @@ -422,16 +520,21 @@ impl DataFrame {
right_by_s.utf8().unwrap(),
left_asof,
right_asof,
tolerance,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
)
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(&left_by, &right_by, left_asof, right_asof)
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
)
}
}
}
Expand Down Expand Up @@ -487,7 +590,7 @@ mod test {
"right_vals" => [1, 2, 3, 4]
]?;

let out = a.join_asof_by(&b, "a", "a", ["b"], ["b"], AsofStrategy::Backward)?;
let out = a.join_asof_by(&b, "a", "a", ["b"], ["b"], AsofStrategy::Backward, None)?;
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
let out = out.column("right_vals").unwrap();
let out = out.i32().unwrap();
Expand Down Expand Up @@ -529,6 +632,7 @@ mod test {
["ticker"],
["ticker"],
AsofStrategy::Backward,
None,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
Expand All @@ -543,6 +647,7 @@ mod test {
["groups_numeric"],
["groups_numeric"],
AsofStrategy::Backward,
None,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ impl DataFrame {
left_by,
right_by,
options.strategy,
options.tolerance,
),
(None, None) => self.join_asof(
other,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3205,7 +3205,7 @@ def join_asof(
Perform an asof join. This is similar to a left-join except that we
match on nearest key rather than equal keys.
Both DataFrames must be sorted by the key.
Both DataFrames must be sorted by the asof_join key.
For each row in the left DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def join_asof(
Perform an asof join. This is similar to a left-join except that we
match on nearest key rather than equal keys.
Both DataFrames must be sorted by the key.
Both DataFrames must be sorted by the join_asof key.
For each row in the left DataFrame:
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,28 @@ def test_strptime_dates_datetimes() -> None:
datetime(2021, 4, 22, 0, 0),
datetime(2022, 1, 4, 0, 0),
]


def test_asof_join_tolerance_grouper() -> None:
from datetime import date

df1 = pl.DataFrame({"date": [date(2020, 1, 5), date(2020, 1, 10)], "by": [1, 1]})
df2 = pl.DataFrame(
{
"date": [date(2020, 1, 5), date(2020, 1, 6)],
"by": [1, 1],
"values": [100, 200],
}
)

out = df1.join_asof(df2, by="by", on="date", tolerance="3d")

expected = pl.DataFrame(
{
"date": [date(2020, 1, 5), date(2020, 1, 10)],
"by": [1, 1],
"values": [100, None],
}
)

assert out.frame_equal(expected)

0 comments on commit df27f82

Please sign in to comment.