From 14bdbc54651ea46d4ab3cb20c668064e9122e1f0 Mon Sep 17 00:00:00 2001 From: Vlad Emelianov Date: Sat, 29 Jul 2023 00:43:34 +0300 Subject: [PATCH] Fix method arguments for mypy --- mypy_boto3_builder/parsers/service_package.py | 9 ++++ mypy_boto3_builder/parsers/shape_parser.py | 45 +++++++++++++++++++ .../type_annotations/type_subscript.py | 28 ++++++++++++ tests/type_annotations/test_type_subscript.py | 15 +++++++ 4 files changed, 97 insertions(+) diff --git a/mypy_boto3_builder/parsers/service_package.py b/mypy_boto3_builder/parsers/service_package.py index 2b7a496b..07e96040 100644 --- a/mypy_boto3_builder/parsers/service_package.py +++ b/mypy_boto3_builder/parsers/service_package.py @@ -95,6 +95,15 @@ def parse_service_package( result.client.methods.append(method) shape_parser.fix_typed_dict_names() + + methods = [ + *result.client.methods, + *(result.service_resource.methods if result.service_resource else []), + *[method for paginator in result.paginators for method in paginator.methods], + *[method for waiter in result.waiters for method in waiter.methods], + ] + shape_parser.fix_method_arguments_for_mypy(methods) + result.typed_dicts = result.extract_typed_dicts() result.literals = result.extract_literals() result.validate() diff --git a/mypy_boto3_builder/parsers/shape_parser.py b/mypy_boto3_builder/parsers/shape_parser.py index 207246a4..f4823d97 100644 --- a/mypy_boto3_builder/parsers/shape_parser.py +++ b/mypy_boto3_builder/parsers/shape_parser.py @@ -80,6 +80,7 @@ def __init__(self, session: Session, service_name: ServiceName): self._typed_dict_map: dict[str, TypeTypedDict] = {} self._output_typed_dict_map: dict[str, TypeTypedDict] = {} self._response_typed_dict_map: dict[str, TypeTypedDict] = {} + self._fixed_typed_dict_map: dict[TypeTypedDict, TypeTypedDict] = {} self._waiters_shape: Mapping[str, Any] | None = None with contextlib.suppress(UnknownServiceError): @@ -874,6 +875,7 @@ def fix_typed_dict_names(self) -> None: old_typed_dict_name = typed_dict.name new_typed_dict_name = self._get_non_clashing_typed_dict_name(typed_dict, "Output") + self._fixed_typed_dict_map[typed_dict] = output_typed_dict self.logger.debug( f"Fixing TypedDict name clash {old_typed_dict_name} -> {new_typed_dict_name}" ) @@ -922,3 +924,46 @@ def fix_typed_dict_names(self) -> None: response_typed_dict.name = new_typed_dict_name del self._response_typed_dict_map[old_typed_dict_name] self._response_typed_dict_map[response_typed_dict.name] = response_typed_dict + + def fix_method_arguments_for_mypy(self, methods: Sequence[Method]) -> None: + """ + Accept both input and output shapes in method arguments. + + mypy does not compare TypedDicts, so we need to accept both input and output shapes. + https://github.com/youtype/mypy_boto3_builder/issues/209 + """ + for input_typed_dict, output_typed_dict in self._fixed_typed_dict_map.items(): + for method in methods: + for argument in method.arguments: + if not argument.type_annotation: + continue + if ( + argument.type_annotation.is_typed_dict() + and argument.type_annotation == input_typed_dict + ): + self.logger.debug( + f"Adding output shape to {method.name} {argument.name} type:" + f" {input_typed_dict.name} | {output_typed_dict.name}" + ) + argument.type_annotation = TypeSubscript( + Type.Union, + [input_typed_dict, output_typed_dict], + ) + continue + if isinstance(argument.type_annotation, TypeSubscript): + parent = argument.type_annotation.find_type_annotation_parent( + input_typed_dict + ) + if parent: + self.logger.debug( + f"Adding output shape to {method.name} {argument.name} type:" + f" {input_typed_dict.name} | {output_typed_dict.name}" + ) + parent.replace_child( + input_typed_dict, + TypeSubscript( + Type.Union, + [input_typed_dict, output_typed_dict], + ), + ) + continue diff --git a/mypy_boto3_builder/type_annotations/type_subscript.py b/mypy_boto3_builder/type_annotations/type_subscript.py index dc9b570a..0c31f2a5 100644 --- a/mypy_boto3_builder/type_annotations/type_subscript.py +++ b/mypy_boto3_builder/type_annotations/type_subscript.py @@ -99,3 +99,31 @@ def get_local_types(self) -> list[FakeAnnotation]: for child in self.children: result.extend(child.get_local_types()) return result + + def find_type_annotation_parent( + self, type_annotation: FakeAnnotation + ) -> "TypeSubscript | None": + """ + Check recursively if child is present in subscript. + """ + if type_annotation in self.children: + return self + + type_subscript_children = [i for i in self.children if isinstance(i, TypeSubscript)] + for child in type_subscript_children: + result = child.find_type_annotation_parent(type_annotation) + if result: + return result + + return None + + def replace_child(self: _R, child: FakeAnnotation, new_child: FakeAnnotation) -> _R: + """ + Replace child type annotation with a new one. + """ + if child not in self.children: + raise ValueError(f"Child not found: {child}") + + index = self.children.index(child) + self.children[index] = new_child + return self diff --git a/tests/type_annotations/test_type_subscript.py b/tests/type_annotations/test_type_subscript.py index 573a32a4..a863f0a6 100644 --- a/tests/type_annotations/test_type_subscript.py +++ b/tests/type_annotations/test_type_subscript.py @@ -42,3 +42,18 @@ def test_is_list(self) -> None: def test_copy(self) -> None: assert self.result.copy().parent == Type.Dict + + def test_find_type_annotation_parent(self) -> None: + inner = TypeSubscript(Type.List, [Type.int]) + outer = TypeSubscript(Type.Dict, [Type.str, inner]) + assert outer.find_type_annotation_parent(Type.int) == inner + assert outer.find_type_annotation_parent(Type.str) == outer + assert outer.find_type_annotation_parent(Type.List) is None + + def test_replace_child(self) -> None: + inner = TypeSubscript(Type.List, [Type.int]) + outer = TypeSubscript(Type.Dict, [Type.str, inner]) + assert outer.copy().replace_child(Type.str, Type.bool).render() == "Dict[bool, List[int]]" + assert outer.copy().replace_child(inner, Type.bool).render() == "Dict[str, bool]" + with pytest.raises(ValueError): + outer.copy().replace_child(Type.int, Type.bool)