Skip to content

Commit

Permalink
fix(rust, python): correct expr::diff dtype for temporal columns (#5416)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 3, 2022
1 parent 386af6a commit 13c1092
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 6 deletions.
5 changes: 4 additions & 1 deletion polars/polars-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ pub mod pct_change;
mod round;
mod to_list;
mod unique;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[derive(Copy, Clone)]
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum NullBehavior {
/// drop nulls
Drop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ pub(super) fn is_unique(s: &Series) -> PolarsResult<Series> {
pub(super) fn is_duplicated(s: &Series) -> PolarsResult<Series> {
s.is_duplicated().map(|ca| ca.into_series())
}

#[cfg(feature = "diff")]
pub(super) fn diff(s: &Series, n: usize, null_behavior: NullBehavior) -> PolarsResult<Series> {
Ok(s.diff(n, null_behavior))
}
6 changes: 6 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ pub enum FunctionExpr {
IsDuplicated,
Coalesce,
ShrinkType,
#[cfg(feature = "diff")]
Diff(usize, NullBehavior),
}

impl Display for FunctionExpr {
Expand Down Expand Up @@ -160,6 +162,8 @@ impl Display for FunctionExpr {
IsDuplicated => "is_duplicated",
Coalesce => "coalesce",
ShrinkType => "shrink_dtype",
#[cfg(feature = "diff")]
Diff(_, _) => "diff",
};
write!(f, "{}", s)
}
Expand Down Expand Up @@ -331,6 +335,8 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
IsDuplicated => map!(dispatch::is_duplicated),
Coalesce => map_as_slice!(fill_null::coalesce),
ShrinkType => map_owned!(shrink_type::shrink),
#[cfg(feature = "diff")]
Diff(n, null_behavior) => map!(dispatch::diff, n, null_behavior),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ impl FunctionExpr {
TopK { .. } => same_type(),
Shift(..) | Reverse => same_type(),
IsNotNull | IsNull | Not | IsUnique | IsDuplicated => with_dtype(DataType::Boolean),
#[cfg(feature = "diff")]
Diff(_, _) => map_dtype(&|dt| match dt {
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, _) => DataType::Duration(*tu),
#[cfg(feature = "dtype-date")]
DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
#[cfg(feature = "dtype-time")]
DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
dt => dt.clone(),
}),
ShrinkType => {
// we return the smallest type this can return
// this might not be correct once the actual data
Expand Down
6 changes: 1 addition & 5 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,11 +1814,7 @@ impl Expr {
#[cfg(feature = "diff")]
#[cfg_attr(docsrs, doc(cfg(feature = "diff")))]
pub fn diff(self, n: usize, null_behavior: NullBehavior) -> Expr {
self.apply(
move |s| Ok(s.diff(n, null_behavior)),
GetOutput::same_type(),
)
.with_fmt("diff")
self.apply_private(FunctionExpr::Diff(n, null_behavior))
}

#[cfg(feature = "pct_change")]
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,15 @@ def test_shrink_dtype() -> None:
"g": [0.10000000149011612, 1.3200000524520874, 0.11999999731779099],
"h": [True, None, False],
}


def test_diff_duration_dtype() -> None:
dates = ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-03"]
df = pl.DataFrame({"date": pl.Series(dates).str.strptime(pl.Date, "%Y-%m-%d")})

assert df.select(pl.col("date").diff() < pl.duration(days=1))["date"].to_list() == [
None,
False,
False,
True,
]

0 comments on commit 13c1092

Please sign in to comment.