diff --git a/src/config/__init__.py b/src/config/__init__.py index 7baa55a..c493d61 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -4,6 +4,7 @@ import os import sys from importlib.abc import InspectLoader +from pathlib import Path from types import ModuleType from typing import Any, Dict, Iterable, List, Mapping, Optional, TextIO, Union, cast @@ -349,7 +350,7 @@ class FileConfiguration(Configuration): def __init__( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, lowercase_keys: bool = False, @@ -370,13 +371,15 @@ def __init__( interpolate=interpolate, interpolate_type=interpolate_type, ) - self._filename = data if read_from_file and isinstance(data, str) else None + self._filename = ( + data if read_from_file and isinstance(data, (str, Path)) else None + ) self._ignore_missing_paths = ignore_missing_paths self._reload_with_check(data, read_from_file) def _reload_with_check( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, ) -> None: # pragma: no cover try: @@ -388,7 +391,7 @@ def _reload_with_check( def _reload( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, ) -> None: # pragma: no cover raise NotImplementedError() @@ -404,12 +407,12 @@ class JSONConfiguration(FileConfiguration): def _reload( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, ) -> None: """Reload the JSON data.""" if read_from_file: - if isinstance(data, str): + if isinstance(data, (str, Path)): with open(data, "rt") as f: result = json.load(f) else: @@ -420,7 +423,7 @@ def _reload( def config_from_json( - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, lowercase_keys: bool = False, @@ -456,7 +459,7 @@ class INIConfiguration(FileConfiguration): def __init__( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, section_prefix: str = "", @@ -476,7 +479,11 @@ def __init__( ignore_missing_paths=ignore_missing_paths, ) - def _reload(self, data: Union[str, TextIO], read_from_file: bool = False) -> None: + def _reload( + self, + data: Union[str, Path, TextIO], + read_from_file: bool = False, + ) -> None: """Reload the INI data.""" import configparser @@ -487,7 +494,7 @@ def optionxform(self, optionstr: str) -> str: return super().optionxform(optionstr) if lowercase else optionstr if read_from_file: - if isinstance(data, str): + if isinstance(data, (str, Path)): with open(data, "rt") as f: data = f.read() else: @@ -505,7 +512,7 @@ def optionxform(self, optionstr: str) -> str: def config_from_ini( - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, section_prefix: str = "", @@ -543,7 +550,7 @@ class DotEnvConfiguration(FileConfiguration): def __init__( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, prefix: str = "", separator: str = "__", @@ -567,12 +574,12 @@ def __init__( def _reload( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, ) -> None: """Reload the .env data.""" if read_from_file: - if isinstance(data, str): + if isinstance(data, (str, Path)): with open(data, "rt") as f: data = f.read() else: @@ -594,7 +601,7 @@ def _reload( def config_from_dotenv( - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, prefix: str = "", separator: str = "__", @@ -634,7 +641,7 @@ class PythonConfiguration(Configuration): def __init__( self, - module: Union[str, ModuleType], + module: Union[str, Path, ModuleType], prefix: str = "", separator: str = "_", *, @@ -651,7 +658,8 @@ def __init__( lowercase_keys: whether to convert every key to lower case. """ try: - if isinstance(module, str): + if isinstance(module, (str, Path)): + module = str(module) if module.endswith(".py"): import importlib.util from importlib import machinery @@ -708,7 +716,7 @@ def reload(self) -> None: def config_from_python( - module: Union[str, ModuleType], + module: Union[str, Path, ModuleType], prefix: str = "", separator: str = "_", *, @@ -796,7 +804,7 @@ class YAMLConfiguration(FileConfiguration): def __init__( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, lowercase_keys: bool = False, @@ -818,9 +826,13 @@ def __init__( ignore_missing_paths=ignore_missing_paths, ) - def _reload(self, data: Union[str, TextIO], read_from_file: bool = False) -> None: + def _reload( + self, + data: Union[str, Path, TextIO], + read_from_file: bool = False, + ) -> None: """Reload the YAML data.""" - if read_from_file and isinstance(data, str): + if read_from_file and isinstance(data, (str, Path)): with open(data, "rt") as f: loaded = yaml.load(f, Loader=yaml.FullLoader) else: @@ -831,7 +843,7 @@ def _reload(self, data: Union[str, TextIO], read_from_file: bool = False) -> Non def config_from_yaml( - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, lowercase_keys: bool = False, @@ -866,7 +878,7 @@ class TOMLConfiguration(FileConfiguration): def __init__( self, - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, section_prefix: str = "", @@ -891,10 +903,14 @@ def __init__( ignore_missing_paths=ignore_missing_paths, ) - def _reload(self, data: Union[str, TextIO], read_from_file: bool = False) -> None: + def _reload( + self, + data: Union[str, Path, TextIO], + read_from_file: bool = False, + ) -> None: """Reload the TOML data.""" if read_from_file: - if isinstance(data, str): + if isinstance(data, (str, Path)): with open(data, "rb") as f: loaded = toml.load(f) else: @@ -914,7 +930,7 @@ def _reload(self, data: Union[str, TextIO], read_from_file: bool = False) -> Non def config_from_toml( - data: Union[str, TextIO], + data: Union[str, Path, TextIO], read_from_file: bool = False, *, section_prefix: str = "", diff --git a/tests/test_json.py b/tests/test_json.py index decd945..08432c5 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,7 +1,7 @@ from config import config_from_dict, config_from_json import tempfile import json - +from pathlib import Path DICT = { "a1.b1.c1": 1, @@ -49,6 +49,17 @@ def test_load_json_filename(): # type: ignore assert cfg == config_from_dict(DICT) +def test_load_json_filename_2(): # type: ignore + with tempfile.NamedTemporaryFile() as f: + f.file.write(JSON.encode()) + f.file.flush() + cfg = config_from_json(Path(f.name), read_from_file=True) + assert cfg["a1.b1.c1"] == 1 + assert cfg["a1.b1"].as_dict() == {"c1": 1, "c2": 2, "c3": 3} + assert cfg["a1.b2"].as_dict() == {"c1": "a", "c2": True, "c3": 1.1} + assert cfg == config_from_dict(DICT) + + def test_equality(): # type: ignore cfg = config_from_json(JSON) assert cfg == config_from_dict(DICT) diff --git a/tests/test_toml.py b/tests/test_toml.py index 44839cd..665587f 100644 --- a/tests/test_toml.py +++ b/tests/test_toml.py @@ -1,3 +1,4 @@ +from pathlib import Path import tempfile import pytest @@ -97,6 +98,18 @@ def test_load_toml_filename(): # type: ignore assert cfg == config_from_dict(DICT) +@pytest.mark.skipif("toml is None") +def test_load_toml_filename_2(): # type: ignore + with tempfile.NamedTemporaryFile() as f: + f.file.write(TOML.encode()) + f.file.flush() + cfg = config_from_toml(Path(f.name), read_from_file=True) + assert cfg["a1.b1.c1"] == 1 + assert cfg["a1.b1"].as_dict() == {"c1": 1, "c2": 2, "c3": 3} + assert cfg["a1.b2"].as_dict() == {"c1": "a", "c2": True, "c3": 1.1} + assert cfg == config_from_dict(DICT) + + @pytest.mark.skipif("toml is None") def test_equality(): # type: ignore cfg = config_from_toml(TOML) diff --git a/tests/test_yaml.py b/tests/test_yaml.py index 0642f48..f0a782f 100644 --- a/tests/test_yaml.py +++ b/tests/test_yaml.py @@ -1,6 +1,7 @@ import pytest from config import config_from_dict from pytest import raises +from pathlib import Path import tempfile try: @@ -107,6 +108,18 @@ def test_load_yaml_filename(): # type: ignore assert cfg == config_from_dict(DICT) +@pytest.mark.skipif("yaml is None") +def test_load_yaml_filename_2(): # type: ignore + with tempfile.NamedTemporaryFile() as f: + f.file.write(YAML.encode()) + f.file.flush() + cfg = config_from_yaml(Path(f.name), read_from_file=True) + assert cfg["a1.b1.c1"] == 1 + assert cfg["a1.b1"].as_dict() == {"c1": 1, "c2": 2, "c3": 3} + assert cfg["a1.b2"].as_dict() == {"c1": "a", "c2": True, "c3": 1.1} + assert cfg == config_from_dict(DICT) + + @pytest.mark.skipif("yaml is None") def test_equality(): # type: ignore cfg = config_from_yaml(YAML)