Skip to content

Commit

Permalink
implement string comparisson for categorical
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 21, 2021
1 parent e41ec2d commit 704a237
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 57 deletions.
19 changes: 19 additions & 0 deletions polars/polars-core/src/chunked_array/builder/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl RevMapping {
}
}

/// Categorical to str
pub fn get(&self, idx: u32) -> &str {
match self {
Self::Global(map, a, _) => {
Expand All @@ -70,6 +71,24 @@ impl RevMapping {
_ => false,
}
}

/// str to Categorical
pub fn find(&self, value: &str) -> Option<u32> {
match self {
Self::Global(map, a, _) => {
map.iter()
// Safety:
// value is always within bounds
.find(|(_k, &v)| (unsafe { a.value_unchecked(v as usize) } == value))
.map(|(k, _v)| *k)
}
Self::Local(a) => {
// Safety: within bounds
unsafe { (0..a.len()).find(|idx| a.value_unchecked(*idx) == value) }
.map(|idx| idx as u32)
}
}
}
}

pub struct CategoricalChunkedBuilder {
Expand Down
189 changes: 137 additions & 52 deletions polars/polars-core/src/series/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,79 +7,158 @@ use crate::series::arithmetic::coerce_lhs_rhs;

macro_rules! impl_compare {
($self:expr, $rhs:expr, $method:ident) => {{
match $self.dtype() {
DataType::Boolean => $self.bool().unwrap().$method($rhs.bool().unwrap()),
DataType::Utf8 => $self.utf8().unwrap().$method($rhs.utf8().unwrap()),
DataType::UInt8 => $self.u8().unwrap().$method($rhs.u8().unwrap()),
DataType::UInt16 => $self.u16().unwrap().$method($rhs.u16().unwrap()),
DataType::UInt32 => $self.u32().unwrap().$method($rhs.u32().unwrap()),
DataType::UInt64 => $self.u64().unwrap().$method($rhs.u64().unwrap()),
DataType::Int8 => $self.i8().unwrap().$method($rhs.i8().unwrap()),
DataType::Int16 => $self.i16().unwrap().$method($rhs.i16().unwrap()),
DataType::Int32 => $self.i32().unwrap().$method($rhs.i32().unwrap()),
DataType::Int64 => $self.i64().unwrap().$method($rhs.i64().unwrap()),
DataType::Float32 => $self.f32().unwrap().$method($rhs.f32().unwrap()),
DataType::Float64 => $self.f64().unwrap().$method($rhs.f64().unwrap()),
DataType::Date32 => $self.date32().unwrap().$method($rhs.date32().unwrap()),
DataType::Date64 => $self.date64().unwrap().$method($rhs.date64().unwrap()),
DataType::Time64(TimeUnit::Nanosecond) => $self
let (lhs, rhs) = coerce_lhs_rhs($self, $rhs).expect("cannot coerce datatypes");
let lhs = lhs.as_ref();
let rhs = rhs.as_ref();
match lhs.dtype() {
DataType::Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
DataType::Utf8 => lhs.utf8().unwrap().$method(rhs.utf8().unwrap()),
DataType::UInt8 => lhs.u8().unwrap().$method(rhs.u8().unwrap()),
DataType::UInt16 => lhs.u16().unwrap().$method(rhs.u16().unwrap()),
DataType::UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
DataType::UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
DataType::Int8 => lhs.i8().unwrap().$method(rhs.i8().unwrap()),
DataType::Int16 => lhs.i16().unwrap().$method(rhs.i16().unwrap()),
DataType::Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
DataType::Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
DataType::Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
DataType::Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
DataType::Date32 => lhs.date32().unwrap().$method(rhs.date32().unwrap()),
DataType::Date64 => lhs.date64().unwrap().$method(rhs.date64().unwrap()),
DataType::Time64(TimeUnit::Nanosecond) => lhs
.time64_nanosecond()
.unwrap()
.$method($rhs.time64_nanosecond().unwrap()),
DataType::Duration(TimeUnit::Nanosecond) => $self
.$method(rhs.time64_nanosecond().unwrap()),
DataType::Duration(TimeUnit::Nanosecond) => lhs
.duration_nanosecond()
.unwrap()
.$method($rhs.duration_nanosecond().unwrap()),
DataType::Duration(TimeUnit::Millisecond) => $self
.$method(rhs.duration_nanosecond().unwrap()),
DataType::Duration(TimeUnit::Millisecond) => lhs
.duration_millisecond()
.unwrap()
.$method($rhs.duration_millisecond().unwrap()),
DataType::List(_) => $self.list().unwrap().$method($rhs.list().unwrap()),
.$method(rhs.duration_millisecond().unwrap()),
DataType::List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()),
_ => unimplemented!(),
}
}};
}

fn compare_cat_to_str_value<Compare>(
cat: &Series,
value: &str,
name: &str,
compare: Compare,
fill_value: bool,
) -> BooleanChunked
where
Compare: Fn(&Series, u32) -> BooleanChunked,
{
let cat = cat.categorical().expect("should be categorical");
let cat_map = cat.get_categorical_map().unwrap();
match cat_map.find(value) {
None => BooleanChunked::full(name, fill_value, cat.len()),
Some(cat_idx) => {
let cat = cat.cast_with_dtype(&DataType::UInt32).unwrap();
compare(&cat, cat_idx)
}
}
}

fn compare_cat_to_str_series<Compare>(
cat: &Series,
string: &Series,
name: &str,
compare: Compare,
fill_value: bool,
) -> BooleanChunked
where
Compare: Fn(&Series, u32) -> BooleanChunked,
{
match string.utf8().expect("should be utf8 column").get(0) {
None => cat.is_null(),
Some(value) => compare_cat_to_str_value(cat, value, name, compare, fill_value),
}
}

impl ChunkCompare<&Series> for Series {
fn eq_missing(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), eq_missing)
impl_compare!(self, rhs, eq_missing)
}

/// Create a boolean mask by checking for equality.
fn eq(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), eq)
use DataType::*;
match (self.dtype(), rhs.dtype(), self.len(), rhs.len()) {
(Categorical, Utf8, _, 1) => {
return compare_cat_to_str_series(
self,
rhs,
self.name(),
|s, idx| s.eq(idx),
false,
);
}
(Utf8, Categorical, 1, _) => {
return compare_cat_to_str_series(
rhs,
self,
self.name(),
|s, idx| s.eq(idx),
false,
);
}
_ => {
impl_compare!(self, rhs, eq)
}
}
}

