Skip to content

Commit

Permalink
fix[rust]: fix alignment in joins (#4723)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 4, 2022
1 parent d4f0211 commit f114d1d
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 44 deletions.
24 changes: 18 additions & 6 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ where
// TODO! optimize this. This does a full scan backwards. Use the same strategy as in the single `by`
// implementations
fn asof_join_by_multiple<T>(
a: &DataFrame,
b: &DataFrame,
a: &mut DataFrame,
b: &mut DataFrame,
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
tolerance: Option<AnyValue<'static>>,
Expand Down Expand Up @@ -487,8 +487,8 @@ impl DataFrame {

check_asof_columns(left_asof, right_asof)?;

let left_by = self.select(left_by)?;
let right_by = other.select(right_by)?;
let mut left_by = self.select(left_by)?;
let mut right_by = other.select(right_by)?;

let left_by_s = &left_by.get_columns()[0];
let right_by_s = &right_by.get_columns()[0];
Expand Down Expand Up @@ -531,7 +531,13 @@ impl DataFrame {
#[cfg(feature = "dtype-categorical")]
check_categorical_src(lhs.dtype(), rhs.dtype())?;
}
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof, tolerance)
asof_join_by_multiple(
&mut left_by,
&mut right_by,
left_asof,
right_asof,
tolerance,
)
}
} else {
// we cannot use bit repr as that loses ordering
Expand Down Expand Up @@ -566,7 +572,13 @@ impl DataFrame {
}
}
} else {
asof_join_by_multiple(&left_by, &right_by, left_asof, right_asof, tolerance)
asof_join_by_multiple(
&mut left_by,
&mut right_by,
left_asof,
right_asof,
tolerance,
)
}
};

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/frame/groupby/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ pub(crate) fn populate_multiple_key_hashmap2<'a, V, H, F, G>(
}

