Skip to content

Commit

Permalink
fix(rust, python): asof join 'by', 'forward' combination (#5720)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 5, 2022
1 parent 8caacab commit e6b33ce
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 15 deletions.
70 changes: 55 additions & 15 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ fn process_group<K, T>(
left_asof: &[T],
right_asof: &[T],
results: &mut Vec<Option<IdxSize>>,
forward: bool,
) where
K: Hash + PartialEq + Eq,
T: NativeType + Sub<Output = T> + PartialOrd + num::Zero,
Expand All @@ -175,6 +176,9 @@ fn process_group<K, T>(
right_tbl_offsets.insert(k, (offset_slice, join_idx));
}
None => {
if forward {
previous_join_idx = None;
}
if tolerance > num::zero() {
if let Some(idx) = previous_join_idx {
debug_assert!((idx as usize) < right_asof.len());
Expand Down Expand Up @@ -204,20 +208,31 @@ where
S::Native: Hash + Eq + AsU64,
{
#[allow(clippy::type_complexity)]
let (join_asof_fn, tolerance): (
let (join_asof_fn, tolerance, forward): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
_,
) = match (tolerance, strategy) {
(Some(tolerance), AsofStrategy::Backward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
(
join_asof_backward_with_indirection_and_tolerance,
tol,
false,
)
}
(None, AsofStrategy::Backward) => (join_asof_backward_with_indirection, T::Native::zero()),
(None, AsofStrategy::Backward) => (
join_asof_backward_with_indirection,
T::Native::zero(),
false,
),
(Some(tolerance), AsofStrategy::Forward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_forward_with_indirection_and_tolerance, tol)
(join_asof_forward_with_indirection_and_tolerance, tol, true)
}
(None, AsofStrategy::Forward) => {
(join_asof_forward_with_indirection, T::Native::zero(), true)
}
(None, AsofStrategy::Forward) => (join_asof_forward_with_indirection, T::Native::zero()),
};

let left_asof = left_asof.rechunk();
Expand Down Expand Up @@ -297,6 +312,7 @@ where
left_asof,
right_asof,
&mut results,
forward,
);
}
// only left values, right = null
Expand All @@ -322,20 +338,31 @@ where
T: PolarsNumericType,
{
#[allow(clippy::type_complexity)]
let (join_asof_fn, tolerance): (
let (join_asof_fn, tolerance, forward): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
_,
) = match (tolerance, strategy) {
(Some(tolerance), AsofStrategy::Backward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
(
join_asof_backward_with_indirection_and_tolerance,
tol,
false,
)
}
(None, AsofStrategy::Backward) => (join_asof_backward_with_indirection, T::Native::zero()),
(None, AsofStrategy::Backward) => (
join_asof_backward_with_indirection,
T::Native::zero(),
false,
),
(Some(tolerance), AsofStrategy::Forward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_forward_with_indirection_and_tolerance, tol)
(join_asof_forward_with_indirection_and_tolerance, tol, true)
}
(None, AsofStrategy::Forward) => {
(join_asof_forward_with_indirection, T::Native::zero(), true)
}
(None, AsofStrategy::Forward) => (join_asof_forward_with_indirection, T::Native::zero()),
};

let left_asof = left_asof.rechunk();
Expand Down Expand Up @@ -407,6 +434,7 @@ where
left_asof,
right_asof,
&mut results,
forward,
);
}
// only left values, right = null
Expand Down Expand Up @@ -434,20 +462,31 @@ where
T: PolarsNumericType,
{
#[allow(clippy::type_complexity)]
let (join_asof_fn, tolerance): (
let (join_asof_fn, tolerance, forward): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
_,
) = match (tolerance, strategy) {
(Some(tolerance), AsofStrategy::Backward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
(
join_asof_backward_with_indirection_and_tolerance,
tol,
false,
)
}
(None, AsofStrategy::Backward) => (join_asof_backward_with_indirection, T::Native::zero()),
(None, AsofStrategy::Backward) => (
join_asof_backward_with_indirection,
T::Native::zero(),
false,
),
(Some(tolerance), AsofStrategy::Forward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_forward_with_indirection_and_tolerance, tol)
(join_asof_forward_with_indirection_and_tolerance, tol, true)
}
(None, AsofStrategy::Forward) => {
(join_asof_forward_with_indirection, T::Native::zero(), true)
}
(None, AsofStrategy::Forward) => (join_asof_forward_with_indirection, T::Native::zero()),
};
let left_asof = left_asof.rechunk();
let left_asof = left_asof.cont_slice().unwrap();
Expand Down Expand Up @@ -515,6 +554,7 @@ where
left_asof,
right_asof,
&mut results,
forward,
);
}
// only left values, right = null
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,3 +2103,23 @@ def test_tz_aware_filter_lit() -> None:
),
],
}


def test_asof_join_by_forward() -> None:
dfa = pl.DataFrame(
{"category": ["a", "a", "a", "a", "a"], "value_one": [1, 2, 3, 5, 12]}
)

dfb = pl.DataFrame({"category": ["a"], "value_two": [3]})

assert dfa.join_asof(
dfb,
left_on="value_one",
right_on="value_two",
by="category",
strategy="forward",
).to_dict(False) == {
"category": ["a", "a", "a", "a", "a"],
"value_one": [1, 2, 3, 5, 12],
"value_two": [3, 3, 3, None, None],
}

0 comments on commit e6b33ce

Please sign in to comment.