Skip to content

Commit

Permalink
✨ Replace con* functions by Annotated (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 14, 2023
1 parent 98a3666 commit 3bd8332
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 3 deletions.
61 changes: 58 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Bump Pydantic is a tool to help you migrate your code from Pydantic V1 to V2.
- [BP005: Replace `GenericModel` by `BaseModel`](#bp005-replace-genericmodel-by-basemodel)
- [BP006: Replace `__root__` by `RootModel`](#bp006-replace-__root__-by-rootmodel)
- [BP007: Replace decorators](#bp007-replace-decorators)
- [BP008: Replace `con*` functions by `Annotated` versions](#bp008-replace-con-functions-by-annotated-versions)
- [License](#license)

---
Expand Down Expand Up @@ -258,8 +259,39 @@ class User(BaseModel):
return values
```

<!--
### BP008: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
### BP008: Replace `con*` functions by `Annotated` versions

- ✅ Replace `constr(*args)` by `Annotated[str, StringConstraints(*args)]`.
- ✅ Replace `conint(*args)` by `Annotated[int, Field(*args)]`.
- ✅ Replace `confloat(*args)` by `Annotated[float, Field(*args)]`.
- ✅ Replace `conbytes(*args)` by `Annotated[bytes, Field(*args)]`.
- ✅ Replace `condecimal(*args)` by `Annotated[Decimal, Field(*args)]`.
- ✅ Replace `conset(T, *args)` by `Annotated[Set[T], Field(*args)]`.
- ✅ Replace `confrozenset(T, *args)` by `Annotated[Set[T], Field(*args)]`.
- ✅ Replace `conlist(T, *args)` by `Annotated[List[T], Field(*args)]`.

The following code will be transformed:

```py
from pydantic import BaseModel, constr


class User(BaseModel):
name: constr(min_length=1)
```

Into:

```py
from pydantic import BaseModel, StringConstraints
from typing_extensions import Annotated


class User(BaseModel):
name: Annotated[str, StringConstraints(min_length=1)]
```

<!-- ### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.
Expand Down Expand Up @@ -300,8 +332,31 @@ class Users(BaseModel):
users = TypeAdapter(Users).validate_python({'users': [{'name': 'John'}]})
``` -->

<!-- ### BP010: Replace `PyObject` by `ImportString`
- ✅ Replace `PyObject` by `ImportString`.
The following code will be transformed:
```py
from pydantic import BaseModel, PyObject
class User(BaseModel):
name: PyObject
```
-->
Into:
```py
from pydantic import BaseModel, ImportString
class User(BaseModel):
name: ImportString
``` -->

---

Expand Down
7 changes: 7 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.con_func import ConFuncCallCommand
from bump_pydantic.codemods.field import FieldCodemod
from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod
from bump_pydantic.codemods.replace_generic_model import ReplaceGenericModelCommand
Expand All @@ -28,6 +29,8 @@ class Rule(str, Enum):
"""Replace `BaseModel.__root__ = T` with `RootModel[T]`."""
BP007 = "BP007"
"""Replace `@validator` with `@field_validator`."""
BP008 = "BP008"
"""Replace `constr(<args>)` with `Annotated[str, StringConstraints(<args>)`."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand All @@ -39,6 +42,10 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP002 not in disabled:
codemods.append(ReplaceConfigCodemod)

# The `ConFuncCallCommand` needs to run before the `FieldCodemod`.
if Rule.BP008 not in disabled:
codemods.append(ConFuncCallCommand)

if Rule.BP003 not in disabled:
codemods.append(FieldCodemod)

Expand Down
152 changes: 152 additions & 0 deletions bump_pydantic/codemods/con_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import cast

import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

CONSTR_CALL = m.Call(func=m.Name("constr") | m.Attribute(value=m.Name("pydantic"), attr=m.Name("constr")))
ANN_ASSIGN_CONSTR_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CONSTR_CALL))


CON_NUMBER_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conint", "confloat", "condecimal", "conbytes")
]
)
ANN_ASSIGN_CON_NUMBER_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_NUMBER_CALL))

CON_COLLECTION_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conlist", "conset", "confrozenset")
]
)
ANN_ASSIGN_COLLECTION_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_COLLECTION_CALL))

MAP_FUNC_TO_TYPE = {
"constr": "str",
"conint": "int",
"confloat": "float",
"condecimal": "Decimal",
"conbytes": "bytes",
"conlist": "List",
"conset": "Set",
"confrozenset": "FrozenSet",
}
MAP_TYPE_TO_NEEDED_IMPORT = {
"Decimal": {"module": "decimal", "obj": "Decimal"},
"List": {"module": "typing", "obj": "List"},
"Set": {"module": "typing", "obj": "Set"},
"FrozenSet": {"module": "typing", "obj": "FrozenSet"},
}
COLLECTIONS = ("List", "Set", "FrozenSet")


class ConFuncCallCommand(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

@m.leave(ANN_ASSIGN_CONSTR_CALL | ANN_ASSIGN_CON_NUMBER_CALL | ANN_ASSIGN_COLLECTION_CALL)
def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
annotation = cast(cst.Call, original_node.annotation.annotation)
if m.matches(annotation.func, m.Name()):
func_name = cast(str, annotation.func.value) # type: ignore
else:
func_name = cast(str, annotation.func.attr.value) # type: ignore
type_name = MAP_FUNC_TO_TYPE[func_name]

needed_import = MAP_TYPE_TO_NEEDED_IMPORT.get(type_name)
if needed_import is not None:
AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type]

if type_name in COLLECTIONS:
slice_value = cst.Index(
value=cst.Subscript(
value=cst.Name(type_name),
slice=[cst.SubscriptElement(slice=cst.Index(value=self.inner_type))],
)
)
else:
slice_value = cst.Index(value=cst.Name(type_name))

AddImportsVisitor.add_needed_import(context=self.context, module="typing_extensions", obj="Annotated")
annotated = cst.Subscript(
value=cst.Name("Annotated"),
slice=[
cst.SubscriptElement(slice=slice_value),
cst.SubscriptElement(slice=cst.Index(value=updated_node.annotation.annotation)),
],
)
annotation = cst.Annotation(annotation=annotated) # type: ignore[assignment]
return updated_node.with_changes(annotation=annotation)

@m.leave(CONSTR_CALL)
def leave_constr_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints")
return updated_node.with_changes(
func=cst.Name("StringConstraints"),
args=[
arg if arg.keyword and arg.keyword.value != "regex" else arg.with_changes(keyword=cst.Name("pattern"))
for arg in updated_node.args
],
)

@m.leave(CON_NUMBER_CALL)
def leave_con_number_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="Field")
return updated_node.with_changes(func=cst.Name("Field"))

@m.leave(CON_COLLECTION_CALL)
def leave_con_collection_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="Field")
# NOTE: It's guaranteed to have at least one argument.
self.inner_type = updated_node.args[0].value
return updated_node.with_changes(func=cst.Name("Field"), args=updated_node.args[1:])

def _remove_import(self, func: cst.BaseExpression) -> None:
if m.matches(func, m.Name()):
assert isinstance(func, cst.Name)
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic", obj=func.value)
elif m.matches(func, m.Attribute()):
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic")


if __name__ == "__main__":
import textwrap

from rich.console import Console

console = Console()

source = textwrap.dedent(
"""
from pydantic import BaseModel, constr
class A(BaseModel):
a: constr(max_length=10)
b: conint(ge=0, le=100)
c: confloat(ge=0, le=100)
d: condecimal(ge=0, le=100)
e: conbytes(max_length=10)
f: conlist(int, min_items=1, max_items=10)
g: conset(float, min_items=1, max_items=10)
h: confrozenset(str, min_items=1, max_items=10)
i: conlist(int, unique_items=True)
"""
)
console.print(source)
console.print("=" * 80)

mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ConFuncCallCommand(context=context)
console.print(mod)

mod = wrapper.visit(command)
print(mod.code)
2 changes: 2 additions & 0 deletions tests/integration/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..folder import Folder
from .add_none import cases as add_none_cases
from .base_settings import cases as base_settings_cases
from .con_func import cases as con_func_cases
from .config_to_model import cases as config_to_model_cases
from .field import cases as generic_model_cases
from .folder_inside_folder import cases as folder_inside_folder_cases
Expand All @@ -26,6 +27,7 @@
*generic_model_cases,
*folder_inside_folder_cases,
*unicode_cases,
*con_func_cases,
]
before = Folder("project", *[case.source for case in cases])
expected = Folder("project", *[case.expected for case in cases])
43 changes: 43 additions & 0 deletions tests/integration/cases/con_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from ..case import Case
from ..file import File

cases = [
Case(
name="Con* functions",
source=File(
"con_func.py",
content=[
"from pydantic import BaseModel, constr, conlist, conint, conbytes, condecimal, confloat, conset",
"",
"",
"class Potato(BaseModel):",
" a: constr(regex='[a-z]+')",
" b: conlist(int, min_items=1, max_items=10)",
" c: conint(gt=0, lt=10)",
" d: conbytes(min_length=1, max_length=10)",
" e: condecimal(gt=0, lt=10)",
" f: confloat(gt=0, lt=10)",
" g: conset(int, min_items=1, max_items=10)",
],
),
expected=File(
"con_func.py",
content=[
"from pydantic import Field, StringConstraints, BaseModel",
"from decimal import Decimal",
"from typing import List, Set",
"from typing_extensions import Annotated",
"",
"",
"class Potato(BaseModel):",
" a: Annotated[str, StringConstraints(pattern='[a-z]+')]",
" b: Annotated[List[int], Field(min_length=1, max_length=10)]",
" c: Annotated[int, Field(gt=0, lt=10)]",
" d: Annotated[bytes, Field(min_length=1, max_length=10)]",
" e: Annotated[Decimal, Field(gt=0, lt=10)]",
" f: Annotated[float, Field(gt=0, lt=10)]",
" g: Annotated[Set[int], Field(min_length=1, max_length=10)]",
],
),
)
]
75 changes: 75 additions & 0 deletions tests/unit/test_con_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from libcst.codemod import CodemodTest

from bump_pydantic.codemods.con_func import ConFuncCallCommand


class TestFieldCommand(CodemodTest):
TRANSFORM = ConFuncCallCommand

maxDiff = None

def test_constr_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, constr
class Potato(BaseModel):
potato: constr(min_length=1, max_length=10)
"""
after = """
from pydantic import StringConstraints, BaseModel
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[str, StringConstraints(min_length=1, max_length=10)]
"""
self.assertCodemod(before, after)

def test_pydantic_constr_to_annotated(self) -> None:
before = """
import pydantic
from pydantic import BaseModel
class Potato(BaseModel):
potato: pydantic.constr(min_length=1, max_length=10)
"""
after = """
from pydantic import StringConstraints, BaseModel
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[str, StringConstraints(min_length=1, max_length=10)]
"""
self.assertCodemod(before, after)

def test_conlist_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, conlist
class Potato(BaseModel):
potato: conlist(str, min_items=1, max_items=10)
"""
after = """
from pydantic import Field, BaseModel
from typing import List
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[List[str], Field(min_items=1, max_items=10)]
"""
self.assertCodemod(before, after)

def test_conint_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, conint
class Potato(BaseModel):
potato: conint(ge=0, le=100)
"""
after = """
from pydantic import Field, BaseModel
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[int, Field(ge=0, le=100)]
"""
self.assertCodemod(before, after)

0 comments on commit 3bd8332

Please sign in to comment.