Skip to content

Commit

Permalink
fix(rust, python): fix sort-merge dispatch of utf8 (#5315)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 24, 2022
1 parent 55643f6 commit d3397fa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
16 changes: 14 additions & 2 deletions polars/polars-core/src/frame/hash_join/sort_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ pub(super) fn par_sorted_merge_left(
DataType::Int64 => {
par_sorted_merge_left_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
}
DataType::Float32 => {
par_sorted_merge_left_impl(s_left.f32().unwrap(), s_right.f32().unwrap())
}
DataType::Float64 => {
par_sorted_merge_left_impl(s_left.f64().unwrap(), s_right.f64().unwrap())
}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -138,6 +144,12 @@ pub(super) fn par_sorted_merge_inner(
DataType::Int64 => {
par_sorted_merge_inner_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
}
DataType::Float32 => {
par_sorted_merge_inner_impl(s_left.f32().unwrap(), s_right.f32().unwrap())
}
DataType::Float64 => {
par_sorted_merge_inner_impl(s_left.f64().unwrap(), s_right.f64().unwrap())
}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -183,7 +195,7 @@ pub(super) fn sort_or_hash_inner(
.unwrap_or(1.0);
let is_numeric = s_left.dtype().to_physical().is_numeric();
match (s_left.is_sorted(), s_right.is_sorted()) {
(IsSorted::Ascending, IsSorted::Ascending) => {
(IsSorted::Ascending, IsSorted::Ascending) if is_numeric => {
if verbose {
eprintln!("inner join: keys are sorted: use sorted merge join");
}
Expand Down Expand Up @@ -253,7 +265,7 @@ pub(super) fn sort_or_hash_left(s_left: &Series, s_right: &Series, verbose: bool
let is_numeric = s_left.dtype().to_physical().is_numeric();

match (s_left.is_sorted(), s_right.is_sorted()) {
(IsSorted::Ascending, IsSorted::Ascending) => {
(IsSorted::Ascending, IsSorted::Ascending) if is_numeric => {
if verbose {
eprintln!("left join: keys are sorted: use sorted merge join");
}
Expand Down
26 changes: 16 additions & 10 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,22 @@ def test_sorted_merge_joins() -> None:
df_b = df_b.select(pl.all().reverse())

join_strategies: list[JoinStrategy] = ["left", "inner"]
for how in join_strategies:
# hash join
out_hash_join = df_a.join(df_b, on="a", how=how)

# sorted merge join
out_sorted_merge_join = df_a.with_column(
pl.col("a").set_sorted(reverse)
).join(df_b.with_column(pl.col("a").set_sorted(reverse)), on="a", how=how)

assert out_hash_join.frame_equal(out_sorted_merge_join)
for cast_to in [int, str, float]:
for how in join_strategies:
df_a_ = df_a.with_column(pl.col("a").cast(cast_to))
df_b_ = df_b.with_column(pl.col("a").cast(cast_to))

# hash join
out_hash_join = df_a_.join(df_b_, on="a", how=how)

# sorted merge join
out_sorted_merge_join = df_a_.with_column(
pl.col("a").set_sorted(reverse)
).join(
df_b_.with_column(pl.col("a").set_sorted(reverse)), on="a", how=how
)

assert out_hash_join.frame_equal(out_sorted_merge_join)


def test_join_negative_integers() -> None:
Expand Down

0 comments on commit d3397fa

Please sign in to comment.