Skip to content

Commit

Permalink
utils.envwrap: fix types
Browse files Browse the repository at this point in the history
- fixes #966
  • Loading branch information
casperdcl committed Aug 9, 2023
1 parent 55dc0fb commit af44eeb
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .meta/.readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ Documentation
class tqdm():
"""{DOC_tqdm}"""
@envwrap("TQDM_", is_method=True) # override defaults via env vars
@envwrap("TQDM_") # override defaults via env vars
def __init__(self, iterable=None, desc=None, total=None, leave=True,
file=None, ncols=None, mininterval=0.1,
maxinterval=10.0, miniters=None, ascii=None, disable=False,
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ Documentation
progressbar every time a value is requested.
"""
@envwrap("TQDM_", is_method=True) # override defaults via env vars
@envwrap("TQDM_") # override defaults via env vars
def __init__(self, iterable=None, desc=None, total=None, leave=True,
file=None, ncols=None, mininterval=0.1,
maxinterval=10.0, miniters=None, ascii=None, disable=False,
Expand Down
57 changes: 33 additions & 24 deletions tests/tests_utils.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
from pytest import mark
from ast import literal_eval
from collections import defaultdict

from tqdm.utils import IS_WIN, envwrap
from tqdm.utils import envwrap


def test_envwrap(monkeypatch):
"""Test envwrap overrides"""
env_a = 42
env_c = 1337
monkeypatch.setenv('FUNC_A', str(env_a))
monkeypatch.setenv('FUNC_TyPe_HiNt', str(env_c))
"""Test @envwrap (basic)"""
monkeypatch.setenv('FUNC_A', "42")
monkeypatch.setenv('FUNC_TyPe_HiNt', "1337")
monkeypatch.setenv('FUNC_Unused', "x")

@envwrap("FUNC_")
def func(a=1, b=2, type_hint: int = None):
return a, b, type_hint

assert (env_a, 2, 1337) == func(), "expected env override"
assert (99, 2, 1337) == func(a=99), "expected manual override"
assert (42, 2, 1337) == func()
assert (99, 2, 1337) == func(a=99)

env_literal = 3.14159
monkeypatch.setenv('FUNC_literal', str(env_literal))

@envwrap("FUNC_", literal_eval=True)
def another_func(literal="some_string"):
return literal
def test_envwrap_types(monkeypatch):
"""Test @envwrap(types)"""
monkeypatch.setenv('FUNC_notype', "3.14159")

assert env_literal == another_func()
@envwrap("FUNC_", types=defaultdict(lambda: literal_eval))
def func(notype=None):
return notype

assert 3.14159 == func()

@mark.skipif(IS_WIN, reason="no lowercase environ on Windows")
def test_envwrap_case(monkeypatch):
"""Test envwrap case-sensitive overrides"""
env_liTeRaL = 3.14159
monkeypatch.setenv('FUNC_liTeRaL', str(env_liTeRaL))
monkeypatch.setenv('FUNC_number', "1")
monkeypatch.setenv('string', "1")

@envwrap("FUNC_", literal_eval=True, case_sensitive=True)
def func(liTeRaL="some_string"):
return liTeRaL
@envwrap("FUNC_", types={'number': int})
def nofallback(number=None, string=None):
return number, string

assert env_liTeRaL == func()
assert 1, "1" == nofallback()


def test_envwrap_annotations(monkeypatch):
"""Test @envwrap with typehints"""
monkeypatch.setenv('FUNC_number', "1.1")
monkeypatch.setenv('FUNC_string', "1.1")

@envwrap("FUNC_")
def annotated(number: int | float = None, string: int = None):
return number, string

assert 1.1, "1.1" == annotated()
6 changes: 4 additions & 2 deletions tqdm/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,13 +949,15 @@ def wrapper(*args, **kwargs):
elif _Rolling_and_Expanding is not None:
_Rolling_and_Expanding.progress_apply = inner_generator()

@envwrap("TQDM_", is_method=True) # override defaults via env vars
# override defaults via env vars
@envwrap("TQDM_", is_method=True, types={'total': float, 'ncols': int, 'miniters': float,
'position': int, 'nrows': int})
def __init__(self, iterable=None, desc=None, total=None, leave=True, file=None,
ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None,
ascii=None, disable=False, unit='it', unit_scale=False,
dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0,
position=None, postfix=None, unit_divisor=1000, write_bytes=False,
lock_args=None, nrows=None, colour=None, delay=0, gui=False,
lock_args=None, nrows=None, colour=None, delay=0.0, gui=False,
**kwargs):
"""see tqdm.tqdm for arguments"""
if file is None:
Expand Down
44 changes: 27 additions & 17 deletions tqdm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import re
import sys
from ast import literal_eval as safe_eval
from functools import partial, partialmethod, wraps
from inspect import signature
# TODO consider using wcswidth third-party package for 0-width characters
Expand Down Expand Up @@ -32,9 +31,12 @@
colorama.init()


def envwrap(prefix, case_sensitive=False, literal_eval=False, is_method=False):
def envwrap(prefix, types=None, is_method=False):
"""
Override parameter defaults via `os.environ[prefix + param_name]`.
Maps UPPER_CASE env vars map to lower_case param names.
camelCase isn't supported (because Windows ignores case).
Precedence (highest first):
- call (`foo(a=3)`)
- environ (`FOO_A=2`)
Expand All @@ -44,11 +46,10 @@ def envwrap(prefix, case_sensitive=False, literal_eval=False, is_method=False):
----------
prefix : str
Env var prefix, e.g. "FOO_"
case_sensitive : bool, optional
If (default: False), treat env var "FOO_Some_ARG" as "FOO_some_arg".
literal_eval : bool, optional
Whether to `ast.literal_eval` the detected env var overrides.
Otherwise if (default: False), infer types from function signature.
types : dict, optional
Fallback mappings `{'param_name': type, ...}` if types cannot be
inferred from function signature.
Consider using `types=collections.defaultdict(lambda: ast.literal_eval)`.
is_method : bool, optional
Whether to use `functools.partialmethod`. If (default: False) use `functools.partial`.
Expand All @@ -65,25 +66,34 @@ def test(a=1, b=2, c=3):
received: a=42, b=2, c=99
```
"""
if types is None:
types = {}
i = len(prefix)
env_overrides = {k[i:] if case_sensitive else k[i:].lower(): v
for k, v in os.environ.items() if k.startswith(prefix)}
env_overrides = {k[i:].lower(): v for k, v in os.environ.items() if k.startswith(prefix)}
part = partialmethod if is_method else partial

def wrap(func):
params = signature(func).parameters
# ignore unknown env vars
overrides = {k: v for k, v in env_overrides.items() if k in params}
if literal_eval:
return part(func, **{k: safe_eval(v) for k, v in overrides.items()})
# use `func` signature to infer env override `type` (fallback to `str`)
# infer overrides' `type`s
for k in overrides:
param = params[k]
if param.annotation is not param.empty:
typ = param.annotation
# TODO: parse type in {Union, Any, Optional, ...}
if param.annotation is not param.empty: # typehints
for typ in getattr(param.annotation, '__args__', (param.annotation,)):
try:
overrides[k] = typ(overrides[k])
except Exception:
pass
else:
break
elif param.default is not None: # type of default value
overrides[k] = type(param.default)(overrides[k])
else:
typ = str if param.default is None else type(param.default)
overrides[k] = typ(overrides[k])
try: # `types` fallback
overrides[k] = types[k](overrides[k])
except KeyError: # keep unconverted (`str`)
pass
return part(func, **overrides)
return wrap

Expand Down

0 comments on commit af44eeb

Please sign in to comment.