Skip to content

Commit

Permalink
fix(rust, python); fix projection pushdown in asof joins (#5542)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 18, 2022
1 parent 9377116 commit 244b3fb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
8 changes: 7 additions & 1 deletion polars/polars-core/src/frame/asof_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use crate::prelude::*;
use crate::utils::slice_slice;

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AsOfOptions {
pub strategy: AsofStrategy,
Expand Down Expand Up @@ -51,6 +51,12 @@ pub enum AsofStrategy {
Forward,
}

impl Default for AsofStrategy {
fn default() -> Self {
AsofStrategy::Backward
}
}

impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,24 @@ impl ProjectionPushDown {
for proj in acc_projections {
let mut add_local = true;

// Asof joins don't replace
// the right column name with the left one
// so the two join columns remain
#[cfg(feature = "asof_join")]
if matches!(options.how, JoinType::AsOf(_)) {
let names = aexpr_to_leaf_names(proj, expr_arena);
if names.len() == 1
// we only add to local projection
// if the right join column differs from the left
&& names_right.contains(&names[0])
&& !names_left.contains(&names[0])
&& !local_projection.contains(&proj)
{
local_projection.push(proj);
continue;
}
}

// if it is an alias we want to project the root column name downwards
// but we don't want to project it a this level, otherwise we project both
// the root and the alias, hence add_local = false.
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,28 @@ def test_streaming_joins() -> None:
# we cast to integer because pandas joins creates floats
a = pl.from_pandas(pd_result).with_column(pl.all().cast(int)).sort(["a", "b"])
pl.testing.assert_frame_equal(a, pl_result, check_dtype=False)


def test_join_asof_projection() -> None:
df1 = pl.DataFrame(
{
"df1_date": [20221011, 20221012, 20221013, 20221014, 20221016],
"df1_col1": ["foo", "bar", "foo", "bar", "foo"],
}
)

df2 = pl.DataFrame(
{
"df2_date": [20221012, 20221015, 20221018],
"df2_col1": ["1", "2", "3"],
}
)

assert (
(
df1.lazy().join_asof(df2.lazy(), left_on="df1_date", right_on="df2_date")
).select([pl.col("df2_date"), "df1_date"])
).collect().to_dict(False) == {
"df2_date": [None, 20221012, 20221012, 20221012, 20221015],
"df1_date": [20221011, 20221012, 20221013, 20221014, 20221016],
}

0 comments on commit 244b3fb

Please sign in to comment.