Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expressify dt.round #15861

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions crates/polars-plan/src/dsl/dt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,14 @@ impl DateLikeNameSpace {
}

/// Round the Datetime/Date range into buckets.
pub fn round<S: AsRef<str>>(self, every: S, offset: S) -> Expr {
let every = every.as_ref().into();
pub fn round<S: AsRef<str>>(self, every: Expr, offset: S) -> Expr {
let offset = offset.as_ref().into();
self.0
.map_private(FunctionExpr::TemporalExpr(TemporalFunction::Round(
every, offset,
)))
self.0.map_many_private(
FunctionExpr::TemporalExpr(TemporalFunction::Round(offset)),
&[every],
false,
false,
)
}

/// Offset this `Date/Datetime` by a given offset [`Duration`].
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub enum TemporalFunction {
BaseUtcOffset,
#[cfg(feature = "timezones")]
DSTOffset,
Round(String, String),
Round(String),
#[cfg(feature = "timezones")]
ReplaceTimeZone(Option<TimeZone>, NonExistent),
Combine(TimeUnit),
Expand Down Expand Up @@ -465,11 +465,11 @@ pub(super) fn dst_offset(s: &Series) -> PolarsResult<Series> {
}
}

pub(super) fn round(s: &[Series], every: &str, offset: &str) -> PolarsResult<Series> {
let every = Duration::parse(every);
pub(super) fn round(s: &[Series], offset: &str) -> PolarsResult<Series> {
let offset = Duration::parse(offset);

let time_series = &s[0];
let every = s[1].str()?;

Ok(match time_series.dtype() {
DataType::Datetime(_, tz) => match tz {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl From<TemporalFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
BaseUtcOffset => map!(datetime::base_utc_offset),
#[cfg(feature = "timezones")]
DSTOffset => map!(datetime::dst_offset),
Round(every, offset) => map_as_slice!(datetime::round, &every, &offset),
Round(offset) => map_as_slice!(datetime::round, &offset),
#[cfg(feature = "timezones")]
ReplaceTimeZone(tz, non_existent) => {
map_as_slice!(dispatch::replace_time_zone, tz.as_deref(), non_existent)
Expand Down
81 changes: 56 additions & 25 deletions crates/polars-time/src/round.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,77 @@
use arrow::legacy::time_zone::Tz;
use arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY};
use polars_core::prelude::arity::broadcast_try_binary_elementwise;
use polars_core::prelude::*;
use polars_utils::cache::FastFixedCache;

use crate::prelude::*;

pub trait PolarsRound {
fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult<Self>
fn round(&self, every: &StringChunked, offset: Duration, tz: Option<&Tz>) -> PolarsResult<Self>
where
Self: Sized;
}

impl PolarsRound for DatetimeChunked {
fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult<Self> {
if every.negative {
polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
}
fn round(
&self,
every: &StringChunked,
offset: Duration,
tz: Option<&Tz>,
) -> PolarsResult<Self> {
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
let out = broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| {
match (opt_t, opt_every) {
(Some(timestamp), Some(every)) => {
let every =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));

let w = Window::new(every, every, offset);
if every.negative {
polars_bail!(ComputeError: "Cannot round a Datetime to a negative duration")
}

let func = match self.time_unit() {
TimeUnit::Nanoseconds => Window::round_ns,
TimeUnit::Microseconds => Window::round_us,
TimeUnit::Milliseconds => Window::round_ms,
};
let w = Window::new(every, every, offset);

let out = { self.try_apply_nonnull_values_generic(|t| func(&w, t, tz)) };
out.map(|ok| ok.into_datetime(self.time_unit(), self.time_zone().clone()))
let func = match self.time_unit() {
TimeUnit::Nanoseconds => Window::round_ns,
TimeUnit::Microseconds => Window::round_us,
TimeUnit::Milliseconds => Window::round_ms,
};
func(&w, timestamp, tz).map(Some)
},
_ => Ok(None),
}
});
Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone()))
}
}

impl PolarsRound for DateChunked {
fn round(&self, every: Duration, offset: Duration, _tz: Option<&Tz>) -> PolarsResult<Self> {
if every.negative {
polars_bail!(ComputeError: "cannot round a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(self
.try_apply_nonnull_values_generic(|t| {
const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY;
PolarsResult::Ok((w.round_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32)
})?
.into_date())
fn round(
&self,
every: &StringChunked,
offset: Duration,
_tz: Option<&Tz>,
) -> PolarsResult<Self> {
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY;
let out = broadcast_try_binary_elementwise(&self.0, every, |opt_t, opt_every| {
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
let every =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
if every.negative {
polars_bail!(ComputeError: "Cannot round a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(Some(
(w.round_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32,
))
},
_ => Ok(None),
}
});
Ok(out?.into_date())
}
}
12 changes: 7 additions & 5 deletions py-polars/polars/expr/datetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime as dt
from datetime import timedelta
from typing import TYPE_CHECKING, Iterable

import polars._reexport as pl
Expand All @@ -19,13 +20,12 @@
from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Int32

if TYPE_CHECKING:
from datetime import timedelta

from polars import Expr
from polars.type_aliases import (
Ambiguous,
EpochTimeUnit,
IntoExpr,
IntoExprColumn,
NonExistent,
Roll,
TimeUnit,
Expand Down Expand Up @@ -344,7 +344,7 @@ def truncate(
@unstable()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli This function was marked as experimental since 2022-10, do you think it's stabilized now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delay - I think so, yes!

def round(
self,
every: str | timedelta,
every: str | timedelta | IntoExprColumn,
offset: str | timedelta | None = None,
*,
ambiguous: Ambiguous | Expr | None = None,
Expand Down Expand Up @@ -481,10 +481,12 @@ def round(
"`ambiguous` is deprecated. It is now automatically inferred; you can safely omit this argument.",
version="0.19.13",
)

if isinstance(every, timedelta):
every = parse_as_duration_string(every)
every = parse_as_expression(every, str_as_lit=True)
return wrap_expr(
self._pyexpr.dt_round(
parse_as_duration_string(every),
every,
parse_as_duration_string(offset),
)
)
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Ambiguous,
EpochTimeUnit,
IntoExpr,
IntoExprColumn,
NonExistent,
Roll,
TemporalLiteral,
Expand Down Expand Up @@ -1805,7 +1806,7 @@ def truncate(
@unstable()
def round(
self,
every: str | dt.timedelta,
every: str | dt.timedelta | IntoExprColumn,
offset: str | dt.timedelta | None = None,
*,
ambiguous: Ambiguous | Series | None = None,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/expr/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ impl PyExpr {
self.inner.clone().dt().dst_offset().into()
}

fn dt_round(&self, every: &str, offset: &str) -> Self {
self.inner.clone().dt().round(every, offset).into()
fn dt_round(&self, every: Self, offset: &str) -> Self {
Copy link
Collaborator Author

@reswqa reswqa Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only expressify the every since offset is deprecated since 0.20.19.

self.inner.clone().dt().round(every.inner, offset).into()
}

fn dt_combine(&self, time: Self, time_unit: Wrap<TimeUnit>) -> Self {
Expand Down
54 changes: 52 additions & 2 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,15 +573,65 @@ def test_round(
assert out.dt[-1] == stop


def test_round_expr() -> None:
df = pl.DataFrame(
{
"date": [
datetime(2022, 11, 14),
datetime(2023, 10, 11),
datetime(2022, 3, 20, 5, 7, 18),
datetime(2022, 4, 3, 13, 30, 32),
],
"every": ["1y", "1mo", "1m", "1m"],
}
)

output = df.select(
all_expr=pl.col("date").dt.round(every=pl.col("every")),
date_lit=pl.lit(datetime(2022, 4, 3, 13, 30, 32)).dt.round(
every=pl.col("every")
),
every_lit=pl.col("date").dt.round("1d"),
)

expected = pl.DataFrame(
{
"all_expr": [
datetime(2023, 1, 1),
datetime(2023, 10, 1),
datetime(2022, 3, 20, 5, 7),
datetime(2022, 4, 3, 13, 31),
],
"date_lit": [
datetime(2022, 1, 1),
datetime(2022, 4, 1),
datetime(2022, 4, 3, 13, 31),
datetime(2022, 4, 3, 13, 31),
],
"every_lit": [
datetime(2022, 11, 14),
datetime(2023, 10, 11),
datetime(2022, 3, 20),
datetime(2022, 4, 4),
],
}
)

assert_frame_equal(output, expected)

all_lit = pl.select(all_lit=pl.lit(datetime(2022, 3, 20, 5, 7)).dt.round("1h"))
assert all_lit.to_dict(as_series=False) == {"all_lit": [datetime(2022, 3, 20, 5)]}


def test_round_negative() -> None:
"""Test that rounding to a negative duration gives a helpful error message."""
with pytest.raises(
ComputeError, match="cannot round a Date to a negative duration"
ComputeError, match="Cannot round a Date to a negative duration"
):
pl.Series([date(1895, 5, 7)]).dt.round("-1m")

with pytest.raises(
ComputeError, match="cannot round a Datetime to a negative duration"
ComputeError, match="Cannot round a Datetime to a negative duration"
):
pl.Series([datetime(1895, 5, 7)]).dt.round("-1m")

Expand Down