Skip to content

Commit

Permalink
fix[rust]: fix joinasof + by schema (#4607)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 29, 2022
1 parent eb7116d commit 8d34e96
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 55 deletions.
38 changes: 10 additions & 28 deletions polars/polars-lazy/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_core::frame::explode::MeltArgs;
use polars_core::prelude::*;
use polars_utils::arena::{Arena, Node};

use crate::logical_plan::schema::det_join_schema;
#[cfg(feature = "ipc")]
use crate::logical_plan::IpcScanOptionsInner;
#[cfg(feature = "parquet")]
Expand Down Expand Up @@ -792,37 +793,18 @@ impl<'a> ALogicalPlanBuilder<'a> {
) -> Self {
let schema_left = self.schema();
let schema_right = self.lp_arena.get(other).schema(self.lp_arena);

// column names of left table
let mut names: PlHashSet<&str> =
PlHashSet::with_capacity(schema_left.len() + schema_right.len());
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());

for (name, dtype) in schema_left.iter() {
names.insert(name.as_str());
new_schema.with_column(name.to_string(), dtype.clone())
}

let right_names: PlHashSet<_> = right_on
let right_names = right_on
.iter()
.map(|e| {
aexpr_to_root_column_name(*e, self.expr_arena)
.expect("could not determine join column names")
self.expr_arena
.get(*e)
.to_field(&schema_right, Context::Default, self.expr_arena)
.unwrap()
.name
})
.collect();

for (name, dtype) in schema_right.iter() {
if !right_names.contains(name.as_str()) {
if names.contains(name.as_str()) {
let new_name = format!("{}{}", name, options.suffix.as_ref());
new_schema.with_column(new_name, dtype.clone());
} else {
new_schema.with_column(name.to_string(), dtype.clone());
}
}
}
.collect::<Vec<_>>();

let schema = Arc::new(new_schema);
let schema = det_join_schema(&schema_left, &schema_right, &right_names, &options);

let lp = ALogicalPlan::Join {
input_left: self.root,
Expand All @@ -832,7 +814,7 @@ impl<'a> ALogicalPlanBuilder<'a> {
right_on,
options,
};
drop(names);

let root = self.lp_arena.add(lp);
Self::new(root, self.expr_arena, self.lp_arena)
}
Expand Down
39 changes: 12 additions & 27 deletions polars/polars-lazy/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use polars_io::{
};

use crate::logical_plan::projection::rewrite_projections;
use crate::logical_plan::schema::det_join_schema;
use crate::prelude::*;
use crate::utils;
use crate::utils::{combine_predicates_expr, has_expr};
Expand Down Expand Up @@ -509,35 +510,19 @@ impl LogicalPlanBuilder {
) -> Self {
let schema_left = try_delayed!(self.0.schema(), &self.0, into);
let schema_right = try_delayed!(other.schema(), &self.0, into);
let right_names = try_delayed!(
right_on
.iter()
.map(|e| e
.to_field(&schema_right, Context::Default)
.map(|field| field.name))
.collect::<Result<Vec<_>>>(),
&self.0,
into
);

// column names of left table
let mut names: PlHashSet<&str> = PlHashSet::default();
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());

for (name, dtype) in schema_left.iter() {
names.insert(name);
new_schema.with_column(name.to_string(), dtype.clone())
}

let right_names: PlHashSet<_> = right_on
.iter()
.map(|e| utils::expr_output_name(e).expect("could not find name"))
.collect();

for (name, dtype) in schema_right.iter() {
if !right_names.iter().any(|s| s.as_ref() == name) {
if names.contains(&**name) {
let new_name = format!("{}{}", name, options.suffix.as_ref());
new_schema.with_column(new_name, dtype.clone())
} else {
new_schema.with_column(name.to_string(), dtype.clone())
}
}
}

let schema = Arc::new(new_schema);
let schema = det_join_schema(&schema_left, &schema_right, &right_names, &options);

drop(names);
LogicalPlan::Join {
input_left: Box::new(self.0),
input_right: Box::new(other),
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod lit;
pub(crate) mod optimizer;
pub(crate) mod options;
mod projection;
mod schema;

pub use anonymous_scan::*;
pub use apply::*;
Expand Down
49 changes: 49 additions & 0 deletions polars/polars-lazy/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use polars_core::prelude::*;

use crate::prelude::*;

pub(crate) fn det_join_schema(
schema_left: &Schema,
schema_right: &Schema,
right_on: &[String],
options: &JoinOptions,
) -> SchemaRef {
// column names of left table
let mut names: PlHashSet<&str> =
PlHashSet::with_capacity(schema_left.len() + schema_right.len());
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());

for (name, dtype) in schema_left.iter() {
names.insert(name.as_str());
new_schema.with_column(name.to_string(), dtype.clone())
}

let right_names: PlHashSet<_> = right_on.iter().map(|s| s.as_str()).collect();

for (name, dtype) in schema_right.iter() {
if !right_names.contains(name.as_str()) {
if names.contains(name.as_str()) {
#[cfg(feature = "asof_join")]
if let JoinType::AsOf(asof_options) = &options.how {
if let (Some(left_by), Some(right_by)) =
(&asof_options.left_by, &asof_options.right_by)
{
{
// Do not add suffix. The column of the left table will be used
if left_by.contains(name) && right_by.contains(name) {
continue;
}
}
}
}

let new_name = format!("{}{}", name, options.suffix.as_ref());
new_schema.with_column(new_name, dtype.clone());
} else {
new_schema.with_column(name.to_string(), dtype.clone());
}
}
}

Arc::new(new_schema)
}
7 changes: 7 additions & 0 deletions py-polars/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,10 @@ def custom2(
assert df.lazy().map(custom2, validate_output_schema=False).collect().to_dict(
False
) == {"a": ["1", "2", "3"], "b": ["a", "b", "c"]}


def test_join_as_of_by_schema() -> None:
a = pl.DataFrame({"a": [1], "b": [2], "c": [3]}).lazy()
b = pl.DataFrame({"a": [1], "b": [2], "d": [4]}).lazy()
q = a.join_asof(b, on="a", by="b")
assert q.collect().columns == q.columns

0 comments on commit 8d34e96

Please sign in to comment.