Skip to content

Commit

Permalink
Merge pull request #160 from pytest-dev/extract-attibute-processing
Browse files Browse the repository at this point in the history
Refactor FixtureDef generation logic
  • Loading branch information
youtux committed May 26, 2022
2 parents 3110773 + 53ed432 commit 4886339
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 71 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ disallow_untyped_decorators = true
disallow_any_explicit = false
disallow_any_generics = true
disallow_untyped_calls = true
disallow_untyped_defs = true
ignore_errors = false
ignore_missing_imports = true
implicit_reexport = false
Expand All @@ -27,9 +28,11 @@ warn_redundant_casts = true
warn_unused_configs = true
warn_unreachable = true
warn_no_return = true
warn_return_any = true
pretty = true
show_error_codes = true

[[tool.mypy.overrides]]
module = ["tests.*"]
disallow_untyped_decorators = false
disallow_untyped_defs = false
6 changes: 3 additions & 3 deletions pytest_factoryboy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class FixtureDef:
related: list[str] = field(default_factory=list)

@property
def kwargs_var_name(self):
def kwargs_var_name(self) -> str:
return f"_{self.name}__kwargs"


Expand Down Expand Up @@ -96,7 +96,7 @@ def make_temp_folder(package_name: str) -> pathlib.Path:


@lru_cache() # This way we reuse the same folder for the whole execution of the program
def create_package(package_name: str, init_py_content=init_py_content) -> pathlib.Path:
def create_package(package_name: str, init_py_content: str = init_py_content) -> pathlib.Path:
path = cache_dir / package_name
try:
if path.exists():
Expand Down Expand Up @@ -134,7 +134,7 @@ def make_module(code: str, module_name: str, package_name: str) -> ModuleType:
return mod


def make_fixture_model_module(model_name, fixture_defs: list[FixtureDef]):
def make_fixture_model_module(model_name: str, fixture_defs: list[FixtureDef]) -> ModuleType:
code = module_template.render(fixture_defs=fixture_defs)
generated_module = make_module(code, module_name=model_name, package_name="_pytest_factoryboy_generated_fixtures")
for fixture_def in fixture_defs:
Expand Down
161 changes: 94 additions & 67 deletions pytest_factoryboy/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@
import sys
from dataclasses import dataclass
from inspect import signature
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Type, cast, overload

import factory
import factory.builder
import factory.declarations
import factory.enums
import inflection
from factory.declarations import NotProvided
from typing_extensions import TypeAlias

from .codegen import FixtureDef, make_fixture_model_module
from .compat import PostGenerationContext

FactoryType: TypeAlias = Type[factory.Factory]

if TYPE_CHECKING:
from typing import Any, Callable, TypeVar
from typing import Any, Callable, Iterable, Mapping, TypeVar

from _pytest.fixtures import FixtureFunction, SubRequest
from factory.builder import BuildStep
from factory.declarations import PostGeneration, PostGenerationContext

FactoryType = type[factory.Factory]
from .plugin import Request as FactoryboyRequest

T = TypeVar("T")
F = TypeVar("F", bound=FactoryType)

Expand Down Expand Up @@ -85,75 +89,54 @@ def register_(factory_class: F) -> F:
assert not factory_class._meta.abstract, "Can't register abstract factories."
assert factory_class._meta.model is not None, "Factory model class is not specified."

fixture_defs: list[FixtureDef] = []

model_name = get_model_name(factory_class) if _name is None else _name
factory_name = get_factory_name(factory_class)

deps = get_deps(factory_class, model_name=model_name)
related: list[str] = []
fixture_defs = list(
generate_fixturedefs(
factory_class=factory_class, model_name=model_name, overrides=kwargs, caller_locals=_caller_locals
)
)

for attr, value in factory_class._meta.declarations.items():
args = []
attr_name = SEPARATOR.join((model_name, attr))
value = kwargs.get(attr, value)

if isinstance(value, (factory.SubFactory, factory.RelatedFactory)):
subfactory_class = value.get_factory()
subfactory_deps = get_deps(subfactory_class, factory_class)

args = list(subfactory_deps)
if isinstance(value, factory.RelatedFactory):
related_model = get_model_name(subfactory_class)
args.append(related_model)
related.append(related_model)
related.append(attr_name)
related.extend(subfactory_deps)

if isinstance(value, factory.SubFactory):
args.append(inflection.underscore(subfactory_class._meta.model.__name__))

fixture_defs.append(
FixtureDef(
name=attr_name,
function_name="subfactory_fixture",
function_kwargs={"factory_class": subfactory_class},
deps=args,
)
)
continue
generated_module = make_fixture_model_module(model_name, fixture_defs)

if isinstance(value, factory.PostGeneration):
default_value = None
elif isinstance(value, factory.PostGenerationMethodCall):
default_value = value.method_arg
else:
default_value = value
for fixture_def in fixture_defs:
exported_name = fixture_def.name
fixture_function = getattr(generated_module, exported_name)
inject_into_caller(exported_name, fixture_function, _caller_locals)

return factory_class

value = kwargs.get(attr, default_value)

