Skip to content

Commit

Permalink
fix cross join (#4045)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 17, 2022
1 parent 34a2f2d commit 8776d7b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl DataFrame {
(0..total_rows).map(|i| i % n_rows_right).collect_trusted();
// Safety:
// take right is in bounds
unsafe { self.take_unchecked(&take_right.into_inner()) }
unsafe { other.take_unchecked(&take_right.into_inner()) }
} else {
let iter = (0..n_rows_left).map(|_| other);
concat_df_unchecked(iter)
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,16 @@ def test_windows_not_cached() -> None:
ldf.collect()


def test_cross_join() -> None:
# triggers > 100 rows implementation
# https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L34
df1 = pl.DataFrame({"col1": ["a"], "col2": ["d"]})
df2 = pl.DataFrame({"frame2": pl.arange(0, 100, eager=True)})
out = df2.join(df1, how="cross")
df2 = pl.DataFrame({"frame2": pl.arange(0, 101, eager=True)})
assert df2.join(df1, how="cross").slice(0, 100).frame_equal(out)


if __name__ == "__main__":
test_windows_not_cached()
test_cross_join()

0 comments on commit 8776d7b

Please sign in to comment.