Skip to content

Commit

Permalink
fix bug in outer join on categorical
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 17, 2022
1 parent da7a4bb commit 76a2cbd
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 56 deletions.
32 changes: 29 additions & 3 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod multiple_keys;

use polars_arrow::utils::CustomIterTools;

use crate::frame::hash_join::multiple_keys::{
Expand Down Expand Up @@ -1398,8 +1399,26 @@ impl DataFrame {
)
},
);
let mut s = s_left.zip_outer_join_column(s_right, &opt_join_tuples);
let mut s = s_left
.to_physical_repr()
.zip_outer_join_column(&s_right.to_physical_repr(), &opt_join_tuples);
s.rename(s_left.name());
let s = match s_left.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical => {
let ca_or = s_left.categorical().unwrap();
let ca_new = s.cast(&DataType::Categorical).unwrap();
let mut ca_new = ca_new.categorical().unwrap().clone();
ca_new.categorical_map = ca_or.categorical_map.clone();
ca_new.into_series()
}
dt @ DataType::Datetime(_, _)
| dt @ DataType::Time
| dt @ DataType::Date
| dt @ DataType::Duration(_) => s.cast(dt).unwrap(),
_ => s,
};

df_left.hstack_mut(&[s])?;
self.finish_join(df_left, df_right, suffix)
}
Expand Down Expand Up @@ -1657,7 +1676,14 @@ mod test {

assert_eq!(Vec::from(ca), correct_ham);

// Test an error when joining on different string cache
// test dispatch
for jt in [JoinType::Left, JoinType::Inner, JoinType::Outer] {
let out = df_a.join(&df_b, "b", "bar", jt, None).unwrap();
let out = out.column("b").unwrap();
assert_eq!(out.dtype(), &DataType::Categorical);
}

// Test error when joining on different string cache
let (mut df_a, mut df_b) = get_dfs();
df_a.try_apply("b", |s| s.cast(&DataType::Categorical))
.unwrap();
Expand All @@ -1668,7 +1694,7 @@ mod test {
df_b.try_apply("bar", |s| s.cast(&DataType::Categorical))
.unwrap();
let out = df_a.join(&df_b, "b", "bar", JoinType::Left, None);
assert!(out.is_err())
assert!(out.is_err());
}

#[test]
Expand Down
6 changes: 0 additions & 6 deletions polars/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ use crate::series::implementations::SeriesWrap;
use ahash::RandomState;
use arrow::array::ArrayRef;
use polars_arrow::prelude::QuantileInterpolOptions;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;
use std::ops::{BitAnd, BitOr, BitXor};

Expand Down Expand Up @@ -417,10 +415,6 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
RepeatBy::repeat_by(&self.0, by)
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}
#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
Ok(self.0.mode()?.into_series())
Expand Down
6 changes: 0 additions & 6 deletions polars/polars-core/src/series/implementations/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ use crate::series::implementations::SeriesWrap;
use ahash::RandomState;
use arrow::array::ArrayRef;
use polars_arrow::prelude::QuantileInterpolOptions;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;
use std::ops::Deref;

Expand Down Expand Up @@ -389,10 +387,6 @@ impl SeriesTrait for SeriesWrap<CategoricalChunked> {
self.0.is_first()
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}
#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
Ok(self.0.mode()?.into_series())
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ use crate::frame::{groupby::*, hash_join::*};
use crate::prelude::*;
use ahash::RandomState;
use polars_arrow::prelude::QuantileInterpolOptions;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};

Expand Down Expand Up @@ -618,11 +616,6 @@ macro_rules! impl_dyn_series {
self.0.is_first()
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}

#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
self.0.mode().map(|ca| ca.$into_logical().into_series())
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/implementations/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use crate::frame::groupby::pivot::*;
use crate::frame::{groupby::*, hash_join::*};
use crate::prelude::*;
use ahash::RandomState;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};

Expand Down Expand Up @@ -646,11 +644,6 @@ impl SeriesTrait for SeriesWrap<DatetimeChunked> {
self.0.is_first()
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}

#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
self.0.mode().map(|ca| {
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use crate::frame::groupby::pivot::*;
use crate::frame::{groupby::*, hash_join::*};
use crate::prelude::*;
use ahash::RandomState;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};

Expand Down Expand Up @@ -604,11 +602,6 @@ impl SeriesTrait for SeriesWrap<DurationChunked> {
self.0.is_first()
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}

#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
self.0
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(feature = "object")]
use std::any::Any;

use super::private;
use super::IntoSeries;
use super::SeriesTrait;
Expand Down Expand Up @@ -597,10 +594,6 @@ macro_rules! impl_dyn_series {
self.0.is_first()
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}
#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
Ok(self.0.mode()?.into_series())
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ use crate::series::implementations::SeriesWrap;
use ahash::RandomState;
use arrow::array::ArrayRef;
use polars_arrow::prelude::QuantileInterpolOptions;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;

impl IntoSeries for ListChunked {
Expand Down Expand Up @@ -295,9 +293,4 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
Arc::new(SeriesWrap(Clone::clone(&self.0)))
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}
}
6 changes: 0 additions & 6 deletions polars/polars-core/src/series/implementations/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ use crate::series::implementations::SeriesWrap;
use ahash::RandomState;
use arrow::array::ArrayRef;
use polars_arrow::prelude::QuantileInterpolOptions;
#[cfg(feature = "object")]
use std::any::Any;
use std::borrow::Cow;

impl IntoSeries for Utf8Chunked {
Expand Down Expand Up @@ -402,10 +400,6 @@ impl SeriesTrait for SeriesWrap<Utf8Chunked> {
self.0.is_first()
}

#[cfg(feature = "object")]
fn as_any(&self) -> &dyn Any {
&self.0
}
#[cfg(feature = "mode")]
fn mode(&self) -> Result<Series> {
Ok(self.0.mode()?.into_series())
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,14 @@ impl Series {
/// * Date -> Int32
/// * Datetime-> Int64
/// * Time -> Int64
/// * Categorical -> UInt32
///
pub fn to_physical_repr(&self) -> Cow<Series> {
use DataType::*;
match self.dtype() {
Date => Cow::Owned(self.cast(&DataType::Int32).unwrap()),
Datetime(_, _) | Duration(_) | Time => Cow::Owned(self.cast(&DataType::Int64).unwrap()),
Categorical => Cow::Owned(self.cast(&DataType::UInt32).unwrap()),
_ => Cow::Borrowed(self),
}
}
Expand Down

0 comments on commit 76a2cbd

Please sign in to comment.