diff --git a/bump_pydantic/main.py b/bump_pydantic/main.py index 63ee085..dc6a7f8 100644 --- a/bump_pydantic/main.py +++ b/bump_pydantic/main.py @@ -1,4 +1,5 @@ import difflib +import multiprocessing import os import time from pathlib import Path @@ -69,47 +70,51 @@ def main( codemods = gather_codemods() # TODO: We can run this in parallel - batch it into files / cores. - # with ProcessPoolExecutor(): - # cpu_count = multiprocessing.cpu_count() - # batch_size = len(files) // cpu_count + 1 - - # batches = [files[i : i + batch_size] for i in range(0, len(files), batch_size)] - # print(batches) - - for codemod in codemods: - for filename in files: - module_and_package = calculate_module_and_package(str(package), filename) - context = CodemodContext( - metadata_manager=metadata_manager, - filename=filename, - full_module_name=module_and_package.name, - full_package_name=module_and_package.package, - ) - context.scratch.update(scratch) - - transformer = codemod(context=context) - - old_code = Path(filename).read_text() - input_tree = cst.parse_module(old_code) - output_tree = transformer.transform_module(input_tree) - - input_code = input_tree.code - output_code = output_tree.code - - if input_code != output_code: + with multiprocessing.Pool(): + cpu_count = multiprocessing.cpu_count() + batch_size = len(files) // cpu_count + 1 + + [files[i : i + batch_size] for i in range(0, len(files), batch_size)] + + for filename in files: + module_and_package = calculate_module_and_package(str(package), filename) + context = CodemodContext( + metadata_manager=metadata_manager, + filename=filename, + full_module_name=module_and_package.name, + full_package_name=module_and_package.package, + ) + context.scratch.update(scratch) + + file_path = Path(filename) + with file_path.open("r+") as fp: + code = fp.read() + fp.seek(0) + + input_code = str(code) + + for codemod in codemods: + transformer = codemod(context=context) + + input_tree = cst.parse_module(input_code) + output_tree = transformer.transform_module(input_tree) + + input_code = output_tree.code + + if code != input_code: if diff: color_diff( console=console, lines=difflib.unified_diff( + code.splitlines(keepends=True), input_code.splitlines(keepends=True), - output_code.splitlines(keepends=True), fromfile=filename, tofile=filename, ), ) else: - with open(filename, "w") as fp: - fp.write(output_tree.code) + fp.write(input_code) + fp.truncate() modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] if modified: