Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Replace constr by StringConstraints #90

Merged
merged 3 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Bump Pydantic is a tool to help you migrate your code from Pydantic V1 to V2.
- [BP005: Replace `GenericModel` by `BaseModel`](#bp005-replace-genericmodel-by-basemodel)
- [BP006: Replace `__root__` by `RootModel`](#bp006-replace-__root__-by-rootmodel)
- [BP007: Replace decorators](#bp007-replace-decorators)
- [BP008: Replace `con*` functions by `Annotated` versions](#bp008-replace-con-functions-by-annotated-versions)
- [License](#license)

---
Expand Down Expand Up @@ -258,8 +259,39 @@ class User(BaseModel):
return values
```

<!--
### BP008: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
### BP008: Replace `con*` functions by `Annotated` versions

- ✅ Replace `constr(*args)` by `Annotated[str, StringConstraints(*args)]`.
- ✅ Replace `conint(*args)` by `Annotated[int, Field(*args)]`.
- ✅ Replace `confloat(*args)` by `Annotated[float, Field(*args)]`.
- ✅ Replace `conbytes(*args)` by `Annotated[bytes, Field(*args)]`.
- ✅ Replace `condecimal(*args)` by `Annotated[Decimal, Field(*args)]`.
- ✅ Replace `conset(T, *args)` by `Annotated[Set[T], Field(*args)]`.
- ✅ Replace `confrozenset(T, *args)` by `Annotated[Set[T], Field(*args)]`.
- ✅ Replace `conlist(T, *args)` by `Annotated[List[T], Field(*args)]`.

The following code will be transformed:

```py
from pydantic import BaseModel, constr


class User(BaseModel):
name: constr(min_length=1)
```

Into:

```py
from pydantic import BaseModel, StringConstraints
from typing_extensions import Annotated


class User(BaseModel):
name: Annotated[str, StringConstraints(min_length=1)]
```

<!-- ### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`

- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.

Expand Down Expand Up @@ -300,8 +332,31 @@ class Users(BaseModel):


users = TypeAdapter(Users).validate_python({'users': [{'name': 'John'}]})
``` -->

<!-- ### BP010: Replace `PyObject` by `ImportString`

- ✅ Replace `PyObject` by `ImportString`.

The following code will be transformed:

```py
from pydantic import BaseModel, PyObject


class User(BaseModel):
name: PyObject
```
-->

Into:

```py
from pydantic import BaseModel, ImportString


class User(BaseModel):
name: ImportString
``` -->

---

Expand Down
7 changes: 7 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.con_func import ConFuncCallCommand
from bump_pydantic.codemods.field import FieldCodemod
from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod
from bump_pydantic.codemods.replace_generic_model import ReplaceGenericModelCommand
Expand All @@ -28,6 +29,8 @@ class Rule(str, Enum):
"""Replace `BaseModel.__root__ = T` with `RootModel[T]`."""
BP007 = "BP007"
"""Replace `@validator` with `@field_validator`."""
BP008 = "BP008"
"""Replace `constr(<args>)` with `Annotated[str, StringConstraints(<args>)`."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand All @@ -39,6 +42,10 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP002 not in disabled:
codemods.append(ReplaceConfigCodemod)

# The `ConFuncCallCommand` needs to run before the `FieldCodemod`.
if Rule.BP008 not in disabled:
codemods.append(ConFuncCallCommand)

if Rule.BP003 not in disabled:
codemods.append(FieldCodemod)

Expand Down
152 changes: 152 additions & 0 deletions bump_pydantic/codemods/con_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import cast

import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

CONSTR_CALL = m.Call(func=m.Name("constr") | m.Attribute(value=m.Name("pydantic"), attr=m.Name("constr")))
ANN_ASSIGN_CONSTR_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CONSTR_CALL))


CON_NUMBER_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conint", "confloat", "condecimal", "conbytes")
]
)
ANN_ASSIGN_CON_NUMBER_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_NUMBER_CALL))

CON_COLLECTION_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conlist", "conset", "confrozenset")
]
)
ANN_ASSIGN_COLLECTION_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_COLLECTION_CALL))

MAP_FUNC_TO_TYPE = {
"constr": "str",
"conint": "int",
"confloat": "float",
"condecimal": "Decimal",
"conbytes": "bytes",
"conlist": "List",
"conset": "Set",
"confrozenset": "FrozenSet",
}
MAP_TYPE_TO_NEEDED_IMPORT = {
"Decimal": {"module": "decimal", "obj": "Decimal"},
"List": {"module": "typing", "obj": "List"},
"Set": {"module": "typing", "obj": "Set"},
"FrozenSet": {"module": "typing", "obj": "FrozenSet"},
}
COLLECTIONS = ("List", "Set", "FrozenSet")


class ConFuncCallCommand(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

@m.leave(ANN_ASSIGN_CONSTR_CALL | ANN_ASSIGN_CON_NUMBER_CALL | ANN_ASSIGN_COLLECTION_CALL)
def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
annotation = cast(cst.Call, original_node.annotation.annotation)
if m.matches(annotation.func, m.Name()):
func_name = cast(str, annotation.func.value) # type: ignore
else:
func_name = cast(str, annotation.func.attr.value) # type: ignore
type_name = MAP_FUNC_TO_TYPE[func_name]

needed_import = MAP_TYPE_TO_NEEDED_IMPORT.get(type_name)
if needed_import is not None:
AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type]

if type_name in COLLECTIONS:
slice_value = cst.Index(
value=cst.Subscript(
value=cst.Name(type_name),
slice=[cst.SubscriptElement(slice=cst.Index(value=self.inner_type))],
)
)
else:
slice_value = cst.Index(value=cst.Name(type_name))

AddImportsVisitor.add_needed_import(context=self.context, module="typing_extensions", obj="Annotated")
annotated = cst.Subscript(
value=cst.Name("Annotated"),
slice=[
cst.SubscriptElement(slice=slice_value),
cst.SubscriptElement(slice=cst.Index(value=updated_node.annotation.annotation)),
],
)
annotation = cst.Annotation(annotation=annotated) # type: ignore[assignment]
return updated_node.with_changes(annotation=annotation)

@m.leave(CONSTR_CALL)
def leave_constr_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints")
return updated_node.with_changes(
func=cst.Name("StringConstraints"),
args=[
arg if arg.keyword and arg.keyword.value != "regex" else arg.with_changes(keyword=cst.Name("pattern"))
for arg in updated_node.args
],
)

@m.leave(CON_NUMBER_CALL)
def leave_con_number_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="Field")
return updated_node.with_changes(func=cst.Name("Field"))

@m.leave(CON_COLLECTION_CALL)
def leave_con_collection_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="Field")
# NOTE: It's guaranteed to have at least one argument.
self.inner_type = updated_node.args[0].value
return updated_node.with_changes(func=cst.Name("Field"), args=updated_node.args[1:])

def _remove_import(self, func: cst.BaseExpression) -> None:
if m.matches(func, m.Name()):
assert isinstance(func, cst.Name)
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic", obj=func.value)
elif m.matches(func, m.Attribute()):
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic")


if __name__ == "__main__":
import textwrap

from rich.console import Console

console = Console()

source = textwrap.dedent(
"""
from pydantic import BaseModel, constr

class A(BaseModel):
a: constr(max_length=10)
b: conint(ge=0, le=100)
c: confloat(ge=0, le=100)
d: condecimal(ge=0, le=100)
e: conbytes(max_length=10)
f: conlist(int, min_items=1, max_items=10)
g: conset(float, min_items=1, max_items=10)
h: confrozenset(str, min_items=1, max_items=10)
i: conlist(int, unique_items=True)
"""
)
console.print(source)
console.print("=" * 80)

mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ConFuncCallCommand(context=context)
console.print(mod)

mod = wrapper.visit(command)
print(mod.code)
2 changes: 2 additions & 0 deletions tests/integration/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..folder import Folder
from .add_none import cases as add_none_cases
from .base_settings import cases as base_settings_cases
from .con_func import cases as con_func_cases
from .config_to_model import cases as config_to_model_cases
from .field import cases as generic_model_cases
from .folder_inside_folder import cases as folder_inside_folder_cases
Expand All @@ -26,6 +27,7 @@
*generic_model_cases,
*folder_inside_folder_cases,
*unicode_cases,
*con_func_cases,
]
before = Folder("project", *[case.source for case in cases])
expected = Folder("project", *[case.expected for case in cases])
43 changes: 43 additions & 0 deletions tests/integration/cases/con_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from ..case import Case
from ..file import File

cases = [
Case(
name="Con* functions",
source=File(
"con_func.py",
content=[
"from pydantic import BaseModel, constr, conlist, conint, conbytes, condecimal, confloat, conset",
"",
"",
"class Potato(BaseModel):",
" a: constr(regex='[a-z]+')",
" b: conlist(int, min_items=1, max_items=10)",
" c: conint(gt=0, lt=10)",
" d: conbytes(min_length=1, max_length=10)",
" e: condecimal(gt=0, lt=10)",
" f: confloat(gt=0, lt=10)",
" g: conset(int, min_items=1, max_items=10)",
],
),
expected=File(
"con_func.py",
content=[
"from pydantic import Field, StringConstraints, BaseModel",
"from decimal import Decimal",
"from typing import List, Set",
"from typing_extensions import Annotated",
"",
"",
"class Potato(BaseModel):",
" a: Annotated[str, StringConstraints(pattern='[a-z]+')]",
" b: Annotated[List[int], Field(min_length=1, max_length=10)]",
" c: Annotated[int, Field(gt=0, lt=10)]",
" d: Annotated[bytes, Field(min_length=1, max_length=10)]",
" e: Annotated[Decimal, Field(gt=0, lt=10)]",
" f: Annotated[float, Field(gt=0, lt=10)]",
" g: Annotated[Set[int], Field(min_length=1, max_length=10)]",
],
),
)
]
75 changes: 75 additions & 0 deletions tests/unit/test_con_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from libcst.codemod import CodemodTest

from bump_pydantic.codemods.con_func import ConFuncCallCommand


class TestFieldCommand(CodemodTest):
TRANSFORM = ConFuncCallCommand

maxDiff = None

def test_constr_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, constr

class Potato(BaseModel):
potato: constr(min_length=1, max_length=10)
"""
after = """
from pydantic import StringConstraints, BaseModel
from typing_extensions import Annotated

class Potato(BaseModel):
potato: Annotated[str, StringConstraints(min_length=1, max_length=10)]
"""
self.assertCodemod(before, after)

def test_pydantic_constr_to_annotated(self) -> None:
before = """
import pydantic
from pydantic import BaseModel

class Potato(BaseModel):
potato: pydantic.constr(min_length=1, max_length=10)
"""
after = """
from pydantic import StringConstraints, BaseModel
from typing_extensions import Annotated

class Potato(BaseModel):
potato: Annotated[str, StringConstraints(min_length=1, max_length=10)]
"""
self.assertCodemod(before, after)

def test_conlist_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, conlist

class Potato(BaseModel):
potato: conlist(str, min_items=1, max_items=10)
"""
after = """
from pydantic import Field, BaseModel
from typing import List
from typing_extensions import Annotated

class Potato(BaseModel):
potato: Annotated[List[str], Field(min_items=1, max_items=10)]
"""
self.assertCodemod(before, after)

def test_conint_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, conint

class Potato(BaseModel):
potato: conint(ge=0, le=100)
"""
after = """
from pydantic import Field, BaseModel
from typing_extensions import Annotated

class Potato(BaseModel):
potato: Annotated[int, Field(ge=0, le=100)]
"""
self.assertCodemod(before, after)