Skip to content

Commit

Permalink
Merge c036fe7 into 1927225
Browse files Browse the repository at this point in the history
  • Loading branch information
Smlep committed Dec 30, 2020
2 parents 1927225 + c036fe7 commit ecfa669
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/dotenv/__init__.py
@@ -1,5 +1,5 @@
from .compat import IS_TYPE_CHECKING
from .main import load_dotenv, get_key, set_key, unset_key, find_dotenv, dotenv_values
from .main import load_dotenv, get_key, get_bool, get_boolean_key, set_key, unset_key, find_dotenv, dotenv_values

if IS_TYPE_CHECKING:
from typing import Any, Optional
Expand Down Expand Up @@ -40,6 +40,8 @@ def get_cli_string(path=None, action=None, key=None, value=None, quote=None):
'load_dotenv',
'dotenv_values',
'get_key',
'get_bool',
'get_boolean_key',
'set_key',
'unset_key',
'find_dotenv',
Expand Down
37 changes: 34 additions & 3 deletions src/dotenv/main.py
Expand Up @@ -9,6 +9,7 @@
import tempfile
from collections import OrderedDict
from contextlib import contextmanager
from distutils.util import strtobool

from .compat import IS_TYPE_CHECKING, PY2, StringIO, to_env
from .parser import Binding, parse_stream
Expand Down Expand Up @@ -100,23 +101,32 @@ def set_as_environment_variables(self):

return True

def get(self, key):
# type: (Text) -> Optional[Text]
def get(self, key, boolean=False):
# type: (Text, Optional[bool]) -> Union[Text, bool, None]
"""
"""
data = self.dict()

if key in data:
if boolean:
return strtobool(str(data[key]))
return data[key]

if self.verbose:
logger.warning("Key %s not found in %s.", key, self.dotenv_path)

return None

def get_as_boolean(self, key):
# type: (Text) -> Optional[bool]
"""
Wrapper around get with boolean=True
"""
return bool(self.get(key, True))


def get_key(dotenv_path, key_to_get):
# type: (Union[Text, _PathLike], Text) -> Optional[Text]
# type: (Union[Text, _PathLike], Text) -> Union[Text, bool, None]
"""
Gets the value of a given key from the given .env
Expand All @@ -125,6 +135,16 @@ def get_key(dotenv_path, key_to_get):
return DotEnv(dotenv_path, verbose=True).get(key_to_get)


def get_boolean_key(dotenv_path, key_to_get):
# type: (Union[Text, _PathLike], Text) -> Optional[bool]
"""
Gets the value of a given key from the given .env
If the .env path given doesn't exist, fails
"""
return DotEnv(dotenv_path, verbose=True).get_as_boolean(key_to_get)


@contextmanager
def rewrite(path):
# type: (_PathLike) -> Iterator[Tuple[IO[Text], IO[Text]]]
Expand Down Expand Up @@ -312,3 +332,14 @@ def dotenv_values(dotenv_path=None, stream=None, verbose=False, interpolate=True
# type: (Union[Text, _PathLike, None], Optional[_StringIO], bool, bool, Union[None, Text]) -> Dict[Text, Optional[Text]] # noqa: E501
f = dotenv_path or stream or find_dotenv()
return DotEnv(f, verbose=verbose, interpolate=interpolate, override=True, **kwargs).dict()


def get_bool(key, default=None):
# type: (Text, Union[Text, bool, None]) -> bool
value = os.getenv(key, default)

if isinstance(value, (bool, int)):
# happens if default was a boolean
return bool(value)

return bool(strtobool(str(value)))
95 changes: 95 additions & 0 deletions tests/test_main.py
Expand Up @@ -121,6 +121,42 @@ def test_get_key_none(dotenv_file):
mock_warning.assert_not_called()


@pytest.mark.parametrize("true_value", ["Yes", "yes", "True", "true", "on", "1", "y"])
def test_get_boolean_key_ok(dotenv_file, true_value):
logger = logging.getLogger("dotenv.main")
with open(dotenv_file, "w") as f:
f.write("foo={}".format(true_value))

with mock.patch.object(logger, "warning") as mock_warning:
result = dotenv.get_boolean_key(dotenv_file, "foo")

assert result
mock_warning.assert_not_called()


@pytest.mark.parametrize("false_value", ["n", "no", "f", "False", "false", "off", "0"])
def test_get_boolean_key_ok_false(dotenv_file, false_value):
logger = logging.getLogger("dotenv.main")
with open(dotenv_file, "w") as f:
f.write("foo={}".format(false_value))

with mock.patch.object(logger, "warning") as mock_warning:
result = dotenv.get_boolean_key(dotenv_file, "foo")

assert not result
mock_warning.assert_not_called()


def test_get_boolean_key_not_found(dotenv_file):
logger = logging.getLogger("dotenv.main")

with mock.patch.object(logger, "warning") as mock_warning:
result = dotenv.get_boolean_key(dotenv_file, "foo")

assert not result
mock_warning.assert_called_once_with("Key %s not found in %s.", "foo", dotenv_file)


def test_unset_with_value(dotenv_file):
logger = logging.getLogger("dotenv.main")
with open(dotenv_file, "w") as f:
Expand Down Expand Up @@ -367,3 +403,62 @@ def test_dotenv_values_stream(env, string, interpolate, expected):
result = dotenv.dotenv_values(stream=stream, interpolate=interpolate)

assert result == expected


@pytest.mark.parametrize("true_value", ["Yes", "yes", "True", "true", "on", "1", "y"])
def test_get_bool(dotenv_file, true_value):
logger = logging.getLogger("dotenv.main")
with open(dotenv_file, "w") as f:
f.write("foo={}".format(true_value))

dotenv.load_dotenv(dotenv_file)

with mock.patch.object(logger, "warning") as mock_warning:
value = dotenv.get_bool("foo")

assert type(value) == bool
assert value

mock_warning.assert_not_called()


@pytest.mark.parametrize("false_value", ["n", "no", "f", "False", "false", "off", "0"])
def test_get_bool_false(dotenv_file, false_value):
logger = logging.getLogger("dotenv.main")
with open(dotenv_file, "w") as f:
f.write("foo={}".format(false_value))

dotenv.load_dotenv(dotenv_file, override=True)

with mock.patch.object(logger, "warning") as mock_warning:
value = dotenv.get_bool("foo")

assert type(value) == bool
assert not value

mock_warning.assert_not_called()


@pytest.mark.parametrize("default,expected_value", [(True, True), (1, True), ("true", True), ("f", False)])
def test_get_bool_default(dotenv_file, default, expected_value):
logger = logging.getLogger("dotenv.main")
dotenv.load_dotenv(dotenv_file)

with mock.patch.object(logger, "warning") as mock_warning:
value = dotenv.get_bool("bar", default)

assert type(value) == bool
assert bool(value) == expected_value

mock_warning.assert_not_called()


@pytest.mark.parametrize("value", ["b", "trrue", "fals", "foo", "bar"])
def test_get_bool_invalid(dotenv_file, value):
with open(dotenv_file, "w") as f:
f.write("foo={}".format(value))

dotenv.load_dotenv(dotenv_file, override=True)

with pytest.raises(ValueError):
value = dotenv.get_bool("foo")

0 comments on commit ecfa669

Please sign in to comment.