Skip to content

Commit

Permalink
categorical keep type in comparisson (#3370)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 11, 2022
1 parent ca45888 commit c399b9d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
12 changes: 12 additions & 0 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ fn use_supertype(
// do nothing
_ => {}
}
} else {
use DataType::*;
match (type_left, type_right, left, right) {
// if the we compare a categorical to a literal string we want to cast the literal to categorical
#[cfg(feature = "dtype-categorical")]
(Categorical(_), Utf8, _, AExpr::Literal(_))
| (Utf8, Categorical(_), AExpr::Literal(_), _) => {
st = DataType::Categorical(None);
}
// do nothing
_ => {}
}
}
st
}
Expand Down
30 changes: 2 additions & 28 deletions polars/polars-lazy/src/physical_plan/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,8 @@ impl PhysicalExpr for LiteralExpr {
}

fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
use LiteralValue::*;
let name = "literal";
let field = match &self.0 {
#[cfg(feature = "dtype-i8")]
Int8(_) => Field::new(name, DataType::Int8),
#[cfg(feature = "dtype-i16")]
Int16(_) => Field::new(name, DataType::Int16),
Int32(_) => Field::new(name, DataType::Int32),
Int64(_) => Field::new(name, DataType::Int64),
#[cfg(feature = "dtype-u8")]
UInt8(_) => Field::new(name, DataType::UInt8),
#[cfg(feature = "dtype-u16")]
UInt16(_) => Field::new(name, DataType::UInt16),
UInt32(_) => Field::new(name, DataType::UInt32),
UInt64(_) => Field::new(name, DataType::UInt64),
Float32(_) => Field::new(name, DataType::Float32),
Float64(_) => Field::new(name, DataType::Float64),
Boolean(_) => Field::new(name, DataType::Boolean),
Utf8(_) => Field::new(name, DataType::Utf8),
Null => Field::new(name, DataType::Null),
Range { data_type, .. } => Field::new(name, data_type.clone()),
#[cfg(all(feature = "temporal", feature = "dtype-datetime"))]
DateTime(_, tu) => Field::new(name, DataType::Datetime(*tu, None)),
#[cfg(all(feature = "temporal", feature = "dtype-duration"))]
Duration(_, tu) => Field::new(name, DataType::Duration(*tu)),
Series(s) => s.field().into_owned(),
};
Ok(field)
let dtype = self.0.get_datatype();
Ok(Field::new("literal", dtype))
}
}

Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,17 @@ def test_cat_to_dummies() -> None:
"bar_b": [0, 1, 0, 0],
"bar_c": [0, 0, 0, 1],
}


def test_comp_categorical_lit_dtype() -> None:
df = pl.DataFrame(
data={"column": ["a", "b", "e"], "values": [1, 5, 9]},
columns=[("column", pl.Categorical), ("more", pl.Int32)],
)

assert df.with_column(
pl.when(pl.col("column") == "e")
.then("d")
.otherwise(pl.col("column"))
.alias("column")
).dtypes == [pl.Categorical, pl.Int32]

0 comments on commit c399b9d

Please sign in to comment.