diff --git a/scripts/generate_snapshots.py b/scripts/generate_snapshots.py index a8d0d75ce..e3da24873 100755 --- a/scripts/generate_snapshots.py +++ b/scripts/generate_snapshots.py @@ -13,6 +13,7 @@ # flake8: noqa: F402 +import argparse import json import sys import warnings @@ -29,6 +30,30 @@ class SnapshotGenerator: years = range(1950, 2051) + def __init__(self) -> None: + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + "-c", + "--country", + action="extend", + nargs="+", + default=[], + help="Country codes to use for snapshot generation", + required=False, + type=str, + ) + arg_parser.add_argument( + "-m", + "--market", + action="extend", + nargs="+", + default=[], + help="Market codes to use for snapshot generation", + required=False, + type=str, + ) + self.args = arg_parser.parse_args() + @staticmethod def save(snapshot, file_path): with open(file_path, "w") as output: @@ -48,7 +73,19 @@ def update_snapshot(snapshot, data): def generate_country_snapshots(self): """Generates country snapshots.""" - for country_code in list_supported_countries(): + if len(self.args.market) > 0: + return None + + country_list = self.args.country + supported_countries = list_supported_countries() + if country_list: + unknown_countries = set(country_list).difference(supported_countries.keys()) + if len(unknown_countries) > 0: + raise ValueError(f"Countries {', '.join(unknown_countries)} not available") + else: + country_list = supported_countries + + for country_code in country_list: country = getattr(holidays, country_code) snapshot = {} @@ -67,7 +104,19 @@ def generate_country_snapshots(self): def generate_financial_snapshots(self): """Generates financial snapshots.""" - for market_code in list_supported_financial(): + if len(self.args.country) > 0: + return None + + market_list = self.args.market + supported_markets = list_supported_financial() + if market_list: + unknown_markets = set(market_list).difference(supported_markets.keys()) + if len(unknown_markets) > 0: + raise ValueError(f"Markets {', '.join(unknown_markets)} not available") + else: + market_list = supported_markets + + for market_code in market_list: self.save( holidays.country_holidays( market_code,