From adaffb077596fad38b52d807f9a7989c1e303b54 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 10 Jul 2023 10:57:06 +0200 Subject: [PATCH] Refactor a bit --- bump_pydantic/codemods/field.py | 17 ++++++++++------- tests/unit/test_field.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bump_pydantic/codemods/field.py b/bump_pydantic/codemods/field.py index c20f881..e6d647f 100644 --- a/bump_pydantic/codemods/field.py +++ b/bump_pydantic/codemods/field.py @@ -99,13 +99,16 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c new_args: List[cst.Arg] = [] for arg in updated_node.args: if m.matches(arg, m.Arg(keyword=m.Name())): - args_dict = dict() - keyword = RENAMED_KEYWORDS.get(arg.keyword.value, arg.keyword.value) - args_dict["keyword"] = arg.keyword.with_changes(value=keyword) # type: ignore - # Check if keyword is `allow_mutation` and if so, invert the value. - if arg.keyword.value == "allow_mutation": - args_dict["value"] = arg.value.with_changes(value=str(not(arg.value.value == "True"))) - new_args.append(arg.with_changes(**args_dict)) + keyword = RENAMED_KEYWORDS.get(arg.keyword.value, arg.keyword.value) # type: ignore + value = arg.value + # The `allow_mutation` keyword argument is a special case. It's the negative of `frozen`. + if arg.keyword and arg.keyword.value == "allow_mutation": + if m.matches(arg.value, m.Name(value="False")): + value = cst.Name("True") + elif m.matches(arg.value, m.Name(value="True")): + value = cst.Name("False") + new_arg = arg.with_changes(keyword=arg.keyword.with_changes(value=keyword), value=value) # type: ignore + new_args.append(new_arg) # type: ignore else: new_args.append(arg) diff --git a/tests/unit/test_field.py b/tests/unit/test_field.py index cf09017..62e02b5 100644 --- a/tests/unit/test_field.py +++ b/tests/unit/test_field.py @@ -118,4 +118,4 @@ class Settings(BaseSettings): potato: int = Field(..., frozen=True) strawberry: int = Field(..., frozen=False) """ - self.assertCodemod(before, after) \ No newline at end of file + self.assertCodemod(before, after)