Skip to content

Commit

Permalink
✨ Support replacement of Config class by model_config
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jun 16, 2023
1 parent a6edf91 commit 31ef212
Show file tree
Hide file tree
Showing 14 changed files with 293 additions and 79 deletions.
19 changes: 11 additions & 8 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v2

- name: set up python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: set up python
uses: actions/setup-python@v4
with:
python-version: "3.10"

- uses: pre-commit/action@v3.0.0
with:
extra_args: --all-files
- name: Install hatch
run: pip install hatch

- uses: pre-commit/action@v3.0.0
with:
extra_args: --all-files
16 changes: 2 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,7 @@ repos:
hooks:
- id: lint
name: Lint
entry: ruff
args:
- --fix
types: [python]
language: system
- id: mypy
name: Mypy
entry: mypy
types: [python]
language: system
# pass_filenames: false
- id: pyupgrade
name: Pyupgrade
entry: pyupgrade --py37-plus
entry: hatch run lint
types: [python]
language: system
pass_filenames: false
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,25 @@ class User(BaseModel):

#### BP003: Replace `Config` class by `model_config`

- ✅ Replace `Config` class by `model_config = ConfigDict()`.

The following code will be transformed:

```py
class User(BaseModel):
name: str

class Config:
extra = 'forbid'
```

Into:

```py
class User(BaseModel):
name: str

model_config = ConfigDict(extra='forbid')
```

#### BP004: Replace `BaseModel` methods
2 changes: 2 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from libcst.codemod.visitors import AddImportsVisitor

from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.replace_config import ReplaceConfigCodemod
from bump_pydantic.codemods.replace_imports import ReplaceImportsCodemod


