Skip to content

Commit

Permalink
feat: add holidays argument to business_day_count (#15580)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 12, 2024
1 parent 23791bd commit 0b84b14
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 45 deletions.
50 changes: 45 additions & 5 deletions crates/polars-ops/src/series/ops/business.rs
Expand Up @@ -7,14 +7,19 @@ use polars_core::prelude::*;
/// - `start`: Series holding start dates.
/// - `end`: Series holding end dates.
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of
/// days since the UNIX epoch.
pub fn business_day_count(
start: &Series,
end: &Series,
week_mask: [bool; 7],
holidays: &[i32],
) -> PolarsResult<Series> {
if !week_mask.iter().any(|&x| x) {
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}

let holidays = normalise_holidays(holidays, &week_mask);
let start_dates = start.date()?;
let end_dates = end.date()?;
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
Expand All @@ -28,6 +33,7 @@ pub fn business_day_count(
end_date,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
})
} else {
Expand All @@ -42,6 +48,7 @@ pub fn business_day_count(
end_date,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
})
} else {
Expand All @@ -54,6 +61,7 @@ pub fn business_day_count(
end_date,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
}),
};
Expand All @@ -67,6 +75,7 @@ fn business_day_count_impl(
mut end_date: i32,
week_mask: &[bool; 7],
n_business_days_in_week_mask: i32,
holidays: &[i32],
) -> i32 {
let swapped = start_date > end_date;
if swapped {
Expand All @@ -75,21 +84,28 @@ fn business_day_count_impl(
end_date += 1;
}

let holidays_begin = match holidays.binary_search(&start_date) {
Ok(x) => x,
Err(x) => x,
} as i32;
let holidays_end = match holidays[(holidays_begin as usize)..].binary_search(&end_date) {
Ok(x) => x as i32 + holidays_begin,
Err(x) => x as i32 + holidays_begin,
};

let mut start_weekday = weekday(start_date);
let diff = end_date - start_date;
let whole_weeks = diff / 7;
let mut count = 0;
let mut count = -(holidays_end - holidays_begin);
count += whole_weeks * n_business_days_in_week_mask;
start_date += whole_weeks * 7;
while start_date < end_date {
// SAFETY: week_mask is length 7, start_weekday is between 0 and 6
if unsafe { *week_mask.get_unchecked(start_weekday) } {
count += 1;
}
start_date += 1;
start_weekday += 1;
if start_weekday >= 7 {
start_weekday = 0;
}
start_weekday = increment_weekday(start_weekday);
}
if swapped {
-count
Expand All @@ -98,9 +114,33 @@ fn business_day_count_impl(
}
}

/// Sort and deduplicate holidays and remove holidays that are not business days.
fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {
let mut holidays: Vec<i32> = holidays.to_vec();
holidays.sort_unstable();
let mut previous_holiday: Option<i32> = None;
holidays.retain(|&x| {
// SAFETY: week_mask is length 7, start_weekday is between 0 and 6
if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(weekday(x)) } {
return false;
}
previous_holiday = Some(x);
true
});
holidays
}

fn weekday(x: i32) -> usize {
// the first modulo might return a negative number, so we add 7 and take
// the modulo again so we're sure we have something between 0 (Monday)
// and 6 (Sunday)
(((x - 4) % 7 + 7) % 7) as usize
}

fn increment_weekday(x: usize) -> usize {
if x == 6 {
0
} else {
x + 1
}
}
20 changes: 15 additions & 5 deletions crates/polars-plan/src/dsl/function_expr/business.rs
Expand Up @@ -12,7 +12,10 @@ use crate::prelude::SeriesUdf;
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum BusinessFunction {
#[cfg(feature = "business")]
BusinessDayCount { week_mask: [bool; 7] },
BusinessDayCount {
week_mask: [bool; 7],
holidays: Vec<i32>,
},
}

impl Display for BusinessFunction {
Expand All @@ -30,16 +33,23 @@ impl From<BusinessFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
use BusinessFunction::*;
match func {
#[cfg(feature = "business")]
BusinessDayCount { week_mask } => {
map_as_slice!(business_day_count, week_mask)
BusinessDayCount {
week_mask,
holidays,
} => {
map_as_slice!(business_day_count, week_mask, &holidays)
},
}
}
}

#[cfg(feature = "business")]
pub(super) fn business_day_count(s: &[Series], week_mask: [bool; 7]) -> PolarsResult<Series> {
pub(super) fn business_day_count(
s: &[Series],
week_mask: [bool; 7],
holidays: &[i32],
) -> PolarsResult<Series> {
let start = &s[0];
let end = &s[1];
polars_ops::prelude::business_day_count(start, end, week_mask)
polars_ops::prelude::business_day_count(start, end, week_mask, holidays)
}
12 changes: 10 additions & 2 deletions crates/polars-plan/src/dsl/functions/business.rs
@@ -1,12 +1,20 @@
use super::*;

#[cfg(feature = "dtype-date")]
pub fn business_day_count(start: Expr, end: Expr, week_mask: [bool; 7]) -> Expr {
pub fn business_day_count(
start: Expr,
end: Expr,
week_mask: [bool; 7],
holidays: Vec<i32>,
) -> Expr {
let input = vec![start, end];

Expr::Function {
input,
function: FunctionExpr::Business(BusinessFunction::BusinessDayCount { week_mask }),
function: FunctionExpr::Business(BusinessFunction::BusinessDayCount {
week_mask,
holidays,
}),
options: FunctionOptions {
allow_rename: true,
..Default::default()
Expand Down
88 changes: 62 additions & 26 deletions py-polars/polars/functions/business.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
from datetime import date
from typing import TYPE_CHECKING, Iterable

from polars._utils.parse_expr_input import parse_as_expression
Expand All @@ -10,8 +11,6 @@
import polars.polars as plr

if TYPE_CHECKING:
from datetime import date

from polars import Expr
from polars.type_aliases import IntoExprColumn

Expand All @@ -20,6 +19,7 @@ def business_day_count(
start: date | IntoExprColumn,
end: date | IntoExprColumn,
week_mask: Iterable[bool] = (True, True, True, True, True, False, False),
holidays: Iterable[date] = (),
) -> Expr:
"""
Count the number of business days between `start` and `end` (not including `end`).
Expand All @@ -34,6 +34,19 @@ def business_day_count(
Which days of the week to count. The default is Monday to Friday.
If you wanted to count only Monday to Thursday, you would pass
`(True, True, True, True, False, False, False)`.
holidays
Holidays to exclude from the count. The Python package
`python-holidays <https://github.com/vacanza/python-holidays>`_
may come in handy here. You can install it with ``pip install holidays``,
and then, to get all Dutch holidays for years 2020-2024:
.. code-block:: python
import holidays
my_holidays = holidays.country_holidays("NL", years=range(2020, 2025))
and pass `holidays=my_holidays` when you call `business_day_count`.
Returns
-------
Expand All @@ -49,39 +62,62 @@ def business_day_count(
... }
... )
>>> df.with_columns(
... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(),
... business_day_count=pl.business_day_count("start", "end"),
... )
shape: (2, 4)
┌────────────┬────────────┬─────────────────┬────────────────────┐
│ start ┆ end ┆ total_day_count ┆ business_day_count │
│ --- ┆ --- ┆ --- ┆ ---
│ date ┆ date ┆ i64 ┆ i32 │
╞════════════╪════════════╪═════════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 1 ┆ 1
│ 2020-01-02 ┆ 2020-01-10 ┆ 8 ┆ 6 │
└────────────┴────────────┴─────────────────┴────────────────────┘
Note how the two "count" columns differ due to the weekend (2020-01-04 - 2020-01-05)
not being counted by `business_day_count`.
shape: (2, 3)
┌────────────┬────────────┬────────────────────┐
│ start ┆ end ┆ business_day_count │
│ --- ┆ --- ┆ --- │
│ date ┆ date ┆ i32 │
╞════════════╪════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 1 │
│ 2020-01-02 ┆ 2020-01-10 ┆ 6 │
└────────────┴────────────┴────────────────────┘
Note how the business day count is 6 (as opposed a regular day count of 8)
due to the weekend (2020-01-04 - 2020-01-05) not being counted.
You can pass a custom weekend - for example, if you only take Sunday off:
>>> week_mask = (True, True, True, True, True, True, False)
>>> df.with_columns(
... total_day_count=(pl.col("end") - pl.col("start")).dt.total_days(),
... business_day_count=pl.business_day_count("start", "end", week_mask),
... )
shape: (2, 4)
┌────────────┬────────────┬─────────────────┬────────────────────┐
│ start ┆ end ┆ total_day_count ┆ business_day_count │
│ --- ┆ --- ┆ --- ┆ --- │
│ date ┆ date ┆ i64 ┆ i32 │
╞════════════╪════════════╪═════════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 1 ┆ 1 │
│ 2020-01-02 ┆ 2020-01-10 ┆ 8 ┆ 7 │
└────────────┴────────────┴─────────────────┴────────────────────┘
shape: (2, 3)
┌────────────┬────────────┬────────────────────┐
│ start ┆ end ┆ business_day_count │
│ --- ┆ --- ┆ --- │
│ date ┆ date ┆ i32 │
╞════════════╪════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 1 │
│ 2020-01-02 ┆ 2020-01-10 ┆ 7 │
└────────────┴────────────┴────────────────────┘
You can also pass a list of holidays to exclude from the count:
>>> from datetime import date
>>> holidays = [date(2020, 1, 1), date(2020, 1, 2)]
>>> df.with_columns(
... business_day_count=pl.business_day_count("start", "end", holidays=holidays)
... )
shape: (2, 3)
┌────────────┬────────────┬────────────────────┐
│ start ┆ end ┆ business_day_count │
│ --- ┆ --- ┆ --- │
│ date ┆ date ┆ i32 │
╞════════════╪════════════╪════════════════════╡
│ 2020-01-01 ┆ 2020-01-02 ┆ 0 │
│ 2020-01-02 ┆ 2020-01-10 ┆ 5 │
└────────────┴────────────┴────────────────────┘
"""
start_pyexpr = parse_as_expression(start)
end_pyexpr = parse_as_expression(end)
return wrap_expr(plr.business_day_count(start_pyexpr, end_pyexpr, week_mask))
unix_epoch = date(1970, 1, 1)
return wrap_expr(
plr.business_day_count(
start_pyexpr,
end_pyexpr,
week_mask,
[(holiday - unix_epoch).days for holiday in holidays],
)
)
9 changes: 7 additions & 2 deletions py-polars/src/functions/business.rs
Expand Up @@ -4,8 +4,13 @@ use pyo3::prelude::*;
use crate::PyExpr;

#[pyfunction]
pub fn business_day_count(start: PyExpr, end: PyExpr, week_mask: [bool; 7]) -> PyExpr {
pub fn business_day_count(
start: PyExpr,
end: PyExpr,
week_mask: [bool; 7],
holidays: Vec<i32>,
) -> PyExpr {
let start = start.inner;
let end = end.inner;
dsl::business_day_count(start, end, week_mask).into()
dsl::business_day_count(start, end, week_mask, holidays).into()
}
17 changes: 12 additions & 5 deletions py-polars/tests/parametric/time_series/test_business_day_count.py
Expand Up @@ -18,19 +18,26 @@
min_size=7,
max_size=7,
),
holidays=st.lists(
st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),
min_size=0,
max_size=100,
),
)
def test_against_np_busday_count(
start: dt.date,
end: dt.date,
week_mask: tuple[bool, ...],
start: dt.date, end: dt.date, week_mask: tuple[bool, ...], holidays: list[dt.date]
) -> None:
assume(any(week_mask))
result = (
pl.DataFrame({"start": [start], "end": [end]})
.select(n=pl.business_day_count("start", "end", week_mask=week_mask))["n"]
.select(
n=pl.business_day_count(
"start", "end", week_mask=week_mask, holidays=holidays
)
)["n"]
.item()
)
expected = np.busday_count(start, end, weekmask=week_mask)
expected = np.busday_count(start, end, weekmask=week_mask, holidays=holidays)
if start > end and parse_version(np.__version__) < parse_version("1.25"):
# Bug in old versions of numpy
reject()
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/functions/business/test_business_day_count.py
Expand Up @@ -104,3 +104,19 @@ def test_business_day_count_schema() -> None:
assert result.schema["business_day_count"] == pl.Int32
assert result.collect().schema["business_day_count"] == pl.Int32
assert 'col("start").business_day_count([col("end")])' in result.explain()


def test_business_day_count_w_holidays() -> None:
df = pl.DataFrame(
{
"start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],
"end": [date(2020, 1, 2), date(2020, 1, 10), date(2020, 1, 9)],
}
)
result = df.select(
business_day_count=pl.business_day_count(
"start", "end", holidays=[date(2020, 1, 1), date(2020, 1, 9)]
),
)["business_day_count"]
expected = pl.Series("business_day_count", [0, 5, 5], pl.Int32)
assert_series_equal(result, expected)

0 comments on commit 0b84b14

Please sign in to comment.