Skip to content

Commit

Permalink
expand IsIn::is_in functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 7, 2021
1 parent 47539ad commit 2403987
Show file tree
Hide file tree
Showing 17 changed files with 222 additions and 89 deletions.
2 changes: 2 additions & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pivot = ["polars-core/pivot"]
downsample = ["polars-core/downsample"]
# sort by multiple columns
sort_multiple = ["polars-core/sort_multiple"]
# is_in operation
is_in = ["polars-core/is_in", "polars-lazy/is_in"]

# all opt-in datatypes
dtype-full = [
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pivot = []
downsample = ["temporal", "dtype-date64"]
# sort by multiple columns
sort_multiple = []
# is_in operation
is_in = []

# opt-in datatypes for Series
dtype-time64-ns = []
Expand Down
5 changes: 3 additions & 2 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::chunked_array::kernels::{cast_numeric_from_dtype, transmute_array_fro
use crate::prelude::*;
use arrow::array::{make_array, Array, ArrayDataBuilder};
use arrow::compute::cast;
use num::{NumCast, ToPrimitive};
use num::NumCast;

fn cast_ca<N, T>(ca: &ChunkedArray<T>) -> Result<ChunkedArray<N>>
where
Expand Down Expand Up @@ -127,7 +127,7 @@ impl ChunkCast for CategoricalChunked {
impl<T> ChunkCast for ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: NumCast + ToPrimitive,
T::Native: NumCast,
{
fn cast<N>(&self) -> Result<ChunkedArray<N>>
where
Expand Down Expand Up @@ -219,6 +219,7 @@ impl ChunkCast for ListChunked {
N: PolarsDataType,
{
match N::get_dtype() {
// Cast list inner type
DataType::List(child_type) => {
let chunks = self
.downcast_iter()
Expand Down
159 changes: 137 additions & 22 deletions polars/polars-core/src/chunked_array/ops/is_in.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,62 @@
use crate::prelude::*;
use crate::utils::get_supertype;
use hashbrown::hash_set::HashSet;
use num::NumCast;
use std::hash::Hash;

unsafe fn is_in_helper<T, P>(ca: &ChunkedArray<T>, other: &Series) -> Result<BooleanChunked>
where
T: PolarsNumericType,
T::Native: NumCast,
P: Eq + Hash + Copy,
{
let mut set = HashSet::with_capacity(other.len());

let other = ca.unpack_series_matching_type(other)?;
other.downcast_iter().for_each(|iter| {
iter.into_iter().for_each(|opt_val| {
// Safety
// bit sizes are/ should be equal
let ptr = &opt_val as *const Option<T::Native> as *const Option<P>;
let opt_val = *ptr;
set.insert(opt_val);
})
});

let name = ca.name();
let mut ca: BooleanChunked = ca
.into_iter()
.map(|opt_val| {
// Safety
// bit sizes are/ should be equal
let ptr = &opt_val as *const Option<T::Native> as *const Option<P>;
let opt_val = *ptr;
set.contains(&opt_val)
})
.collect();
ca.rename(name);
Ok(ca)
}

impl<T> IsIn for ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: NumCast + Copy,
{
fn is_in(&self, list_array: &ListChunked) -> Result<BooleanChunked> {
match list_array.dtype() {
DataType::List(dt) if self.dtype() == dt => {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
// We check implicitly cast to supertype here
match other.dtype() {
DataType::List(dt) => {
let st = get_supertype(self.dtype(), &dt.into())?;
if &st != self.dtype() {
let left = self.cast_with_dtype(&st)?;
let right = other.cast_with_dtype(&DataType::List(st.to_arrow()))?;
return left.is_in(&right);
}

let ca: BooleanChunked = self
.into_iter()
.zip(list_array.into_iter())
.zip(other.list()?.into_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.unpack::<T>().unwrap();
Expand All @@ -20,24 +67,49 @@ where
.collect();
Ok(ca)
}
_ => Err(PolarsError::DataTypeMisMatch(
format!(
"cannot do is_in operation with left a dtype: {:?} and right a dtype {:?}",
self.dtype(),
list_array.dtype()
)
.into(),
)),
_ => {
// first make sure that the types are equal
let st = get_supertype(self.dtype(), other.dtype())?;
if &st != self.dtype() {
let left = self.cast_with_dtype(&st)?;
let right = other.cast_with_dtype(&st)?;
return left.is_in(&right);
}
// now that the types are equal, we coerce every 32 bit array to u32
// and every 64 bit array to u64 (including floats)
// this allows hashing them and greatly reduces the number of code paths.
match self.dtype() {
DataType::UInt64 | DataType::Int64 | DataType::Float64 => unsafe {
is_in_helper::<T, u64>(self, other)
},
DataType::UInt32 | DataType::Int32 | DataType::Float32 => unsafe {
is_in_helper::<T, u32>(self, other)
},
DataType::UInt8 | DataType::Int8 => unsafe {
is_in_helper::<T, u8>(self, other)
},
DataType::UInt16 | DataType::Int16 => unsafe {
is_in_helper::<T, u16>(self, other)
},
_ => Err(PolarsError::Other(
format!(
"Data type {:?} not supported in is_in operation",
self.dtype()
)
.into(),
)),
}
}
}
}
}
impl IsIn for Utf8Chunked {
fn is_in(&self, list_array: &ListChunked) -> Result<BooleanChunked> {
match list_array.dtype() {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == dt => {
let ca: BooleanChunked = self
.into_iter()
.zip(list_array.into_iter())
.zip(other.list()?.into_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.unpack::<Utf8Type>().unwrap();
Expand All @@ -48,11 +120,27 @@ impl IsIn for Utf8Chunked {
.collect();
Ok(ca)
}
DataType::Utf8 => {
let mut set = HashSet::with_capacity(other.len());

let other = other.utf8()?;
other.downcast_iter().for_each(|iter| {
iter.into_iter().for_each(|opt_val| {
set.insert(opt_val);
})
});
let mut ca: BooleanChunked = self
.into_iter()
.map(|opt_val| set.contains(&opt_val))
.collect();
ca.rename(self.name());
Ok(ca)
}
_ => Err(PolarsError::DataTypeMisMatch(
format!(
"cannot do is_in operation with left a dtype: {:?} and right a dtype {:?}",
self.dtype(),
list_array.dtype()
other.dtype()
)
.into(),
)),
Expand All @@ -61,12 +149,12 @@ impl IsIn for Utf8Chunked {
}

impl IsIn for BooleanChunked {
fn is_in(&self, list_array: &ListChunked) -> Result<BooleanChunked> {
match list_array.dtype() {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == dt => {
let ca: BooleanChunked = self
.into_iter()
.zip(list_array.into_iter())
.zip(other.list()?.into_iter())
.map(|(value, series)| match (value, series) {
(val, Some(series)) => {
let ca = series.unpack::<BooleanType>().unwrap();
Expand All @@ -81,7 +169,7 @@ impl IsIn for BooleanChunked {
format!(
"cannot do is_in operation with left a dtype: {:?} and right a dtype {:?}",
self.dtype(),
list_array.dtype()
other.dtype()
)
.into(),
)),
Expand All @@ -90,9 +178,36 @@ impl IsIn for BooleanChunked {
}

impl IsIn for CategoricalChunked {
fn is_in(&self, list_array: &ListChunked) -> Result<BooleanChunked> {
self.cast::<UInt32Type>().unwrap().is_in(list_array)
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
self.cast::<UInt32Type>().unwrap().is_in(other)
}
}

impl IsIn for ListChunked {}

#[cfg(test)]
mod test {
use crate::prelude::*;

#[test]
fn test_is_in() -> Result<()> {
let a = Int32Chunked::new_from_slice("a", &[1, 2, 3, 4]);
let b = Int64Chunked::new_from_slice("b", &[4, 5, 1]);

let out = a.is_in(&b.into_series())?;
assert_eq!(
Vec::from(&out),
[Some(true), Some(false), Some(false), Some(true)]
);

let a = Utf8Chunked::new_from_slice("a", &["a", "b", "c", "d"]);
let b = Utf8Chunked::new_from_slice("b", &["d", "e", "c"]);

let out = a.is_in(&b.into_series())?;
assert_eq!(
Vec::from(&out),
[Some(false), Some(false), Some(true), Some(true)]
);
Ok(())
}
}
8 changes: 6 additions & 2 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub(crate) mod downcast;
pub(crate) mod explode;
pub(crate) mod fill_none;
pub(crate) mod filter;
#[cfg(feature = "is_in")]
#[cfg_attr(docsrs, doc(cfg(feature = "is_in")))]
pub(crate) mod is_in;
pub(crate) mod peaks;
pub(crate) mod set;
Expand Down Expand Up @@ -908,9 +910,11 @@ pub trait ChunkPeaks {
}

/// Check if element is member of list array
#[cfg(feature = "is_in")]
#[cfg_attr(docsrs, doc(cfg(feature = "is_in")))]
pub trait IsIn {
/// Check if the element of this array is in the elements of the list array
fn is_in(&self, _list_array: &ListChunked) -> Result<BooleanChunked> {
/// Check if elements of this array are in the right Series, or List values of the right Series.
fn is_in(&self, _other: &Series) -> Result<BooleanChunked> {
unimplemented!()
}
}
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,9 +1324,9 @@ mod test {

let (mut df_a, mut df_b) = get_dfs();

df_a.may_apply("b", |s| s.cast_with_datatype(&DataType::Categorical))
df_a.may_apply("b", |s| s.cast_with_dtype(&DataType::Categorical))
.unwrap();
df_b.may_apply("bar", |s| s.cast_with_datatype(&DataType::Categorical))
df_b.may_apply("bar", |s| s.cast_with_dtype(&DataType::Categorical))
.unwrap();

let out = df_a.join(&df_b, "b", "bar", JoinType::Left).unwrap();
Expand All @@ -1346,13 +1346,13 @@ mod test {

// Test an error when joining on different string cache
let (mut df_a, mut df_b) = get_dfs();
df_a.may_apply("b", |s| s.cast_with_datatype(&DataType::Categorical))
df_a.may_apply("b", |s| s.cast_with_dtype(&DataType::Categorical))
.unwrap();
// create a new cache
toggle_string_cache(false);
toggle_string_cache(true);

df_b.may_apply("bar", |s| s.cast_with_datatype(&DataType::Categorical))
df_b.may_apply("bar", |s| s.cast_with_dtype(&DataType::Categorical))
.unwrap();
let out = df_a.join(&df_b, "b", "bar", JoinType::Left);
assert!(out.is_err())
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/series/implementations/dates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ macro_rules! impl_dyn_series {
try_physical_dispatch!(self, zip_with_same_type, mask, other)
}

fn is_in_same_type(&self, list_array: &ListChunked) -> Result<BooleanChunked> {
cast_and_apply!(self, is_in_same_type, list_array)
}

fn vec_hash(&self, random_state: RandomState) -> UInt64Chunked {
cast_and_apply!(self, vec_hash, random_state)
}
Expand Down Expand Up @@ -661,6 +657,10 @@ macro_rules! impl_dyn_series {
fn peak_min(&self) -> BooleanChunked {
cast_and_apply!(self, peak_min,)
}
#[cfg(feature = "is_in")]
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
IsIn::is_in(&self.0, other)
}
}
};
}
Expand Down
9 changes: 5 additions & 4 deletions polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ macro_rules! impl_dyn_series {
.map(|ca| ca.into_series())
}

fn is_in_same_type(&self, list_array: &ListChunked) -> Result<BooleanChunked> {
IsIn::is_in(&self.0, list_array)
}

fn vec_hash(&self, random_state: RandomState) -> UInt64Chunked {
self.0.vec_hash(random_state)
}
Expand Down Expand Up @@ -814,6 +810,11 @@ macro_rules! impl_dyn_series {
fn peak_min(&self) -> BooleanChunked {
self.0.peak_min()
}

#[cfg(feature = "is_in")]
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
IsIn::is_in(&self.0, other)
}
}
};
}
Expand Down

0 comments on commit 2403987

Please sign in to comment.