Skip to content

Commit

Permalink
fix edge cases & windows tests
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Aug 8, 2023
1 parent fff3bf5 commit 84546d6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
28 changes: 22 additions & 6 deletions tests/tests_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from tqdm.utils import envwrap
from pytest import mark

from tqdm.utils import IS_WIN, envwrap


def test_envwrap(monkeypatch):
Expand All @@ -7,6 +9,7 @@ def test_envwrap(monkeypatch):
env_c = 1337
monkeypatch.setenv('FUNC_A', str(env_a))
monkeypatch.setenv('FUNC_TyPe_HiNt', str(env_c))
monkeypatch.setenv('FUNC_Unused', "x")

@envwrap("FUNC_")
def func(a=1, b=2, type_hint: int = None):
Expand All @@ -15,11 +18,24 @@ def func(a=1, b=2, type_hint: int = None):
assert (env_a, 2, 1337) == func(), "expected env override"
assert (99, 2, 1337) == func(a=99), "expected manual override"

env_liTeral = 3.14159
monkeypatch.setenv('FUNC_liTeral', str(env_liTeral))
env_literal = 3.14159
monkeypatch.setenv('FUNC_literal', str(env_literal))

@envwrap("FUNC_", literal_eval=True)
def another_func(literal="some_string"):
return literal

assert env_literal == another_func()


@mark.skipif(IS_WIN, reason="no lowercase environ on Windows")
def test_envwrap_case(monkeypatch):
"""Test envwrap case-sensitive overrides"""
env_liTeRaL = 3.14159
monkeypatch.setenv('FUNC_liTeRaL', str(env_liTeRaL))

@envwrap("FUNC_", literal_eval=True, case_sensitive=True)
def another_func(liTeral=1):
return liTeral
def func(liTeRaL="some_string"):
return liTeRaL

assert env_liTeral == another_func()
assert env_liTeRaL == func()
22 changes: 11 additions & 11 deletions tqdm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ def test(a=1, b=2, c=3):
```
"""
i = len(prefix)
overrides = {k[i:] if case_sensitive else k[i:].lower(): v
for k, v in os.environ.items() if k.startswith(prefix)}
env_overrides = {k[i:] if case_sensitive else k[i:].lower(): v
for k, v in os.environ.items() if k.startswith(prefix)}
part = partialmethod if is_method else partial

def wrap(func):
params = signature(func).parameters
overrides = {k: v for k, v in env_overrides.items() if k in params}
if literal_eval:
return part(func, **{k: safe_eval(v) for k, v in overrides.items() if k in params})
return part(func, **{k: safe_eval(v) for k, v in overrides.items()})
# use `func` signature to infer env override `type` (fallback to `str`)
for k in overrides:
param = params.get(k, None)
if param is not None:
if param.annotation is not param.empty:
typ = param.annotation
# TODO: parse type in {Union, Any, Optional, ...}
else:
typ = str if param.default is None else type(param.default)
overrides[k] = typ(overrides[k])
param = params[k]
if param.annotation is not param.empty:
typ = param.annotation
# TODO: parse type in {Union, Any, Optional, ...}
else:
typ = str if param.default is None else type(param.default)
overrides[k] = typ(overrides[k])
return part(func, **overrides)
return wrap

Expand Down

0 comments on commit 84546d6

Please sign in to comment.