Skip to content

Commit

Permalink
fix read of inherited attribute with annotation (close #8)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanyakosmos committed Oct 29, 2019
1 parent 7dfed2e commit 5fc7e59
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
5 changes: 0 additions & 5 deletions Makefile

This file was deleted.

23 changes: 13 additions & 10 deletions argser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ def _get_nargs(typ, default):
return typ, None


def _get_fields(cls: Type[Args], ann: dict):
fields_with_value = cls.__dict__
def _get_fields(cls: Type[Args]):
ann = getattr(cls, '__annotations__', {})
fields_with_value = {
key: value
for key, value in cls.__dict__.items()
if not key.startswith('__') and not isinstance(value, type) # skip built-ins and inner classes
}
fields = {k: None for k in ann if k not in fields_with_value}
for key, value in fields_with_value.items():
# skip built-ins and inner classes
if key.startswith('__') or isinstance(value, type):
continue
fields[key] = value
fields.update(**fields_with_value)
# get fields from bases classes
for base in cls.__bases__:
for name, value in _get_fields(base, ann).items():
if base is object:
continue
for name, value in _get_fields(base).items():
# update without touching redefined values in inherited classes
if name not in fields:
fields[name] = value
Expand All @@ -64,7 +67,7 @@ def _get_type_and_nargs(ann: dict, field_name: str, default):


def _collect_annotations(cls: type):
ann = getattr(cls, '__annotations__', {})
ann = getattr(cls, '__annotations__', {}).copy() # don't modify class annotation
for base in cls.__bases__:
for name, typ in _collect_annotations(base).items():
# update without touching redefined values in inherited classes
Expand All @@ -83,7 +86,7 @@ def _read_args(
args = []
sub_commands = {}
ann = _collect_annotations(args_cls)
fields = _get_fields(args_cls, ann)
fields = _get_fields(args_cls)
for key, value in fields.items(): # type: str, Any
logger.log(VERBOSE, f"reading {key!r}")
if hasattr(value, SUB_COMMAND_MARK):
Expand Down
14 changes: 14 additions & 0 deletions cli.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

set -e

function test() {
pytest --cov=argser --no-cov-on-fail --cov-report html --cov-report term-missing
}

function docs() {
make -C docs clean
make -C docs html
}

$@
6 changes: 5 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class CommonArgs:
a: int
b: float
c = 'c'
g: bool = Opt(default=True, help="gg")

class Args1(CommonArgs):
a: str
Expand All @@ -451,15 +452,18 @@ class Args2(CommonArgs):
assert args.a == 1
assert args.b == 2.2
assert args.c == 'cc'
assert args.g is True

args = parse_args(Args1, '-a 1 -b 4.4 -c 5 -d 2.2')
assert args.a == '1'
assert args.b == 4.4
assert args.c == 5
assert args.d == 2.2
assert args.g is True

args = parse_args(Args2, '-a 1 -b 5.5 -e "foo bar"')
args = parse_args(Args2, '-a 1 -b 5.5 -e "foo bar" --no-g')
assert args.a == 1
assert args.b == 5.5
assert args.c == 'c'
assert args.e == 'foo bar'
assert args.g is False

0 comments on commit 5fc7e59

Please sign in to comment.