Skip to content

Commit

Permalink
fix(rust, python): keep f32 dtype in fill_null by int (#5834)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 17, 2022
1 parent 00977c5 commit 25507d1
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::*;
pub(super) fn fill_null(s: &[Series], super_type: &DataType) -> PolarsResult<Series> {
let array = &s[0];
let fill_value = &s[1];

if matches!(super_type, DataType::Unknown) {
return Err(PolarsError::SchemaMisMatch(
format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ fn modify_supertype(
if type_left.is_numeric() && type_right.is_numeric() {
match (left, right) {
// don't let the literal f64 coerce the f32 column
(AExpr::Literal(LiteralValue::Float64(_)), _) if matches!(type_right, DataType::Float32) => {
(AExpr::Literal(LiteralValue::Float64(_) | LiteralValue::Int32(_) | LiteralValue::Int64(_)), _) if matches!(type_right, DataType::Float32) => {
st = DataType::Float32
}
(_, AExpr::Literal(LiteralValue::Float64(_))) if matches!(type_left, DataType::Float32) => {
(_, AExpr::Literal(LiteralValue::Float64(_) | LiteralValue::Int32(_) | LiteralValue::Int64(_))) if matches!(type_left, DataType::Float32) => {
st = DataType::Float32
}

Expand Down Expand Up @@ -66,7 +66,7 @@ fn modify_supertype(
#[cfg(feature = "dtype-categorical")]
(Categorical(_), Utf8, _, AExpr::Literal(_))
| (Utf8, Categorical(_), AExpr::Literal(_), _) => {
st = DataType::Categorical(None);
st = Categorical(None);
}
// when then expression literals can have a different list type.
// so we cast the literal to the other hand side.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5606,7 +5606,7 @@ def select(
┌─────────┐
│ literal │
│ --- │
i64
i32
╞═════════╡
│ 0 │
├╌╌╌╌╌╌╌╌╌┤
Expand Down
1 change: 0 additions & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,6 @@ def lit(

except AttributeError:
item = value

return pli.wrap_expr(pylit(item, allow_object))


Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,7 +1586,7 @@ def select(
┌─────────┐
│ literal │
│ --- │
i64
i32
╞═════════╡
│ 0 │
├╌╌╌╌╌╌╌╌╌┤
Expand Down
6 changes: 3 additions & 3 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1760,14 +1760,14 @@ pub fn cumfold(acc: PyExpr, lambda: PyObject, exprs: Vec<PyExpr>, include_init:
let exprs = py_exprs_to_exprs(exprs);

let func = move |a: Series, b: Series| binary_lambda(&lambda, a, b);
polars::lazy::dsl::cumfold_exprs(acc.inner, func, exprs, include_init).into()
cumfold_exprs(acc.inner, func, exprs, include_init).into()
}

pub fn cumreduce(lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
let exprs = py_exprs_to_exprs(exprs);

let func = move |a: Series, b: Series| binary_lambda(&lambda, a, b);
polars::lazy::dsl::cumreduce_exprs(func, exprs).into()
cumreduce_exprs(func, exprs).into()
}

pub fn lit(value: &PyAny, allow_object: bool) -> PyResult<PyExpr> {
Expand All @@ -1777,7 +1777,7 @@ pub fn lit(value: &PyAny, allow_object: bool) -> PyResult<PyExpr> {
} else if let Ok(int) = value.downcast::<PyInt>() {
match int.extract::<i64>() {
Ok(val) => {
if val > 0 && val < i32::MAX as i64 || val < 0 && val > i32::MIN as i64 {
if val >= 0 && val < i32::MAX as i64 || val <= 0 && val > i32::MIN as i64 {
Ok(dsl::lit(val as i32).into())
} else {
Ok(dsl::lit(val).into())
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,9 @@ def test_schema_owned_arithmetic_5669() -> None:
.collect()
)
assert df.columns == ["A", "literal"], df.columns


def test_fill_null_f32_with_lit() -> None:
# ensure the literal integer does not upcast the f32 to an f64
df = pl.DataFrame({"a": [1.1, 1.2]}, columns=[("a", pl.Float32)])
assert df.fill_null(value=0).dtypes == [pl.Float32]

0 comments on commit 25507d1

Please sign in to comment.