Skip to content

Commit

Permalink
fix outer join dispatch of categorical types (#2590)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 9, 2022
1 parent 7ef3cb8 commit 7ab0723
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
16 changes: 12 additions & 4 deletions polars/polars-core/src/series/implementations/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use ahash::RandomState;
use arrow::array::ArrayRef;
use polars_arrow::prelude::QuantileInterpolOptions;
use std::borrow::Cow;
use std::ops::Deref;

impl IntoSeries for CategoricalChunked {
fn into_series(self) -> Series {
Expand Down Expand Up @@ -86,11 +85,20 @@ impl private::PrivateSeries for SeriesWrap<CategoricalChunked> {
right_column: &Series,
opt_join_tuples: &[(Option<u32>, Option<u32>)],
) -> Series {
let ca = self.0.deref();
let categorical_map_out = Some(
self.0
.merge_categorical_map(right_column.categorical().unwrap()),
);
let s_left = self.0.cast(&DataType::UInt32).unwrap();
let ca = s_left.u32().unwrap();

let right = right_column.cast(&DataType::UInt32).unwrap();
ZipOuterJoinColumn::zip_outer_join_column(ca, &right, opt_join_tuples)
let out = ZipOuterJoinColumn::zip_outer_join_column(ca, &right, opt_join_tuples)
.cast(&DataType::Categorical)
.unwrap()
.unwrap();
let mut out = out.categorical().unwrap().clone();
out.categorical_map = categorical_map_out;
out.into_series()
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
Expand Down
24 changes: 24 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,3 +1868,27 @@ def test_first_last_expression(fruits_cars: pl.DataFrame) -> None:

out = df.select(pl.last())
assert out.columns == ["cars"]


def test_categorical_outer_join() -> None:
with pl.StringCache():
df1 = pl.DataFrame(
[
pl.Series("key1", [42]),
pl.Series("key2", ["bar"], dtype=pl.Categorical),
pl.Series("val1", [1]),
]
).lazy()

df2 = pl.DataFrame(
[
pl.Series("key1", [42]),
pl.Series("key2", ["bar"], dtype=pl.Categorical),
pl.Series("val2", [2]),
]
).lazy()

out = df1.join(df2, on=["key1", "key2"], how="outer").collect()
expected = pl.DataFrame({"val1": [1], "key1": [42], "key2": ["bar"], "val2": [2]})

assert out.frame_equal(expected)

0 comments on commit 7ab0723

Please sign in to comment.