diff --git a/src/ploomber/config.py b/src/ploomber/config.py new file mode 100644 index 000000000..bd87ffbb6 --- /dev/null +++ b/src/ploomber/config.py @@ -0,0 +1,59 @@ +import abc +from collections.abc import Mapping + +import yaml + + +class Config(abc.ABC): + def __init__(self): + # resolve home directory + path = self.path() + + if not path.exists(): + defaults = self._get_data() + path.write_text(yaml.dump(defaults)) + self._set_data(defaults) + else: + text = path.read_text() + + try: + content = yaml.safe_load(text) + loaded = True + except Exception: + # NOTE: show warning? + loaded = False + content = self._get_data() + + if loaded and not isinstance(content, Mapping): + # NOTE: show warning? + content = self._get_data() + + self._set_data(content) + + # TODO: delete, only here for compatibility + def read(self): + return self._get_data() + + def _get_data(self): + return {key: getattr(self, key) for key in self.__annotations__} + + def _set_data(self, data): + for key in self.__annotations__: + if key in data: + setattr(self, key, data[key]) + + def _write(self): + path = self.path() + data = self._get_data() + path.write_text(yaml.dump(data)) + + def __setattr__(self, name, value): + if name not in self.__annotations__: + raise ValueError(f'{name} not a valid field') + else: + super().__setattr__(name, value) + self._write() + + @abc.abstractclassmethod + def path(cls): + pass diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..735cc1fc2 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,80 @@ +from pathlib import Path + +import pytest +import yaml + +from ploomber.config import Config + + +class MyConfig(Config): + number: int = 42 + string: str = 'value' + + def path(self): + return Path('myconfig.yaml') + + +def test_stores_defaults(tmp_directory): + MyConfig() + + content = yaml.safe_load(Path('myconfig.yaml').read_text()) + + assert content == {'number': 42, 'string': 'value'} + + +def test_reads_existing(tmp_directory): + path = Path('myconfig.yaml') + path.write_text(yaml.dump({'number': 100})) + + cfg = MyConfig() + + assert cfg.number == 100 + assert cfg.string == 'value' + + +def test_ignores_extra(tmp_directory): + path = Path('myconfig.yaml') + path.write_text( + yaml.dump({ + 'number': 200, + 'string': 'another', + 'unknown': 2000, + })) + + cfg = MyConfig() + + assert cfg.number == 200 + assert cfg.string == 'another' + + +def test_stores_on_update(tmp_directory): + cfg = MyConfig() + + cfg.number = 500 + + content = yaml.safe_load(Path('myconfig.yaml').read_text()) + + assert content == {'number': 500, 'string': 'value'} + + +def test_uses_default_if_missing(tmp_directory): + path = Path('myconfig.yaml') + path.write_text(yaml.dump({'number': 100})) + + cfg = MyConfig() + + assert cfg.number == 100 + assert cfg.string == 'value' + + +@pytest.mark.parametrize('content', ['not yaml', '[[]']) +def test_uses_defaults_if_corrupted(tmp_directory, content): + path = Path('myconfig.yaml') + path.write_text(content) + + cfg = MyConfig() + content = yaml.safe_load(Path('myconfig.yaml').read_text()) + + assert cfg.number == 42 + assert cfg.string == 'value' + assert content == {'number': 42, 'string': 'value'}