Skip to content

Commit

Permalink
Merge pull request #139 from pytest-dev/type-annotations
Browse files Browse the repository at this point in the history
Type annotations
  • Loading branch information
youtux committed May 1, 2022
2 parents 02974de + 2191a41 commit e056a84
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 113 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Unreleased
----------
- Drop support for Python 3.6. We now support only python >= 3.7.
- Improve "debuggability". Internal pytest-factoryboy calls are now visible when using a debugger like PDB or PyCharm.
- Add type annotations. Now `register` and `LazyFixture` are type annotated.


2.1.0
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include *.rst
include pytest_factoryboy/py.typed
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[tool.black]
line-length = 120
target-version = ['py36']
target-version = ['py37', 'py38', 'py39', 'py310', 'py310']
5 changes: 1 addition & 4 deletions pytest_factoryboy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,4 @@
__version__ = "2.1.0"


__all__ = [
register.__name__,
LazyFixture.__name__,
]
__all__ = ("register", "LazyFixture")
100 changes: 71 additions & 29 deletions pytest_factoryboy/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import sys
from dataclasses import dataclass
from inspect import getmodule, signature

import factory
Expand All @@ -12,11 +13,36 @@

from .codegen import make_fixture_model_module, FixtureDef
from .compat import PostGenerationContext
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Callable, TypeVar
from _pytest.fixtures import FixtureRequest
from factory.builder import BuildStep
from factory.declarations import PostGeneration
from factory.declarations import PostGenerationContext
from types import ModuleType

FactoryType = type[factory.Factory]
T = TypeVar("T")
F = TypeVar("F", bound=FactoryType)


SEPARATOR = "__"


def register(factory_class, _name=None, **kwargs):
@dataclass(eq=False)
class DeferredFunction:
name: str
factory: FactoryType
is_related: bool
function: Callable[[FixtureRequest], Any]

def __call__(self, request: FixtureRequest) -> Any:
return self.function(request)


