Skip to content

Commit

Permalink
Fix method arguments for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
vemel committed Jul 28, 2023
1 parent 3bc1d7e commit 14bdbc5
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mypy_boto3_builder/parsers/service_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def parse_service_package(
result.client.methods.append(method)

shape_parser.fix_typed_dict_names()

methods = [
*result.client.methods,
*(result.service_resource.methods if result.service_resource else []),
*[method for paginator in result.paginators for method in paginator.methods],
*[method for waiter in result.waiters for method in waiter.methods],
]
shape_parser.fix_method_arguments_for_mypy(methods)

result.typed_dicts = result.extract_typed_dicts()
result.literals = result.extract_literals()
result.validate()
Expand Down
45 changes: 45 additions & 0 deletions mypy_boto3_builder/parsers/shape_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, session: Session, service_name: ServiceName):
self._typed_dict_map: dict[str, TypeTypedDict] = {}
self._output_typed_dict_map: dict[str, TypeTypedDict] = {}
self._response_typed_dict_map: dict[str, TypeTypedDict] = {}
self._fixed_typed_dict_map: dict[TypeTypedDict, TypeTypedDict] = {}

self._waiters_shape: Mapping[str, Any] | None = None
with contextlib.suppress(UnknownServiceError):
Expand Down Expand Up @@ -874,6 +875,7 @@ def fix_typed_dict_names(self) -> None:

old_typed_dict_name = typed_dict.name
new_typed_dict_name = self._get_non_clashing_typed_dict_name(typed_dict, "Output")
self._fixed_typed_dict_map[typed_dict] = output_typed_dict
self.logger.debug(
f"Fixing TypedDict name clash {old_typed_dict_name} -> {new_typed_dict_name}"
)
Expand Down Expand Up @@ -922,3 +924,46 @@ def fix_typed_dict_names(self) -> None:
response_typed_dict.name = new_typed_dict_name
del self._response_typed_dict_map[old_typed_dict_name]
self._response_typed_dict_map[response_typed_dict.name] = response_typed_dict

def fix_method_arguments_for_mypy(self, methods: Sequence[Method]) -> None:
"""
Accept both input and output shapes in method arguments.
mypy does not compare TypedDicts, so we need to accept both input and output shapes.
https://github.com/youtype/mypy_boto3_builder/issues/209
"""
for input_typed_dict, output_typed_dict in self._fixed_typed_dict_map.items():
for method in methods:
for argument in method.arguments:
if not argument.type_annotation:
continue
if (
argument.type_annotation.is_typed_dict()
and argument.type_annotation == input_typed_dict
):
self.logger.debug(
f"Adding output shape to {method.name} {argument.name} type:"
f" {input_typed_dict.name} | {output_typed_dict.name}"
)
argument.type_annotation = TypeSubscript(
Type.Union,
[input_typed_dict, output_typed_dict],
)
continue
if isinstance(argument.type_annotation, TypeSubscript):
parent = argument.type_annotation.find_type_annotation_parent(
input_typed_dict
)
if parent:
self.logger.debug(
f"Adding output shape to {method.name} {argument.name} type:"
f" {input_typed_dict.name} | {output_typed_dict.name}"
)
parent.replace_child(
input_typed_dict,
TypeSubscript(
Type.Union,
[input_typed_dict, output_typed_dict],
),
)
continue
28 changes: 28 additions & 0 deletions mypy_boto3_builder/type_annotations/type_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,31 @@ def get_local_types(self) -> list[FakeAnnotation]:
for child in self.children:
result.extend(child.get_local_types())
return result

def find_type_annotation_parent(
self, type_annotation: FakeAnnotation
) -> "TypeSubscript | None":
"""
Check recursively if child is present in subscript.
"""
if type_annotation in self.children:
return self

type_subscript_children = [i for i in self.children if isinstance(i, TypeSubscript)]
for child in type_subscript_children:
result = child.find_type_annotation_parent(type_annotation)
if result:
return result

return None

def replace_child(self: _R, child: FakeAnnotation, new_child: FakeAnnotation) -> _R:
"""
Replace child type annotation with a new one.
"""
if child not in self.children:
raise ValueError(f"Child not found: {child}")

index = self.children.index(child)
self.children[index] = new_child
return self
15 changes: 15 additions & 0 deletions tests/type_annotations/test_type_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ def test_is_list(self) -> None:

def test_copy(self) -> None:
assert self.result.copy().parent == Type.Dict

def test_find_type_annotation_parent(self) -> None:
inner = TypeSubscript(Type.List, [Type.int])
outer = TypeSubscript(Type.Dict, [Type.str, inner])
assert outer.find_type_annotation_parent(Type.int) == inner
assert outer.find_type_annotation_parent(Type.str) == outer
assert outer.find_type_annotation_parent(Type.List) is None

def test_replace_child(self) -> None:
inner = TypeSubscript(Type.List, [Type.int])
outer = TypeSubscript(Type.Dict, [Type.str, inner])
assert outer.copy().replace_child(Type.str, Type.bool).render() == "Dict[bool, List[int]]"
assert outer.copy().replace_child(inner, Type.bool).render() == "Dict[str, bool]"
with pytest.raises(ValueError):
outer.copy().replace_child(Type.int, Type.bool)

0 comments on commit 14bdbc5

Please sign in to comment.