Skip to content

Commit

Permalink
Replace other con functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jul 14, 2023
1 parent 37e2dab commit 4de73c6
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 14 deletions.
39 changes: 34 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Bump Pydantic is a tool to help you migrate your code from Pydantic V1 to V2.
- [BP005: Replace `GenericModel` by `BaseModel`](#bp005-replace-genericmodel-by-basemodel)
- [BP006: Replace `__root__` by `RootModel`](#bp006-replace-__root__-by-rootmodel)
- [BP007: Replace decorators](#bp007-replace-decorators)
- [BP008: Replace `constr(*args)` by `Annotated[str, StringConstraints(*args)]`](#bp008-replace-constrargs-by-annotatedstr-stringconstraintsargs)
- [BP008: Replace `con*` functions by `Annotated` versions](#bp008-replace-con-functions-by-annotated-versions)
- [License](#license)

---
Expand Down Expand Up @@ -259,9 +259,16 @@ class User(BaseModel):
return values
```

### BP008: Replace `constr(*args)` by `Annotated[str, StringConstraints(*args)]`
### BP008: Replace `con*` functions by `Annotated` versions

- ✅ Replace `constr(*args)` by `Annotated[str, StringConstraints(*args)]`.
- ✅ Replace `conint(*args)` by `Annotated[int, Field(*args)]`.
- ✅ Replace `confloat(*args)` by `Annotated[float, Field(*args)]`.
- ✅ Replace `conbytes(*args)` by `Annotated[bytes, Field(*args)]`.
- ✅ Replace `condecimal(*args)` by `Annotated[Decimal, Field(*args)]`.
- ✅ Replace `conset(T, *args)` by `Annotated[Set[T], Field(*args)]`.
- ✅ Replace `confrozenset(T, *args)` by `Annotated[Set[T], Field(*args)]`.
- ✅ Replace `conlist(T, *args)` by `Annotated[List[T], Field(*args)]`.

The following code will be transformed:

Expand All @@ -284,8 +291,7 @@ class User(BaseModel):
name: Annotated[str, StringConstraints(min_length=1)]
```

<!--
### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
<!-- ### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.
Expand Down Expand Up @@ -326,8 +332,31 @@ class Users(BaseModel):
users = TypeAdapter(Users).validate_python({'users': [{'name': 'John'}]})
``` -->

<!-- ### BP010: Replace `PyObject` by `ImportString`
- ✅ Replace `PyObject` by `ImportString`.
The following code will be transformed:
```py
from pydantic import BaseModel, PyObject
class User(BaseModel):
name: PyObject
```
-->
Into:
```py
from pydantic import BaseModel, ImportString
class User(BaseModel):
name: ImportString
``` -->

---

Expand Down
7 changes: 4 additions & 3 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP002 not in disabled:
codemods.append(ReplaceConfigCodemod)

# The `ConFuncCallCommand` needs to run before the `FieldCodemod`.
if Rule.BP008 not in disabled:
codemods.append(ConFuncCallCommand)

if Rule.BP003 not in disabled:
codemods.append(FieldCodemod)

Expand All @@ -57,9 +61,6 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP007 not in disabled:
codemods.append(ValidatorCodemod)

if Rule.BP008 not in disabled:
codemods.append(ConFuncCallCommand)

# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods
96 changes: 91 additions & 5 deletions bump_pydantic/codemods/con_func.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import libcst as cst
from libcst import matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
Expand All @@ -7,29 +9,106 @@
ANN_ASSIGN_CONSTR_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CONSTR_CALL))


CON_NUMBER_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conint", "confloat", "condecimal", "conbytes")
]
)
ANN_ASSIGN_CON_NUMBER_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_NUMBER_CALL))

CON_COLLECTION_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conlist", "conset", "confrozenset")
]
)
ANN_ASSIGN_COLLECTION_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_COLLECTION_CALL))

MAP_FUNC_TO_TYPE = {
"constr": "str",
"conint": "int",
"confloat": "float",
"condecimal": "Decimal",
"conbytes": "bytes",
"conlist": "List",
"conset": "Set",
"confrozenset": "FrozenSet",
}
MAP_TYPE_TO_NEEDED_IMPORT = {
"Decimal": {"module": "decimal", "obj": "Decimal"},
"List": {"module": "typing", "obj": "List"},
"Set": {"module": "typing", "obj": "Set"},
"FrozenSet": {"module": "typing", "obj": "FrozenSet"},
}
COLLECTIONS = ("List", "Set", "FrozenSet")


class ConFuncCallCommand(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

@m.leave(ANN_ASSIGN_CONSTR_CALL)
@m.leave(ANN_ASSIGN_CONSTR_CALL | ANN_ASSIGN_CON_NUMBER_CALL | ANN_ASSIGN_COLLECTION_CALL)
def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
annotation = cast(cst.Call, original_node.annotation.annotation)
if m.matches(annotation.func, m.Name()):
func_name = cast(str, annotation.func.value) # type: ignore
else:
func_name = cast(str, annotation.func.attr.value) # type: ignore
type_name = MAP_FUNC_TO_TYPE[func_name]

needed_import = MAP_TYPE_TO_NEEDED_IMPORT.get(type_name)
if needed_import is not None:
AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type]

if type_name in COLLECTIONS:
slice_value = cst.Index(
value=cst.Subscript(
value=cst.Name(type_name),
slice=[cst.SubscriptElement(slice=cst.Index(value=self.inner_type))],
)
)
else:
slice_value = cst.Index(value=cst.Name(type_name))

AddImportsVisitor.add_needed_import(context=self.context, module="typing_extensions", obj="Annotated")
annotated = cst.Subscript(
value=cst.Name("Annotated"),
slice=[
cst.SubscriptElement(slice=cst.Index(value=cst.Name("str"))),
cst.SubscriptElement(slice=slice_value),
cst.SubscriptElement(slice=cst.Index(value=updated_node.annotation.annotation)),
],
)
annotation = cst.Annotation(annotation=annotated)
annotation = cst.Annotation(annotation=annotated) # type: ignore[assignment]
return updated_node.with_changes(annotation=annotation)

@m.leave(CONSTR_CALL)
def leave_constr_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic", obj="constr")
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints")
return updated_node.with_changes(func=cst.Name("StringConstraints"))

@m.leave(CON_NUMBER_CALL)
def leave_con_number_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="Field")
return updated_node.with_changes(func=cst.Name("Field"))

@m.leave(CON_COLLECTION_CALL)
def leave_con_collection_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="Field")
# NOTE: It's guaranteed to have at least one argument.
self.inner_type = updated_node.args[0].value
return updated_node.with_changes(func=cst.Name("Field"), args=updated_node.args[1:])

def _remove_import(self, func: cst.BaseExpression) -> None:
if m.matches(func, m.Name()):
assert isinstance(func, cst.Name)
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic", obj=func.value)
elif m.matches(func, m.Attribute()):
RemoveImportsVisitor.remove_unused_import(context=self.context, module="pydantic")


if __name__ == "__main__":
import textwrap
Expand All @@ -44,7 +123,14 @@ def leave_constr_call(self, original_node: cst.Call, updated_node: cst.Call) ->
class A(BaseModel):
a: constr(max_length=10)
b: Annotated[str, StringConstraints(max_length=10)]
b: conint(ge=0, le=100)
c: confloat(ge=0, le=100)
d: condecimal(ge=0, le=100)
e: conbytes(max_length=10)
f: conlist(int, min_items=1, max_items=10)
g: conset(float, min_items=1, max_items=10)
h: confrozenset(str, min_items=1, max_items=10)
i: conlist(int, unique_items=True)
"""
)
console.print(source)
Expand Down
34 changes: 33 additions & 1 deletion tests/unit/test_con_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,43 @@ class Potato(BaseModel):
potato: pydantic.constr(min_length=1, max_length=10)
"""
after = """
import pydantic
from pydantic import StringConstraints, BaseModel
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[str, StringConstraints(min_length=1, max_length=10)]
"""
self.assertCodemod(before, after)

def test_conlist_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, conlist
class Potato(BaseModel):
potato: conlist(str, min_items=1, max_items=10)
"""
after = """
from pydantic import Field, BaseModel
from typing import List
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[List[str], Field(min_items=1, max_items=10)]
"""
self.assertCodemod(before, after)

def test_conint_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, conint
class Potato(BaseModel):
potato: conint(ge=0, le=100)
"""
after = """
from pydantic import Field, BaseModel
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Annotated[int, Field(ge=0, le=100)]
"""
self.assertCodemod(before, after)

0 comments on commit 4de73c6

Please sign in to comment.