Skip to content

Commit

Permalink
fill_nan preserve name (#4119)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 22, 2022
1 parent fe7af78 commit 74730bb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
10 changes: 5 additions & 5 deletions polars/polars-arrow/src/kernels/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ pub fn is_not_nan<T>(arr: &PrimitiveArray<T>) -> ArrayRef
where
T: NativeType + Float,
{
let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| !v.is_nan()));
let mut values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| !v.is_nan()));
if let Some(validity) = arr.validity() {
values = &values | &!validity
}

Box::new(BooleanArray::from_data_default(
values,
arr.validity().cloned(),
))
Box::new(BooleanArray::from_data_default(values, None))
}

pub fn is_finite<T>(arr: &PrimitiveArray<T>) -> ArrayRef
Expand Down
12 changes: 7 additions & 5 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,12 @@ impl Expr {
|s| match s.dtype() {
DataType::Float32 => {
let ca = s.f32()?;
let mask = ca.is_not_nan().fill_null(FillNullStrategy::One)?;
let mask = ca.is_not_nan();
ca.filter(&mask).map(|ca| ca.into_series())
}
DataType::Float64 => {
let ca = s.f64()?;
let mask = ca.is_not_nan().fill_null(FillNullStrategy::One)?;
let mask = ca.is_not_nan();
ca.filter(&mask).map(|ca| ca.into_series())
}
_ => Ok(s),
Expand Down Expand Up @@ -1137,9 +1137,11 @@ impl Expr {

/// Replace the floating point `NaN` values by a value.
pub fn fill_nan<E: Into<Expr>>(self, fill_value: E) -> Self {
when(self.clone().is_nan())
.then(fill_value.into())
.otherwise(self)
// we take the not branch so that self is truthy value of `when -> then -> otherwise`
// and that ensure we keep the name of `self`
when(self.clone().is_not_nan())
.then(self)
.otherwise(fill_value.into())
}
/// Count the values of the Series
/// or
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ def test_fill_nan() -> None:
.collect()["a"]
.series_equal(pl.Series("a", [1.0, 2.0, 3.0]))
)
assert df.select(pl.col("a").fill_nan(2))["literal"].series_equal(
pl.Series("literal", [1.0, 2.0, 3.0])
assert df.select(pl.col("a").fill_nan(2))["a"].series_equal(
pl.Series("a", [1.0, 2.0, 3.0])
)


Expand Down

0 comments on commit 74730bb

Please sign in to comment.