Skip to content

Commit

Permalink
prevent expensive type coercion in expression and fix when->then->oth…
Browse files Browse the repository at this point in the history
…eriwise fill null for list (#3579)
  • Loading branch information
ritchie46 committed Jun 5, 2022
1 parent fe771f5 commit fc5bcf0
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 17 deletions.
18 changes: 18 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,24 @@ impl PartialEq for DataType {
impl Eq for DataType {}

impl DataType {
pub fn value_within_range(&self, other: AnyValue) -> bool {
use DataType::*;
match self {
UInt8 => other.extract::<u8>().is_some(),
#[cfg(feature = "dtype-u16")]
UInt16 => other.extract::<u16>().is_some(),
UInt32 => other.extract::<u32>().is_some(),
UInt64 => other.extract::<u64>().is_some(),
#[cfg(feature = "dtype-i8")]
Int8 => other.extract::<i8>().is_some(),
#[cfg(feature = "dtype-i16")]
Int16 => other.extract::<i16>().is_some(),
Int32 => other.extract::<i32>().is_some(),
Int64 => other.extract::<i64>().is_some(),
_ => false,
}
}

pub fn inner_dtype(&self) -> Option<&DataType> {
if let DataType::List(inner) = self {
Some(&*inner)
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl ListNameSpace {
Ok(ca.lst_lengths().into_series())
};
self.0
.map(function, GetOutput::from_type(DataType::UInt32))
.map(function, GetOutput::from_type(IDX_DTYPE))
.with_fmt("arr.len")
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl fmt::Debug for Expr {
falsy,
} => write!(
f,
"\nWHEN {:?}\n\t{:?}\nOTHERWISE\n\t{:?}",
"\nWHEN {:?}\nTHEN\n\t{:?}\nOTHERWISE\n\t{:?}",
predicate, truthy, falsy
),
AnonymousFunction { input, options, .. } | Function { input, options, .. } => {
Expand Down
24 changes: 24 additions & 0 deletions polars/polars-lazy/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ pub enum LiteralValue {
}

impl LiteralValue {
pub fn to_anyvalue(&self) -> Option<AnyValue> {
use LiteralValue::*;
let av = match self {
Null => AnyValue::Null,
Boolean(v) => AnyValue::Boolean(*v),
#[cfg(feature = "dtype-u8")]
UInt8(v) => AnyValue::UInt8(*v),
#[cfg(feature = "dtype-u16")]
UInt16(v) => AnyValue::UInt16(*v),
UInt32(v) => AnyValue::UInt32(*v),
UInt64(v) => AnyValue::UInt64(*v),
#[cfg(feature = "dtype-i16")]
Int8(v) => AnyValue::Int8(*v),
#[cfg(feature = "dtype-i16")]
Int16(v) => AnyValue::Int16(*v),
Int32(v) => AnyValue::Int32(*v),
Int64(v) => AnyValue::Int64(*v),
Float32(v) => AnyValue::Float32(*v),
Float64(v) => AnyValue::Float64(*v),
_ => return None,
};
Some(av)
}

/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
match self {
Expand Down
21 changes: 12 additions & 9 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,21 @@ fn use_supertype(
|(_, AExpr::Literal(LiteralValue::Float32(_) | LiteralValue::Float64(_)))
=> {}

// cast literal to right type
(AExpr::Literal(_), _) => {
// never cast signed to unsigned
if type_right.is_signed() {
st = type_right.clone();
// cast literal to right type if they fit in the range
(AExpr::Literal(value), _) => {
if let Some(lit_val) = value.to_anyvalue() {
if type_right.value_within_range(lit_val) {
st = type_right.clone();
}
}
}
// cast literal to left type
(_, AExpr::Literal(_)) => {
// never cast signed to unsigned
if type_left.is_signed() {
st = type_left.clone();
(_, AExpr::Literal(value)) => {

if let Some(lit_val) = value.to_anyvalue() {
if type_left.value_within_range(lit_val) {
st = type_left.clone();
}
}
}
// do nothing
Expand Down
7 changes: 5 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ impl CastExpr {

if input.bool().is_ok() && input.null_count() == input.len() {
match &self.data_type {
DataType::List(_) => {
return Ok(ListChunked::full_null(input.name(), input.len()).into_series())
DataType::List(inner) => {
return Ok(
ListChunked::full_null_with_dtype(input.name(), input.len(), inner)
.into_series(),
)
}
#[cfg(feature = "dtype-date")]
DataType::Date => {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5173,7 +5173,7 @@ def nanoseconds(self) -> Expr:


def expr_to_lit_or_expr(
expr: Union[Expr, bool, int, float, str, "pli.Series"],
expr: Union[Expr, bool, int, float, str, "pli.Series", None],
str_to_lit: bool = True,
) -> Expr:
"""
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/internals/whenthen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def when(self, predicate: pli.Expr) -> "WhenThenThen":
"""
return WhenThenThen(self.pywhenthenthen.when(predicate._pyexpr))

def then(self, expr: Union[pli.Expr, int, float, str]) -> "WhenThenThen":
def then(self, expr: Union[pli.Expr, int, float, str, None]) -> "WhenThenThen":
"""
Values to return in case of the predicate being `True`.
Expand All @@ -33,7 +33,7 @@ def then(self, expr: Union[pli.Expr, int, float, str]) -> "WhenThenThen":
expr_ = pli.expr_to_lit_or_expr(expr)
return WhenThenThen(self.pywhenthenthen.then(expr_._pyexpr))

def otherwise(self, expr: Union[pli.Expr, int, float, str]) -> pli.Expr:
def otherwise(self, expr: Union[pli.Expr, int, float, str, None]) -> pli.Expr:
"""
Values to return in case of the predicate being `False`.
Expand Down Expand Up @@ -75,7 +75,7 @@ class When:
def __init__(self, pywhen: "pywhen"):
self._pywhen = pywhen

def then(self, expr: Union[pli.Expr, int, float, str]) -> WhenThen:
def then(self, expr: Union[pli.Expr, int, float, str, None]) -> WhenThen:
"""
Values to return in case of the predicate being `True`.
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,15 @@ def test_list_empty_groupby_result_3521() -> None:
.groupby("groupby_column")
.agg(pl.col("n_unique_column").drop_nulls())
).to_dict(False) == {"groupby_column": [1], "n_unique_column": [[]]}


def test_list_fill_null() -> None:
df = pl.DataFrame({"C": [["a", "b", "c"], [], [], ["d", "e"]]})
assert df.with_columns(
[
pl.when(pl.col("C").arr.lengths() == 0)
.then(None)
.otherwise(pl.col("C"))
.alias("C")
]
).to_series().to_list() == [["a", "b", "c"], None, None, ["d", "e"]]

0 comments on commit fc5bcf0

Please sign in to comment.