Skip to content

Commit

Permalink
Add categories and subdivisions support to substituted holidays (#1558)
Browse files Browse the repository at this point in the history
  • Loading branch information
KJhellico committed Nov 30, 2023
1 parent 5eabe3a commit 59dde59
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
6 changes: 5 additions & 1 deletion holidays/groups/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ class StaticHolidays:

def __init__(self, cls) -> None:
for attribute_name in cls.__dict__.keys():
if attribute_name.startswith("special_") or attribute_name.startswith("substituted_"):
if attribute_name.startswith("special_"):
setattr(self, attribute_name, getattr(cls, attribute_name))
self._has_special = True
elif attribute_name.startswith("substituted_"):
setattr(self, attribute_name, getattr(cls, attribute_name))
self._has_substituted = True
33 changes: 20 additions & 13 deletions holidays/holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _add_subdiv_holidays(self):

def _add_substituted_holidays(self):
"""Populate substituted holidays."""
if len(self.substituted_holidays) == 0:
if not hasattr(self, "_has_substituted"):
return None
if not hasattr(self, "substituted_label") or not hasattr(self, "substituted_date_format"):
raise ValueError(
Expand All @@ -675,11 +675,15 @@ def _add_substituted_holidays(self):
)
substituted_label = self.tr(self.substituted_label)
substituted_date_format = self.tr(self.substituted_date_format)
for hol in _normalize_tuple(self.substituted_holidays.get(self._year, ())):
from_year = hol[0] if len(hol) == 5 else self._year
from_month, from_day, to_month, to_day = hol[-4:]
from_date = date(from_year, from_month, from_day).strftime(substituted_date_format)
self._add_holiday(substituted_label % from_date, to_month, to_day)

for mapping_name in self._get_static_holiday_mapping_names():
for hol in _normalize_tuple(
getattr(self, f"substituted_{mapping_name}", {}).get(self._year, ())
):
from_year = hol[0] if len(hol) == 5 else self._year
from_month, from_day, to_month, to_day = hol[-4:]
from_date = date(from_year, from_month, from_day).strftime(substituted_date_format)
self._add_holiday(substituted_label % from_date, to_month, to_day)

def _check_weekday(self, weekday: int, *args) -> bool:
"""
Expand Down Expand Up @@ -751,27 +755,30 @@ def _populate(self, year: int) -> None:
# Populate substituted holidays.
self._add_substituted_holidays()

def _get_special_holiday_mapping_names(self):
def _get_static_holiday_mapping_names(self):
# Check for general special holidays.
mapping_names = ["special_holidays"]
mapping_names = ["holidays"]

# Check subdivision specific special holidays.
if self.subdiv is not None:
subdiv = self.subdiv.replace("-", "_").replace(" ", "_").lower()
mapping_names.append(f"special_{subdiv}_holidays")
mapping_names.append(f"{subdiv}_holidays")

# Check category specific special holidays (both general and per subdivision).
for category in sorted(self.categories):
mapping_names.append(f"special_{category}_holidays")
mapping_names.append(f"{category}_holidays")
if self.subdiv is not None:
mapping_names.append(f"special_{subdiv}_{category}_holidays")
mapping_names.append(f"{subdiv}_{category}_holidays")

return mapping_names

def _add_special_holidays(self):
for mapping_name in self._get_special_holiday_mapping_names():
if not hasattr(self, "_has_special"):
return None

for mapping_name in self._get_static_holiday_mapping_names():
for month, day, name in _normalize_tuple(
getattr(self, mapping_name, {}).get(self._year, ())
getattr(self, f"special_{mapping_name}", {}).get(self._year, ())
):
self._add_holiday(name, date(self._year, month, day))

Expand Down
4 changes: 2 additions & 2 deletions holidays/observed_holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _add_special_holidays(self):
if not self.observed:
return None

for mapping_name in self._get_special_holiday_mapping_names():
for mapping_name in self._get_static_holiday_mapping_names():
for month, day, name in _normalize_tuple(
getattr(self, f"{mapping_name}_observed", {}).get(self._year, ())
getattr(self, f"special_{mapping_name}_observed", {}).get(self._year, ())
):
self._add_holiday(
self.tr(self.observed_label) % self.tr(name), date(self._year, month, day)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@


class EntityStub(HolidayBase):
_has_special = True
_has_substituted = True
special_holidays = {
1111: (JAN, 1, "Test holiday"),
2222: (FEB, 2, "Test holiday"),
Expand Down Expand Up @@ -931,6 +933,7 @@ def test_market(self):

class TestSubstitutedHolidays(unittest.TestCase):
class SubstitutedHolidays(HolidayBase):
_has_substituted = True
country = "HB"
substituted_holidays = {
1991: (
Expand Down

0 comments on commit 59dde59

Please sign in to comment.