if isinstance(value, LazyFixture):
args = value.args
def generate_fixturedefs(
factory_class: FactoryType, model_name: str, overrides: Mapping[str, Any], caller_locals: Mapping[str, Any]
) -> Iterable[FixtureDef]:
"""Generate all the FixtureDefs for the given factory class."""
factory_name = get_factory_name(factory_class)

fixture_defs.append(
FixtureDef(
name=attr_name,
function_name="attr_fixture",
function_kwargs={"value": value},
deps=args,
related: list[str] = []
for attr, value in factory_class._meta.declarations.items():
value = overrides.get(attr, value)
attr_name = SEPARATOR.join((model_name, attr))
yield (
make_declaration_fixturedef(
attr_name=attr_name,
value=value,
factory_class=factory_class,
related=related,
)
)

if factory_name not in _caller_locals:
fixture_defs.append(
if factory_name not in caller_locals:
yield (
FixtureDef(
name=factory_name,
function_name="factory_fixture",
function_kwargs={"factory_class": factory_class},
)
)

fixture_defs.append(
deps = get_deps(factory_class, model_name=model_name)
yield (
FixtureDef(
name=model_name,
function_name="model_fixture",
Expand All @@ -163,14 +146,56 @@ def register_(factory_class: F) -> F:
)
)

generated_module = make_fixture_model_module(model_name, fixture_defs)

for fixture_def in fixture_defs:
exported_name = fixture_def.name
fixture_function = getattr(generated_module, exported_name)
inject_into_caller(exported_name, fixture_function, _caller_locals)
def make_declaration_fixturedef(
attr_name: str,
value: Any,
factory_class: FactoryType,
related: list[str],
) -> FixtureDef:
"""Create the FixtureDef for a factory declaration."""
if isinstance(value, (factory.SubFactory, factory.RelatedFactory)):
subfactory_class = value.get_factory()
subfactory_deps = get_deps(subfactory_class, factory_class)

args = list(subfactory_deps)
if isinstance(value, factory.RelatedFactory):
related_model = get_model_name(subfactory_class)
args.append(related_model)
related.append(related_model)
related.append(attr_name)
related.extend(subfactory_deps)

if isinstance(value, factory.SubFactory):
args.append(inflection.underscore(subfactory_class._meta.model.__name__))

return FixtureDef(
name=attr_name,
function_name="subfactory_fixture",
function_kwargs={"factory_class": subfactory_class},
deps=args,
)

return factory_class
deps: list[str] # makes mypy happy
if isinstance(value, factory.PostGeneration):
value = None
deps = []
elif isinstance(value, factory.PostGenerationMethodCall):
value = value.method_arg
deps = []
elif isinstance(value, LazyFixture):
value = value
deps = value.args
else:
value = value
deps = []

return FixtureDef(
name=attr_name,
function_name="attr_fixture",
function_kwargs={"value": value},
deps=deps,
)


def inject_into_caller(name: str, function: Callable[..., Any], locals_: dict[str, Any]) -> None:
Expand Down Expand Up @@ -241,21 +266,23 @@ def evaluate(request: SubRequest, value: LazyFixture | Any) -> Any:

def model_fixture(request: SubRequest, factory_name: str) -> Any:
"""Model fixture implementation."""
factoryboy_request = request.getfixturevalue("factoryboy_request")
factoryboy_request: FactoryboyRequest = request.getfixturevalue("factoryboy_request")

# Try to evaluate as much post-generation dependencies as possible
factoryboy_request.evaluate(request)

assert request.fixturename # NOTE: satisfy mypy
fixture_name = request.fixturename
prefix = "".join((fixture_name, SEPARATOR))
# NOTE: following type hinting is required, because of `mypy` bug.
# Reference: https://github.com/python/mypy/issues/2477
factory_class: factory.base.FactoryMetaClass = request.getfixturevalue(factory_name)

factory_class: FactoryType = request.getfixturevalue(factory_name)

# Create model fixture instance
class Factory(factory_class):
pass
Factory: FactoryType = cast(FactoryType, type("Factory", (factory_class,), {}))
# equivalent to:
# class Factory(factory_class):
# pass
# it just makes mypy understand it.

Factory._meta.base_declarations = {
k: v
Expand Down
19 changes: 19 additions & 0 deletions tests/test_lazy_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,22 @@ def test_lazy_attribute(user: User):
def test_lazy_attribute_partial(partial_user: User):
"""Test LazyFixture value is extracted before the LazyAttribute is called. Partial."""
assert partial_user.is_active


class TestLazyFixtureDeclaration:
@pytest.fixture
def name(self):
return "from fixture name"

@register
class UserFactory(factory.Factory):
class Meta:
model = User

username = LazyFixture("name")
password = "foo"
is_active = False

def test_lazy_fixture_declaration(self, user):
"""Test that we can use the LazyFixture declaration in the factory itself."""
assert user.username == "from fixture name"
2 changes: 1 addition & 1 deletion tests/test_postgen_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class BarFactory(factory.Factory):
@classmethod
def _create(cls, model_class: type[Bar], foo: Foo) -> Bar:
assert foo.value == foo.expected
bar = super()._create(model_class, foo=foo)
bar: Bar = super()._create(model_class, foo=foo)
foo.bar = bar
return bar

Expand Down

0 comments on commit 4886339

Please sign in to comment.