forked from GoogleCloudPlatform/python-docs-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert-types.py
322 lines (266 loc) · 9.95 KB
/
convert-types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#!/usr/bin/env python
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to convert type hints to follow PEP-0585
For more information, see https://peps.python.org/pep-0585
To run from the repository's root directory:
python convert-types.py
"""
from __future__ import annotations
from collections.abc import Callable, Iterator
import difflib
from glob import glob
import logging
import re
import sys
from typing import NamedTuple, TypeVar
# TODO:
# - False positives with lambdas and dict comprehensions due to the `:` misinterpreted as a type hint
# - Sort imports case insensitive (e.g. PIL, Flask)
# - Type hint arguments can be lists like `Callable[[a, b], c]`
# Cases not covered:
# - Multi-line imports like `from M import (\nA,\nB,\n)`
# - Importing `typing` directly like `import typing` and `x: typing.Any`
# - Parsing types with `|` syntax like Union or Optional
# - typing.re.Match --> re.Match
# - typing.re.Pattern --> re.Pattern
BUILTIN_TYPES = {"Tuple", "List", "Dict", "Set", "FrozenSet", "Type"}
COLLECTIONS_TYPES = {"Deque", "DefaultDict", "OrderedDict", "Counter", "ChainMap"}
COLLECTIONS_ABC_TYPES = {
"Awaitable",
"Coroutine",
"AsyncIterable",
"AsyncIterator",
"AsyncGenerator",
"Iterable",
"Iterator",
"Generator",
"Reversible",
"Container",
"Collection",
"Callable",
"AbstractSet",
"MutableSet",
"Mapping",
"MutableMapping",
"Sequence",
"MutableSequence",
"ByteString",
"MappingView",
"KeysView",
"ItemsView",
"ValuesView",
}
CONTEXTLIB_TYPES = {"ContextManager", "AsyncContextManager"}
RE_TYPES = {"Match", "Pattern"}
RENAME_TYPES = {
"Tuple": "tuple", # builtin
"List": "list", # builtin
"Dict": "dict", # builtin
"Set": "set", # builtin
"FrozenSet": "frozenset", # builtin
"Type": "type", # builtin
"Deque": "deque", # collections
"DefaultDict": "defaultdict", # collections
"AbstractSet": "Set", # collections.abc
"ContextManager": "AbstractContextManager", # contextlib
"AsyncContextManager": "AbstractAsyncContextManager", # contextlib
}
# Parser a = String -> (a, String)
a = TypeVar("a")
Parser = Callable[[str], tuple[a, str]]
class TypeHint(NamedTuple):
name: str
args: list[TypeHint]
def __repr__(self) -> str:
match (self.name, self.args):
case ("Optional", [x]):
return f"{x} | None"
case ("Union", args):
return " | ".join(map(str, args))
case (name, []):
return name
case (name, args):
return f"{name}[{', '.join(map(str, args))}]"
def patch(self, types: set[str]) -> TypeHint:
if self.name in types:
name = RENAME_TYPES.get(self.name, self.name)
else:
name = self.name
return TypeHint(name, [arg.patch(types) for arg in self.args])
def patch_file(file_path: str, dry_run: bool = False, quiet: bool = False) -> None:
with open(file_path) as f:
before = f.read()
try:
lines = [line.rstrip() for line in before.splitlines()]
if types := find_typing_imports(lines):
lines = insert_import_annotations(lines)
lines = [patched for line in lines for patched in patch_imports(line)]
lines = sort_imports(lines)
after = patch_type_hints("\n".join(lines), types) + "\n"
if before == after:
return
if not dry_run:
with open(file_path, "w") as f:
f.write(after)
print(file_path)
elif not quiet:
print(f"| {file_path}")
print(f"+--{'-' * len(file_path)}")
diffs = difflib.context_diff(
before.splitlines(keepends=True),
after.splitlines(keepends=True),
fromfile="Before changes",
tofile="After changes",
n=1,
)
sys.stdout.writelines(diffs)
print(f"+{'=' * 100}")
print("| Press [ENTER] to continue to the next file")
input()
except Exception:
logging.exception(f"Could not process file: {file_path}")
def insert_import_annotations(lines: list[str]) -> list[str]:
new_import = "from __future__ import annotations"
if new_import in lines:
return lines
match find_import(lines):
case None:
return lines
case i:
if lines[i].startswith("from __future__ import "):
return lines[:i] + [new_import] + lines[i:]
return lines[:i] + [new_import, ""] + lines[i:]
def find_typing_imports(lines: list[str]) -> set[str]:
return {
name.strip()
for line in lines
if line.startswith("from typing import ")
for name in line.split("import")[1].split(",")
}
def find_import(lines: list[str]) -> int | None:
for i, line in enumerate(lines):
if line.startswith(("import ", "from ")):
return i
return None
def get_imports_group(lines: list[str]) -> tuple[list[str], list[str]]:
for i, line in enumerate(lines):
if not line.strip() or line.startswith("#"):
return (lines[:i], lines[i:])
return ([], lines)
def import_name(line: str) -> str:
match line.split():
case ["import", name, *_]:
return name
case ["from", name, "import", *_]:
return name
raise ValueError(f"not an import: {line}")
def sort_imports(lines: list[str]) -> list[str]:
match find_import(lines):
case None:
return lines
case i:
(imports, left) = get_imports_group(lines[i:])
if imports:
return lines[:i] + sorted(imports, key=import_name) + sort_imports(left)
return left
def patch_imports(line: str) -> Iterator[str]:
if not line.startswith("from typing import "):
yield line
return
types = find_typing_imports([line])
collections_types = types.intersection(COLLECTIONS_TYPES)
collections_abc_types = types.intersection(COLLECTIONS_ABC_TYPES)
contextlib_types = types.intersection(CONTEXTLIB_TYPES)
re_types = types.intersection(RE_TYPES)
typing_types = (
types
- BUILTIN_TYPES
- COLLECTIONS_TYPES
- COLLECTIONS_ABC_TYPES
- CONTEXTLIB_TYPES
- RE_TYPES
- {"Optional", "Union"}
)
rename = lambda name: RENAME_TYPES.get(name, name)
if collections_types:
names = sorted(map(rename, collections_types))
yield f"from collections import {', '.join(names)}"
if collections_abc_types:
names = sorted(map(rename, collections_abc_types))
yield f"from collections.abc import {', '.join(names)}"
if contextlib_types:
names = sorted(map(rename, contextlib_types))
yield f"from contextlib import {', '.join(names)}"
if re_types:
names = sorted(map(rename, re_types))
yield f"from re import {', '.join(names)}"
if typing_types:
names = sorted(map(rename, typing_types))
yield f"from typing import {', '.join(names)}"
def patch_type_hints(txt: str, types: set[str]) -> str:
if m := re.search(rf"(?:->|:) *(\w+)", txt):
(typ, left) = parse_type_hint(txt[m.start(1) :])
return f"{txt[:m.start(1)]}{typ.patch(types)}{patch_type_hints(left, types)}"
return txt
# Parser combinators
def parse_text(src: str, txt: str) -> tuple[str, str]:
if src.startswith(txt):
return (src[: len(txt)], src[len(txt) :])
raise SyntaxError("text")
def parse_identifier(src: str) -> tuple[str, str]:
if m := re.search(r"[\w\._]+", src):
return (m.group(), src[m.end() :])
raise SyntaxError("name")
def parse_zero_or_more(src: str, parser: Parser[a]) -> tuple[list[a], str]:
try:
(x, src) = parser(src)
(xs, src) = parse_zero_or_more(src, parser)
return ([x] + xs, src)
except SyntaxError:
return ([], src)
def parse_comma_separated(src: str, parser: Parser[a]) -> tuple[list[a], str]:
def parse_next(src: str) -> tuple[a, str]:
(_, src) = parse_text(src, ",")
(_, src) = parse_zero_or_more(src, lambda src: parse_text(src, " "))
return parser(src)
try:
(x, src) = parser(src)
(xs, src) = parse_zero_or_more(src, parse_next)
return ([x] + xs, src)
except SyntaxError:
return ([], src)
def parse_type_hint(src: str) -> tuple[TypeHint, str]:
(name, src) = parse_identifier(src)
try:
(_, src) = parse_text(src, "[")
(args, src) = parse_comma_separated(src, parse_type_hint)
(_, src) = parse_text(src, "]")
return (TypeHint(name, args), src)
except SyntaxError:
return (TypeHint(name, []), src)
def run(patterns: list[str], dry_run: bool = False, quiet: bool = False):
for pattern in patterns:
for filename in glob(pattern, recursive=True):
patch_file(filename, dry_run, quiet)
if __name__ == "__main__":
import argparse
assert sys.version_info.major == 3, "Requires Python 3"
assert sys.version_info.minor >= 10, "Requires Python >= 3.10 for pattern matching"
parser = argparse.ArgumentParser()
parser.add_argument("patterns", nargs="*", default=["**/*.py"])
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--quiet", action="store_true")
args = parser.parse_args()
run(**args.__dict__)