Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SequenceValue for heterogeneous sequences #515

Merged
merged 4 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 142 additions & 19 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
ParamSpecKwargsValue,
ParameterTypeGuardExtension,
SelfTVV,
SequenceValue,
TypeGuardExtension,
TypedValue,
SequenceIncompleteValue,
Expand Down Expand Up @@ -344,14 +345,26 @@ def value_from_ast(
return val


def _type_from_ast(node: ast.AST, ctx: Context, is_typeddict: bool = False) -> Value:
def _type_from_ast(
node: ast.AST,
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
) -> Value:
val = value_from_ast(node, ctx)
return _type_from_value(val, ctx, is_typeddict=is_typeddict)
return _type_from_value(
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
)


def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Value:
def _type_from_runtime(
val: Any, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False
) -> Value:
if isinstance(val, str):
return _eval_forward_ref(val, ctx, is_typeddict=is_typeddict)
return _eval_forward_ref(
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
)
elif isinstance(val, tuple):
# This happens under some Python versions for types
# nested in tuples, e.g. on 3.6:
Expand All @@ -365,13 +378,17 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
args = (val[1],)
else:
args = val[1:]
return _value_of_origin_args(origin, args, val, ctx)
return _value_of_origin_args(
origin, args, val, ctx, unpack_allowed=unpack_allowed
)
elif GenericAlias is not None and isinstance(val, GenericAlias):
origin = get_origin(val)
args = get_args(val)
if origin is tuple and not args:
return SequenceIncompleteValue(tuple, [])
return _value_of_origin_args(origin, args, val, ctx)
return _value_of_origin_args(
origin, args, val, ctx, unpack_allowed=origin is tuple
)
elif typing_inspect.is_literal_type(val):
args = typing_inspect.get_args(val)
if len(args) == 0:
Expand All @@ -393,7 +410,17 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
elif len(args) == 1 and args[0] == ():
return SequenceIncompleteValue(tuple, []) # empty tuple
else:
args_vals = [_type_from_runtime(arg, ctx) for arg in args]
args_vals = [
_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args
]
if any(isinstance(val, UnpackedValue) for val in args_vals):
members = []
for val in args_vals:
if isinstance(val, UnpackedValue):
members += val.elements
else:
members.append((False, val))
return SequenceValue(tuple, members)
return SequenceIncompleteValue(tuple, args_vals)
elif is_instance_of_typing_name(val, "_TypedDictMeta"):
required_keys = getattr(val, "__required_keys__", None)
Expand Down Expand Up @@ -434,7 +461,14 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
args = typing_inspect.get_args(val)
if getattr(val, "_special", False):
args = [] # distinguish List from List[T] on 3.7 and 3.8
return _value_of_origin_args(origin, args, val, ctx, is_typeddict=is_typeddict)
return _value_of_origin_args(
origin,
args,
val,
ctx,
is_typeddict=is_typeddict,
unpack_allowed=unpack_allowed or origin is tuple or origin is Tuple,
)
elif typing_inspect.is_callable_type(val):
args = typing_inspect.get_args(val)
return _value_of_origin_args(Callable, args, val, ctx)
Expand Down Expand Up @@ -535,6 +569,13 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
cls = "Required" if required else "NotRequired"
ctx.show_error(f"{cls}[] used in unsupported context")
return AnyValue(AnySource.error)
# Also 3.6 only.
elif is_instance_of_typing_name(val, "_Unpack"):
if unpack_allowed:
return _make_unpacked_value(_type_from_runtime(val.__type__, ctx), ctx)
else:
ctx.show_error("Unpack[] used in unsupported context")
return AnyValue(AnySource.error)
elif is_typing_name(val, "TypeAlias"):
return AnyValue(AnySource.incomplete_annotation)
elif is_typing_name(val, "TypedDict"):
Expand Down Expand Up @@ -638,28 +679,51 @@ def _get_typeddict_value(
return required, val


def _eval_forward_ref(val: str, ctx: Context, is_typeddict: bool = False) -> Value:
def _eval_forward_ref(
val: str, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False
) -> Value:
try:
tree = ast.parse(val, mode="eval")
except SyntaxError:
ctx.show_error(f"Syntax error in type annotation: {val}")
return AnyValue(AnySource.error)
else:
return _type_from_ast(tree.body, ctx, is_typeddict=is_typeddict)
return _type_from_ast(
tree.body, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
)


def _type_from_value(value: Value, ctx: Context, is_typeddict: bool = False) -> Value:
def _type_from_value(
value: Value,
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
) -> Value:
if isinstance(value, KnownValue):
return _type_from_runtime(value.val, ctx, is_typeddict=is_typeddict)
return _type_from_runtime(
value.val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
)
elif isinstance(value, TypeVarValue):
return value
elif isinstance(value, MultiValuedValue):
return unite_values(*[_type_from_value(val, ctx) for val in value.vals])
return unite_values(
*[
_type_from_value(
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
)
for val in value.vals
]
)
elif isinstance(value, AnnotatedValue):
return _type_from_value(value.value, ctx)
elif isinstance(value, _SubscriptedValue):
return _type_from_subscripted_value(
value.root, value.members, ctx, is_typeddict=is_typeddict
value.root,
value.members,
ctx,
is_typeddict=is_typeddict,
unpack_allowed=unpack_allowed,
)
elif isinstance(value, AnyValue):
return value
Expand All @@ -677,7 +741,9 @@ def _type_from_subscripted_value(
root: Optional[Value],
members: Sequence[Value],
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
) -> Value:
if isinstance(root, GenericValue):
if len(root.args) == len(members):
Expand All @@ -690,7 +756,13 @@ def _type_from_subscripted_value(
elif isinstance(root, MultiValuedValue):
return unite_values(
*[
_type_from_subscripted_value(subval, members, ctx, is_typeddict)
_type_from_subscripted_value(
subval,
members,
ctx,
is_typeddict=is_typeddict,
unpack_allowed=unpack_allowed,
)
for subval in root.vals
]
)
Expand Down Expand Up @@ -729,9 +801,16 @@ def _type_from_subscripted_value(
elif len(members) == 1 and members[0] == KnownValue(()):
return SequenceIncompleteValue(tuple, [])
else:
return SequenceIncompleteValue(
tuple, [_type_from_value(arg, ctx) for arg in members]
)
args = [_type_from_value(arg, ctx, unpack_allowed=True) for arg in members]
if any(isinstance(val, UnpackedValue) for val in args):
tuple_members = []
for val in args:
if isinstance(val, UnpackedValue):
tuple_members += val.elements
else:
tuple_members.append((False, val))
return SequenceValue(tuple, tuple_members)
return SequenceIncompleteValue(tuple, args)
elif root is typing.Optional:
if len(members) != 1:
ctx.show_error("Optional[] takes only one argument")
Expand Down Expand Up @@ -769,6 +848,14 @@ def _type_from_subscripted_value(
ctx.show_error("NotRequired[] requires a single argument")
return AnyValue(AnySource.error)
return Pep655Value(False, _type_from_value(members[0], ctx))
elif is_typing_name(root, "Unpack"):
if not unpack_allowed:
ctx.show_error("Unpack[] used in unsupported context")
return AnyValue(AnySource.error)
if len(members) != 1:
ctx.show_error("Unpack requires a single argument")
return AnyValue(AnySource.error)
return _make_unpacked_value(_type_from_value(members[0], ctx), ctx)
elif root is Callable or root is typing.Callable:
if len(members) == 2:
args, return_value = members
Expand Down Expand Up @@ -877,6 +964,11 @@ class Pep655Value(Value):
value: Value


@dataclass
class UnpackedValue(Value):
elements: Sequence[Tuple[bool, Value]]


class _Visitor(ast.NodeVisitor):
def __init__(self, ctx: Context) -> None:
self.ctx = ctx
Expand All @@ -892,6 +984,12 @@ def visit_Subscript(self, node: ast.Subscript) -> Value:
index = self.visit(node.slice)
if isinstance(index, SequenceIncompleteValue):
members = index.members
elif isinstance(index, SequenceValue):
members = index.get_member_sequence()
if members is None:
# TODO support unpacking here
return AnyValue(AnySource.inference)
members = tuple(members)
else:
members = (index,)
return _SubscriptedValue(value, members)
Expand Down Expand Up @@ -1047,7 +1145,9 @@ def _value_of_origin_args(
args: Sequence[object],
val: object,
ctx: Context,
*,
is_typeddict: bool = False,
unpack_allowed: bool = False,
) -> Value:
if origin is typing.Type or origin is type:
if not args:
Expand All @@ -1061,7 +1161,9 @@ def _value_of_origin_args(
elif len(args) == 1 and args[0] == ():
return SequenceIncompleteValue(tuple, [])
else:
args_vals = [_type_from_runtime(arg, ctx) for arg in args]
args_vals = [
_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args
]
return SequenceIncompleteValue(tuple, args_vals)
elif origin is typing.Union:
return unite_values(*[_type_from_runtime(arg, ctx) for arg in args])
Expand Down Expand Up @@ -1126,6 +1228,14 @@ def _value_of_origin_args(
ctx.show_error("NotRequired[] requires a single argument")
return AnyValue(AnySource.error)
return Pep655Value(False, _type_from_runtime(args[0], ctx))
elif is_typing_name(origin, "Unpack"):
if not unpack_allowed:
ctx.show_error("Invalid usage of Unpack")
return AnyValue(AnySource.error)
if len(args) != 1:
ctx.show_error("Unpack requires a single argument")
return AnyValue(AnySource.error)
return _make_unpacked_value(_type_from_runtime(args[0], ctx), ctx)
elif origin is None and isinstance(val, type):
# This happens for SupportsInt in 3.7.
return _maybe_typed_value(val)
Expand All @@ -1144,6 +1254,19 @@ def _maybe_typed_value(val: Union[type, str]) -> Value:
return TypedValue(val)


def _make_unpacked_value(val: Value, ctx: Context) -> UnpackedValue:
if isinstance(val, SequenceValue) and val.typ is tuple:
return UnpackedValue(val.members)
elif isinstance(val, SequenceIncompleteValue) and val.typ is tuple:
return UnpackedValue([(False, elt) for elt in val.members])
elif isinstance(val, GenericValue) and val.typ is tuple:
return UnpackedValue([(True, val.args[0])])
elif isinstance(val, TypedValue) and val.typ is tuple:
return UnpackedValue([(True, AnyValue(AnySource.generic_argument))])
ctx.show_error(f"Invalid argument for Unpack: {val}")
return UnpackedValue([])


def _make_callable_from_value(
args: Value, return_value: Value, ctx: Context, is_asynq: bool = False
) -> Value:
Expand Down
18 changes: 18 additions & 0 deletions pyanalyze/boolability.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
KnownValue,
MultiValuedValue,
SequenceIncompleteValue,
SequenceValue,
SubclassValue,
TypedDictValue,
TypedValue,
Expand Down Expand Up @@ -125,6 +126,23 @@ def _get_boolability_no_mvv(value: Value) -> Boolability:
return Boolability.value_always_true_mutable
else:
return Boolability.value_always_false_mutable
elif isinstance(value, SequenceValue):
if not value.members:
if value.typ is tuple:
return Boolability.value_always_false
else:
return Boolability.value_always_false_mutable
may_be_empty = all(is_many for is_many, _ in value.members)
if may_be_empty:
return Boolability.boolable
if value.typ is tuple:
# We lie slightly here, since at the type level a tuple
# may be false. But tuples are a common source of boolability
# bugs and they're rarely mutated, so we put a stronger
# condition on them.
return Boolability.type_always_true
else:
return Boolability.value_always_true_mutable
elif isinstance(value, DictIncompleteValue):
if any(pair.is_required and not pair.is_many for pair in value.kv_pairs):
return Boolability.value_always_true_mutable
Expand Down
5 changes: 5 additions & 0 deletions pyanalyze/format_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
KnownValue,
DictIncompleteValue,
SequenceIncompleteValue,
SequenceValue,
TypedValue,
Value,
flatten_values,
Expand Down Expand Up @@ -370,6 +371,10 @@ def accept_tuple_args_no_mvv(
args = replace_known_sequence_value(args)
if isinstance(args, SequenceIncompleteValue):
all_args = args.members
elif isinstance(args, SequenceValue):
all_args = args.get_member_sequence()
if all_args is None:
return
else:
# it's a tuple but we don't know what's in it, so assume it's ok
return
Expand Down
Loading