Skip to content

Commit

Permalink
✨ Add codemod and rule to add type annotations or TODO comment for fi…
Browse files Browse the repository at this point in the history
…elds with a default and no type annotations (#163)

* ✨ Implement add_annotations codemod

* ✨ Add AddAnnotations to CLI default codemods and new rule BP010

* ✅ Add test for AddAnnotationsCommand

* 📝 Add docs for new rule BP010

* ✅ Add test for model_config field

* 🐛 Fix add_annotations codemod to respect model_config

* ✅ Fix tests for CLI, const_to_literal with Pydantic model and Enum with the same class name
  • Loading branch information
tiangolo committed Apr 18, 2024
1 parent d959573 commit 5e4a43b
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 4 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,44 @@ class SomeThing:
field_schema['example'] = "Weird example"
```

### BP010: Add type annotations or TODO comments to fields without them

- ✅ Add type annotations based on the default value for a few types that can be inferred, like `bool`, `str`, `int`, `float`.
- ✅ Add `# TODO[pydantic]: add type annotation` comments to fields that can't be inferred.

The following code will be transformed:

```py
from pydantic import BaseModel, Field

class Potato(BaseModel):
name: str
is_sale = True
tags = ["tag1", "tag2"]
price = 10.5
description = "Some item"
active = Field(default=True)
ready = Field(True)
age = Field(10, title="Age")
```

Into:

```py
from pydantic import BaseModel, Field

class Potato(BaseModel):
name: str
is_sale: bool = True
# TODO[pydantic]: add type annotation
tags = ["tag1", "tag2"]
price: float = 10.5
description: str = "Some item"
active: bool = Field(default=True)
ready: bool = Field(True)
age: int = Field(10, title="Age")
```

<!-- ### BP010: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`
- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.
Expand Down
6 changes: 6 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from bump_pydantic.codemods.add_annotations import AddAnnotationsCommand
from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.con_func import ConFuncCallCommand
from bump_pydantic.codemods.custom_types import CustomTypeCodemod
Expand Down Expand Up @@ -34,6 +35,8 @@ class Rule(str, Enum):
"""Replace `con*` functions by `Annotated` versions."""
BP009 = "BP009"
"""Mark Pydantic "protocol" functions in custom types with proper TODOs."""
BP010 = "BP010"
"""Add type annotations or TODOs to fields without them."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand Down Expand Up @@ -67,6 +70,9 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP009 not in disabled:
codemods.append(CustomTypeCodemod)

if Rule.BP010 not in disabled:
codemods.append(AddAnnotationsCommand)

# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods
158 changes: 158 additions & 0 deletions bump_pydantic/codemods/add_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName

from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor

COMMENT = "# TODO[pydantic]: add type annotation"


class AddAnnotationsCommand(VisitorBasedCodemodCommand):
"""This codemod adds a type annotation or TODO comment to pydantic fields without
a type annotation.
Example::
# Before
```py
from pydantic import BaseModel, Field
class Foo(BaseModel):
name: str
is_sale = True
tags = ["tag1", "tag2"]
price = 10.5
description = "Some item"
active = Field(default=True)
ready = Field(True)
age = Field(10, title="Age")
```
# After
```py
from pydantic import BaseModel, Field
class Foo(BaseModel):
name: str
is_sale: bool = True
# TODO[pydantic]: add type annotation
tags = ["tag1", "tag2"]
price: float = 10.5
description: str = "Some item"
active: bool = Field(default=True)
ready: bool = Field(True)
age: int = Field(10, title="Age")
```
"""

METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,)

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

self.inside_base_model = False
self.base_model_fields: set[cst.Assign | cst.AnnAssign | cst.SimpleStatementLine] = set()
self.statement: cst.SimpleStatementLine | None = None
self.needs_comment = False
self.has_comment = False
self.in_field = False

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)

if not fqn_set:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]:
self.inside_base_model = True
self.base_model_fields = {
child for child in node.body.children if isinstance(child, cst.SimpleStatementLine)
}
return

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.base_model_fields = set()
return updated_node

def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
if node not in self.base_model_fields:
return
if not self.inside_base_model:
return
self.statement = node
self.in_field = True
for line in node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=COMMENT))):
self.has_comment = True

def leave_SimpleStatementLine(
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.SimpleStatementLine:
if original_node not in self.base_model_fields:
return updated_node
if self.needs_comment and not self.has_comment:
updated_node = updated_node.with_changes(
leading_lines=[
*updated_node.leading_lines,
cst.EmptyLine(comment=cst.Comment(value=(COMMENT))),
],
body=[
*updated_node.body,
],
)
self.statement = None
self.needs_comment = False
self.has_comment = False
self.in_field = False
return updated_node

def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.Assign | cst.AnnAssign:
if not self.in_field:
return updated_node
if self.inside_base_model:
if m.matches(updated_node, m.Assign(targets=[m.AssignTarget(target=m.Name("model_config"))])):
return updated_node
Undefined = object()
value: cst.BaseExpression | object = Undefined
if m.matches(updated_node.value, m.Call(func=m.Name("Field"))):
assert isinstance(updated_node.value, cst.Call)
args = updated_node.value.args
if args:
default_keywords = [arg.value for arg in args if arg.keyword and arg.keyword.value == "default"]
# NOTE: It has a "default" value as positional argument.
if args[0].keyword is None:
value = args[0].value
# NOTE: It has a "default" keyword argument.
elif default_keywords:
value = default_keywords[0]
else:
value = updated_node.value
if value is Undefined:
self.needs_comment = True
return updated_node

# Infer simple type annotations
ann_type = None
assert isinstance(value, cst.BaseExpression)
if m.matches(value, m.Name("True") | m.Name("False")):
ann_type = "bool"
elif m.matches(value, m.SimpleString()):
ann_type = "str"
elif m.matches(value, m.Integer()):
ann_type = "int"
elif m.matches(value, m.Float()):
ann_type = "float"

# If there's a simple inferred type annotation, return that
if ann_type:
return cst.AnnAssign(
target=updated_node.targets[0].target,
annotation=cst.Annotation(cst.Name(ann_type)),
value=updated_node.value,
)
else:
self.needs_comment = True
return updated_node
8 changes: 4 additions & 4 deletions tests/integration/cases/replace_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@
"from pydantic import BaseModel, Field",
"",
"",
"class A(str, Enum):",
"class E(str, Enum):",
" a = 'a'",
" b = 'b'",
"",
"class A(BaseModel):",
" a: A = Field(A.a, const=True)",
" a: E = Field(E.a, const=True)",
],
),
expected=File(
Expand All @@ -70,12 +70,12 @@
"from typing import Literal",
"",
"",
"class A(str, Enum):",
"class E(str, Enum):",
" a = 'a'",
" b = 'b'",
"",
"class A(BaseModel):",
" a: Literal[A.a] = A.a",
" a: Literal[E.a] = E.a",
],
),
),
Expand Down
Loading

0 comments on commit 5e4a43b

Please sign in to comment.