Skip to content

Commit

Permalink
Extend HolidayBase::categories to accept a single value (#1550)
Browse files Browse the repository at this point in the history
  • Loading branch information
arkid15r committed Nov 11, 2023
1 parent e96a5a1 commit 4eed261
Show file tree
Hide file tree
Showing 22 changed files with 188 additions and 147 deletions.
21 changes: 20 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,31 @@ To get a list of other categories holidays (for countries that support them):

.. code-block:: python
>>> for dt, name in sorted(holidays.BE(years=2023, language="en_US", categories=(BANK,)).items()):
>>> for dt, name in sorted(holidays.BE(years=2023, language="en_US", categories=BANK).items()):
>>> print(dt, name)
2023-04-07 Good Friday
2023-05-19 Friday after Ascension Day
2023-12-26 Bank Holiday
>>> for dt, name in sorted(holidays.BE(years=2023, language="en_US", categories=(BANK, PUBLIC)).items()):
>>> print(dt, name)
2023-01-01 New Year's Day
2023-04-07 Good Friday
2023-04-09 Easter
2023-04-10 Easter Monday
2023-05-01 Labor Day
2023-05-18 Ascension Day
2023-05-19 Friday after Ascension Day
2023-05-28 Whit Sunday
2023-05-29 Whit Monday
2023-07-21 National Day
2023-08-15 Assumption of Mary
2023-11-01 All Saints' Day
2023-11-11 Armistice Day
2023-12-25 Christmas Day
2023-12-26 Bank Holiday
Date from holiday name
----------------------

Expand Down
22 changes: 22 additions & 0 deletions holidays/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,31 @@
# License: MIT (see LICENSE file)


def _normalize_arguments(cls, value):
"""Normalize arguments.
:param cls:
A type of arguments to normalize.
:param value:
Either a single item or an iterable of `cls` type.
:return:
A set created from `value` argument.
"""
if isinstance(value, cls):
return {value}

return set(value) if value is not None else set()


def _normalize_tuple(data):
"""Normalize tuple.
:param data:
Either a tuple or a tuple of tuples.
:return:
An unchanged object for tuple of tuples, e.g., ((JAN, 10), (DEC, 31)).
An object put into a tuple otherwise, e.g., ((JAN, 10),).
Expand Down
20 changes: 10 additions & 10 deletions holidays/holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@
_get_nth_weekday_of_month,
)
from holidays.constants import HOLIDAY_NAME_DELIMITER, ALL_CATEGORIES, PUBLIC
from holidays.helpers import _normalize_tuple
from holidays.helpers import _normalize_arguments, _normalize_tuple

CategoryArg = Union[str, Iterable[str]]
DateArg = Union[date, Tuple[int, int]]
DateLike = Union[date, datetime, str, float, int]
SpecialHoliday = Union[Tuple[int, int, str], Tuple[Tuple[int, int, str], ...]]
SubstitutedHoliday = Union[
Union[Tuple[int, int, int, int], Tuple[int, int, int, int, int]],
Tuple[Union[Tuple[int, int, int, int], Tuple[int, int, int, int, int]], ...],
]
YearArg = Union[int, Iterable[int]]


class HolidayBase(Dict[date, str]):
Expand Down Expand Up @@ -236,14 +238,14 @@ def _populate(self, year):

def __init__(
self,
years: Optional[Union[int, Iterable[int]]] = None,
years: Optional[YearArg] = None,
expand: bool = True,
observed: bool = True,
subdiv: Optional[str] = None,
prov: Optional[str] = None, # Deprecated.
state: Optional[str] = None, # Deprecated.
language: Optional[str] = None,
categories: Optional[Tuple[str]] = None,
categories: Optional[CategoryArg] = None,
) -> None:
"""
:param years:
Expand Down Expand Up @@ -282,11 +284,11 @@ def __init__(
"""
super().__init__()

self.categories = _normalize_arguments(str, categories or {PUBLIC})
self.expand = expand
self.language = language.lower() if language else None
self.observed = observed
self.subdiv = subdiv or prov or state
self.categories = set(categories) if categories else {PUBLIC}

self.tr = gettext # Default translation method.

Expand Down Expand Up @@ -317,7 +319,9 @@ def __init__(
DeprecationWarning,
)

unknown_categories = self.categories.difference(ALL_CATEGORIES)
unknown_categories = self.categories.difference( # type: ignore[union-attr]
ALL_CATEGORIES
)
if len(unknown_categories) > 0:
raise NotImplementedError(
f"Category is not supported: {', '.join(unknown_categories)}."
Expand All @@ -333,11 +337,7 @@ def __init__(
)
self.tr = translator.gettext

if isinstance(years, int):
self.years = {years}
else:
self.years = set(years) if years is not None else set()

self.years = _normalize_arguments(int, years)
for year in self.years:
self._populate(year)

Expand Down
2 changes: 1 addition & 1 deletion tests/countries/test_austria.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_2022(self):

def test_bank_2022(self):
self.assertHolidays(
Austria(categories=(BANK,), years=2022),
Austria(categories=BANK, years=2022),
("2022-04-15", "Karfreitag"),
("2022-12-24", "Heiliger Abend"),
("2022-12-31", "Silvester"),
Expand Down
2 changes: 1 addition & 1 deletion tests/countries/test_belgium.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_2022(self):

def test_bank_2022(self):
self.assertHolidays(
Belgium(categories=(BANK,), years=2022),
Belgium(categories=BANK, years=2022),
("2022-04-15", "Goede Vrijdag"),
("2022-05-27", "Vrijdag na O. L. H. Hemelvaart"),
("2022-12-26", "Banksluitingsdag"),
Expand Down
2 changes: 1 addition & 1 deletion tests/countries/test_brazil.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_christmas_day(self):
self.assertNoHolidayName("Natal", range(1890, 1922))

def test_optional_holidays(self):
holidays = Brazil(categories=(OPTIONAL,))
holidays = Brazil(categories=OPTIONAL)
dt = (
"2018-02-12",
"2018-02-13",
Expand Down
2 changes: 1 addition & 1 deletion tests/countries/test_bulgaria.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_national_awakening_day(self):
name = "Ден на народните будители"
self.assertHolidayName(
name,
Bulgaria(categories=(SCHOOL,), years=range(1990, 2050)),
Bulgaria(categories=SCHOOL, years=range(1990, 2050)),
(f"{year}-11-01" for year in range(1990, 2050)),
)
self.assertNoHolidayName(name)
Expand Down

0 comments on commit 4eed261

Please sign in to comment.