Skip to content

Commit

Permalink
Merge pull request #964 from tefra/fix-954
Browse files Browse the repository at this point in the history
fix: Move ruff format in the code generator
  • Loading branch information
tefra committed Mar 2, 2024
2 parents bda6150 + 24a8fa9 commit eb610d5
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 108 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -51,7 +51,7 @@ cli = [
"docformatter>=1.7.2",
"jinja2>=2.10",
"toposort>=1.5",
"ruff>=0.1.9"
"ruff>=0.3.0"
]
docs = [
"mkdocs",
Expand Down
48 changes: 1 addition & 47 deletions tests/codegen/test_writer.py
Expand Up @@ -23,15 +23,11 @@ def setUp(self):
generator = NoneGenerator(config)
self.writer = CodeWriter(generator)

@mock.patch.object(CodeWriter, "ruff_code")
@mock.patch.object(NoneGenerator, "render_header")
@mock.patch.object(NoneGenerator, "render")
@mock.patch.object(NoneGenerator, "normalize_packages")
def test_write(
self, mock_normalize_packages, mock_render, mock_render_header, mock_ruff_code
):
def test_write(self, mock_normalize_packages, mock_render, mock_render_header):
classes = ClassFactory.list(2)
mock_ruff_code.side_effect = lambda x, y: x
with TemporaryDirectory() as tmpdir:
mock_render.return_value = [
GeneratorResult(Path(f"{tmpdir}/foo/a.py"), "file", "aAa"),
Expand All @@ -46,7 +42,6 @@ def test_write(
self.assertFalse(Path(f"{tmpdir}/c.py").exists())
mock_normalize_packages.assert_called_once_with(classes)

@mock.patch.object(CodeWriter, "ruff_code")
@mock.patch.object(NoneGenerator, "render_header")
@mock.patch.object(NoneGenerator, "render")
@mock.patch.object(NoneGenerator, "normalize_packages")
Expand All @@ -57,7 +52,6 @@ def test_print(
mock_normalize_packages,
mock_render,
mock_render_header,
mock_ruff_code,
):
classes = ClassFactory.list(2)
mock_render.return_value = [
Expand All @@ -66,7 +60,6 @@ def test_print(
GeneratorResult(Path("c.py"), "file", ""),
]
mock_render_header.return_value = "# H\n"
mock_ruff_code.side_effect = lambda x, y: x

self.writer.print(classes)

Expand All @@ -87,42 +80,3 @@ def test_from_config(self):
CodeWriter.register_generator("dataclasses", DataclassGenerator)
writer = CodeWriter.from_config(config)
self.assertIsInstance(writer.generator, DataclassGenerator)

def test_ruff_code(self):
src_code = (
"\n"
"import sys\n"
"@dataclass\n"
"\n"
"class MyType:\n"
"\n"
' value: Optional[str] = field(default=None, metadata={"type": "Element", "required": True})\n'
"\n"
"\n"
" "
)

self.writer.generator.config.output.max_line_length = 55
actual = self.writer.ruff_code(src_code, Path(__file__))
expected = (
"import sys\n"
"\n"
"\n"
"@dataclass\n"
"class MyType:\n"
" value: Optional[str] = field(\n"
" default=None,\n"
' metadata={"type": "Element", "required": True},\n'
" )\n"
)
self.assertEqual(expected, actual)

def test_format_with_invalid_code(self):
src_code = """a = "1"""
file_path = Path(__file__)

self.writer.generator.config.output.max_line_length = 55
with self.assertRaises(CodeGenerationError) as cm:
self.writer.ruff_code(src_code, file_path)

self.assertIn("Ruff failed", str(cm.exception))
39 changes: 28 additions & 11 deletions tests/formats/dataclass/test_generator.py
Expand Up @@ -3,6 +3,7 @@
from unittest import mock

from xsdata.codegen.resolver import DependenciesResolver
from xsdata.exceptions import CodeGenerationError
from xsdata.formats.dataclass.generator import DataclassGenerator
from xsdata.models.config import GeneratorConfig
from xsdata.utils.testing import ClassFactory, FactoryTestCase
Expand Down Expand Up @@ -45,12 +46,16 @@ def test_render(self, mock_render_module, mock_render_package):
self.assertEqual(expected, actual)
mock_render_package.assert_has_calls(
[
mock.call([classes[0]], "foo.bar"),
mock.call([classes[1]], "bar.foo"),
mock.call([classes[2]], "thug.life"),
mock.call([classes[0]], "foo.bar", cwd.joinpath("foo/bar/__init__.py")),
mock.call([classes[1]], "bar.foo", cwd.joinpath("bar/foo/__init__.py")),
mock.call(
[classes[2]], "thug.life", cwd.joinpath("thug/life/__init__.py")
),
]
)
mock_render_module.assert_has_calls([mock.call(mock.ANY, [x]) for x in classes])
mock_render_module.assert_has_calls(
[mock.call(mock.ANY, [x], mock.ANY) for x in classes]
)

def test_render_package(self):
classes = [
Expand All @@ -62,7 +67,7 @@ def test_render_package(self):

random.shuffle(classes)

actual = self.generator.render_package(classes, "foo.tests")
actual = self.generator.render_package(classes, "foo.tests", Path.cwd())
expected = "\n".join(
[
"from foo.bar import A as BarA",
Expand Down Expand Up @@ -94,7 +99,7 @@ def test_render_module(self):

resolver = DependenciesResolver({})

actual = self.generator.render_module(resolver, classes)
actual = self.generator.render_module(resolver, classes, Path.cwd())
expected = (
"from dataclasses import dataclass, field\n"
"from enum import Enum\n"
Expand All @@ -110,6 +115,7 @@ def test_render_module(self):
" :cvar ATTR_B: I am a member\n"
" :cvar ATTR_C:\n"
' """\n'
"\n"
" ATTR_B = None\n"
" ATTR_C = None\n"
"\n"
Expand All @@ -120,6 +126,7 @@ def test_render_module(self):
" :ivar attr_d: I am a field\n"
" :ivar attr_e:\n"
' """\n'
"\n"
" class Meta:\n"
' name = "class_C"\n'
"\n"
Expand All @@ -128,14 +135,14 @@ def test_render_module(self):
" metadata={\n"
' "name": "attr_D",\n'
' "type": "Element",\n'
" }\n"
" },\n"
" )\n"
" attr_e: Optional[str] = field(\n"
" default=None,\n"
" metadata={\n"
' "name": "attr_E",\n'
' "type": "Element",\n'
" }\n"
" },\n"
" )\n"
"\n"
"\n"
Expand All @@ -153,7 +160,7 @@ def test_render_module_with_mixed_target_namespaces(self):
]
resolver = DependenciesResolver({})

actual = self.generator.render_module(resolver, classes)
actual = self.generator.render_module(resolver, classes, Path.cwd())
expected = (
"from dataclasses import dataclass, field\n"
"from typing import Optional\n"
Expand All @@ -170,7 +177,7 @@ def test_render_module_with_mixed_target_namespaces(self):
" metadata={\n"
' "name": "attr_C",\n'
' "type": "Element",\n'
" }\n"
" },\n"
" )\n"
"\n"
"\n"
Expand All @@ -185,7 +192,7 @@ def test_render_module_with_mixed_target_namespaces(self):
" metadata={\n"
' "name": "attr_B",\n'
' "type": "Element",\n'
" }\n"
" },\n"
" )\n"
)

Expand All @@ -205,3 +212,13 @@ def test_package_name(self):
)

self.assertEqual("", self.generator.package_name(""))

def test_format_with_invalid_code(self):
src_code = """a = "1"""
file_path = Path(__file__)

self.generator.config.output.max_line_length = 55
with self.assertRaises(CodeGenerationError) as cm:
self.generator.ruff_code(src_code, file_path)

self.assertIn("Ruff failed", str(cm.exception))
43 changes: 2 additions & 41 deletions xsdata/codegen/writer.py
@@ -1,6 +1,3 @@
import subprocess
from pathlib import Path
from textwrap import indent
from typing import ClassVar, Dict, List, Type

from xsdata.codegen.models import Class
Expand Down Expand Up @@ -46,7 +43,7 @@ def write(self, classes: List[Class]):
for result in self.generator.render(classes):
if result.source.strip():
logger.info("Generating package: %s", result.title)
src_code = self.ruff_code(header + result.source, result.path)
src_code = header + result.source
result.path.parent.mkdir(parents=True, exist_ok=True)
result.path.write_text(src_code, encoding="utf-8")

Expand All @@ -60,7 +57,7 @@ def print(self, classes: List[Class]):
header = self.generator.render_header()
for result in self.generator.render(classes):
if result.source.strip():
src_code = self.ruff_code(header + result.source, result.path)
src_code = header + result.source
print(src_code, end="")

@classmethod
Expand Down Expand Up @@ -101,39 +98,3 @@ def unregister_generator(cls, name: str):
name: The generator name
"""
cls.generators.pop(name)

def ruff_code(self, src_code: str, file_path: Path) -> str:
"""Run ruff format on the src code.
Args:
src_code: The output source code
file_path: The file path the source code will be written to
Returns:
The formatted output source code
"""
commands = [
[
"ruff",
"format",
"--stdin-filename",
str(file_path),
"--line-length",
str(self.generator.config.output.max_line_length),
],
]
try:
src_code_encoded = src_code.encode()
for command in commands:
result = subprocess.run(
command,
input=src_code_encoded,
capture_output=True,
check=True,
)
src_code_encoded = result.stdout

return src_code_encoded.decode()
except subprocess.CalledProcessError as e:
error = indent(e.stderr.decode(), " ")
raise CodeGenerationError(f"Ruff failed:\n{error}")

0 comments on commit eb610d5

Please sign in to comment.