def gather_codemods() -> List[Type[ContextAwareTransformer]]:
return [
AddDefaultNoneCommand,
ReplaceConfigCodemod,
ReplaceImportsCodemod,
# AddImportsVisitor needs to be the last.
AddImportsVisitor,
Expand Down
18 changes: 4 additions & 14 deletions bump_pydantic/codemods/add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,14 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
if fqn.name in self.context.scratch[BASE_MODEL_CONTEXT_KEY]:
self.inside_base_model = True

def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.inside_base_model = False
return updated_node

def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
if m.matches(
node.annotation.annotation,
m.Subscript(
m.Name("Optional") | m.Attribute(m.Name("typing"), m.Name("Optional"))
)
m.Subscript(m.Name("Optional") | m.Attribute(m.Name("typing"), m.Name("Optional")))
| m.Subscript(
m.Name("Union") | m.Attribute(m.Name("typing"), m.Name("Union")),
slice=[
Expand All @@ -85,14 +81,8 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
self.should_add_none = True
return super().visit_AnnAssign(node)

def leave_AnnAssign(
self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
) -> cst.AnnAssign:
if (
self.inside_base_model
and self.should_add_none
and updated_node.value is None
):
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
if self.inside_base_model and self.should_add_none and updated_node.value is None:
updated_node = updated_node.with_changes(value=cst.Name("None"))
self.inside_an_assign = False
self.should_add_none = False
Expand Down
181 changes: 181 additions & 0 deletions bump_pydantic/codemods/replace_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import List

import libcst as cst
from libcst import matchers as m
from libcst._nodes.module import Module
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
from libcst.metadata import PositionProvider

REMOVED_KEYS = [
"allow_mutation",
"error_msg_templates",
"fields",
"getter_dict",
"smart_union",
"underscore_attrs_are_private",
"json_loads",
"json_dumps",
"json_encoders",
"copy_on_model_validation",
"post_init_call",
]
RENAMED_KEYS = {
"allow_population_by_field_name": "populate_by_name",
"anystr_lower": "str_to_lower",
"anystr_strip_whitespace": "str_strip_whitespace",
"anystr_upper": "str_to_upper",
"keep_untouched": "ignored_types",
"max_anystr_length": "str_max_length",
"min_anystr_length": "str_min_length",
"orm_mode": "from_attributes",
"schema_extra": "json_schema_extra",
"validate_all": "validate_default",
}
# TODO: The codemod should not replace `Config` in case of removed keys, right?

base_model_with_config = m.ClassDef(
bases=[
m.ZeroOrMore(),
m.Arg(),
m.ZeroOrMore(),
],
body=m.IndentedBlock(
body=[
m.ZeroOrMore(),
m.ClassDef(name=m.Name(value="Config"), bases=[]),
m.ZeroOrMore(),
]
),
)
base_model_with_config_child = m.ClassDef(
bases=[
m.ZeroOrMore(),
m.Arg(),
m.ZeroOrMore(),
],
body=m.IndentedBlock(
body=[
m.ZeroOrMore(),
m.ClassDef(name=m.Name(value="Config"), bases=[m.AtLeastN(n=1)]),
m.ZeroOrMore(),
]
),
)


class ReplaceConfigCodemod(VisitorBasedCodemodCommand):
"""Replace `Config` class by `ConfigDict` call."""

METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

self.inside_config_class = False

self.config_args: List[cst.Arg] = []

@m.visit(m.ClassDef(name=m.Name(value="Config")))
def visit_config_class(self, node: cst.ClassDef) -> None:
self.inside_config_class = True

@m.leave(m.ClassDef(name=m.Name(value="Config")))
def leave_config_class(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.inside_config_class = False
return updated_node

def visit_Assign(self, node: cst.Assign) -> None:
# NOTE: There's no need for the `leave_Assign`.
self.assign_value = node.value

def visit_AssignTarget(self, node: cst.AssignTarget) -> None:
self.config_args.append(
cst.Arg(
keyword=node.target, # type: ignore[arg-type]
value=self.assign_value,
equal=cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
),
)
)

def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
return updated_node

@m.leave(base_model_with_config_child)
def leave_config_class_child(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
position = self.get_metadata(PositionProvider, original_node)
print("You'll need to manually replace the `Config` class to the `model_config` attribute.")
print(
"File: {filename}:-{start_line},{start_column}:{end_line},{end_column}".format(
filename=self.context.filename,
start_line=position.start.line,
start_column=position.start.column,
end_line=position.end.line,
end_column=position.end.column,
)
)
return updated_node

@m.leave(base_model_with_config)
def leave_config_class_childless(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
"""Replace the `Config` class with a `model_config` attribute.
Any class that contains a `Config` class will have that class replaced
with a `model_config` attribute. The `model_config` attribute will be
assigned a `ConfigDict` object with the same arguments as the attributes
from `Config` class.
"""
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="ConfigDict")
block = cst.ensure_type(original_node.body, cst.IndentedBlock)
body = [
cst.SimpleStatementLine(
body=[
cst.Assign(
targets=[cst.AssignTarget(target=cst.Name("model_config"))],
value=cst.Call(
func=cst.Name("ConfigDict"),
args=self.config_args,
),
)
],
)
if m.matches(statement, m.ClassDef(name=m.Name(value="Config")))
else statement
for statement in block.body
]
self.config_args = []
return updated_node.with_changes(body=updated_node.body.with_changes(body=body))


if __name__ == "__main__":
import textwrap

from rich.console import Console

console = Console()

source = textwrap.dedent(
"""
from pydantic import BaseModel
class A(BaseModel):
class Config:
arbitrary_types_allowed = True
"""
)
console.print(source)
console.print("=" * 80)

mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ReplaceConfigCodemod(context=context)

mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)
14 changes: 3 additions & 11 deletions bump_pydantic/codemods/replace_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,15 @@ class ImportInfo:

class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(IMPORT_MATCH)
def leave_replace_import(
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
def leave_replace_import(self, _: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
for import_info in IMPORT_INFOS:
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
AddImportsVisitor.add_needed_import(
self.context, *import_info.to_import_str
)
AddImportsVisitor.add_needed_import(self.context, *import_info.to_import_str)
if len(updated_node.names) > 1: # type: ignore
names = [
alias
for alias in aliases
if alias.name.value != import_info.to_import_str[-1]
]
names = [alias for alias in aliases if alias.name.value != import_info.to_import_str[-1]]
updated_node = updated_node.with_changes(names=names)
else:
return cst.RemoveFromParent() # type: ignore[return-value]
Expand Down
8 changes: 2 additions & 6 deletions bump_pydantic/codemods/replace_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def func() -> A:
}

MATCH_DEPRECATED_METHODS = m.Call(
func=m.Attribute(
attr=m.Name(value=m.MatchIfTrue(lambda value: value in DEPRECATED_METHODS))
)
func=m.Attribute(attr=m.Name(value=m.MatchIfTrue(lambda value: value in DEPRECATED_METHODS)))
)


Expand All @@ -96,9 +94,7 @@ def visit_Assign(self, node: cst.Assign) -> bool | None:
# TODO: Add a warning in case you find a method that matches the rules, but it's not
# identified as a BaseModel instance.
@m.leave(MATCH_DEPRECATED_METHODS)
def leave_deprecated_methods(
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.Call:
def leave_deprecated_methods(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self.get_metadata(QualifiedNameProvider, original_node)
# print("hi")
self.get_metadata(ScopeProvider, original_node)
Expand Down
9 changes: 5 additions & 4 deletions bump_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from pathlib import Path
from typing import Any, Dict

import libcst as cst
from libcst.codemod import CodemodContext
Expand Down Expand Up @@ -36,14 +37,14 @@ def main(
version: bool = Option(None, "--version", callback=version_callback, is_eager=True),
):
cwd = os.getcwd()
files = [path.absolute() for path in package.glob("**/*.py")]
files = [str(file.relative_to(cwd)) for file in files]
files_str = [path.absolute() for path in package.glob("**/*.py")]
files = [str(file.relative_to(cwd)) for file in files_str]

providers = {ScopeProvider, PositionProvider, FullyQualifiedNameProvider}
metadata_manager = FullRepoManager(cwd, files, providers=providers)
metadata_manager = FullRepoManager(cwd, files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()

scratch = {}
scratch: Dict[str, Any] = {}
for filename in files:
code = Path(filename).read_text()
module = cst.parse_module(code)
Expand Down
4 changes: 1 addition & 3 deletions bump_pydantic/markers/find_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
CONTEXT_KEY = "find_base_model"


def revert_dictionary(
classes: defaultdict[str, set[str]]
) -> defaultdict[str, set[str]]:
def revert_dictionary(classes: defaultdict[str, set[str]]) -> defaultdict[str, set[str]]:
revert_classes: defaultdict[str, set[str]] = defaultdict(set)
for cls, bases in classes.copy().items():
for base in bases:
Expand Down
Loading

0 comments on commit 31ef212

Please sign in to comment.