Skip to content

Commit

Permalink
cast string to categorical in 'is_in' (#3606)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 7, 2022
1 parent a929f24 commit 47b3b20
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ impl Default for RevMapping {

#[allow(clippy::len_without_is_empty)]
impl RevMapping {
pub fn is_global(&self) -> bool {
matches!(self, Self::Global(_, _, _))
}

/// Get the length of the [`RevMapping`]
pub fn len(&self) -> usize {
match self {
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/is_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use super::*;

pub(super) fn is_in(s: &mut [Series]) -> Result<Series> {
let left = &s[0];
let other = &s[1];

left.is_in(other).map(|ca| ca.into_series())
}
10 changes: 10 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "is_in")]
mod is_in;
mod pow;

use super::*;
Expand All @@ -12,6 +14,8 @@ pub enum FunctionExpr {
Pow,
#[cfg(feature = "row_hash")]
Hash(usize),
#[cfg(feature = "is_in")]
IsIn,
}

impl FunctionExpr {
Expand Down Expand Up @@ -40,6 +44,8 @@ impl FunctionExpr {
Pow => float_dtype(),
#[cfg(feature = "row_hash")]
Hash(_) => with_dtype(DataType::UInt64),
#[cfg(feature = "is_in")]
IsIn => with_dtype(DataType::Boolean),
}
}
}
Expand Down Expand Up @@ -72,6 +78,10 @@ impl From<FunctionExpr> for NoEq<Arc<dyn SeriesUdf>> {
};
wrap!(f)
}
#[cfg(feature = "is_in")]
IsIn => {
wrap!(is_in::is_in)
}
}
}
}
58 changes: 46 additions & 12 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,50 @@ impl Expr {
}
}

pub fn apply_many_private(
self,
function_expr: FunctionExpr,
arguments: &[Expr],
fmt_str: &'static str,
) -> Self {
let mut input = Vec::with_capacity(arguments.len() + 1);
input.push(self);
input.extend_from_slice(arguments);

Expr::Function {
input,
function: function_expr,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str,
},
}
}

pub fn map_many_private(
self,
function_expr: FunctionExpr,
arguments: &[Expr],
fmt_str: &'static str,
) -> Self {
let mut input = Vec::with_capacity(arguments.len() + 1);
input.push(self);
input.extend_from_slice(arguments);

Expr::Function {
input,
function: function_expr,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str,
},
}
}

/// Get mask of finite values if dtype is Float
#[allow(clippy::wrong_self_convention)]
pub fn is_finite(self) -> Self {
Expand Down Expand Up @@ -1175,23 +1219,13 @@ impl Expr {
}
}
}

let f = |s: &mut [Series]| {
let left = &s[0];
let other = &s[1];

left.is_in(other).map(|ca| ca.into_series())
};
let arguments = &[other];
let output_type = GetOutput::from_type(DataType::Boolean);

// we don't have to apply on groups, so this is faster
if has_literal {
self.map_many(f, arguments, output_type)
self.map_many_private(FunctionExpr::IsIn, arguments, "is_in_map")
} else {
self.apply_many(f, arguments, output_type)
self.apply_many_private(FunctionExpr::IsIn, arguments, "is_in_apply")
}
.with_fmt("is_in")
}

/// Sort this column by the ordering of another column.
Expand Down
84 changes: 63 additions & 21 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::dsl::function_expr::FunctionExpr;
use polars_core::prelude::*;
use polars_core::utils::get_supertype;

Expand Down Expand Up @@ -73,6 +74,19 @@ fn use_supertype(
st
}

fn get_input(lp_arena: &Arena<ALogicalPlan>, lp_node: Node) -> [Option<Node>; 2] {
let plan = lp_arena.get(lp_node);
let mut inputs = [None, None];

// Used to get the schema of the input.
if is_scan(plan) {
inputs[0] = Some(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};
inputs
}

impl OptimizationRule for TypeCoercionRule {
fn optimize_expr(
&self,
Expand All @@ -88,17 +102,7 @@ impl OptimizationRule for TypeCoercionRule {
falsy: falsy_node,
predicate,
} => {
let plan = lp_arena.get(lp_node);
let mut inputs = [None, None];

// Used to get the schema of the input.
if is_scan(plan) {
inputs[0] = Some(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};

if let Some(input) = inputs[0] {
if let Some(input) = get_input(lp_arena, lp_node)[0] {
let input_schema = lp_arena.get(input).schema(lp_arena);
let truthy = expr_arena.get(truthy_node);
let falsy = expr_arena.get(falsy_node);
Expand Down Expand Up @@ -154,16 +158,7 @@ impl OptimizationRule for TypeCoercionRule {
op,
right: node_right,
} => {
let plan = lp_arena.get(lp_node);
let mut inputs = [None, None];

if is_scan(plan) {
inputs[0] = Some(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};

if let Some(input) = inputs[0] {
if let Some(input) = get_input(lp_arena, lp_node)[0] {
let input_schema = lp_arena.get(input).schema(lp_arena);

let left = expr_arena.get(node_left);
Expand Down Expand Up @@ -325,6 +320,53 @@ impl OptimizationRule for TypeCoercionRule {
None
}
}
#[cfg(feature = "is_in")]
AExpr::Function {
function: FunctionExpr::IsIn,
ref input,
options,
} => {
if let Some(input_node) = get_input(lp_arena, lp_node)[0] {
let input_schema = lp_arena.get(input_node).schema(lp_arena);
let left_node = input[0];
let other_node = input[1];
let left = expr_arena.get(left_node);
let other = expr_arena.get(other_node);

let type_left = left
.get_type(input_schema, Context::Default, expr_arena)
.ok()?;
let type_other = other
.get_type(input_schema, Context::Default, expr_arena)
.ok()?;

match (&type_left, type_other) {
(DataType::Categorical(Some(rev_map)), DataType::Utf8)
if rev_map.is_global() =>
{
let mut input = input.clone();

let casted_expr = AExpr::Cast {
expr: other_node,
data_type: DataType::Categorical(None),
// does not matter
strict: false,
};
let other_input = expr_arena.add(casted_expr);
input[1] = other_input;

Some(AExpr::Function {
function: FunctionExpr::IsIn,
input,
options,
})
}
_ => None,
}
} else {
None
}
}
_ => None,
}
}
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,19 @@ def test_categorical_describe_3487() -> None:
df = pl.DataFrame({"cats": ["a", "b"]})
df = df.with_column(pl.col("cats").cast(pl.Categorical))
df.describe()


def test_categorical_is_in_list() -> None:
# this requires type coercion to cast.
# we should not cast within the function as this would be expensive within a groupby context
# that would be a cast per group
with pl.StringCache():
df = pl.DataFrame(
{"a": [1, 2, 3, 1, 2], "b": ["a", "b", "c", "d", "e"]}
).with_column(pl.col("b").cast(pl.Categorical))

cat_list = ["a", "b", "c"]
assert df.filter(pl.col("b").is_in(cat_list)).to_dict(False) == {
"a": [1, 2, 3],
"b": ["a", "b", "c"],
}

0 comments on commit 47b3b20

Please sign in to comment.