Skip to content

Commit

Permalink
feat(rust, python): Implement forward strategy in groupby join_asof (#…
Browse files Browse the repository at this point in the history
…5335)

Co-authored-by: Alec Zorab <alec@zorab.io>
  • Loading branch information
AlecZorab and AlecZorab committed Oct 26, 2022
1 parent 4a5a1b2 commit db8ced1
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 17 deletions.
244 changes: 227 additions & 17 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,35 @@ pub(super) unsafe fn join_asof_backward_with_indirection_and_tolerance<
}
}

pub(super) unsafe fn join_asof_forward_with_indirection_and_tolerance<
T: PartialOrd + Copy + Sub<Output = T> + Debug,
>(
val_l: T,
right: &[T],
offsets: &[IdxSize],
tolerance: T,
) -> (Option<IdxSize>, usize) {
if offsets.is_empty() {
return (None, 0);
}
let last_offset = *offsets.get_unchecked(offsets.len() - 1);
let last_value = *right.get_unchecked(last_offset as usize);
if val_l <= last_value {
for (idx, &offset) in offsets.iter().enumerate() {
let val_r = *right.get_unchecked(offset as usize);
if val_r >= val_l {
let dist = val_r - val_l;
return if dist > tolerance {
(None, idx)
} else {
(Some(offset), idx)
};
}
}
}
(None, offsets.len())
}

pub(super) unsafe fn join_asof_backward_with_indirection<T: PartialOrd + Copy + Debug>(
val_l: T,
right: &[T],
Expand All @@ -87,6 +116,29 @@ pub(super) unsafe fn join_asof_backward_with_indirection<T: PartialOrd + Copy +
}
}

pub(super) unsafe fn join_asof_forward_with_indirection<T: PartialOrd + Copy + Debug>(
val_l: T,
right: &[T],
offsets: &[IdxSize],
// only there to have the same function signature
_: T,
) -> (Option<IdxSize>, usize) {
if offsets.is_empty() {
return (None, 0);
}
let last_offset = *offsets.get_unchecked(offsets.len() - 1);
let last_value = *right.get_unchecked(last_offset as usize);
if val_l <= last_value {
for (idx, &offset) in offsets.iter().enumerate() {
let val_r = *right.get_unchecked(offset as usize);
if val_r >= val_l {
return (Some(offset), idx);
}
}
}
(None, offsets.len())
}

// process the group taken by the `by` operation and keep track of the offset.
// we don't process a group at once but per `index_left` we find the `right_index` and keep track
// of the offsets we have already processed in a separate hashmap. Then on a next iteration we can
Expand Down Expand Up @@ -144,6 +196,7 @@ fn asof_join_by_numeric<T, S>(
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
strategy: AsofStrategy,
) -> PolarsResult<Vec<Option<IdxSize>>>
where
T: PolarsNumericType,
Expand All @@ -154,13 +207,19 @@ where
let (join_asof_fn, tolerance): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
) = match tolerance {
Some(tolerance) => {
) = match (tolerance, strategy) {
(Some(tolerance), AsofStrategy::Backward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
}
None => (join_asof_backward_with_indirection, T::Native::zero()),
(None, AsofStrategy::Backward) => (join_asof_backward_with_indirection, T::Native::zero()),
(Some(tolerance), AsofStrategy::Forward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_forward_with_indirection_and_tolerance, tol)
}
(None, AsofStrategy::Forward) => (join_asof_forward_with_indirection, T::Native::zero()),
};

let left_asof = left_asof.rechunk();
let err = |_: PolarsError| {
PolarsError::ComputeError("Keys are not allowed to have null values in asof join.".into())
Expand Down Expand Up @@ -257,6 +316,7 @@ fn asof_join_by_utf8<T>(
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
strategy: AsofStrategy,
) -> Vec<Option<IdxSize>>
where
T: PolarsNumericType,
Expand All @@ -265,12 +325,17 @@ where
let (join_asof_fn, tolerance): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
) = match tolerance {
Some(tolerance) => {
) = match (tolerance, strategy) {
(Some(tolerance), AsofStrategy::Backward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
}
None => (join_asof_backward_with_indirection, T::Native::zero()),
(None, AsofStrategy::Backward) => (join_asof_backward_with_indirection, T::Native::zero()),
(Some(tolerance), AsofStrategy::Forward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_forward_with_indirection_and_tolerance, tol)
}
(None, AsofStrategy::Forward) => (join_asof_forward_with_indirection, T::Native::zero()),
};

