From 4eef5b5daf8a2c014cc40fb7dce78662c3738f9e Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sat, 13 Apr 2024 16:10:28 +0100 Subject: [PATCH] feat: add Expr.dt.add_business_days and Series.dt.add_business_days (#15595) --- crates/polars-ops/Cargo.toml | 2 +- crates/polars-ops/src/series/ops/business.rs | 281 ++++++++++++++++-- crates/polars-plan/src/dsl/dt.rs | 21 ++ .../src/dsl/function_expr/business.rs | 29 ++ .../src/dsl/function_expr/schema.rs | 5 +- crates/polars-utils/src/binary_search.rs | 25 ++ crates/polars-utils/src/lib.rs | 1 + .../source/reference/expressions/temporal.rst | 1 + .../docs/source/reference/series/temporal.rst | 1 + py-polars/polars/expr/datetime.py | 128 +++++++- py-polars/polars/series/datetime.py | 95 +++++- py-polars/polars/type_aliases.py | 1 + py-polars/src/conversion/mod.rs | 16 + py-polars/src/expr/datetime.rs | 14 + .../time_series/test_add_business_days.py | 51 ++++ .../business/test_add_business_days.py | 236 +++++++++++++++ 16 files changed, 883 insertions(+), 24 deletions(-) create mode 100644 crates/polars-utils/src/binary_search.rs create mode 100644 py-polars/tests/parametric/time_series/test_add_business_days.py create mode 100644 py-polars/tests/unit/functions/business/test_add_business_days.py diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index bf67349c7cd8..f132a2800c0a 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -75,7 +75,7 @@ is_unique = [] unique_counts = [] is_between = [] approx_unique = [] -business = ["dtype-date"] +business = ["dtype-date", "chrono"] fused = [] cutqcut = ["dtype-categorical", "dtype-struct"] rle = ["dtype-struct"] diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs index b8a7fe2efbde..51b8ddcd0777 100644 --- a/crates/polars-ops/src/series/ops/business.rs +++ b/crates/polars-ops/src/series/ops/business.rs @@ -1,5 +1,23 @@ -use polars_core::prelude::arity::binary_elementwise_values; +#[cfg(feature = "dtype-date")] +use chrono::DateTime; +use polars_core::prelude::arity::{binary_elementwise_values, try_binary_elementwise}; use polars_core::prelude::*; +#[cfg(feature = "dtype-date")] +use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; +use polars_utils::binary_search::{find_first_ge_index, find_first_gt_index}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "timezones")] +use crate::prelude::replace_time_zone; + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Roll { + Forward, + Backward, + Raise, +} /// Count the number of business days between `start` and `end`, excluding `end`. /// @@ -84,28 +102,21 @@ fn business_day_count_impl( end_date += 1; } - let holidays_begin = match holidays.binary_search(&start_date) { - Ok(x) => x, - Err(x) => x, - } as i32; - let holidays_end = match holidays[(holidays_begin as usize)..].binary_search(&end_date) { - Ok(x) => x as i32 + holidays_begin, - Err(x) => x as i32 + holidays_begin, - }; - - let mut start_weekday = weekday(start_date); + let holidays_begin = find_first_ge_index(holidays, start_date); + let holidays_end = find_first_ge_index(&holidays[holidays_begin..], end_date) + holidays_begin; + let mut start_day_of_week = get_day_of_week(start_date); let diff = end_date - start_date; let whole_weeks = diff / 7; - let mut count = -(holidays_end - holidays_begin); + let mut count = -((holidays_end - holidays_begin) as i32); count += whole_weeks * n_business_days_in_week_mask; start_date += whole_weeks * 7; while start_date < end_date { - // SAFETY: week_mask is length 7, start_weekday is between 0 and 6 - if unsafe { *week_mask.get_unchecked(start_weekday) } { + // SAFETY: week_mask is length 7, start_day_of_week is between 0 and 6 + if unsafe { *week_mask.get_unchecked(start_day_of_week) } { count += 1; } start_date += 1; - start_weekday = increment_weekday(start_weekday); + start_day_of_week = increment_day_of_week(start_day_of_week); } if swapped { -count @@ -114,14 +125,238 @@ fn business_day_count_impl( } } +/// Add a given number of business days. +/// +/// # Arguments +/// - `start`: Series holding start dates. +/// - `n`: Number of business days to add. +/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day. +/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of +/// days since the UNIX epoch. +/// - `roll`: what to do when the start date doesn't land on a business day: +/// - `Roll::Forward`: roll forward to the next business day. +/// - `Roll::Backward`: roll backward to the previous business day. +/// - `Roll::Raise`: raise an error. +pub fn add_business_days( + start: &Series, + n: &Series, + week_mask: [bool; 7], + holidays: &[i32], + roll: Roll, +) -> PolarsResult { + if !week_mask.iter().any(|&x| x) { + polars_bail!(ComputeError:"`week_mask` must have at least one business day"); + } + + match start.dtype() { + DataType::Date => {}, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(time_unit, None) => { + let result_date = + add_business_days(&start.cast(&DataType::Date)?, n, week_mask, holidays, roll)?; + let start_time = start + .cast(&DataType::Time)? + .cast(&DataType::Duration(*time_unit))?; + return Ok(result_date.cast(&DataType::Datetime(*time_unit, None))? + start_time); + }, + #[cfg(feature = "timezones")] + DataType::Datetime(time_unit, Some(time_zone)) => { + let start_naive = replace_time_zone( + start.datetime().unwrap(), + None, + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )?; + let result_date = add_business_days( + &start_naive.cast(&DataType::Date)?, + n, + week_mask, + holidays, + roll, + )?; + let start_time = start_naive + .cast(&DataType::Time)? + .cast(&DataType::Duration(*time_unit))?; + let result_naive = + result_date.cast(&DataType::Datetime(*time_unit, None))? + start_time; + let result_tz_aware = replace_time_zone( + result_naive.datetime().unwrap(), + Some(time_zone), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )?; + return Ok(result_tz_aware.into_series()); + }, + _ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", start.dtype()), + } + + let holidays = normalise_holidays(holidays, &week_mask); + let start_dates = start.date()?; + let n = match &n.dtype() { + DataType::Int64 | DataType::UInt64 | DataType::UInt32 => n.cast(&DataType::Int32)?, + DataType::Int32 => n.clone(), + _ => { + polars_bail!(InvalidOperation: "expected Int64, Int32, UInt64, or UInt32, got {}", n.dtype()) + }, + }; + let n = n.i32()?; + let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; + + let out: Int32Chunked = match (start_dates.len(), n.len()) { + (_, 1) => { + if let Some(n) = n.get(0) { + start_dates.try_apply_nonnull_values_generic(|start_date| { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + Ok::(add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + )) + })? + } else { + Int32Chunked::full_null(start_dates.name(), start_dates.len()) + } + }, + (1, _) => { + if let Some(start_date) = start_dates.get(0) { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + n.apply_values(|n| { + add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ) + }) + } else { + Int32Chunked::full_null(start_dates.name(), n.len()) + } + }, + _ => try_binary_elementwise(start_dates, n, |opt_start_date, opt_n| { + match (opt_start_date, opt_n) { + (Some(start_date), Some(n)) => { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + Ok::, PolarsError>(Some(add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ))) + }, + _ => Ok(None), + } + })?, + }; + Ok(out.into_date().into_series()) +} + +/// Ported from: +/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L265-L353 +fn add_business_days_impl( + mut date: i32, + mut day_of_week: usize, + mut n: i32, + week_mask: &[bool; 7], + n_business_days_in_week_mask: i32, + holidays: &[i32], +) -> i32 { + if n > 0 { + let holidays_begin = find_first_ge_index(holidays, date); + date += (n / n_business_days_in_week_mask) * 7; + n %= n_business_days_in_week_mask; + let holidays_temp = find_first_gt_index(&holidays[holidays_begin..], date) + holidays_begin; + n += (holidays_temp - holidays_begin) as i32; + let holidays_begin = holidays_temp; + while n > 0 { + date += 1; + day_of_week = increment_day_of_week(day_of_week); + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + if unsafe { + (*week_mask.get_unchecked(day_of_week)) + && (!holidays[holidays_begin..].contains(&date)) + } { + n -= 1; + } + } + date + } else { + let holidays_end = find_first_gt_index(holidays, date); + date += (n / n_business_days_in_week_mask) * 7; + n %= n_business_days_in_week_mask; + let holidays_temp = find_first_ge_index(&holidays[..holidays_end], date); + n -= (holidays_end - holidays_temp) as i32; + let holidays_end = holidays_temp; + while n < 0 { + date -= 1; + day_of_week = decrement_day_of_week(day_of_week); + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + if unsafe { + (*week_mask.get_unchecked(day_of_week)) + && (!holidays[..holidays_end].contains(&date)) + } { + n += 1; + } + } + date + } +} + +fn roll_start_date( + mut date: i32, + roll: Roll, + week_mask: &[bool; 7], + holidays: &[i32], +) -> PolarsResult<(i32, usize)> { + let mut day_of_week = get_day_of_week(date); + match roll { + Roll::Raise => { + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + if holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } { + let date = DateTime::from_timestamp(date as i64 * SECONDS_IN_DAY, 0) + .unwrap() + .format("%Y-%m-%d"); + polars_bail!(ComputeError: + "date {} is not a business date; use `roll` to roll forwards (or backwards) to the next (or previous) valid date.", date + ) + }; + }, + Roll::Forward => { + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + while holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } { + date += 1; + day_of_week = increment_day_of_week(day_of_week); + } + }, + Roll::Backward => { + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + while holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } { + date -= 1; + day_of_week = decrement_day_of_week(day_of_week); + } + }, + } + Ok((date, day_of_week)) +} + /// Sort and deduplicate holidays and remove holidays that are not business days. fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec { let mut holidays: Vec = holidays.to_vec(); holidays.sort_unstable(); let mut previous_holiday: Option = None; holidays.retain(|&x| { - // SAFETY: week_mask is length 7, start_weekday is between 0 and 6 - if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(weekday(x)) } { + // SAFETY: week_mask is length 7, get_day_of_week result is between 0 and 6 + if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(get_day_of_week(x)) } + { return false; } previous_holiday = Some(x); @@ -130,17 +365,25 @@ fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec { holidays } -fn weekday(x: i32) -> usize { +fn get_day_of_week(x: i32) -> usize { // the first modulo might return a negative number, so we add 7 and take // the modulo again so we're sure we have something between 0 (Monday) // and 6 (Sunday) (((x - 4) % 7 + 7) % 7) as usize } -fn increment_weekday(x: usize) -> usize { +fn increment_day_of_week(x: usize) -> usize { if x == 6 { 0 } else { x + 1 } } + +fn decrement_day_of_week(x: usize) -> usize { + if x == 0 { + 6 + } else { + x - 1 + } +} diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index a81446c28f6c..f11c9a003979 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -4,6 +4,27 @@ use super::*; pub struct DateLikeNameSpace(pub(crate) Expr); impl DateLikeNameSpace { + /// Add a given number of business days. + #[cfg(feature = "business")] + pub fn add_business_days( + self, + n: Expr, + week_mask: [bool; 7], + holidays: Vec, + roll: Roll, + ) -> Expr { + self.0.map_many_private( + FunctionExpr::Business(BusinessFunction::AddBusinessDay { + week_mask, + holidays, + roll, + }), + &[n], + false, + false, + ) + } + /// Convert from Date/Time/Datetime into String with the given format. /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). pub fn to_string(self, format: &str) -> Expr { diff --git a/crates/polars-plan/src/dsl/function_expr/business.rs b/crates/polars-plan/src/dsl/function_expr/business.rs index 2740aa856ac5..0d4fc2939d98 100644 --- a/crates/polars-plan/src/dsl/function_expr/business.rs +++ b/crates/polars-plan/src/dsl/function_expr/business.rs @@ -1,6 +1,7 @@ use std::fmt::{Display, Formatter}; use polars_core::prelude::*; +use polars_ops::prelude::Roll; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -16,6 +17,12 @@ pub enum BusinessFunction { week_mask: [bool; 7], holidays: Vec, }, + #[cfg(feature = "business")] + AddBusinessDay { + week_mask: [bool; 7], + holidays: Vec, + roll: Roll, + }, } impl Display for BusinessFunction { @@ -24,6 +31,8 @@ impl Display for BusinessFunction { let s = match self { #[cfg(feature = "business")] &BusinessDayCount { .. } => "business_day_count", + #[cfg(feature = "business")] + &AddBusinessDay { .. } => "add_business_days", }; write!(f, "{s}") } @@ -39,6 +48,14 @@ impl From for SpecialEq> { } => { map_as_slice!(business_day_count, week_mask, &holidays) }, + #[cfg(feature = "business")] + AddBusinessDay { + week_mask, + holidays, + roll, + } => { + map_as_slice!(add_business_days, week_mask, &holidays, roll) + }, } } } @@ -53,3 +70,15 @@ pub(super) fn business_day_count( let end = &s[1]; polars_ops::prelude::business_day_count(start, end, week_mask, holidays) } + +#[cfg(feature = "business")] +pub(super) fn add_business_days( + s: &[Series], + week_mask: [bool; 7], + holidays: &[i32], + roll: Roll, +) -> PolarsResult { + let start = &s[0]; + let n = &s[1]; + polars_ops::prelude::add_business_days(start, n, week_mask, holidays, roll) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 3082527e829d..98ce87676eb0 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -28,7 +28,10 @@ impl FunctionExpr { // Other expressions Boolean(func) => func.get_field(mapper), #[cfg(feature = "business")] - Business(_) => mapper.with_dtype(DataType::Int32), + Business(func) => match func { + BusinessFunction::BusinessDayCount { .. } => mapper.with_dtype(DataType::Int32), + BusinessFunction::AddBusinessDay { .. } => mapper.with_same_dtype(), + }, #[cfg(feature = "abs")] Abs => mapper.with_same_dtype(), Negate => mapper.with_same_dtype(), diff --git a/crates/polars-utils/src/binary_search.rs b/crates/polars-utils/src/binary_search.rs new file mode 100644 index 000000000000..b24aa3e33877 --- /dev/null +++ b/crates/polars-utils/src/binary_search.rs @@ -0,0 +1,25 @@ +/// Find the index of the first element of `arr` that is greater +/// or equal to `val`. +/// Assumes that `arr` is sorted. +pub fn find_first_ge_index(arr: &[T], val: T) -> usize +where + T: Ord, +{ + match arr.binary_search(&val) { + Ok(x) => x, + Err(x) => x, + } +} + +/// Find the index of the first element of `arr` that is greater +/// than `val`. +/// Assumes that `arr` is sorted. +pub fn find_first_gt_index(arr: &[T], val: T) -> usize +where + T: Ord, +{ + match arr.binary_search(&val) { + Ok(x) => x + 1, + Err(x) => x, + } +} diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 575571b62985..842ea031d32f 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -2,6 +2,7 @@ pub mod abs_diff; pub mod arena; pub mod atomic; +pub mod binary_search; pub mod cache; pub mod cell; pub mod clmul; diff --git a/py-polars/docs/source/reference/expressions/temporal.rst b/py-polars/docs/source/reference/expressions/temporal.rst index a8c33bb7cfc9..f5720ea9bb0d 100644 --- a/py-polars/docs/source/reference/expressions/temporal.rst +++ b/py-polars/docs/source/reference/expressions/temporal.rst @@ -9,6 +9,7 @@ The following methods are available under the `expr.dt` attribute. :toctree: api/ :template: autosummary/accessor_method.rst + Expr.dt.add_business_days Expr.dt.base_utc_offset Expr.dt.cast_time_unit Expr.dt.century diff --git a/py-polars/docs/source/reference/series/temporal.rst b/py-polars/docs/source/reference/series/temporal.rst index 97e7f7751337..c9864ebd5961 100644 --- a/py-polars/docs/source/reference/series/temporal.rst +++ b/py-polars/docs/source/reference/series/temporal.rst @@ -9,6 +9,7 @@ The following methods are available under the `Series.dt` attribute. :toctree: api/ :template: autosummary/accessor_method.rst + Series.dt.add_business_days Series.dt.base_utc_offset Series.dt.cast_time_unit Series.dt.century diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index 83a8967bdb0e..e681780b51cb 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime as dt -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable import polars._reexport as pl from polars import functions as F @@ -22,7 +22,14 @@ from datetime import timedelta from polars import Expr - from polars.type_aliases import Ambiguous, EpochTimeUnit, NonExistent, TimeUnit + from polars.type_aliases import ( + Ambiguous, + EpochTimeUnit, + IntoExpr, + NonExistent, + Roll, + TimeUnit, + ) class ExprDateTimeNameSpace: @@ -33,6 +40,123 @@ class ExprDateTimeNameSpace: def __init__(self, expr: Expr): self._pyexpr = expr._pyexpr + def add_business_days( + self, + n: int | IntoExpr, + week_mask: Iterable[bool] = (True, True, True, True, True, False, False), + holidays: Iterable[dt.date] = (), + roll: Roll = "raise", + ) -> Expr: + """ + Offset by `n` business days. + + Parameters + ---------- + n + Number of business days to offset by. Can be a single number of an + expression. + week_mask + Which days of the week to count. The default is Monday to Friday. + If you wanted to count only Monday to Thursday, you would pass + `(True, True, True, True, False, False, False)`. + holidays + Holidays to exclude from the count. The Python package + `python-holidays `_ + may come in handy here. You can install it with ``pip install holidays``, + and then, to get all Dutch holidays for years 2020-2024: + + .. code-block:: python + + import holidays + + my_holidays = holidays.country_holidays("NL", years=range(2020, 2025)) + + and pass `holidays=my_holidays` when you call `business_day_count`. + roll + What to do when the start date lands on a non-business day. Options are: + + - `'raise'`: raise an error + - `'forward'`: move to the next business day + - `'backward'`: move to the previous business day + + Returns + ------- + Expr + Data type is preserved. + + Examples + -------- + >>> from datetime import date + >>> df = pl.DataFrame({"start": [date(2020, 1, 1), date(2020, 1, 2)]}) + >>> df.with_columns(result=pl.col("start").dt.add_business_days(5)) + shape: (2, 2) + ┌────────────┬────────────┐ + │ start ┆ result │ + │ --- ┆ --- │ + │ date ┆ date │ + ╞════════════╪════════════╡ + │ 2020-01-01 ┆ 2020-01-08 │ + │ 2020-01-02 ┆ 2020-01-09 │ + └────────────┴────────────┘ + + You can pass a custom weekend - for example, if you only take Sunday off: + + >>> week_mask = (True, True, True, True, True, True, False) + >>> df.with_columns(result=pl.col("start").dt.add_business_days(5, week_mask)) + shape: (2, 2) + ┌────────────┬────────────┐ + │ start ┆ result │ + │ --- ┆ --- │ + │ date ┆ date │ + ╞════════════╪════════════╡ + │ 2020-01-01 ┆ 2020-01-07 │ + │ 2020-01-02 ┆ 2020-01-08 │ + └────────────┴────────────┘ + + You can also pass a list of holidays: + + >>> from datetime import date + >>> holidays = [date(2020, 1, 3), date(2020, 1, 6)] + >>> df.with_columns( + ... result=pl.col("start").dt.add_business_days(5, holidays=holidays) + ... ) + shape: (2, 2) + ┌────────────┬────────────┐ + │ start ┆ result │ + │ --- ┆ --- │ + │ date ┆ date │ + ╞════════════╪════════════╡ + │ 2020-01-01 ┆ 2020-01-10 │ + │ 2020-01-02 ┆ 2020-01-13 │ + └────────────┴────────────┘ + + Roll all dates forwards to the next business day: + + >>> df = pl.DataFrame({"start": [date(2020, 1, 5), date(2020, 1, 6)]}) + >>> df.with_columns( + ... rolled_forwards=pl.col("start").dt.add_business_days(0, roll="forward") + ... ) + shape: (2, 2) + ┌────────────┬─────────────────┐ + │ start ┆ rolled_forwards │ + │ --- ┆ --- │ + │ date ┆ date │ + ╞════════════╪═════════════════╡ + │ 2020-01-05 ┆ 2020-01-06 │ + │ 2020-01-06 ┆ 2020-01-06 │ + └────────────┴─────────────────┘ + """ + n_pyexpr = parse_as_expression(n) + unix_epoch = dt.date(1970, 1, 1) + return wrap_expr( + self._pyexpr.dt_add_business_days( + n_pyexpr, + week_mask, + [(holiday - unix_epoch).days for holiday in holidays], + roll, + ) + ) + def truncate( self, every: str | timedelta | Expr, diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 20b16231df9c..59b12e7b65c3 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable from polars._utils.convert import to_py_date, to_py_datetime from polars._utils.deprecation import deprecate_function, deprecate_renamed_function @@ -17,7 +17,9 @@ from polars.type_aliases import ( Ambiguous, EpochTimeUnit, + IntoExpr, NonExistent, + Roll, TemporalLiteral, TimeUnit, ) @@ -36,6 +38,97 @@ def __getitem__(self, item: int) -> dt.date | dt.datetime | dt.timedelta: s = wrap_s(self._s) return s[item] + def add_business_days( + self, + n: int | IntoExpr, + week_mask: Iterable[bool] = (True, True, True, True, True, False, False), + holidays: Iterable[dt.date] = (), + roll: Roll = "raise", + ) -> Expr: + """ + Offset by `n` business days. + + Parameters + ---------- + n + Number of business days to offset by. Can be a single number of an + expression. + week_mask + Which days of the week to count. The default is Monday to Friday. + If you wanted to count only Monday to Thursday, you would pass + `(True, True, True, True, False, False, False)`. + holidays + Holidays to exclude from the count. The Python package + `python-holidays `_ + may come in handy here. You can install it with ``pip install holidays``, + and then, to get all Dutch holidays for years 2020-2024: + + .. code-block:: python + + import holidays + + my_holidays = holidays.country_holidays("NL", years=range(2020, 2025)) + + and pass `holidays=my_holidays` when you call `business_day_count`. + roll + What to do when the start date lands on a non-business day. Options are: + + - `'raise'`: raise an error + - `'forward'`: move to the next business day + - `'backward'`: move to the previous business day + + Returns + ------- + Expr + Data type is preserved. + + Examples + -------- + >>> from datetime import date + >>> s = pl.Series("start", [date(2020, 1, 1), date(2020, 1, 2)]) + >>> s.dt.add_business_days(5) + shape: (2,) + Series: 'start' [date] + [ + 2020-01-08 + 2020-01-09 + ] + + You can pass a custom weekend - for example, if you only take Sunday off: + + >>> week_mask = (True, True, True, True, True, True, False) + >>> s.dt.add_business_days(5, week_mask) + shape: (2,) + Series: 'start' [date] + [ + 2020-01-07 + 2020-01-08 + ] + + You can also pass a list of holidays: + + >>> from datetime import date + >>> holidays = [date(2020, 1, 3), date(2020, 1, 6)] + >>> s.dt.add_business_days(5, holidays=holidays) + shape: (2,) + Series: 'start' [date] + [ + 2020-01-10 + 2020-01-13 + ] + + Roll all dates forwards to the next business day: + + >>> s = pl.Series("start", [date(2020, 1, 5), date(2020, 1, 6)]) + >>> s.dt.add_business_days(0, roll="forward") + shape: (2,) + Series: 'start' [date] + [ + 2020-01-06 + 2020-01-06 + ] + """ + def min(self) -> dt.date | dt.datetime | dt.timedelta | None: """ Return minimum as Python datetime. diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index 2443ce4f64ad..86f6eba25c83 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -107,6 +107,7 @@ "min", "max", "first", "last", "sum", "mean", "median", "len" ] RankMethod: TypeAlias = Literal["average", "min", "max", "dense", "ordinal", "random"] +Roll: TypeAlias = Literal["raise", "forward", "backward"] SizeUnit: TypeAlias = Literal[ "b", "kb", diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index 88b89f27a26e..d281576843bc 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -859,6 +859,22 @@ impl FromPyObject<'_> for Wrap { } } +impl FromPyObject<'_> for Wrap { + fn extract(ob: &PyAny) -> PyResult { + let parsed = match ob.extract::<&str>()? { + "raise" => Roll::Raise, + "forward" => Roll::Forward, + "backward" => Roll::Backward, + v => { + return Err(PyValueError::new_err(format!( + "`roll` must be one of {{'raise', 'forward', 'backward'}}, got {v}", + ))) + }, + }; + Ok(Wrap(parsed)) + } +} + impl FromPyObject<'_> for Wrap { fn extract(ob: &PyAny) -> PyResult { let parsed = match ob.extract::<&str>()? { diff --git a/py-polars/src/expr/datetime.rs b/py-polars/src/expr/datetime.rs index b28fc326e60e..5b81497a57f7 100644 --- a/py-polars/src/expr/datetime.rs +++ b/py-polars/src/expr/datetime.rs @@ -6,6 +6,20 @@ use crate::PyExpr; #[pymethods] impl PyExpr { + fn dt_add_business_days( + &self, + n: PyExpr, + week_mask: [bool; 7], + holidays: Vec, + roll: Wrap, + ) -> Self { + self.inner + .clone() + .dt() + .add_business_days(n.inner, week_mask, holidays, roll.0) + .into() + } + fn dt_to_string(&self, format: &str) -> Self { self.inner.clone().dt().to_string(format).into() } diff --git a/py-polars/tests/parametric/time_series/test_add_business_days.py b/py-polars/tests/parametric/time_series/test_add_business_days.py new file mode 100644 index 000000000000..a4328c4efdd1 --- /dev/null +++ b/py-polars/tests/parametric/time_series/test_add_business_days.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import datetime as dt +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +import numpy as np +from hypothesis import assume, given + +import polars as pl + +if TYPE_CHECKING: + from polars.type_aliases import Roll + + +@given( + start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + n=st.integers(min_value=-100, max_value=100), + week_mask=st.lists( + st.sampled_from([True, False]), + min_size=7, + max_size=7, + ), + holidays=st.lists( + st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + min_size=0, + max_size=100, + ), + roll=st.sampled_from(["forward", "backward"]), +) +def test_against_np_busday_offset( + start: dt.date, + n: int, + week_mask: tuple[bool, ...], + holidays: list[dt.date], + roll: Roll, +) -> None: + assume(any(week_mask)) + result = ( + pl.DataFrame({"start": [start]}) + .select( + res=pl.col("start").dt.add_business_days( + n, week_mask=week_mask, holidays=holidays, roll=roll + ) + )["res"] + .item() + ) + expected = np.busday_offset( + start, n, weekmask=week_mask, holidays=holidays, roll=roll + ) + assert result == expected diff --git a/py-polars/tests/unit/functions/business/test_add_business_days.py b/py-polars/tests/unit/functions/business/test_add_business_days.py new file mode 100644 index 000000000000..9e082a9f5109 --- /dev/null +++ b/py-polars/tests/unit/functions/business/test_add_business_days.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars.type_aliases import TimeUnit + + +def test_add_business_days() -> None: + # (Expression, expression) + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [-1, 5], + } + ) + result = df.select(result=pl.col("start").dt.add_business_days("n"))["result"] + expected = pl.Series("result", [date(2019, 12, 31), date(2020, 1, 9)], pl.Date) + assert_series_equal(result, expected) + + # (Expression, scalar) + result = df.select(result=pl.col("start").dt.add_business_days(5))["result"] + expected = pl.Series("result", [date(2020, 1, 8), date(2020, 1, 9)], pl.Date) + assert_series_equal(result, expected) + + # (Scalar, expression) + result = df.select( + result=pl.lit(date(2020, 1, 1), dtype=pl.Date).dt.add_business_days(pl.col("n")) + )["result"] + expected = pl.Series("result", [date(2019, 12, 31), date(2020, 1, 8)], pl.Date) + assert_series_equal(result, expected) + + # (Scalar, scalar) + result = df.select( + result=pl.lit(date(2020, 1, 1), dtype=pl.Date).dt.add_business_days(5) + )["result"] + expected = pl.Series("result", [date(2020, 1, 8)], pl.Date) + assert_series_equal(result, expected) + + +def test_add_business_day_w_week_mask() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [1, 5], + } + ) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", week_mask=(True, True, True, True, True, True, False) + ) + )["result"] + expected = pl.Series("result", [date(2020, 1, 2), date(2020, 1, 8)]) + assert_series_equal(result, expected) + + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", week_mask=(True, True, True, True, False, False, True) + ) + )["result"] + expected = pl.Series("result", [date(2020, 1, 2), date(2020, 1, 9)]) + assert_series_equal(result, expected) + + +def test_add_business_day_w_week_mask_invalid() -> None: + with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"): + pl.col("start").dt.add_business_days("n", week_mask=(False, 0)) # type: ignore[arg-type] + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [1, 5], + } + ) + with pytest.raises( + pl.ComputeError, match="`week_mask` must have at least one business day" + ): + df.select(pl.col("start").dt.add_business_days("n", week_mask=[False] * 7)) + + +def test_add_business_days_schema() -> None: + lf = pl.LazyFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [1, 5], + } + ) + result = lf.select( + result=pl.col("start").dt.add_business_days("n"), + ) + assert result.schema["result"] == pl.Date + assert result.collect().schema["result"] == pl.Date + assert 'col("start").add_business_days([col("n")])' in result.explain() + + +def test_add_business_days_w_holidays() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)], + "n": [1, 5, 7], + } + ) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", holidays=[date(2020, 1, 3), date(2020, 1, 9)] + ), + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 2), date(2020, 1, 13), date(2020, 1, 15)] + ) + assert_series_equal(result, expected) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", holidays=[date(2020, 1, 1), date(2020, 1, 2)], roll="backward" + ), + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 3), date(2020, 1, 9), date(2020, 1, 13)] + ) + assert_series_equal(result, expected) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", + holidays=[ + date(2019, 1, 1), + date(2020, 1, 1), + date(2020, 1, 2), + date(2021, 1, 1), + ], + roll="backward", + ), + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 3), date(2020, 1, 9), date(2020, 1, 13)] + ) + assert_series_equal(result, expected) + + +def test_add_business_days_w_roll() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 4)], + "n": [1, 5, 7], + } + ) + with pytest.raises(pl.ComputeError, match="is not a business date"): + df.select(result=pl.col("start").dt.add_business_days("n")) + result = df.select( + result=pl.col("start").dt.add_business_days("n", roll="forward") + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 2), date(2020, 1, 9), date(2020, 1, 15)] + ) + assert_series_equal(result, expected) + result = df.select( + result=pl.col("start").dt.add_business_days("n", roll="backward") + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 2), date(2020, 1, 9), date(2020, 1, 14)] + ) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("time_zone", [None, "Europe/London", "Asia/Kathmandu"]) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_add_business_days_datetime(time_zone: str | None, time_unit: TimeUnit) -> None: + df = pl.DataFrame( + {"start": [datetime(2020, 3, 28, 1), datetime(2020, 1, 10, 4)]}, + schema={"start": pl.Datetime(time_unit, time_zone)}, + ) + result = df.select( + result=pl.col("start").dt.add_business_days(2, week_mask=[True] * 7) + )["result"] + expected = pl.Series( + "result", + [datetime(2020, 3, 30, 1), datetime(2020, 1, 12, 4)], + pl.Datetime(time_unit, time_zone), + ) + assert_series_equal(result, expected) + + with pytest.raises(pl.ComputeError, match="is not a business date"): + df.select(result=pl.col("start").dt.add_business_days(2)) + + +def test_add_business_days_invalid() -> None: + df = pl.DataFrame({"start": [timedelta(1)]}) + with pytest.raises(pl.InvalidOperationError, match="expected date or datetime"): + df.select(result=pl.col("start").dt.add_business_days(2, week_mask=[True] * 7)) + df = pl.DataFrame({"start": [date(2020, 1, 1)]}) + with pytest.raises( + pl.InvalidOperationError, + match="expected Int64, Int32, UInt64, or UInt32, got f64", + ): + df.select( + result=pl.col("start").dt.add_business_days(1.5, week_mask=[True] * 7) + ) + with pytest.raises( + ValueError, + match="`roll` must be one of {'raise', 'forward', 'backward'}, got cabbage", + ): + df.select(result=pl.col("start").dt.add_business_days(1, roll="cabbage")) # type: ignore[arg-type] + + +def test_add_business_days_w_nulls() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 3, 28), None], + "n": [None, 2], + }, + ) + result = df.select(result=pl.col("start").dt.add_business_days("n"))["result"] + expected = pl.Series("result", [None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + result = df.select( + result=pl.col("start").dt.add_business_days(pl.lit(None, dtype=pl.Int32)) + )["result"] + assert_series_equal(result, expected) + + result = df.select(result=pl.lit(None, dtype=pl.Date).dt.add_business_days("n"))[ + "result" + ] + assert_series_equal(result, expected) + + result = df.select( + result=pl.lit(None, dtype=pl.Date).dt.add_business_days( + pl.lit(None, dtype=pl.Int32) + ) + )["result"] + expected = pl.Series("result", [None], dtype=pl.Date) + assert_series_equal(result, expected)