diff --git a/pyanalyze/annotations.py b/pyanalyze/annotations.py index f10bce44..58b77fd8 100644 --- a/pyanalyze/annotations.py +++ b/pyanalyze/annotations.py @@ -82,6 +82,7 @@ ParamSpecKwargsValue, ParameterTypeGuardExtension, SelfTVV, + SequenceValue, TypeGuardExtension, TypedValue, SequenceIncompleteValue, @@ -344,14 +345,26 @@ def value_from_ast( return val -def _type_from_ast(node: ast.AST, ctx: Context, is_typeddict: bool = False) -> Value: +def _type_from_ast( + node: ast.AST, + ctx: Context, + *, + is_typeddict: bool = False, + unpack_allowed: bool = False, +) -> Value: val = value_from_ast(node, ctx) - return _type_from_value(val, ctx, is_typeddict=is_typeddict) + return _type_from_value( + val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + ) -def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Value: +def _type_from_runtime( + val: Any, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False +) -> Value: if isinstance(val, str): - return _eval_forward_ref(val, ctx, is_typeddict=is_typeddict) + return _eval_forward_ref( + val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + ) elif isinstance(val, tuple): # This happens under some Python versions for types # nested in tuples, e.g. on 3.6: @@ -365,13 +378,17 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va args = (val[1],) else: args = val[1:] - return _value_of_origin_args(origin, args, val, ctx) + return _value_of_origin_args( + origin, args, val, ctx, unpack_allowed=unpack_allowed + ) elif GenericAlias is not None and isinstance(val, GenericAlias): origin = get_origin(val) args = get_args(val) if origin is tuple and not args: return SequenceIncompleteValue(tuple, []) - return _value_of_origin_args(origin, args, val, ctx) + return _value_of_origin_args( + origin, args, val, ctx, unpack_allowed=origin is tuple + ) elif typing_inspect.is_literal_type(val): args = typing_inspect.get_args(val) if len(args) == 0: @@ -393,7 +410,17 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va elif len(args) == 1 and args[0] == (): return SequenceIncompleteValue(tuple, []) # empty tuple else: - args_vals = [_type_from_runtime(arg, ctx) for arg in args] + args_vals = [ + _type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args + ] + if any(isinstance(val, UnpackedValue) for val in args_vals): + members = [] + for val in args_vals: + if isinstance(val, UnpackedValue): + members += val.elements + else: + members.append((False, val)) + return SequenceValue(tuple, members) return SequenceIncompleteValue(tuple, args_vals) elif is_instance_of_typing_name(val, "_TypedDictMeta"): required_keys = getattr(val, "__required_keys__", None) @@ -434,7 +461,14 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va args = typing_inspect.get_args(val) if getattr(val, "_special", False): args = [] # distinguish List from List[T] on 3.7 and 3.8 - return _value_of_origin_args(origin, args, val, ctx, is_typeddict=is_typeddict) + return _value_of_origin_args( + origin, + args, + val, + ctx, + is_typeddict=is_typeddict, + unpack_allowed=unpack_allowed or origin is tuple or origin is Tuple, + ) elif typing_inspect.is_callable_type(val): args = typing_inspect.get_args(val) return _value_of_origin_args(Callable, args, val, ctx) @@ -535,6 +569,13 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va cls = "Required" if required else "NotRequired" ctx.show_error(f"{cls}[] used in unsupported context") return AnyValue(AnySource.error) + # Also 3.6 only. + elif is_instance_of_typing_name(val, "_Unpack"): + if unpack_allowed: + return _make_unpacked_value(_type_from_runtime(val.__type__, ctx), ctx) + else: + ctx.show_error("Unpack[] used in unsupported context") + return AnyValue(AnySource.error) elif is_typing_name(val, "TypeAlias"): return AnyValue(AnySource.incomplete_annotation) elif is_typing_name(val, "TypedDict"): @@ -638,28 +679,51 @@ def _get_typeddict_value( return required, val -def _eval_forward_ref(val: str, ctx: Context, is_typeddict: bool = False) -> Value: +def _eval_forward_ref( + val: str, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False +) -> Value: try: tree = ast.parse(val, mode="eval") except SyntaxError: ctx.show_error(f"Syntax error in type annotation: {val}") return AnyValue(AnySource.error) else: - return _type_from_ast(tree.body, ctx, is_typeddict=is_typeddict) + return _type_from_ast( + tree.body, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + ) -def _type_from_value(value: Value, ctx: Context, is_typeddict: bool = False) -> Value: +def _type_from_value( + value: Value, + ctx: Context, + *, + is_typeddict: bool = False, + unpack_allowed: bool = False, +) -> Value: if isinstance(value, KnownValue): - return _type_from_runtime(value.val, ctx, is_typeddict=is_typeddict) + return _type_from_runtime( + value.val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + ) elif isinstance(value, TypeVarValue): return value elif isinstance(value, MultiValuedValue): - return unite_values(*[_type_from_value(val, ctx) for val in value.vals]) + return unite_values( + *[ + _type_from_value( + val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + ) + for val in value.vals + ] + ) elif isinstance(value, AnnotatedValue): return _type_from_value(value.value, ctx) elif isinstance(value, _SubscriptedValue): return _type_from_subscripted_value( - value.root, value.members, ctx, is_typeddict=is_typeddict + value.root, + value.members, + ctx, + is_typeddict=is_typeddict, + unpack_allowed=unpack_allowed, ) elif isinstance(value, AnyValue): return value @@ -677,7 +741,9 @@ def _type_from_subscripted_value( root: Optional[Value], members: Sequence[Value], ctx: Context, + *, is_typeddict: bool = False, + unpack_allowed: bool = False, ) -> Value: if isinstance(root, GenericValue): if len(root.args) == len(members): @@ -690,7 +756,13 @@ def _type_from_subscripted_value( elif isinstance(root, MultiValuedValue): return unite_values( *[ - _type_from_subscripted_value(subval, members, ctx, is_typeddict) + _type_from_subscripted_value( + subval, + members, + ctx, + is_typeddict=is_typeddict, + unpack_allowed=unpack_allowed, + ) for subval in root.vals ] ) @@ -729,9 +801,16 @@ def _type_from_subscripted_value( elif len(members) == 1 and members[0] == KnownValue(()): return SequenceIncompleteValue(tuple, []) else: - return SequenceIncompleteValue( - tuple, [_type_from_value(arg, ctx) for arg in members] - ) + args = [_type_from_value(arg, ctx, unpack_allowed=True) for arg in members] + if any(isinstance(val, UnpackedValue) for val in args): + tuple_members = [] + for val in args: + if isinstance(val, UnpackedValue): + tuple_members += val.elements + else: + tuple_members.append((False, val)) + return SequenceValue(tuple, tuple_members) + return SequenceIncompleteValue(tuple, args) elif root is typing.Optional: if len(members) != 1: ctx.show_error("Optional[] takes only one argument") @@ -769,6 +848,14 @@ def _type_from_subscripted_value( ctx.show_error("NotRequired[] requires a single argument") return AnyValue(AnySource.error) return Pep655Value(False, _type_from_value(members[0], ctx)) + elif is_typing_name(root, "Unpack"): + if not unpack_allowed: + ctx.show_error("Unpack[] used in unsupported context") + return AnyValue(AnySource.error) + if len(members) != 1: + ctx.show_error("Unpack requires a single argument") + return AnyValue(AnySource.error) + return _make_unpacked_value(_type_from_value(members[0], ctx), ctx) elif root is Callable or root is typing.Callable: if len(members) == 2: args, return_value = members @@ -877,6 +964,11 @@ class Pep655Value(Value): value: Value +@dataclass +class UnpackedValue(Value): + elements: Sequence[Tuple[bool, Value]] + + class _Visitor(ast.NodeVisitor): def __init__(self, ctx: Context) -> None: self.ctx = ctx @@ -892,6 +984,12 @@ def visit_Subscript(self, node: ast.Subscript) -> Value: index = self.visit(node.slice) if isinstance(index, SequenceIncompleteValue): members = index.members + elif isinstance(index, SequenceValue): + members = index.get_member_sequence() + if members is None: + # TODO support unpacking here + return AnyValue(AnySource.inference) + members = tuple(members) else: members = (index,) return _SubscriptedValue(value, members) @@ -1047,7 +1145,9 @@ def _value_of_origin_args( args: Sequence[object], val: object, ctx: Context, + *, is_typeddict: bool = False, + unpack_allowed: bool = False, ) -> Value: if origin is typing.Type or origin is type: if not args: @@ -1061,7 +1161,9 @@ def _value_of_origin_args( elif len(args) == 1 and args[0] == (): return SequenceIncompleteValue(tuple, []) else: - args_vals = [_type_from_runtime(arg, ctx) for arg in args] + args_vals = [ + _type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args + ] return SequenceIncompleteValue(tuple, args_vals) elif origin is typing.Union: return unite_values(*[_type_from_runtime(arg, ctx) for arg in args]) @@ -1126,6 +1228,14 @@ def _value_of_origin_args( ctx.show_error("NotRequired[] requires a single argument") return AnyValue(AnySource.error) return Pep655Value(False, _type_from_runtime(args[0], ctx)) + elif is_typing_name(origin, "Unpack"): + if not unpack_allowed: + ctx.show_error("Invalid usage of Unpack") + return AnyValue(AnySource.error) + if len(args) != 1: + ctx.show_error("Unpack requires a single argument") + return AnyValue(AnySource.error) + return _make_unpacked_value(_type_from_runtime(args[0], ctx), ctx) elif origin is None and isinstance(val, type): # This happens for SupportsInt in 3.7. return _maybe_typed_value(val) @@ -1144,6 +1254,19 @@ def _maybe_typed_value(val: Union[type, str]) -> Value: return TypedValue(val) +def _make_unpacked_value(val: Value, ctx: Context) -> UnpackedValue: + if isinstance(val, SequenceValue) and val.typ is tuple: + return UnpackedValue(val.members) + elif isinstance(val, SequenceIncompleteValue) and val.typ is tuple: + return UnpackedValue([(False, elt) for elt in val.members]) + elif isinstance(val, GenericValue) and val.typ is tuple: + return UnpackedValue([(True, val.args[0])]) + elif isinstance(val, TypedValue) and val.typ is tuple: + return UnpackedValue([(True, AnyValue(AnySource.generic_argument))]) + ctx.show_error(f"Invalid argument for Unpack: {val}") + return UnpackedValue([]) + + def _make_callable_from_value( args: Value, return_value: Value, ctx: Context, is_asynq: bool = False ) -> Value: diff --git a/pyanalyze/boolability.py b/pyanalyze/boolability.py index ae6d0ffd..1552f07e 100644 --- a/pyanalyze/boolability.py +++ b/pyanalyze/boolability.py @@ -20,6 +20,7 @@ KnownValue, MultiValuedValue, SequenceIncompleteValue, + SequenceValue, SubclassValue, TypedDictValue, TypedValue, @@ -125,6 +126,23 @@ def _get_boolability_no_mvv(value: Value) -> Boolability: return Boolability.value_always_true_mutable else: return Boolability.value_always_false_mutable + elif isinstance(value, SequenceValue): + if not value.members: + if value.typ is tuple: + return Boolability.value_always_false + else: + return Boolability.value_always_false_mutable + may_be_empty = all(is_many for is_many, _ in value.members) + if may_be_empty: + return Boolability.boolable + if value.typ is tuple: + # We lie slightly here, since at the type level a tuple + # may be false. But tuples are a common source of boolability + # bugs and they're rarely mutated, so we put a stronger + # condition on them. + return Boolability.type_always_true + else: + return Boolability.value_always_true_mutable elif isinstance(value, DictIncompleteValue): if any(pair.is_required and not pair.is_many for pair in value.kv_pairs): return Boolability.value_always_true_mutable diff --git a/pyanalyze/format_strings.py b/pyanalyze/format_strings.py index 71c91429..84c9d821 100644 --- a/pyanalyze/format_strings.py +++ b/pyanalyze/format_strings.py @@ -30,6 +30,7 @@ KnownValue, DictIncompleteValue, SequenceIncompleteValue, + SequenceValue, TypedValue, Value, flatten_values, @@ -370,6 +371,10 @@ def accept_tuple_args_no_mvv( args = replace_known_sequence_value(args) if isinstance(args, SequenceIncompleteValue): all_args = args.members + elif isinstance(args, SequenceValue): + all_args = args.get_member_sequence() + if all_args is None: + return else: # it's a tuple but we don't know what's in it, so assume it's ok return diff --git a/pyanalyze/implementation.py b/pyanalyze/implementation.py index ab68023d..a88ed59d 100644 --- a/pyanalyze/implementation.py +++ b/pyanalyze/implementation.py @@ -35,6 +35,7 @@ HasAttrGuardExtension, KVPair, ParameterTypeGuardExtension, + SequenceValue, TypeVarValue, TypedValue, SubclassValue, @@ -363,6 +364,14 @@ def _list_append_impl(ctx: CallContext) -> ImplReturn: SequenceIncompleteValue.make_or_known(list, (*lst.members, element)), ) return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) + elif isinstance(lst, SequenceValue): + no_return_unless = Constraint( + varname, + ConstraintType.is_value_object, + True, + SequenceValue.make_or_known(list, (*lst.members, (False, element))), + ) + return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) elif isinstance(lst, GenericValue): return _maybe_broaden_weak_type( "list.append", @@ -390,7 +399,44 @@ def inner(key: Value) -> Value: if isinstance(key, KnownValue): if isinstance(key.val, int): - if isinstance(self_value, SequenceIncompleteValue): + if isinstance(self_value, SequenceValue): + members = self_value.get_member_sequence() + if members is not None: + if -len(members) <= key.val < len(members): + return members[key.val] + elif typ is list: + # fall back to the common type + return self_value.args[0] + else: + ctx.show_error(f"Tuple index out of range: {key}") + return AnyValue(AnySource.error) + else: + # The value contains at least one unpack. We try to find a precise + # type if everything leading up to the index we're interested in is + # a single element. For example, given a T: tuple[int, *tuple[str, ...]], + # T[0] should be int, but T[-1] should be int | str, because + # the unpacked tuple may be empty. For T[1] we could infer str, but + # we just infer int | str for simplicity. + if key.val >= 0: + for i, (is_many, member) in enumerate(self_value.members): + if is_many: + # Give up + break + if i == key.val: + return member + else: + index_from_back = -key.val + 1 + for i, (is_many, member) in enumerate( + reversed(self_value.members) + ): + if is_many: + # Give up + break + if i == index_from_back: + return member + # fall back to the common type + return self_value.args[0] + elif isinstance(self_value, SequenceIncompleteValue): if -len(self_value.members) <= key.val < len(self_value.members): return self_value.members[key.val] elif typ is list: @@ -402,7 +448,17 @@ def inner(key: Value) -> Value: else: return self_value.get_generic_arg_for_type(typ, ctx.visitor, 0) elif isinstance(key.val, slice): - if isinstance(self_value, SequenceIncompleteValue): + if isinstance(self_value, SequenceValue): + members = self_value.get_member_sequence() + if members is not None: + return SequenceValue.make_or_known( + typ, [(False, m) for m in members[key.val]] + ) + else: + # If the value contains unpacked values, we don't attempt + # to resolve the slice. + return GenericValue(typ, self_value.args) + elif isinstance(self_value, SequenceIncompleteValue): return SequenceIncompleteValue.make_or_known( list, self_value.members[key.val] ) @@ -864,7 +920,9 @@ def _list_add_impl(ctx: CallContext) -> ImplReturn: def inner(left: Value, right: Value) -> Value: left = replace_known_sequence_value(left) right = replace_known_sequence_value(right) - if isinstance(left, SequenceIncompleteValue) and isinstance( + if isinstance(left, SequenceValue) and isinstance(right, SequenceValue): + return SequenceValue.make_or_known(list, [*left.members, *right.members]) + elif isinstance(left, SequenceIncompleteValue) and isinstance( right, SequenceIncompleteValue ): return SequenceIncompleteValue.make_or_known( @@ -911,6 +969,33 @@ def inner(lst: Value, iterable: Value) -> ImplReturn: varname, ConstraintType.is_value_object, True, constrained_value ) return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) + elif isinstance(cleaned_lst, SequenceValue): + if isinstance(iterable, SequenceIncompleteValue): + constrained_value = SequenceValue.make_or_known( + list, + (*cleaned_lst.members, *[(False, m) for m in iterable.members]), + ) + elif isinstance(iterable, SequenceValue): + constrained_value = SequenceValue.make_or_known( + list, (*cleaned_lst.members, *iterable.members) + ) + else: + if isinstance(iterable, TypedValue): + arg_type = iterable.get_generic_arg_for_type( + collections.abc.Iterable, ctx.visitor, 0 + ) + else: + arg_type = AnyValue(AnySource.generic_argument) + constrained_value = SequenceValue( + list, [*cleaned_lst.members, (True, arg_type)] + ) + if return_container: + return ImplReturn(constrained_value) + if varname is not None: + no_return_unless = Constraint( + varname, ConstraintType.is_value_object, True, constrained_value + ) + return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) elif ( varname is not None and isinstance(cleaned_lst, GenericValue) @@ -998,6 +1083,16 @@ def _set_add_impl(ctx: CallContext) -> ImplReturn: ), ) return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) + elif isinstance(set_value, SequenceValue): + no_return_unless = Constraint( + varname, + ConstraintType.is_value_object, + True, + SequenceValue.make_or_known( + set, (*set_value.members, (False, element)) + ), + ) + return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) elif isinstance(set_value, GenericValue): return _maybe_broaden_weak_type( "set.add", @@ -1086,9 +1181,14 @@ def _str_format_impl(ctx: CallContext) -> Value: if not isinstance(self, KnownValue): return TypedValue(str) args_value = replace_known_sequence_value(ctx.vars["args"]) - if not isinstance(args_value, SequenceIncompleteValue): + if isinstance(args_value, SequenceIncompleteValue): + args = args_value.members + elif isinstance(args_value, SequenceValue): + args = args_value.get_member_sequence() + if args is None: + return TypedValue(str) + else: return TypedValue(str) - args = args_value.members kwargs_value = replace_known_sequence_value(ctx.vars["kwargs"]) kwargs = {} if isinstance(kwargs_value, DictIncompleteValue): @@ -1219,7 +1319,15 @@ def len_of_value(val: Value) -> Value: and not issubclass(val.typ, KNOWN_MUTABLE_TYPES) ): return KnownValue(len(val.members)) - elif isinstance(val, KnownValue): + if ( + isinstance(val, SequenceValue) + and isinstance(val.typ, type) + and not issubclass(val.typ, KNOWN_MUTABLE_TYPES) + ): + members = val.get_member_sequence() + if members is not None: + return KnownValue(len(members)) + if isinstance(val, KnownValue): try: if not isinstance(val.val, KNOWN_MUTABLE_TYPES): return KnownValue(len(val.val)) diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 8a931a70..96e0771c 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -166,6 +166,7 @@ UNINITIALIZED_VALUE, NO_RETURN_VALUE, NoReturnConstraintExtension, + SequenceValue, annotate_value, check_hashability, flatten_values, @@ -3219,6 +3220,12 @@ def _unwrap_yield_result(self, node: ast.AST, value: Value) -> Value: self._unwrap_yield_result(node, member) for member in value.members ] return self._maybe_make_sequence(value.typ, values, node) + elif isinstance(value, SequenceValue) and isinstance(value.typ, type): + values = [ + (is_many, self._unwrap_yield_result(node, member)) + for is_many, member in value.members + ] + return SequenceValue.make_or_known(value.typ, values) elif isinstance(value, GenericValue): member_value = self._unwrap_yield_result(node, value.get_arg(0)) return GenericValue(value.typ, [member_value]) diff --git a/pyanalyze/patma.py b/pyanalyze/patma.py index c7509a65..ef6a43f5 100644 --- a/pyanalyze/patma.py +++ b/pyanalyze/patma.py @@ -51,6 +51,7 @@ DictIncompleteValue, KVPair, SequenceIncompleteValue, + SequenceValue, SubclassValue, TypedValue, Value, @@ -478,18 +479,34 @@ def get_match_args( if match_args_value is UNINITIALIZED_VALUE: return CanAssignError(f"{cls} has no attribute __match_args__") match_args_value = replace_known_sequence_value(match_args_value) - if ( - not isinstance(match_args_value, SequenceIncompleteValue) - or match_args_value.typ is not tuple - ): - return CanAssignError( - f"__match_args__ must be a literal tuple, not {match_args_value}" - ) - match_args = [] - for i, arg in enumerate(match_args_value.members): - if not isinstance(arg, KnownValue) or not isinstance(arg.val, str): + if isinstance(match_args_value, SequenceIncompleteValue): + if match_args_value.typ is not tuple: + return CanAssignError( + f"__match_args__ must be a literal tuple, not {match_args_value}" + ) + match_args = [] + for i, arg in enumerate(match_args_value.members): + if not isinstance(arg, KnownValue) or not isinstance(arg.val, str): + return CanAssignError( + f"__match_args__ element {i} is {arg}, not a string literal" + ) + match_args.append(arg.val) + return match_args + elif isinstance(match_args_value, SequenceValue): + if match_args_value.typ is not tuple: return CanAssignError( - f"__match_args__ element {i} is {arg}, not a string literal" + f"__match_args__ must be a literal tuple, not {match_args_value}" ) - match_args.append(arg.val) - return match_args + match_args = [] + for i, (is_many, arg) in enumerate(match_args_value.members): + if is_many: + return CanAssignError("Cannot use unpacking in __match_args__") + if not isinstance(arg, KnownValue) or not isinstance(arg.val, str): + return CanAssignError( + f"__match_args__ element {i} is {arg}, not a string literal" + ) + match_args.append(arg.val) + return match_args + return CanAssignError( + f"__match_args__ must be a literal tuple, not {match_args_value}" + ) diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index 22b94dce..c5e4bcc2 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -44,6 +44,7 @@ ParameterTypeGuardExtension, SequenceIncompleteValue, DictIncompleteValue, + SequenceValue, TypeGuardExtension, TypeVarValue, TypedDictValue, @@ -2351,6 +2352,26 @@ def can_assign_var_positional( ) -> Union[List[BoundsMap], CanAssignError]: bounds_maps = [] my_annotation = my_param.get_annotation() + if isinstance(args_annotation, SequenceValue): + members = args_annotation.get_member_sequence() + if members is not None: + length = len(members) + if idx >= length: + return CanAssignError( + f"parameter {my_param.name!r} is not accepted;" + f" {args_annotation} only accepts {length} values" + ) + their_annotation = members[idx] + can_assign = their_annotation.can_assign(my_annotation, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"type of parameter {my_param.name!r} is incompatible: *args[{idx}]" + " type is incompatible", + [can_assign], + ) + bounds_maps.append(can_assign) + return bounds_maps + if isinstance(args_annotation, SequenceIncompleteValue): length = len(args_annotation.members) if idx >= length: diff --git a/pyanalyze/suggested_type.py b/pyanalyze/suggested_type.py index 64fd18e5..35a72164 100644 --- a/pyanalyze/suggested_type.py +++ b/pyanalyze/suggested_type.py @@ -22,6 +22,7 @@ GenericValue, KnownValue, SequenceIncompleteValue, + SequenceValue, SubclassValue, TypedDictValue, TypedValue, @@ -162,6 +163,14 @@ def prepare_type(value: Value) -> Value: ) else: return GenericValue(value.typ, [prepare_type(arg) for arg in value.args]) + elif isinstance(value, SequenceValue): + if value.typ is tuple: + members = value.get_member_sequence() + if members is not None: + return SequenceValue( + tuple, [(False, prepare_type(elt)) for elt in members] + ) + return GenericValue(value.typ, [prepare_type(arg) for arg in value.args]) elif isinstance(value, (TypedDictValue, CallableValue)): return value elif isinstance(value, GenericValue): diff --git a/pyanalyze/test_annotations.py b/pyanalyze/test_annotations.py index 752b8f97..7d20736c 100644 --- a/pyanalyze/test_annotations.py +++ b/pyanalyze/test_annotations.py @@ -12,6 +12,7 @@ MultiValuedValue, NewTypeValue, SequenceIncompleteValue, + SequenceValue, TypeVarValue, TypedDictValue, TypedValue, @@ -1731,3 +1732,27 @@ def capybara(x: X, y: Y, x_quoted: "X", y_quoted: "Y", z: Z) -> None: assert_is_value(x_quoted, TypedValue(int)) assert_is_value(y_quoted, TypedValue(int)) assert_is_value(z, TypedValue(int)) + + +class TestUnpack(TestNameCheckVisitorBase): + @assert_passes() + def test_in_tuple(self): + from typing_extensions import Unpack + from typing import Tuple + + def capybara( + x: Tuple[int, Unpack[Tuple[str, ...]]], + y: "Tuple[int, Unpack[Tuple[str, ...]]]", + ): + assert_is_value( + x, + SequenceValue( + tuple, [(False, TypedValue(int)), (True, TypedValue(str))] + ), + ) + assert_is_value( + y, + SequenceValue( + tuple, [(False, TypedValue(int)), (True, TypedValue(str))] + ), + ) diff --git a/pyanalyze/test_boolability.py b/pyanalyze/test_boolability.py index 2d696825..bcb4d2b4 100644 --- a/pyanalyze/test_boolability.py +++ b/pyanalyze/test_boolability.py @@ -12,6 +12,7 @@ KVPair, KnownValue, SequenceIncompleteValue, + SequenceValue, TypedDictValue, UnboundMethodValue, TypedValue, @@ -61,6 +62,31 @@ def test_get_boolability() -> None: assert Boolability.value_always_false_mutable == get_boolability( SequenceIncompleteValue(list, []) ) + assert Boolability.type_always_true == get_boolability( + SequenceValue(tuple, [(False, KnownValue(1))]) + ) + assert Boolability.value_always_false == get_boolability(SequenceValue(tuple, [])) + assert Boolability.boolable == get_boolability( + SequenceValue(tuple, [(True, KnownValue(1))]) + ) + assert Boolability.type_always_true == get_boolability( + # many 1s followed by one 2 + SequenceValue(tuple, [(True, KnownValue(1)), (False, KnownValue(2))]) + ) + assert Boolability.value_always_true_mutable == get_boolability( + SequenceValue(list, [(False, KnownValue(1))]) + ) + assert Boolability.value_always_false_mutable == get_boolability( + SequenceValue(list, []) + ) + assert Boolability.boolable == get_boolability( + SequenceValue(list, [(True, KnownValue(1))]) + ) + assert Boolability.value_always_true_mutable == get_boolability( + # many 1s followed by one 2 + SequenceValue(list, [(True, KnownValue(1)), (False, KnownValue(2))]) + ) + assert Boolability.value_always_true_mutable == get_boolability( DictIncompleteValue(dict, [KVPair(KnownValue(1), KnownValue(1))]) ) diff --git a/pyanalyze/test_name_check_visitor.py b/pyanalyze/test_name_check_visitor.py index d0638221..73e72b96 100644 --- a/pyanalyze/test_name_check_visitor.py +++ b/pyanalyze/test_name_check_visitor.py @@ -28,6 +28,7 @@ NewTypeValue, ReferencingValue, SequenceIncompleteValue, + SequenceValue, TypedValue, TypeVarValue, UnboundMethodValue, @@ -130,6 +131,7 @@ def _make_module(code_str: str) -> types.ModuleType: KnownValue=KnownValue, MultiValuedValue=MultiValuedValue, AnnotatedValue=AnnotatedValue, + SequenceValue=SequenceValue, SequenceIncompleteValue=SequenceIncompleteValue, TypedValue=TypedValue, UnboundMethodValue=UnboundMethodValue, diff --git a/pyanalyze/test_value.py b/pyanalyze/test_value.py index 82bfb097..b878400d 100644 --- a/pyanalyze/test_value.py +++ b/pyanalyze/test_value.py @@ -23,6 +23,7 @@ CallableValue, CanAssignError, KVPair, + SequenceValue, Value, GenericValue, KnownValue, @@ -33,6 +34,7 @@ TypeVarMap, concrete_values_from_iterable, unite_and_simplify, + unpack_values, ) _checker = Checker() @@ -576,3 +578,65 @@ def test_unite_and_simplify() -> None: assert unite_and_simplify(*vals, limit=2) == GenericValue( list, [TypedValue(int)] ) | GenericValue(list, [AnyValue(AnySource.unreachable)]) + + +def test_unpack_values() -> None: + t_int = SequenceValue(tuple, [(False, TypedValue(int))]) + assert unpack_values(t_int, CTX, 1, None) == [TypedValue(int)] + assert unpack_values(t_int, CTX, 1, 0) == [TypedValue(int), SequenceValue(list, [])] + assert isinstance(unpack_values(t_int, CTX, 1, 1), CanAssignError) + assert isinstance(unpack_values(t_int, CTX, 2, None), CanAssignError) + + t_int_str = SequenceValue( + tuple, [(False, TypedValue(int)), (False, TypedValue(str))] + ) + assert isinstance(unpack_values(t_int_str, CTX, 1, None), CanAssignError) + assert unpack_values(t_int_str, CTX, 2, None) == [TypedValue(int), TypedValue(str)] + assert unpack_values(t_int_str, CTX, 2, 0) == [ + TypedValue(int), + TypedValue(str), + SequenceValue(list, []), + ] + assert unpack_values(t_int_str, CTX, 1, 1) == [ + TypedValue(int), + SequenceValue(list, []), + TypedValue(str), + ] + + t_int_star_str = SequenceValue( + tuple, [(False, TypedValue(int)), (True, TypedValue(str))] + ) + assert unpack_values(t_int_star_str, CTX, 1, None) == [TypedValue(int)] + assert unpack_values(t_int_star_str, CTX, 1, 0) == [ + TypedValue(int), + SequenceValue(list, [(True, TypedValue(str))]), + ] + assert unpack_values(t_int_star_str, CTX, 1, 1) == [ + TypedValue(int), + GenericValue(list, [TypedValue(str)]), + TypedValue(str), + ] + assert unpack_values(t_int_star_str, CTX, 2, None) == [ + TypedValue(int), + TypedValue(str), + ] + + t_int_star_str_float = SequenceValue( + tuple, + [(False, TypedValue(int)), (True, TypedValue(str)), (False, TypedValue(float))], + ) + assert isinstance(unpack_values(t_int_star_str_float, CTX, 1, None), CanAssignError) + assert unpack_values(t_int_star_str_float, CTX, 2, None) == [ + TypedValue(int), + TypedValue(float), + ] + assert unpack_values(t_int_star_str_float, CTX, 1, 1) == [ + TypedValue(int), + SequenceValue(list, [(True, TypedValue(str))]), + TypedValue(float), + ] + assert unpack_values(t_int_star_str_float, CTX, 0, 2) == [ + GenericValue(list, [TypedValue(str) | TypedValue(int)]), + TypedValue(int) | TypedValue(str), + TypedValue(float), + ] diff --git a/pyanalyze/type_evaluation.py b/pyanalyze/type_evaluation.py index dca822b1..5461c0cb 100644 --- a/pyanalyze/type_evaluation.py +++ b/pyanalyze/type_evaluation.py @@ -43,6 +43,7 @@ KnownValue, MultiValuedValue, SequenceIncompleteValue, + SequenceValue, Value, flatten_values, unannotate, @@ -607,6 +608,8 @@ def evaluate_literal(self, node: ast.expr) -> Optional[KnownValue]: and all_of_type(val.members, KnownValue) ): val = KnownValue(val.typ(elt.val for elt in val.members)) + if isinstance(val, SequenceValue): + val = val.make_known_value() if isinstance(val, KnownValue): return val self.errors.append(InvalidEvaluation("Only literals supported", node)) diff --git a/pyanalyze/value.py b/pyanalyze/value.py index e19b9973..fc8c9937 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -885,6 +885,173 @@ def simplify(self) -> Value: return GenericValue(self.typ, [arg.simplify() for arg in self.args]) +@dataclass(unsafe_hash=True, init=False) +class SequenceValue(GenericValue): + """A :class:`TypedValue` subclass representing a sequence of known type. + + This is represented as a sequence, but each entry in the sequence may + consist of multiple values. + For example, the expression ``[int(self.foo)]`` may be typed as + ``SequenceValue(list, [(False, TypedValue(int))])``. The expression + ``["x", *some_str.split()]`` would be represented as + ``SequenceValue(list, [(False, KnownValue("x")), (True, TypedValue(str))])``. + + This is only used for ``set``, ``list``, and ``tuple``. + + """ + + members: Tuple[Tuple[bool, Value], ...] + """The elements of the sequence.""" + + def __init__( + self, typ: Union[type, str], members: Sequence[Tuple[bool, Value]] + ) -> None: + if members: + args = (unite_values(*[typ for _, typ in members]),) + else: + args = (AnyValue(AnySource.unreachable),) + super().__init__(typ, args) + self.members = tuple(members) + + def get_member_sequence(self) -> Optional[Sequence[Value]]: + """Return the :class:`Value` objects in this sequence. Return + None if there are any unpacked values in the sequence.""" + members = [] + for is_many, member in self.members: + if is_many: + return None + members.append(member) + return members + + def make_known_value(self) -> Value: + """Turn this value into a KnownValue if possible.""" + if isinstance(self.typ, str): + return self + return self.make_or_known(self.typ, self.members) + + @classmethod + def make_or_known( + cls, typ: type, members: Sequence[Tuple[bool, Value]] + ) -> Union[KnownValue, "SequenceValue"]: + known_members = [] + for is_many, member in members: + if is_many or not isinstance(member, KnownValue): + return SequenceValue(typ, members) + known_members.append(member.val) + return KnownValue(typ(known_members)) + + def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign: + if isinstance(other, SequenceIncompleteValue): + can_assign = self.get_type_object(ctx).can_assign(self, other, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"Cannot assign {stringify_object(other.typ)} to" + f" {stringify_object(self.typ)}" + ) + my_len = len(self.members) + their_len = len(other.members) + if my_len != their_len: + type_str = stringify_object(self.typ) + return CanAssignError( + f"Cannot assign {type_str} of length {their_len} to {type_str} of" + f" length {my_len}" + ) + if my_len == 0: + return {} # they're both empty + bounds_maps = [can_assign] + for i, ((is_many, my_member), their_member) in enumerate( + zip(self.members, other.members) + ): + if is_many: + return CanAssignError( + f"Member {i} is an unpacked type, but a non-unpacked type is" + " provided" + ) + can_assign = my_member.can_assign(their_member, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"Types for member {i} are incompatible", [can_assign] + ) + bounds_maps.append(can_assign) + return unify_bounds_maps(bounds_maps) + elif isinstance(other, SequenceValue): + can_assign = self.get_type_object(ctx).can_assign(self, other, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"Cannot assign {stringify_object(other.typ)} to" + f" {stringify_object(self.typ)}" + ) + my_len = len(self.members) + their_len = len(other.members) + if my_len != their_len: + type_str = stringify_object(self.typ) + return CanAssignError( + f"Cannot assign {type_str} of length {their_len} to {type_str} of" + f" length {my_len}" + ) + if my_len == 0: + return {} # they're both empty + bounds_maps = [can_assign] + for i, ( + (my_is_many, my_member), + (their_is_many, their_member), + ) in enumerate(zip(self.members, other.members)): + if my_is_many != their_is_many: + if my_is_many: + return CanAssignError( + f"Member {i} is an unpacked type, but a single element is" + " provided" + ) + else: + return CanAssignError( + f"Member {i} is a single element, but an unpacked type is" + " provided" + ) + can_assign = my_member.can_assign(their_member, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"Types for member {i} are incompatible", [can_assign] + ) + bounds_maps.append(can_assign) + return unify_bounds_maps(bounds_maps) + return super().can_assign(other, ctx) + + def substitute_typevars(self, typevars: TypeVarMap) -> Value: + return SequenceValue( + self.typ, + [ + (is_many, member.substitute_typevars(typevars)) + for is_many, member in self.members + ], + ) + + def __str__(self) -> str: + members = ", ".join( + (f"*tuple[{m}, ...]" if is_many else str(m)) for is_many, m in self.members + ) + if self.typ is tuple: + return f"tuple[{members}]" + return f"<{stringify_object(self.typ)} containing [{members}]>" + + def walk_values(self) -> Iterable[Value]: + yield self + for _, member in self.members: + yield from member.walk_values() + + def simplify(self) -> GenericValue: + if self.typ is tuple: + return SequenceValue( + tuple, + [(is_many, member.simplify()) for is_many, member in self.members], + ) + members = [member.simplify() for _, member in self.members] + arg = unite_values(*members) + if arg is NO_RETURN_VALUE: + arg = AnyValue(AnySource.unreachable) + return GenericValue(self.typ, [arg]) + + +# TODO(jelle): Replace with SequenceValue @dataclass(unsafe_hash=True, init=False) class SequenceIncompleteValue(GenericValue): """A :class:`TypedValue` subclass representing a sequence of known type and length. @@ -894,6 +1061,8 @@ class SequenceIncompleteValue(GenericValue): This is only used for ``set``, ``list``, and ``tuple``. + This type is being phased out in favor of :class:`SequenceValue`. + """ members: Tuple[Value, ...] @@ -945,6 +1114,39 @@ def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign: ) bounds_maps.append(can_assign) return unify_bounds_maps(bounds_maps) + elif isinstance(other, SequenceValue): + can_assign = self.get_type_object(ctx).can_assign(self, other, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"Cannot assign {stringify_object(other.typ)} to" + f" {stringify_object(self.typ)}" + ) + my_len = len(self.members) + their_len = len(other.members) + if my_len != their_len: + type_str = stringify_object(self.typ) + return CanAssignError( + f"Cannot assign {type_str} of length {their_len} to {type_str} of" + f" length {my_len}" + ) + if my_len == 0: + return {} # they're both empty + bounds_maps = [can_assign] + for i, (my_member, (is_many, their_member)) in enumerate( + zip(self.members, other.members) + ): + if is_many: + return CanAssignError( + f"Member {i} is a single element, but an unpacked type is" + " provided" + ) + can_assign = my_member.can_assign(their_member, ctx) + if isinstance(can_assign, CanAssignError): + return CanAssignError( + f"Types for member {i} are incompatible", [can_assign] + ) + bounds_maps.append(can_assign) + return unify_bounds_maps(bounds_maps) return super().can_assign(other, ctx) def substitute_typevars(self, typevars: TypeVarMap) -> Value: @@ -2229,6 +2431,11 @@ def concrete_values_from_iterable( return unite_values(*value_subvals, *chain.from_iterable(seq_subvals)) if isinstance(value, SequenceIncompleteValue): return value.members + if isinstance(value, SequenceValue): + members = value.get_member_sequence() + if members is None: + return value.args[0] + return members elif isinstance(value, TypedDictValue): if all(required for required, _ in value.items.items()): return [KnownValue(key) for key in value.items] @@ -2435,6 +2642,13 @@ def unpack_values( ) if not isinstance(vals, CanAssignError): return vals + elif isinstance(value, SequenceValue): + if value.typ is tuple: + return _unpack_sequence_value(value, target_length, post_starred_length) + elif value.typ is list: + vals = _unpack_sequence_value(value, target_length, post_starred_length) + if not isinstance(vals, CanAssignError): + return vals tv_map = get_tv_map(IterableValue, value, ctx) if isinstance(tv_map, CanAssignError): @@ -2456,6 +2670,83 @@ def _create_unpacked_list( return [iterable_type] * target_length +def _unpack_sequence_value( + value: SequenceValue, target_length: int, post_starred_length: Optional[int] +) -> Union[Sequence[Value], CanAssignError]: + head = [] + tail = [] + while len(head) < target_length: + if len(head) >= len(value.members): + return CanAssignError( + f"{value} must have at least {target_length} elements" + ) + is_many, val = value.members[len(head)] + if is_many: + break + head.append(val) + remaining_target_length = target_length - len(head) + if post_starred_length is None: + if remaining_target_length == 0: + if all(is_many for is_many, _ in value.members[target_length:]): + return head + return CanAssignError(f"{value} must have exactly {target_length} elements") + + tail = [] + while len(tail) < remaining_target_length: + if len(tail) + len(head) >= len(value.members): + return CanAssignError( + f"{value} must have at least {target_length} elements" + ) + is_many, val = value.members[-len(tail) - 1] + if is_many: + break + tail.append(val) + + if tail: + remaining_members = value.members[len(head) : -len(tail)] + else: + remaining_members = value.members[len(head) :] + if not remaining_members: + return CanAssignError(f"{value} must have exactly {target_length} elements") + middle_length = remaining_target_length - len(tail) + fallback_value = unite_values(*[val for _, val in remaining_members]) + return [*head, *[fallback_value for _ in range(middle_length)], *tail] + else: + while len(tail) < post_starred_length: + if len(tail) >= len(value.members) - len(head): + return CanAssignError( + f"{value} must have at least" + f" {target_length + post_starred_length} elements" + ) + is_many, val = value.members[-len(tail) - 1] + if is_many: + break + tail.append(val) + remaining_post_starred_length = post_starred_length - len(tail) + + if tail: + remaining_members = value.members[len(head) : -len(tail)] + else: + remaining_members = value.members[len(head) :] + if remaining_target_length != 0 or remaining_post_starred_length != 0: + if not remaining_members: + return CanAssignError( + f"{value} must have at least" + f" {target_length + post_starred_length} elements" + ) + else: + fallback_value = unite_values(*[val for _, val in remaining_members]) + return [ + *head, + *[fallback_value for _ in range(remaining_target_length)], + GenericValue(list, [fallback_value]), + *[fallback_value for _ in range(remaining_post_starred_length)], + *tail, + ] + else: + return [*head, SequenceValue(list, remaining_members), *tail] + + def _unpack_value_sequence( value: Value, members: Sequence[Value],