let left_asof = left_asof.rechunk();
Expand Down Expand Up @@ -363,6 +428,7 @@ fn asof_join_by_multiple<T>(
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
strategy: AsofStrategy,
) -> Vec<Option<IdxSize>>
where
T: PolarsNumericType,
Expand All @@ -371,12 +437,17 @@ where
let (join_asof_fn, tolerance): (
unsafe fn(T::Native, &[T::Native], &[IdxSize], T::Native) -> (Option<IdxSize>, usize),
_,
) = match tolerance {
Some(tolerance) => {
) = match (tolerance, strategy) {
(Some(tolerance), AsofStrategy::Backward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_backward_with_indirection_and_tolerance, tol)
}
None => (join_asof_backward_with_indirection, T::Native::zero()),
(None, AsofStrategy::Backward) => (join_asof_backward_with_indirection, T::Native::zero()),
(Some(tolerance), AsofStrategy::Forward) => {
let tol = tolerance.extract::<T::Native>().unwrap();
(join_asof_forward_with_indirection_and_tolerance, tol)
}
(None, AsofStrategy::Forward) => (join_asof_forward_with_indirection, T::Native::zero()),
};
let left_asof = left_asof.rechunk();
let left_asof = left_asof.cont_slice().unwrap();
Expand Down Expand Up @@ -479,10 +550,6 @@ impl DataFrame {
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
if let AsofStrategy::Forward = strategy {
panic!("forward strategy + groupby not yet implemented");
}

use DataType::*;
let left_asof = self.column(left_on)?;
let right_asof = other.column(right_on)?;
Expand Down Expand Up @@ -511,19 +578,20 @@ impl DataFrame {
left_asof,
right_asof,
tolerance,
strategy,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
}
}
Expand All @@ -540,6 +608,7 @@ impl DataFrame {
left_asof,
right_asof,
tolerance,
strategy,
)
}
} else {
Expand All @@ -557,19 +626,20 @@ impl DataFrame {
left_asof,
right_asof,
tolerance,
strategy,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance,
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
}
}
Expand All @@ -581,6 +651,7 @@ impl DataFrame {
left_asof,
right_asof,
tolerance,
strategy,
)
}
};
Expand Down Expand Up @@ -730,4 +801,143 @@ mod test {

Ok(())
}

#[test]
fn test_asof_by3() -> PolarsResult<()> {
let a = df![
"a" => [ -1, 2, 2, 3, 3, 3, 4],
"b" => ["a", "a", "b", "c", "d", "e", "f"]
]?;

let b = df![
"a" => [ 1, 3, 2, 3, 2],
"b" => ["a", "a", "b", "c", "d"],
"right_vals" => [ 1, 3, 2, 3, 4]
]?;

let out = a.join_asof_by(&b, "a", "a", ["b"], ["b"], AsofStrategy::Forward, None)?;
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
let out = out.column("right_vals").unwrap();
let out = out.i32().unwrap();
assert_eq!(
Vec::from(out),
&[Some(1), Some(3), Some(2), Some(3), None, None, None]
);

let out = a.join_asof_by(
&b,
"a",
"a",
["b"],
["b"],
AsofStrategy::Forward,
Some(AnyValue::Int32(1)),
)?;
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
let out = out.column("right_vals").unwrap();
let out = out.i32().unwrap();
assert_eq!(
Vec::from(out),
&[None, Some(3), Some(2), Some(3), None, None, None]
);

Ok(())
}

#[test]
fn test_asof_by4() -> PolarsResult<()> {
let trades = df![
"time" => [23i64, 38, 48, 48, 48],
"ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
"groups_numeric" => [1, 1, 2, 2, 3],
"bid" => [51.95, 51.95, 720.77, 720.92, 98.0]
]?;

let quotes = df![
"time" => [23i64, 23, 30, 41, 48, 49, 72, 75],
"ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
"bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01],
"groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1],

]?;
/*
trades:
shape: (5, 4)
┌──────┬────────┬────────────────┬────────┐
│ time ┆ ticker ┆ groups_numeric ┆ bid │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ i32 ┆ f64 │
╞══════╪════════╪════════════════╪════════╡
│ 23 ┆ MSFT ┆ 1 ┆ 51.95 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 38 ┆ MSFT ┆ 1 ┆ 51.95 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 48 ┆ GOOG ┆ 2 ┆ 720.77 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 48 ┆ GOOG ┆ 2 ┆ 720.92 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 48 ┆ AAPL ┆ 3 ┆ 98.0 │
└──────┴────────┴────────────────┴────────┘
quotes:
shape: (8, 4)
┌──────┬────────┬───────┬────────────────┐
│ time ┆ ticker ┆ bid ┆ groups_numeric │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ f64 ┆ i32 │
╞══════╪════════╪═══════╪════════════════╡
│ 23 ┆ GOOG ┆ 720.5 ┆ 2 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 23 ┆ MSFT ┆ 51.95 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 30 ┆ MSFT ┆ 51.97 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 41 ┆ MSFT ┆ 51.99 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 48 ┆ GOOG ┆ 720.5 ┆ 2 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 49 ┆ AAPL ┆ 97.99 ┆ 3 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 72 ┆ GOOG ┆ 720.5 ┆ 2 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 75 ┆ MSFT ┆ 52.01 ┆ 1 │
└──────┴────────┴───────┴────────────────┘
*/

let out = trades.join_asof_by(
&quotes,
"time",
"time",
["ticker"],
["ticker"],
AsofStrategy::Forward,
None,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();
let expected = &[
Some(51.95),
Some(51.99),
Some(720.5),
Some(720.5),
Some(97.99),
];

assert_eq!(Vec::from(a), expected);

let out = trades.join_asof_by(
&quotes,
"time",
"time",
["groups_numeric"],
["groups_numeric"],
AsofStrategy::Forward,
None,
)?;
let a = out.column("bid_right").unwrap();
let a = a.f64().unwrap();

assert_eq!(Vec::from(a), expected);

Ok(())
}
}

0 comments on commit db8ced1

Please sign in to comment.