diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index f11c9a003979..e71fd2b284de 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -241,13 +241,14 @@ impl DateLikeNameSpace { } /// Round the Datetime/Date range into buckets. - pub fn round>(self, every: S, offset: S) -> Expr { - let every = every.as_ref().into(); + pub fn round>(self, every: Expr, offset: S) -> Expr { let offset = offset.as_ref().into(); - self.0 - .map_private(FunctionExpr::TemporalExpr(TemporalFunction::Round( - every, offset, - ))) + self.0.map_many_private( + FunctionExpr::TemporalExpr(TemporalFunction::Round(offset)), + &[every], + false, + false, + ) } /// Offset this `Date/Datetime` by a given offset [`Duration`]. diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs index 90a95bdd916e..311bc6529834 100644 --- a/crates/polars-plan/src/dsl/function_expr/datetime.rs +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -56,7 +56,7 @@ pub enum TemporalFunction { BaseUtcOffset, #[cfg(feature = "timezones")] DSTOffset, - Round(String, String), + Round(String), #[cfg(feature = "timezones")] ReplaceTimeZone(Option, NonExistent), Combine(TimeUnit), @@ -465,11 +465,11 @@ pub(super) fn dst_offset(s: &Series) -> PolarsResult { } } -pub(super) fn round(s: &[Series], every: &str, offset: &str) -> PolarsResult { - let every = Duration::parse(every); +pub(super) fn round(s: &[Series], offset: &str) -> PolarsResult { let offset = Duration::parse(offset); let time_series = &s[0]; + let every = s[1].str()?; Ok(match time_series.dtype() { DataType::Datetime(_, tz) => match tz { diff --git a/crates/polars-plan/src/dsl/function_expr/temporal.rs b/crates/polars-plan/src/dsl/function_expr/temporal.rs index be16af9fedae..ddcd2a9820a1 100644 --- a/crates/polars-plan/src/dsl/function_expr/temporal.rs +++ b/crates/polars-plan/src/dsl/function_expr/temporal.rs @@ -57,7 +57,7 @@ impl From for SpecialEq> { BaseUtcOffset => map!(datetime::base_utc_offset), #[cfg(feature = "timezones")] DSTOffset => map!(datetime::dst_offset), - Round(every, offset) => map_as_slice!(datetime::round, &every, &offset), + Round(offset) => map_as_slice!(datetime::round, &offset), #[cfg(feature = "timezones")] ReplaceTimeZone(tz, non_existent) => { map_as_slice!(dispatch::replace_time_zone, tz.as_deref(), non_existent) diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index fffbeec35c23..3ec146bd487c 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -1,46 +1,77 @@ use arrow::legacy::time_zone::Tz; use arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY}; +use polars_core::prelude::arity::broadcast_try_binary_elementwise; use polars_core::prelude::*; +use polars_utils::cache::FastFixedCache; use crate::prelude::*; pub trait PolarsRound { - fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult + fn round(&self, every: &StringChunked, offset: Duration, tz: Option<&Tz>) -> PolarsResult where Self: Sized; } impl PolarsRound for DatetimeChunked { - fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult { - if every.negative { - polars_bail!(ComputeError: "cannot round a Datetime to a negative duration") - } + fn round( + &self, + every: &StringChunked, + offset: Duration, + tz: Option<&Tz>, + ) -> PolarsResult { + let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); + let out = broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| { + match (opt_t, opt_every) { + (Some(timestamp), Some(every)) => { + let every = + *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); - let w = Window::new(every, every, offset); + if every.negative { + polars_bail!(ComputeError: "Cannot round a Datetime to a negative duration") + } - let func = match self.time_unit() { - TimeUnit::Nanoseconds => Window::round_ns, - TimeUnit::Microseconds => Window::round_us, - TimeUnit::Milliseconds => Window::round_ms, - }; + let w = Window::new(every, every, offset); - let out = { self.try_apply_nonnull_values_generic(|t| func(&w, t, tz)) }; - out.map(|ok| ok.into_datetime(self.time_unit(), self.time_zone().clone())) + let func = match self.time_unit() { + TimeUnit::Nanoseconds => Window::round_ns, + TimeUnit::Microseconds => Window::round_us, + TimeUnit::Milliseconds => Window::round_ms, + }; + func(&w, timestamp, tz).map(Some) + }, + _ => Ok(None), + } + }); + Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())) } } impl PolarsRound for DateChunked { - fn round(&self, every: Duration, offset: Duration, _tz: Option<&Tz>) -> PolarsResult { - if every.negative { - polars_bail!(ComputeError: "cannot round a Date to a negative duration") - } - - let w = Window::new(every, every, offset); - Ok(self - .try_apply_nonnull_values_generic(|t| { - const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; - PolarsResult::Ok((w.round_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32) - })? - .into_date()) + fn round( + &self, + every: &StringChunked, + offset: Duration, + _tz: Option<&Tz>, + ) -> PolarsResult { + let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); + const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; + let out = broadcast_try_binary_elementwise(&self.0, every, |opt_t, opt_every| { + match (opt_t, opt_every) { + (Some(t), Some(every)) => { + let every = + *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); + if every.negative { + polars_bail!(ComputeError: "Cannot round a Date to a negative duration") + } + + let w = Window::new(every, every, offset); + Ok(Some( + (w.round_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32, + )) + }, + _ => Ok(None), + } + }); + Ok(out?.into_date()) } } diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index dd2306a7a551..eff357898684 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime as dt +from datetime import timedelta from typing import TYPE_CHECKING, Iterable import polars._reexport as pl @@ -19,13 +20,12 @@ from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Int32 if TYPE_CHECKING: - from datetime import timedelta - from polars import Expr from polars.type_aliases import ( Ambiguous, EpochTimeUnit, IntoExpr, + IntoExprColumn, NonExistent, Roll, TimeUnit, @@ -344,7 +344,7 @@ def truncate( @unstable() def round( self, - every: str | timedelta, + every: str | timedelta | IntoExprColumn, offset: str | timedelta | None = None, *, ambiguous: Ambiguous | Expr | None = None, @@ -481,10 +481,12 @@ def round( "`ambiguous` is deprecated. It is now automatically inferred; you can safely omit this argument.", version="0.19.13", ) - + if isinstance(every, timedelta): + every = parse_as_duration_string(every) + every = parse_as_expression(every, str_as_lit=True) return wrap_expr( self._pyexpr.dt_round( - parse_as_duration_string(every), + every, parse_as_duration_string(offset), ) ) diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index b6690ce55783..188e45b3c7b2 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -18,6 +18,7 @@ Ambiguous, EpochTimeUnit, IntoExpr, + IntoExprColumn, NonExistent, Roll, TemporalLiteral, @@ -1805,7 +1806,7 @@ def truncate( @unstable() def round( self, - every: str | dt.timedelta, + every: str | dt.timedelta | IntoExprColumn, offset: str | dt.timedelta | None = None, *, ambiguous: Ambiguous | Series | None = None, diff --git a/py-polars/src/expr/datetime.rs b/py-polars/src/expr/datetime.rs index 5b81497a57f7..d352382abcb1 100644 --- a/py-polars/src/expr/datetime.rs +++ b/py-polars/src/expr/datetime.rs @@ -90,8 +90,8 @@ impl PyExpr { self.inner.clone().dt().dst_offset().into() } - fn dt_round(&self, every: &str, offset: &str) -> Self { - self.inner.clone().dt().round(every, offset).into() + fn dt_round(&self, every: Self, offset: &str) -> Self { + self.inner.clone().dt().round(every.inner, offset).into() } fn dt_combine(&self, time: Self, time_unit: Wrap) -> Self { diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index 19ae43f01767..5307a2e7e129 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -573,15 +573,73 @@ def test_round( assert out.dt[-1] == stop +def test_round_expr() -> None: + df = pl.DataFrame( + { + "date": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20, 5, 7, 18), + datetime(2022, 4, 3, 13, 30, 32), + None, + datetime(2022, 12, 1), + ], + "every": ["1y", "1mo", "1m", "1m", "1mo", None], + } + ) + + output = df.select( + all_expr=pl.col("date").dt.round(every=pl.col("every")), + date_lit=pl.lit(datetime(2022, 4, 3, 13, 30, 32)).dt.round( + every=pl.col("every") + ), + every_lit=pl.col("date").dt.round("1d"), + ) + + expected = pl.DataFrame( + { + "all_expr": [ + datetime(2023, 1, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 20, 5, 7), + datetime(2022, 4, 3, 13, 31), + None, + None, + ], + "date_lit": [ + datetime(2022, 1, 1), + datetime(2022, 4, 1), + datetime(2022, 4, 3, 13, 31), + datetime(2022, 4, 3, 13, 31), + datetime(2022, 4, 1), + None, + ], + "every_lit": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20), + datetime(2022, 4, 4), + None, + datetime(2022, 12, 1), + ], + } + ) + + assert_frame_equal(output, expected) + + all_lit = pl.select(all_lit=pl.lit(datetime(2022, 3, 20, 5, 7)).dt.round("1h")) + assert all_lit.to_dict(as_series=False) == {"all_lit": [datetime(2022, 3, 20, 5)]} + + def test_round_negative() -> None: """Test that rounding to a negative duration gives a helpful error message.""" with pytest.raises( - ComputeError, match="cannot round a Date to a negative duration" + ComputeError, match="Cannot round a Date to a negative duration" ): pl.Series([date(1895, 5, 7)]).dt.round("-1m") with pytest.raises( - ComputeError, match="cannot round a Datetime to a negative duration" + ComputeError, match="Cannot round a Datetime to a negative duration" ): pl.Series([datetime(1895, 5, 7)]).dt.round("-1m")