/// Create a boolean mask by checking for inequality.
fn neq(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), neq)
use DataType::*;
match (self.dtype(), rhs.dtype(), self.len(), rhs.len()) {
(Categorical, Utf8, _, 1) => {
return compare_cat_to_str_series(
self,
rhs,
self.name(),
|s, idx| s.neq(idx),
true,
);
}
(Utf8, Categorical, 1, _) => {
return compare_cat_to_str_series(
rhs,
self,
self.name(),
|s, idx| s.neq(idx),
true,
);
}
_ => {
impl_compare!(self, rhs, neq)
}
}
}

/// Create a boolean mask by checking if lhs > rhs.
/// Create a boolean mask by checking if self > rhs.
fn gt(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), gt)
impl_compare!(self, rhs, gt)
}

/// Create a boolean mask by checking if lhs >= rhs.
/// Create a boolean mask by checking if self >= rhs.
fn gt_eq(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), gt_eq)
impl_compare!(self, rhs, gt_eq)
}

/// Create a boolean mask by checking if lhs < rhs.
/// Create a boolean mask by checking if self < rhs.
fn lt(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), lt)
impl_compare!(self, rhs, lt)
}

/// Create a boolean mask by checking if lhs <= rhs.
/// Create a boolean mask by checking if self <= rhs.
fn lt_eq(&self, rhs: &Series) -> BooleanChunked {
let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes");
impl_compare!(lhs.as_ref(), rhs.as_ref(), lt_eq)
impl_compare!(self, rhs, lt_eq)
}
}

Expand Down Expand Up @@ -122,50 +201,56 @@ impl ChunkCompare<&str> for Series {
}