def register(factory_class: F, _name: str | None = None, **kwargs: Any) -> F:
r"""Register fixtures for the factory class.
:param factory_class: Factory class to register.
Expand Down Expand Up @@ -119,7 +145,7 @@ def register(factory_class, _name=None, **kwargs):
return factory_class


def get_model_name(factory_class):
def get_model_name(factory_class: FactoryType) -> str:
"""Get model fixture name by factory."""
return (
inflection.underscore(factory_class._meta.model.__name__)
Expand All @@ -128,20 +154,24 @@ def get_model_name(factory_class):
)


def get_factory_name(factory_class):
def get_factory_name(factory_class: FactoryType) -> str:
"""Get factory fixture name by factory."""
return inflection.underscore(factory_class.__name__)


def get_deps(factory_class, parent_factory_class=None, model_name=None):
def get_deps(
factory_class: FactoryType,
parent_factory_class: FactoryType | None = None,
model_name: str | None = None,
) -> list[str]:
"""Get factory dependencies.
:return: List of the fixture argument names for dependency injection.
"""
model_name = get_model_name(factory_class) if model_name is None else model_name
parent_model_name = get_model_name(parent_factory_class) if parent_factory_class is not None else None

def is_dep(value):
def is_dep(value: Any) -> bool:
if isinstance(value, factory.RelatedFactory):
return False
if isinstance(value, factory.SubFactory) and get_model_name(value.get_factory()) == parent_model_name:
Expand All @@ -157,19 +187,19 @@ def is_dep(value):
]


def evaluate(request, value):
def evaluate(request: FixtureRequest, value: LazyFixture | Any) -> Any:
"""Evaluate the declaration (lazy fixtures, etc)."""
return value.evaluate(request) if isinstance(value, LazyFixture) else value


def model_fixture(request, factory_name):
def model_fixture(request: FixtureRequest, factory_name: str) -> Any:
"""Model fixture implementation."""
factoryboy_request = request.getfixturevalue("factoryboy_request")

# Try to evaluate as much post-generation dependencies as possible
factoryboy_request.evaluate(request)

factory_class = request.getfixturevalue(factory_name)
factory_class: FactoryType = request.getfixturevalue(factory_name)
prefix = "".join((request.fixturename, SEPARATOR))

# Create model fixture instance
Expand Down Expand Up @@ -201,7 +231,7 @@ class Factory(factory_class):
request._fixture_defs[request.fixturename] = request._fixturedef

# Defer post-generation declarations
deferred = []
deferred: list[DeferredFunction] = []

for attr in factory_class._meta.post_declarations.sorted():

Expand Down Expand Up @@ -237,7 +267,7 @@ class Factory(factory_class):
return instance


def make_deferred_related(factory, fixture, attr):
def make_deferred_related(factory: FactoryType, fixture: str, attr: str) -> DeferredFunction:
"""Make deferred function for the related factory declaration.
:param factory: Factory class.
Expand All @@ -248,17 +278,27 @@ def make_deferred_related(factory, fixture, attr):
"""
name = SEPARATOR.join((fixture, attr))

def deferred(request):
def deferred_impl(request: FixtureRequest) -> None:
# TODO: Shouldn't we return this result?
request.getfixturevalue(name)

deferred.__name__ = name
deferred._factory = factory
deferred._fixture = fixture
deferred._is_related = True
return deferred
return DeferredFunction(
name=name,
factory=factory,
is_related=True,
function=deferred_impl,
)


def make_deferred_postgen(step, factory_class, fixture, instance, attr, declaration, context):
def make_deferred_postgen(
step: BuildStep,
factory_class: FactoryType,
fixture: str,
instance: Any,
attr: str,
declaration: PostGeneration,
context: PostGenerationContext,
) -> DeferredFunction:
"""Make deferred function for the post-generation declaration.
:param step: factory_boy builder step.
Expand All @@ -272,33 +312,35 @@ def make_deferred_postgen(step, factory_class, fixture, instance, attr, declarat
"""
name = SEPARATOR.join((fixture, attr))

def deferred(request):
def deferred_impl(request: FixtureRequest) -> None:
# TODO: Shouldn't we return this result?
declaration.call(instance, step, context)

deferred.__name__ = name
deferred._factory = factory_class
deferred._fixture = fixture
deferred._is_related = False
return deferred
return DeferredFunction(
name=name,
factory=factory_class,
is_related=False,
function=deferred_impl,
)


def factory_fixture(request, factory_class):
def factory_fixture(request: FixtureRequest, factory_class: F) -> F:
"""Factory fixture implementation."""
return factory_class


def attr_fixture(request, value):
def attr_fixture(request: FixtureRequest, value: T) -> T:
"""Attribute fixture implementation."""
return value


def subfactory_fixture(request, factory_class):
def subfactory_fixture(request: FixtureRequest, factory_class: FactoryType) -> Any:
"""SubFactory/RelatedFactory fixture implementation."""
fixture = inflection.underscore(factory_class._meta.model.__name__)
return request.getfixturevalue(fixture)


def get_caller_module(depth=2):
def get_caller_module(depth: int = 2) -> ModuleType:
"""Get the module of the caller."""
frame = sys._getframe(depth)
module = getmodule(frame)
Expand All @@ -311,7 +353,7 @@ def get_caller_module(depth=2):
class LazyFixture:
"""Lazy fixture."""

def __init__(self, fixture):
def __init__(self, fixture: Callable | str) -> None:
"""Lazy pytest fixture wrapper.
:param fixture: Fixture name or callable with dependencies.
Expand All @@ -323,7 +365,7 @@ def __init__(self, fixture):
else:
self.args = [self.fixture]

def evaluate(self, request):
def evaluate(self, request: FixtureRequest) -> Any:
"""Evaluate the lazy fixture.
:param request: pytest request object.
Expand Down
52 changes: 33 additions & 19 deletions pytest_factoryboy/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
"""pytest-factoryboy plugin."""
from __future__ import annotations

from collections import defaultdict
import pytest
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from typing import Any
from factory import Factory
from _pytest.fixtures import FixtureRequest
from _pytest.config import PytestPluginManager
from _pytest.python import Metafunc
from _pytest.nodes import Item

from .fixture import DeferredFunction


class CycleDetected(Exception):
Expand All @@ -11,22 +24,22 @@ class CycleDetected(Exception):
class Request:
"""PyTest FactoryBoy request."""

def __init__(self):
def __init__(self) -> None:
"""Create pytest_factoryboy request."""
self.deferred = []
self.results = defaultdict(dict)
self.model_factories = {}
self.in_progress = set()
self.deferred: list[list[DeferredFunction]] = []
self.results: dict[str, dict[str, Any]] = defaultdict(dict)
self.model_factories: dict[str, type[Factory]] = {}
self.in_progress: set = set()

def defer(self, functions):
def defer(self, functions: list[DeferredFunction]) -> None:
"""Defer post-generation declaration execution until the end of the test setup.
:param functions: Functions to be deferred.
:note: Once already finalized all following defer calls will execute the function directly.
"""
self.deferred.append(functions)

def get_deps(self, request, fixture, deps=None):
def get_deps(self, request: FixtureRequest, fixture: str, deps: set[str] | None = None) -> set[str]:
request = request.getfixturevalue("request")

if deps is None:
Expand All @@ -41,40 +54,40 @@ def get_deps(self, request, fixture, deps=None):
deps.update(self.get_deps(request, argname, deps))
return deps

def get_current_deps(self, request):
def get_current_deps(self, request: FixtureRequest) -> set[str]:
deps = set()
while hasattr(request, "_parent_request"):
if request.fixturename and request.fixturename not in getattr(request, "_fixturedefs", {}):
deps.add(request.fixturename)
request = request._parent_request
return deps

def execute(self, request, function, deferred):
def execute(self, request: FixtureRequest, function: DeferredFunction, deferred: list[DeferredFunction]) -> None:
"""Execute deferred function and store the result."""
if function in self.in_progress:
raise CycleDetected()
fixture = function.__name__
fixture = function.name
model, attr = fixture.split("__", 1)
if function._is_related:
if function.is_related:
deps = self.get_deps(request, fixture)
if deps.intersection(self.get_current_deps(request)):
raise CycleDetected()
self.model_factories[model] = function._factory
self.model_factories[model] = function.factory

self.in_progress.add(function)
self.results[model][attr] = function(request)
deferred.remove(function)
self.in_progress.remove(function)

def after_postgeneration(self, request):
def after_postgeneration(self, request: FixtureRequest) -> None:
"""Call _after_postgeneration hooks."""
for model in list(self.results.keys()):
results = self.results.pop(model)
obj = request.getfixturevalue(model)
factory = self.model_factories[model]
factory._after_postgeneration(obj, create=True, results=results)

def evaluate(self, request):
def evaluate(self, request: FixtureRequest) -> None:
"""Finalize, run deferred post-generation actions, etc."""
while self.deferred:
try:
Expand All @@ -91,14 +104,15 @@ def evaluate(self, request):


@pytest.fixture
def factoryboy_request():
def factoryboy_request() -> Request:
"""PyTest FactoryBoy request fixture."""
return Request()


@pytest.mark.tryfirst
def pytest_runtest_call(item):
def pytest_runtest_call(item: Item) -> None:
"""Before the test item is called."""
# TODO: We should instead do an `if isinstance(item, Function)`.
try:
request = item._request
except AttributeError:
Expand All @@ -110,15 +124,15 @@ def pytest_runtest_call(item):
request.config.hook.pytest_factoryboy_done(request=request)


def pytest_addhooks(pluginmanager):
def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Register plugin hooks."""
from pytest_factoryboy import hooks

pluginmanager.add_hookspecs(hooks)


def pytest_generate_tests(metafunc):
related = []
def pytest_generate_tests(metafunc: Metafunc) -> None:
related: list[str] = []
for arg2fixturedef in metafunc._arg2fixturedefs.values():
fixturedef = arg2fixturedef[-1]
related_fixtures = getattr(fixturedef.func, "_factoryboy_related", [])
Expand Down
Empty file added pytest_factoryboy/py.typed
Empty file.

0 comments on commit e056a84

Please sign in to comment.