Skip to content

Commit

Permalink
fix(rust, python): fix asof join schema (#5686)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 30, 2022
1 parent 62e8b0f commit 72b9d76
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 7 deletions.
6 changes: 5 additions & 1 deletion polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ impl ZipOuterJoinColumn for Float64Chunked {
}
}

pub fn _join_suffix_name(name: &str, suffix: &str) -> String {
format!("{}{}", name, suffix)
}

/// Utility method to finish a join.
#[doc(hidden)]
pub fn _finish_join(
Expand All @@ -317,7 +321,7 @@ pub fn _finish_join(
let suffix = suffix.unwrap_or("_right");

for name in rename_strs {
df_right.rename(&name, &format!("{}{}", name, suffix))?;
df_right.rename(&name, &_join_suffix_name(&name, suffix))?;
}

drop(left_names);
Expand Down
4 changes: 4 additions & 0 deletions polars/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ impl Schema {
self.inner.get_index(index)
}

pub fn contains(&self, name: &str) -> bool {
self.get(name).is_some()
}

pub fn get_index_mut(&mut self, index: usize) -> Option<(&mut String, &mut DataType)> {
self.inner.get_index_mut(index)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ impl ProjectionPushDown {
}
}

// if it is an alias we want to project the root column name downwards
// if it is an alias we want to project the leaf column name downwards
// but we don't want to project it a this level, otherwise we project both
// the root and the alias, hence add_local = false.
if let AExpr::Alias(expr, name) = expr_arena.get(proj).clone() {
Expand All @@ -976,16 +976,16 @@ impl ProjectionPushDown {
expr_arena,
) {
// Column name of the projection without any alias.
let root_column_name =
let leaf_column_name =
aexpr_to_leaf_names(proj, expr_arena).pop().unwrap();

let suffix = options.suffix.as_ref();
// If _right suffix exists we need to push a projection down without this
// suffix.
if root_column_name.ends_with(suffix) {
if leaf_column_name.ends_with(suffix) {
// downwards name is the name without the _right i.e. "foo".
let (downwards_name, _) = root_column_name
.split_at(root_column_name.len() - suffix.len());
let (downwards_name, _) = leaf_column_name
.split_at(leaf_column_name.len() - suffix.len());

let downwards_name_column =
expr_arena.add(AExpr::Column(Arc::from(downwards_name)));
Expand Down
10 changes: 9 additions & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,15 @@ pub(crate) fn det_join_schema(
let field_right =
right_on.to_field_amortized(schema_right, Context::Default, &mut arena)?;
if field_left.name != field_right.name {
new_schema.with_column(field_right.name, field_right.dtype);
if schema_left.contains(&field_right.name) {
use polars_core::frame::hash_join::_join_suffix_name;
new_schema.with_column(
_join_suffix_name(&field_right.name, options.suffix.as_ref()),
field_right.dtype,
);
} else {
new_schema.with_column(field_right.name, field_right.dtype);
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,14 @@ impl JoinBuilder {
self
}

/// The columns you want to join both tables on.
pub fn on<E: AsRef<[Expr]>>(mut self, on: E) -> Self {
let on = on.as_ref().to_vec();
self.left_on = on.clone();
self.right_on = on;
self
}

/// The columns you want to join the left table on.
pub fn left_on<E: AsRef<[Expr]>>(mut self, on: E) -> Self {
self.left_on = on.as_ref().to_vec();
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,40 @@ def test_asof_join_schema_5211() -> None:
) == {"today": pl.Int64, "next_friday": pl.Int64}


def test_asof_join_schema_5684() -> None:
df_a = pl.DataFrame(
{
"id": [1],
"a": [1],
"b": [1],
}
).lazy()

df_b = pl.DataFrame(
{
"id": [1, 1, 2],
"b": [3, -3, 6],
}
).lazy()

q = (
df_a.join_asof(df_b, by="id", left_on="a", right_on="b")
.drop("b")
.join_asof(df_b, by="id", left_on="a", right_on="b")
.drop("b")
)

projected_result = q.select(pl.all()).collect()
result = q.collect()

assert projected_result.frame_equal(result)
assert (
q.schema
== projected_result.schema
== {"id": pl.Int64, "a": pl.Int64, "b_right": pl.Int64}
)


@typing.no_type_check
def test_streaming_joins() -> None:
n = 100
Expand Down

0 comments on commit 72b9d76

Please sign in to comment.