fn eq(&self, rhs: &str) -> BooleanChunked {
if let Ok(a) = self.utf8() {
a.eq(rhs)
} else {
std::iter::repeat(false).take(self.len()).collect()
use DataType::*;
match self.dtype() {
Utf8 => self.utf8().unwrap().eq(rhs),
Categorical => {
compare_cat_to_str_value(self, rhs, self.name(), |lhs, idx| lhs.eq(idx), false)
}
_ => BooleanChunked::full(self.name(), false, self.len()),
}
}

fn neq(&self, rhs: &str) -> BooleanChunked {
if let Ok(a) = self.utf8() {
a.neq(rhs)
} else {
std::iter::repeat(false).take(self.len()).collect()
use DataType::*;
match self.dtype() {
Utf8 => self.utf8().unwrap().neq(rhs),
Categorical => {
compare_cat_to_str_value(self, rhs, self.name(), |lhs, idx| lhs.neq(idx), true)
}
_ => BooleanChunked::full(self.name(), false, self.len()),
}
}

fn gt(&self, rhs: &str) -> BooleanChunked {
if let Ok(a) = self.utf8() {
a.gt(rhs)
} else {
std::iter::repeat(false).take(self.len()).collect()
BooleanChunked::full(self.name(), false, self.len())
}
}

fn gt_eq(&self, rhs: &str) -> BooleanChunked {
if let Ok(a) = self.utf8() {
a.gt_eq(rhs)
} else {
std::iter::repeat(false).take(self.len()).collect()
BooleanChunked::full(self.name(), false, self.len())
}
}

fn lt(&self, rhs: &str) -> BooleanChunked {
if let Ok(a) = self.utf8() {
a.lt(rhs)
} else {
std::iter::repeat(false).take(self.len()).collect()
BooleanChunked::full(self.name(), false, self.len())
}
}

fn lt_eq(&self, rhs: &str) -> BooleanChunked {
if let Ok(a) = self.utf8() {
a.lt_eq(rhs)
} else {
std::iter::repeat(false).take(self.len()).collect()
BooleanChunked::full(self.name(), false, self.len())
}
}
}
8 changes: 6 additions & 2 deletions polars/polars-lazy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2292,10 +2292,14 @@ mod test {
.lazy()
.select(vec![when(col("a").eq(lit("c")))
.then(Null {}.lit())
.otherwise(col("a"))])
.otherwise(col("a"))
.alias("foo")])
.collect()?;

dbg!(out);
assert_eq!(
out.column("foo")?.is_null().into_iter().collect::<Vec<_>>(),
&[Some(false), Some(false), Some(true)]
);
Ok(())
}
}
32 changes: 29 additions & 3 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use polars_core::prelude::*;
use polars_core::utils::get_supertype;

use crate::logical_plan::optimizer::stack_opt::OptimizationRule;
use crate::logical_plan::Context;
use crate::prelude::*;
use crate::utils::is_scan;

pub struct TypeCoercionRule {}

Expand All @@ -23,7 +25,13 @@ impl OptimizationRule for TypeCoercionRule {
} => {
let plan = lp_arena.get(lp_node);
let mut inputs = [None, None];
plan.copy_inputs(&mut inputs);

// Used to get the schema of the input.
if is_scan(plan) {
inputs[0] = Some(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};

if let Some(input) = inputs[0] {
let input_schema = lp_arena.get(input).schema(lp_arena);
Expand Down Expand Up @@ -65,7 +73,12 @@ impl OptimizationRule for TypeCoercionRule {
} => {
let plan = lp_arena.get(lp_node);
let mut inputs = [None, None];
plan.copy_inputs(&mut inputs);

if is_scan(plan) {
inputs[0] = Some(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};

if let Some(input) = inputs[0] {
let input_schema = lp_arena.get(input).schema(lp_arena);
Expand All @@ -79,7 +92,20 @@ impl OptimizationRule for TypeCoercionRule {
let type_right = right
.get_type(input_schema, Context::Default, expr_arena)
.expect("could not get dtype");
if type_left == type_right {

let compare_cat_to_string = matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Gt
| Operator::Lt
| Operator::GtEq
| Operator::LtEq
) && ((type_left == DataType::Categorical
&& type_right == DataType::Utf8)
|| (type_left == DataType::Utf8 && type_right == DataType::Categorical));

if type_left == type_right || compare_cat_to_string {
None
} else {
let st = get_supertype(&type_left, &type_right)
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ impl PushNode for [Option<Node>; 1] {
}
}

pub(crate) fn is_scan(plan: &ALogicalPlan) -> bool {
match plan {
#[cfg(feature = "csv-file")]
ALogicalPlan::CsvScan { .. } => true,
ALogicalPlan::DataFrameScan { .. } => true,
#[cfg(feature = "parquet")]
ALogicalPlan::ParquetScan { .. } => true,
_ => false,
}
}

impl PushNode for &mut [Option<Node>] {
fn push_node(&mut self, value: Node) {
if self[0].is_some() {
Expand Down

0 comments on commit 704a237

Please sign in to comment.