Skip to content

Commit

Permalink
fix[rust]: fix and test sortedness propagation in joins (#4737)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 6, 2022
1 parent bd7b8df commit c1829c8
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 27 deletions.
30 changes: 21 additions & 9 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,12 @@ impl DataFrame {
self.clone()
} else {
// left join keys are in ascending order
self.take_chunked_unchecked(chunk_ids, IsSorted::Ascending)
let sorted = if left_join {
IsSorted::Ascending
} else {
IsSorted::Not
};
self.take_chunked_unchecked(chunk_ids, sorted)
}
}

Expand All @@ -314,12 +319,19 @@ impl DataFrame {
&self,
join_tuples: &[IdxSize],
left_join: bool,
sorted: bool,
) -> DataFrame {
if left_join && join_tuples.len() == self.height() {
self.clone()
} else {
let sorted = if left_join || sorted {
IsSorted::Ascending
} else {
IsSorted::Not
};

// left join tuples are always in ascending order
self._take_unchecked_slice2(join_tuples, true, IsSorted::Ascending)
self._take_unchecked_slice2(join_tuples, true, sorted)
}
}

Expand Down Expand Up @@ -486,7 +498,7 @@ impl DataFrame {

let (df_left, df_right) = POOL.join(
// safety: join indices are known to be in bounds
|| unsafe { self.create_left_df_from_slice(join_idx_left, false) },
|| unsafe { self.create_left_df_from_slice(join_idx_left, false, !swap) },
|| unsafe {
// remove join columns
remove_selected(other, &selected_right)
Expand Down Expand Up @@ -659,10 +671,10 @@ impl DataFrame {
#[cfg(feature = "dtype-categorical")]
check_categorical_src(s_left.dtype(), s_right.dtype())?;

let (join_tuples_left, join_tuples_right) = if use_sort_merge(s_left, s_right) {
let ((join_tuples_left, join_tuples_right), sorted) = if use_sort_merge(s_left, s_right) {
#[cfg(feature = "performant")]
{
par_sorted_merge_inner(s_left, s_right)
(par_sorted_merge_inner(s_left, s_right), true)
}
#[cfg(not(feature = "performant"))]
{
Expand All @@ -682,7 +694,7 @@ impl DataFrame {

let (df_left, df_right) = POOL.join(
// safety: join indices are known to be in bounds
|| unsafe { self.create_left_df_from_slice(join_tuples_left, false) },
|| unsafe { self.create_left_df_from_slice(join_tuples_left, false, sorted) },
|| unsafe {
other
.drop(s_right.name())
Expand Down Expand Up @@ -749,7 +761,7 @@ impl DataFrame {
if let Some((offset, len)) = slice {
left_idx = slice_slice(left_idx, offset, len);
}
unsafe { self.create_left_df_from_slice(left_idx, true) }
unsafe { self.create_left_df_from_slice(left_idx, true, true) }
};

let materialize_right = || {
Expand Down Expand Up @@ -783,7 +795,7 @@ impl DataFrame {
if let Some((offset, len)) = slice {
left_idx = slice_slice(left_idx, offset, len);
}
unsafe { self.create_left_df_from_slice(left_idx, true) }
unsafe { self.create_left_df_from_slice(left_idx, true, true) }
}
JoinIds::Right(left_idx) => {
let mut left_idx = &*left_idx;
Expand Down Expand Up @@ -866,7 +878,7 @@ impl DataFrame {
if let Some((offset, len)) = slice {
idx = slice_slice(idx, offset, len);
}
// idx from anti-semi join should alwasy be sorted
// idx from anti-semi join should always be sorted
self._take_unchecked_slice2(idx, true, IsSorted::Ascending)
}

Expand Down
20 changes: 13 additions & 7 deletions polars/polars-core/src/frame/hash_join/single_keys_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ impl Series {
}
}

pub(super) fn hash_join_inner(&self, other: &Series) -> (Vec<IdxSize>, Vec<IdxSize>) {
// returns the join tuples and whether or not the lhs tuples are sorted
pub(super) fn hash_join_inner(&self, other: &Series) -> ((Vec<IdxSize>, Vec<IdxSize>), bool) {
let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr());

use DataType::*;
Expand Down Expand Up @@ -133,10 +134,11 @@ where
.collect()
}

// returns the join tuples and whether or not the lhs tuples are sorted
fn num_group_join_inner<T>(
left: &ChunkedArray<T>,
right: &ChunkedArray<T>,
) -> (Vec<IdxSize>, Vec<IdxSize>)
) -> ((Vec<IdxSize>, Vec<IdxSize>), bool)
where
T: PolarsIntegerType,
T::Native: Hash + Eq + Send + AsU64 + Copy,
Expand All @@ -155,17 +157,17 @@ where
(true, true, 1, 1) => {
let keys_a = splitted_to_slice(&splitted_a);
let keys_b = splitted_to_slice(&splitted_b);
hash_join_tuples_inner(keys_a, keys_b, swap)
(hash_join_tuples_inner(keys_a, keys_b, swap), !swap)
}
(true, true, _, _) => {
let keys_a = splitted_by_chunks(&splitted_a);
let keys_b = splitted_by_chunks(&splitted_b);
hash_join_tuples_inner(keys_a, keys_b, swap)
(hash_join_tuples_inner(keys_a, keys_b, swap), !swap)
}
_ => {
let keys_a = splitted_to_opt_vec(&splitted_a);
let keys_b = splitted_to_opt_vec(&splitted_b);
hash_join_tuples_inner(keys_a, keys_b, swap)
(hash_join_tuples_inner(keys_a, keys_b, swap), !swap)
}
}
}
Expand Down Expand Up @@ -335,11 +337,15 @@ impl Utf8Chunked {
(splitted_a, splitted_b, swap, hb)
}

fn hash_join_inner(&self, other: &Utf8Chunked) -> (Vec<IdxSize>, Vec<IdxSize>) {
// returns the join tuples and whether or not the lhs tuples are sorted
fn hash_join_inner(&self, other: &Utf8Chunked) -> ((Vec<IdxSize>, Vec<IdxSize>), bool) {
let (splitted_a, splitted_b, swap, hb) = self.prepare(other, true);
let str_hashes_a = prepare_strs(&splitted_a, &hb);
let str_hashes_b = prepare_strs(&splitted_b, &hb);
hash_join_tuples_inner(str_hashes_a, str_hashes_b, swap)
(
hash_join_tuples_inner(str_hashes_a, str_hashes_b, swap),
!swap,
)
}

fn hash_join_left(&self, other: &Utf8Chunked) -> LeftJoinIds {
Expand Down
20 changes: 15 additions & 5 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3114,16 +3114,26 @@ impl DataFrame {
allow_threads: bool,
sorted: IsSorted,
) -> Self {
#[cfg(debug_assertions)]
{
if idx.len() > 2 {
match sorted {
IsSorted::Ascending => {
assert!(idx[0] <= idx[idx.len() - 1]);
}
IsSorted::Descending => {
assert!(idx[0] >= idx[idx.len() - 1]);
}
_ => {}
}
}
}
let ptr = idx.as_ptr() as *mut IdxSize;
let len = idx.len();

// create a temporary vec. we will not drop it.
let mut ca = IdxCa::from_vec("", Vec::from_raw_parts(ptr, len, len));
match sorted {
IsSorted::Not => {}
IsSorted::Ascending => ca.set_sorted(false),
IsSorted::Descending => ca.set_sorted(true),
}
ca.set_sorted2(sorted);
let out = self.take_unchecked_impl(&ca, allow_threads);

// ref count of buffers should be one because we dropped all allocations
Expand Down
63 changes: 57 additions & 6 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import pytest

import polars as pl
Expand Down Expand Up @@ -421,11 +422,61 @@ def test_join_inline_alias_4694() -> None:


def test_sorted_flag_after_joins() -> None:
a = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 1, 4]}).sort("a")
np.random.seed(1)
dfa = pl.DataFrame(
{
"a": np.random.randint(0, 13, 20),
"b": np.random.randint(0, 13, 20),
}
).sort("a")

b = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 4, 1, 4]})
dfb = pl.DataFrame(
{
"a": np.random.randint(0, 13, 10),
"b": np.random.randint(0, 13, 10),
}
)

for how in ["inner", "left"]:
assert a.join(b, how=how, on="b")["a"].flags[ # type: ignore[arg-type]
"SORTED_ASC"
]
dfapd = dfa.to_pandas()
dfbpd = dfb.to_pandas()

def test_with_pd(
dfa: pd.DataFrame, dfb: pd.DataFrame, on: str, how: str, joined: pl.DataFrame
) -> None:
a = (
dfa.merge(
dfb,
on=on,
how=how, # type: ignore[arg-type]
suffixes=("", "_right"),
)
.sort_values(["a", "b"])
.reset_index(drop=True)
)
b = joined.sort(["a", "b"]).to_pandas()
pd.testing.assert_frame_equal(a, b)

joined = dfa.join(dfb, on="b", how="left")
assert joined["a"].flags["SORTED_ASC"]
test_with_pd(dfapd, dfbpd, "b", "left", joined)

joined = dfa.join(dfb, on="b", how="inner")
assert joined["a"].flags["SORTED_ASC"]
test_with_pd(dfapd, dfbpd, "b", "inner", joined)

joined = dfa.join(dfb, on="b", how="semi")
assert joined["a"].flags["SORTED_ASC"]
joined = dfa.join(dfb, on="b", how="semi")
assert joined["a"].flags["SORTED_ASC"]

joined = dfb.join(dfa, on="b", how="left")
assert not joined["a"].flags["SORTED_ASC"]
test_with_pd(dfbpd, dfapd, "b", "left", joined)

joined = dfb.join(dfa, on="b", how="inner")
assert not joined["a"].flags["SORTED_ASC"]

joined = dfb.join(dfa, on="b", how="semi")
assert not joined["a"].flags["SORTED_ASC"]
joined = dfb.join(dfa, on="b", how="anti")
assert not joined["a"].flags["SORTED_ASC"]

0 comments on commit c1829c8

Please sign in to comment.