Skip to content

Commit

Permalink
fix[rust]: fix semi/anti join schema (#4793)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 9, 2022
1 parent 5dae949 commit fa419a5
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 41 deletions.
2 changes: 1 addition & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ product = ["polars-core/product"]
unique_counts = ["polars-core/unique_counts", "polars-lazy/unique_counts"]
log = ["polars-ops/log", "polars-lazy/log"]
partition_by = ["polars-core/partition_by"]
semi_anti_join = ["polars-core/semi_anti_join"]
semi_anti_join = ["polars-core/semi_anti_join", "polars-lazy/semi_anti_join"]
list_eval = ["polars-lazy/list_eval"]
cumulative_eval = ["polars-lazy/cumulative_eval"]
chunked_ids = ["polars-core/chunked_ids", "polars-lazy/chunked_ids"]
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ search_sorted = ["polars-ops/search_sorted"]
meta = []
pivot = ["polars-core/rows", "polars-ops/pivot"]
top_k = ["polars-ops/top_k"]
semi_anti_join = ["polars-core/semi_anti_join"]

# no guarantees whatsoever
private = ["polars-time/private"]
Expand Down
88 changes: 48 additions & 40 deletions polars/polars-lazy/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,65 @@ use polars_core::prelude::*;
use crate::prelude::*;

pub(crate) fn det_join_schema(
schema_left: &Schema,
schema_right: &Schema,
schema_left: &SchemaRef,
schema_right: &SchemaRef,
left_on: &[Expr],
right_on: &[String],
options: &JoinOptions,
) -> Result<SchemaRef> {
// column names of left table
let mut names: PlHashSet<&str> =
PlHashSet::with_capacity(schema_left.len() + schema_right.len());
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());

for (name, dtype) in schema_left.iter() {
names.insert(name.as_str());
new_schema.with_column(name.to_string(), dtype.clone())
}
match options.how {
// semi and anti joins are just filtering operations
// the schema will never change.
#[cfg(feature = "semi_anti_join")]
JoinType::Semi | JoinType::Anti => Ok(schema_left.clone()),
_ => {
// column names of left table
let mut names: PlHashSet<&str> =
PlHashSet::with_capacity(schema_left.len() + schema_right.len());
let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len());

// make sure that expression are assigned to the schema
// an expression can have an alias, and change a dtype.
// we only do this for the left hand side as the right hand side
// is dropped.
for e in left_on {
let field = e.to_field(schema_left, Context::Default)?;
new_schema.with_column(field.name, field.dtype)
}
for (name, dtype) in schema_left.iter() {
names.insert(name.as_str());
new_schema.with_column(name.to_string(), dtype.clone())
}

// make sure that expression are assigned to the schema
// an expression can have an alias, and change a dtype.
// we only do this for the left hand side as the right hand side
// is dropped.
for e in left_on {
let field = e.to_field(schema_left, Context::Default)?;
new_schema.with_column(field.name, field.dtype)
}

let right_names: PlHashSet<_> = right_on.iter().map(|s| s.as_str()).collect();

for (name, dtype) in schema_right.iter() {
if !right_names.contains(name.as_str()) {
if names.contains(name.as_str()) {
#[cfg(feature = "asof_join")]
if let JoinType::AsOf(asof_options) = &options.how {
if let (Some(left_by), Some(right_by)) =
(&asof_options.left_by, &asof_options.right_by)
{
{
// Do not add suffix. The column of the left table will be used
if left_by.contains(name) && right_by.contains(name) {
continue;
let right_names: PlHashSet<_> = right_on.iter().map(|s| s.as_str()).collect();

for (name, dtype) in schema_right.iter() {
if !right_names.contains(name.as_str()) {
if names.contains(name.as_str()) {
#[cfg(feature = "asof_join")]
if let JoinType::AsOf(asof_options) = &options.how {
if let (Some(left_by), Some(right_by)) =
(&asof_options.left_by, &asof_options.right_by)
{
{
// Do not add suffix. The column of the left table will be used
if left_by.contains(name) && right_by.contains(name) {
continue;
}
}
}
}

let new_name = format!("{}{}", name, options.suffix.as_ref());
new_schema.with_column(new_name, dtype.clone());
} else {
new_schema.with_column(name.to_string(), dtype.clone());
}
}

let new_name = format!("{}{}", name, options.suffix.as_ref());
new_schema.with_column(new_name, dtype.clone());
} else {
new_schema.with_column(name.to_string(), dtype.clone());
}

Ok(Arc::new(new_schema))
}
}

Ok(Arc::new(new_schema))
}
17 changes: 17 additions & 0 deletions py-polars/tests/unit/test_projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import polars as pl


def test_projection_on_semi_join_4789() -> None:
lfa = pl.DataFrame({"a": [1], "p": [1]}).lazy()

lfb = pl.DataFrame({"seq": [1], "p": [1]}).lazy()

ab = lfa.join(lfb, on="p", how="semi").inspect()

intermediate_agg = (ab.groupby("a").agg([pl.col("a").list().alias("seq")])).select(
["a", "seq"]
)

q = ab.join(intermediate_agg, on="a")

assert q.collect().to_dict(False) == {"a": [1], "p": [1], "seq": [[1]]}

0 comments on commit fa419a5

Please sign in to comment.