Skip to content

Commit

Permalink
Merge pull request #548 from justinpark715/eval-node-collection
Browse files Browse the repository at this point in the history
typecheck: Adding helper function to better handle annotations of collection types
  • Loading branch information
david-yz-liu committed Aug 20, 2018
2 parents 623bd9a + b125a23 commit b96537c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
29 changes: 17 additions & 12 deletions python_ta/typecheck/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,20 +955,11 @@ def _node_to_type(node: NodeNG, locals: Dict[str, type] = None) -> type:
if node is None:
return Any
elif isinstance(node, str):
try:
return eval(node, globals(), locals)
except:
return ForwardRef(node)
return _eval_node(node, globals(), locals)
elif isinstance(node, astroid.Name):
try:
return eval(node.name, globals(), locals)
except:
return ForwardRef(node.name)
return _eval_node(node.name, globals(), locals)
elif isinstance(node, astroid.Attribute):
try:
return eval(node.attrname, globals(), locals)
except:
return ForwardRef(node.attrname)
return _eval_node(node.attrname, globals(), locals)
elif isinstance(node, astroid.Subscript):
v = _node_to_type(node.value)
s = _node_to_type(node.slice)
Expand All @@ -987,6 +978,20 @@ def _node_to_type(node: NodeNG, locals: Dict[str, type] = None) -> type:
return node


def _eval_node(node_name: str, _globals: Dict[str, type], _locals: Dict[str, type]):
"""Return a type represented by node_name."""
try:
eval_type = eval(node_name, _globals, _locals)
except:
eval_type = ForwardRef(node_name)

if eval_type in (list, dict, tuple, set):
# Annotation set as class type (ie. list) instead of typing generic (ie. List[Any])
return eval(f"typing.{node_name.capitalize()}", _globals, _locals)
else:
return eval_type


def _collect_tvars(type: type) -> List[type]:
if isinstance(type, TypeVar):
return [type]
Expand Down
9 changes: 6 additions & 3 deletions tests/test_type_inference/test_annassign.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import tests.custom_hypothesis_support as cs
from tests.custom_hypothesis_support import lookup_type
import hypothesis.strategies as hs
from python_ta.typecheck.base import _node_to_type, TypeFail, TypeFailAnnotationInvalid, TypeFailUnify, NoType
from typing import List, Set, Dict, Any, Tuple
from python_ta.typecheck.base import _node_to_type, TypeFail, TypeFailAnnotationInvalid, TypeFailUnify, NoType, _gorg
from typing import List, Set, Dict, Any, Tuple, _GenericAlias
from nose import SkipTest
from nose.tools import eq_
settings.load_profile("pyta")
Expand Down Expand Up @@ -42,7 +42,10 @@ def test_annassign(variables_annotations_dict):
for node in module.nodes_of_class(astroid.AnnAssign):
variable_type = lookup_type(inferer, node, node.target.name)
annotated_type = variables_annotations_dict[node.target.name]
assert variable_type == annotated_type
if isinstance(variable_type, _GenericAlias):
assert _gorg(variable_type) == annotated_type
else:
assert variable_type == annotated_type


def test_annassign_subscript_list():
Expand Down
13 changes: 10 additions & 3 deletions tests/test_type_inference/test_function_def_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import tests.custom_hypothesis_support as cs
from tests.custom_hypothesis_support import lookup_type, types_in_callable
import hypothesis.strategies as hs
from typing import Callable, ForwardRef, Type
from typing import Callable, ForwardRef, Type, _GenericAlias
from nose.tools import eq_
from python_ta.typecheck.base import _gorg
from python_ta.transforms.type_inference_visitor import TypeFail
settings.load_profile("pyta")

Expand Down Expand Up @@ -76,11 +77,17 @@ def test_functiondef_annotated_simple_return(functiondef_node):
arg_name = functiondef_node.args.args[i].name
expected_type = inferer.type_constraints.resolve(functiondef_node.type_environment.lookup_in_env(arg_name)).getValue()
# need to do by name because annotations must be name nodes.
assert expected_type.__name__ == functiondef_node.args.annotations[i].name
if isinstance(expected_type, _GenericAlias):
assert _gorg(expected_type).__name__ == functiondef_node.args.annotations[i].name
else:
assert expected_type.__name__ == functiondef_node.args.annotations[i].name
# test return type
return_node = functiondef_node.body[0].value
expected_rtype = inferer.type_constraints.resolve(functiondef_node.type_environment.lookup_in_env(return_node.name)).getValue()
assert expected_rtype.__name__ == functiondef_node.returns.name
if isinstance(expected_rtype, _GenericAlias):
assert _gorg(expected_rtype).__name__ == functiondef_node.returns.name
else:
assert expected_rtype.__name__ == functiondef_node.returns.name


def test_functiondef_method():
Expand Down

0 comments on commit b96537c

Please sign in to comment.