pub(crate) fn groupby_threaded_multiple_keys_flat(
keys: DataFrame,
mut keys: DataFrame,
n_partitions: usize,
sorted: bool,
) -> GroupsProxy {
let dfs = split_df(&keys, n_partitions).unwrap();
let dfs = split_df(&mut keys, n_partitions).unwrap();
let (hashes, _random_state) = df_rows_to_hashes_threaded(&dfs, None);
let n_partitions = n_partitions as u64;

Expand Down
23 changes: 12 additions & 11 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,9 @@ impl DataFrame {
JoinType::Inner => {
let left = DataFrame::new_no_checks(selected_left_physical);
let right = DataFrame::new_no_checks(selected_right_physical);
let (left, right, swap) = det_hash_prone_order!(left, right);
let (join_idx_left, join_idx_right) = inner_join_multiple_keys(&left, &right, swap);
let (mut left, mut right, swap) = det_hash_prone_order!(left, right);
let (join_idx_left, join_idx_right) =
inner_join_multiple_keys(&mut left, &mut right, swap);
let mut join_idx_left = &*join_idx_left;
let mut join_idx_right = &*join_idx_right;

Expand All @@ -494,18 +495,18 @@ impl DataFrame {
self.finish_join(df_left, df_right, suffix)
}
JoinType::Left => {
let left = DataFrame::new_no_checks(selected_left_physical);
let right = DataFrame::new_no_checks(selected_right_physical);
let ids = left_join_multiple_keys(&left, &right, None, None);
let mut left = DataFrame::new_no_checks(selected_left_physical);
let mut right = DataFrame::new_no_checks(selected_right_physical);
let ids = left_join_multiple_keys(&mut left, &mut right, None, None);

self.finish_left_join(ids, &remove_selected(other, &selected_right), suffix, slice)
}
JoinType::Outer => {
let left = DataFrame::new_no_checks(selected_left_physical);
let right = DataFrame::new_no_checks(selected_right_physical);

let (left, right, swap) = det_hash_prone_order!(left, right);
let opt_join_tuples = outer_join_multiple_keys(&left, &right, swap);
let (mut left, mut right, swap) = det_hash_prone_order!(left, right);
let opt_join_tuples = outer_join_multiple_keys(&mut left, &mut right, swap);

let mut opt_join_tuples = &*opt_join_tuples;

Expand Down Expand Up @@ -547,13 +548,13 @@ impl DataFrame {
)),
#[cfg(feature = "semi_anti_join")]
JoinType::Anti | JoinType::Semi => {
let left = DataFrame::new_no_checks(selected_left_physical);
let right = DataFrame::new_no_checks(selected_right_physical);
let mut left = DataFrame::new_no_checks(selected_left_physical);
let mut right = DataFrame::new_no_checks(selected_right_physical);

let idx = if matches!(how, JoinType::Anti) {
left_anti_multiple_keys(&left, &right)
left_anti_multiple_keys(&mut left, &mut right)
} else {
left_semi_multiple_keys(&left, &right)
left_semi_multiple_keys(&mut left, &mut right)
};
// Safety:
// indices are in bounds
Expand Down
26 changes: 13 additions & 13 deletions polars/polars-core/src/frame/hash_join/multiple_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ pub(crate) fn get_offsets(probe_hashes: &[UInt64Chunked]) -> Vec<usize> {
}

pub(crate) fn inner_join_multiple_keys(
a: &DataFrame,
b: &DataFrame,
a: &mut DataFrame,
b: &mut DataFrame,
swap: bool,
) -> (Vec<IdxSize>, Vec<IdxSize>) {
// we assume that the b DataFrame is the shorter relation.
Expand Down Expand Up @@ -248,14 +248,14 @@ pub fn private_left_join_multiple_keys(
chunk_mapping_left: Option<&[ChunkId]>,
chunk_mapping_right: Option<&[ChunkId]>,
) -> LeftJoinIds {
let a = DataFrame::new_no_checks(to_physical_and_bit_repr(a.get_columns()));
let b = DataFrame::new_no_checks(to_physical_and_bit_repr(b.get_columns()));
left_join_multiple_keys(&a, &b, chunk_mapping_left, chunk_mapping_right)
let mut a = DataFrame::new_no_checks(to_physical_and_bit_repr(a.get_columns()));
let mut b = DataFrame::new_no_checks(to_physical_and_bit_repr(b.get_columns()));
left_join_multiple_keys(&mut a, &mut b, chunk_mapping_left, chunk_mapping_right)
}

pub(crate) fn left_join_multiple_keys(
a: &DataFrame,
b: &DataFrame,
a: &mut DataFrame,
b: &mut DataFrame,
// map the global indices to [chunk_idx, array_idx]
// only needed if we have non contiguous memory
chunk_mapping_left: Option<&[ChunkId]>,
Expand Down Expand Up @@ -388,8 +388,8 @@ pub(crate) fn create_build_table_semi_anti(

#[cfg(feature = "semi_anti_join")]
pub(crate) fn semi_anti_join_multiple_keys_impl<'a>(
a: &'a DataFrame,
b: &'a DataFrame,
a: &'a mut DataFrame,
b: &'a mut DataFrame,
) -> impl ParallelIterator<Item = (IdxSize, bool)> + 'a {
// we should not join on logical types
debug_assert!(!a.iter().any(|s| s.is_logical()));
Expand Down Expand Up @@ -454,15 +454,15 @@ pub(crate) fn semi_anti_join_multiple_keys_impl<'a>(
}

#[cfg(feature = "semi_anti_join")]
pub(super) fn left_anti_multiple_keys(a: &DataFrame, b: &DataFrame) -> Vec<IdxSize> {
pub(super) fn left_anti_multiple_keys(a: &mut DataFrame, b: &mut DataFrame) -> Vec<IdxSize> {
semi_anti_join_multiple_keys_impl(a, b)
.filter(|tpls| !tpls.1)
.map(|tpls| tpls.0)
.collect()
}

#[cfg(feature = "semi_anti_join")]
pub(super) fn left_semi_multiple_keys(a: &DataFrame, b: &DataFrame) -> Vec<IdxSize> {
pub(super) fn left_semi_multiple_keys(a: &mut DataFrame, b: &mut DataFrame) -> Vec<IdxSize> {
semi_anti_join_multiple_keys_impl(a, b)
.filter(|tpls| tpls.1)
.map(|tpls| tpls.0)
Expand Down Expand Up @@ -540,8 +540,8 @@ fn probe_outer<F, G, H>(
}

pub(crate) fn outer_join_multiple_keys(
a: &DataFrame,
b: &DataFrame,
a: &mut DataFrame,
b: &mut DataFrame,
swap: bool,
) -> Vec<(Option<IdxSize>, Option<IdxSize>)> {
// we assume that the b DataFrame is the shorter relation.
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ impl DataFrame {

/// Does a filter but splits thread chunks vertically instead of horizontally
/// This yields a DataFrame with `n_chunks == n_threads`.
fn filter_vertical(&self, mask: &BooleanChunked) -> Result<Self> {
fn filter_vertical(&mut self, mask: &BooleanChunked) -> Result<Self> {
let n_threads = POOL.current_num_threads();

let masks = split_ca(mask, n_threads).unwrap();
Expand Down Expand Up @@ -1503,7 +1503,7 @@ impl DataFrame {
/// ```
pub fn filter(&self, mask: &BooleanChunked) -> Result<Self> {
if std::env::var("POLARS_VERT_PAR").is_ok() {
return self.filter_vertical(mask);
return self.clone().filter_vertical(mask);
}
let new_col = self.try_apply_columns_par(&|s| match s.dtype() {
DataType::Utf8 => s.filter_threaded(mask, true),
Expand Down Expand Up @@ -3060,7 +3060,7 @@ impl DataFrame {

/// Hash and combine the row values
#[cfg(feature = "row_hash")]
pub fn hash_rows(&self, hasher_builder: Option<RandomState>) -> Result<UInt64Chunked> {
pub fn hash_rows(&mut self, hasher_builder: Option<RandomState>) -> Result<UInt64Chunked> {
let dfs = split_df(self, POOL.current_num_threads())?;
let (cas, _) = df_rows_to_hashes_threaded(&dfs, hasher_builder);

Expand Down
5 changes: 4 additions & 1 deletion polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ fn flatten_df(df: &DataFrame) -> impl Iterator<Item = DataFrame> + '_ {

#[cfg(feature = "private")]
#[doc(hidden)]
pub fn split_df(df: &DataFrame, n: usize) -> Result<Vec<DataFrame>> {
/// Split a [`DataFrame`] into `n` parts. We take a `&mut` to be able to repartition/align chunks.
pub fn split_df(df: &mut DataFrame, n: usize) -> Result<Vec<DataFrame>> {
if n == 0 {
return Ok(vec![df.clone()]);
}
// make sure that chunks are aligned.
df.rechunk();
let total_len = df.height();
let chunk_size = total_len / n;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl PartitionGroupByExec {
}

fn run_partitions(
df: &DataFrame,
df: &mut DataFrame,
exec: &PartitionGroupByExec,
state: &ExecutionState,
n_threads: usize,
Expand Down Expand Up @@ -201,7 +201,7 @@ impl Executor for PartitionGroupByExec {
}
}
let dfs = {
let original_df = self.input.execute(state)?;
let mut original_df = self.input.execute(state)?;

// already get the keys. This is the very last minute decision which groupby method we choose.
// If the column is a categorical, we know the number of groups we have and can decide to continue
Expand Down Expand Up @@ -229,7 +229,13 @@ impl Executor for PartitionGroupByExec {

// set it here, because `self.input.execute` will clear the schema cache.
state.set_schema(self.input_schema.clone());
run_partitions(&original_df, self, state, n_threads, self.maintain_order)?
run_partitions(
&mut original_df,
self,
state,
n_threads,
self.maintain_order,
)?
};
state.clear_schema_cache();

Expand Down
8 changes: 4 additions & 4 deletions polars/tests/it/core/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ use super::*;

#[test]
fn test_chunked_left_join() -> Result<()> {
let band_members = df![
let mut band_members = df![
"name" => ["john", "paul", "mick", "bob"],
"band" => ["beatles", "beatles", "stones", "wailers"],
]?;

let band_instruments = df![
let mut band_instruments = df![
"name" => ["john", "paul", "keith"],
"plays" => ["guitar", "bass", "guitar"]
]?;

let band_instruments = accumulate_dataframes_vertical(split_df(&band_instruments, 2)?)?;
let band_members = accumulate_dataframes_vertical(split_df(&band_members, 2)?)?;
let band_instruments = accumulate_dataframes_vertical(split_df(&mut band_instruments, 2)?)?;
let band_members = accumulate_dataframes_vertical(split_df(&mut band_members, 2)?)?;
assert_eq!(band_instruments.n_chunks()?, 2);
assert_eq!(band_members.n_chunks()?, 2);

Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ impl PyDataFrame {
self.df.shrink_to_fit();
}

pub fn hash_rows(&self, k0: u64, k1: u64, k2: u64, k3: u64) -> PyResult<PySeries> {
pub fn hash_rows(&mut self, k0: u64, k1: u64, k2: u64, k3: u64) -> PyResult<PySeries> {
let hb = ahash::RandomState::with_seeds(k0, k1, k2, k3);
let hash = self.df.hash_rows(Some(hb)).map_err(PyPolarsErr::from)?;
Ok(hash.into_series().into())
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,49 @@ def test_asof_join_projection_resolution_4606() -> None:
assert joined_tbl.groupby("a").agg(
[pl.col("c").sum().alias("c")]
).collect().columns == ["a", "c"]


def test_join_chunks_alignment_4720() -> None:
df1 = pl.DataFrame(
{
"index1": pl.arange(0, 2, eager=True),
"index2": pl.arange(10, 12, eager=True),
}
)

df2 = pl.DataFrame(
{
"index3": pl.arange(100, 102, eager=True),
}
)

df3 = pl.DataFrame(
{
"index1": pl.arange(0, 2, eager=True),
"index2": pl.arange(10, 12, eager=True),
"index3": pl.arange(100, 102, eager=True),
}
)
assert (
df1.join(df2, how="cross").join(
df3,
on=["index1", "index2", "index3"],
how="left",
)
).to_dict(False) == {
"index1": [0, 0, 1, 1],
"index2": [10, 10, 11, 11],
"index3": [100, 101, 100, 101],
}

assert (
df1.join(df2, how="cross").join(
df3,
on=["index3", "index1", "index2"],
how="left",
)
).to_dict(False) == {
"index1": [0, 0, 1, 1],
"index2": [10, 10, 11, 11],
"index3": [100, 101, 100, 101],
}

0 comments on commit f114d1d

Please sign in to comment.