Skip to content

Commit

Permalink
Update snapshots generator (#1530)
Browse files Browse the repository at this point in the history
  • Loading branch information
KJhellico committed Oct 26, 2023
1 parent 1ba2709 commit 68867a9
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions scripts/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# flake8: noqa: F402

import argparse
import json
import sys
import warnings
Expand All @@ -29,6 +30,30 @@ class SnapshotGenerator:

years = range(1950, 2051)

def __init__(self) -> None:
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"-c",
"--country",
action="extend",
nargs="+",
default=[],
help="Country codes to use for snapshot generation",
required=False,
type=str,
)
arg_parser.add_argument(
"-m",
"--market",
action="extend",
nargs="+",
default=[],
help="Market codes to use for snapshot generation",
required=False,
type=str,
)
self.args = arg_parser.parse_args()

@staticmethod
def save(snapshot, file_path):
with open(file_path, "w") as output:
Expand All @@ -48,7 +73,19 @@ def update_snapshot(snapshot, data):

def generate_country_snapshots(self):
"""Generates country snapshots."""
for country_code in list_supported_countries():
if len(self.args.market) > 0:
return None

country_list = self.args.country
supported_countries = list_supported_countries()
if country_list:
unknown_countries = set(country_list).difference(supported_countries.keys())
if len(unknown_countries) > 0:
raise ValueError(f"Countries {', '.join(unknown_countries)} not available")
else:
country_list = supported_countries

for country_code in country_list:
country = getattr(holidays, country_code)
snapshot = {}

Expand All @@ -67,7 +104,19 @@ def generate_country_snapshots(self):

def generate_financial_snapshots(self):
"""Generates financial snapshots."""
for market_code in list_supported_financial():
if len(self.args.country) > 0:
return None

market_list = self.args.market
supported_markets = list_supported_financial()
if market_list:
unknown_markets = set(market_list).difference(supported_markets.keys())
if len(unknown_markets) > 0:
raise ValueError(f"Markets {', '.join(unknown_markets)} not available")
else:
market_list = supported_markets

for market_code in market_list:
self.save(
holidays.country_holidays(
market_code,
Expand Down

0 comments on commit 68867a9

Please sign in to comment.