Skip to content

Commit

Permalink
Arrange db fixtures in a more extensible way
Browse files Browse the repository at this point in the history
The current scheme is not conducive to adding additional modifier
fixtures similar to ``django_db_reset_sequence``, such as a fixture for
``serialized_rollback`` or for specifying databases.

Instead, arrange it such that there is a base helper fixture
`_django_db_helper` which does all the work, and the other fixtures
merely exist to modify it.
  • Loading branch information
bluetech committed Nov 28, 2021
1 parent eeeb163 commit d6b9563
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 84 deletions.
100 changes: 50 additions & 50 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,30 @@ def teardown_database() -> None:
request.addfinalizer(teardown_database)


def _django_db_fixture_helper(
@pytest.fixture()
def _django_db_helper(
request,
django_db_setup: None,
django_db_blocker,
transactional: bool = False,
reset_sequences: bool = False,
) -> None:
from django import VERSION

if is_django_unittest(request):
return

if not transactional and "live_server" in request.fixturenames:
# Do nothing, we get called with transactional=True, too.
return
marker = request.node.get_closest_marker("django_db")
if marker:
transactional, reset_sequences, _databases = validate_django_db(marker)
else:
transactional, reset_sequences, _databases = False, False, None

_databases = getattr(
request.node, "_pytest_django_databases", None,
) # type: Optional[_DjangoDbDatabases]
transactional = transactional or (
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
)
reset_sequences = reset_sequences or (
"django_db_reset_sequences" in request.fixturenames
)

django_db_blocker.unblock()
request.addfinalizer(django_db_blocker.restore)
Expand Down Expand Up @@ -186,6 +192,26 @@ class PytestDjangoTestCase(test_case_class): # type: ignore[misc,valid-type]
request.addfinalizer(test_case._post_teardown)


def validate_django_db(marker) -> "_DjangoDb":
"""Validate the django_db marker.
It checks the signature and creates the ``transaction``,
``reset_sequences`` and ``databases`` attributes on the marker
which will have the correct values.
A sequence reset is only allowed when combined with a transaction.
"""

def apifun(
transaction: bool = False,
reset_sequences: bool = False,
databases: "_DjangoDbDatabases" = None,
) -> "_DjangoDb":
return transaction, reset_sequences, databases

return apifun(*marker.args, **marker.kwargs)


def _disable_migrations() -> None:
from django.conf import settings
from django.core.management.commands import migrate
Expand Down Expand Up @@ -229,41 +255,24 @@ def _set_suffix_to_test_databases(suffix: str) -> None:


@pytest.fixture(scope="function")
def db(
request,
django_db_setup: None,
django_db_blocker,
) -> None:
def db(_django_db_helper: None) -> None:
"""Require a django test database.
This database will be setup with the default fixtures and will have
the transaction management disabled. At the end of the test the outer
transaction that wraps the test itself will be rolled back to undo any
changes to the database (in case the backend supports transactions).
This is more limited than the ``transactional_db`` resource but
This is more limited than the ``transactional_db`` fixture but
faster.
If multiple database fixtures are requested, they take precedence
over each other in the following order (the last one wins): ``db``,
``transactional_db``, ``django_db_reset_sequences``.
If both ``db`` and ``transactional_db`` are requested,
``transactional_db`` takes precedence.
"""
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if (
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
):
request.getfixturevalue("transactional_db")
else:
_django_db_fixture_helper(request, django_db_blocker, transactional=False)
# The `_django_db_helper` fixture checks if `db` is requested.


@pytest.fixture(scope="function")
def transactional_db(
request,
django_db_setup: None,
django_db_blocker,
) -> None:
def transactional_db(_django_db_helper: None) -> None:
"""Require a django test database with transaction support.
This will re-initialise the django database for each test and is
Expand All @@ -272,35 +281,26 @@ def transactional_db(
If you want to use the database with transactions you must request
this resource.
If multiple database fixtures are requested, they take precedence
over each other in the following order (the last one wins): ``db``,
``transactional_db``, ``django_db_reset_sequences``.
If both ``db`` and ``transactional_db`` are requested,
``transactional_db`` takes precedence.
"""
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
_django_db_fixture_helper(request, django_db_blocker, transactional=True)
# The `_django_db_helper` fixture checks if `transactional_db` is requested.


@pytest.fixture(scope="function")
def django_db_reset_sequences(
request,
django_db_setup: None,
django_db_blocker,
_django_db_helper: None,
transactional_db: None,
) -> None:
"""Require a transactional test database with sequence reset support.
This behaves like the ``transactional_db`` fixture, with the addition
of enforcing a reset of all auto increment sequences. If the enquiring
This requests the ``transactional_db`` fixture, and additionally
enforces a reset of all auto increment sequences. If the enquiring
test relies on such values (e.g. ids as primary keys), you should
request this resource to ensure they are consistent across tests.
If multiple database fixtures are requested, they take precedence
over each other in the following order (the last one wins): ``db``,
``transactional_db``, ``django_db_reset_sequences``.
"""
_django_db_fixture_helper(
request, django_db_blocker, transactional=True, reset_sequences=True
)
# The `_django_db_helper` fixture checks if `django_db_reset_sequences`
# is requested.


@pytest.fixture()
Expand Down
36 changes: 3 additions & 33 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from .django_compat import is_django_unittest # noqa
from .fixtures import _django_db_helper # noqa
from .fixtures import _live_server_helper # noqa
from .fixtures import admin_client # noqa
from .fixtures import admin_user # noqa
Expand All @@ -40,6 +41,7 @@
from .fixtures import rf # noqa
from .fixtures import settings # noqa
from .fixtures import transactional_db # noqa
from .fixtures import validate_django_db
from .lazy_django import django_settings_is_configured, skip_if_no_django


Expand All @@ -49,8 +51,6 @@

import django

from .fixtures import _DjangoDb, _DjangoDbDatabases


SETTINGS_MODULE_ENV = "DJANGO_SETTINGS_MODULE"
CONFIGURATION_ENV = "DJANGO_CONFIGURATION"
Expand Down Expand Up @@ -464,17 +464,7 @@ def _django_db_marker(request) -> None:
"""
marker = request.node.get_closest_marker("django_db")
if marker:
transaction, reset_sequences, databases = validate_django_db(marker)

# TODO: Use pytest Stash (item.stash) once that's stable.
request.node._pytest_django_databases = databases

if reset_sequences:
request.getfixturevalue("django_db_reset_sequences")
elif transaction:
request.getfixturevalue("transactional_db")
else:
request.getfixturevalue("db")
request.getfixturevalue("_django_db_helper")


@pytest.fixture(autouse=True, scope="class")
Expand Down Expand Up @@ -743,26 +733,6 @@ def restore(self) -> None:
_blocking_manager = _DatabaseBlocker()


def validate_django_db(marker) -> "_DjangoDb":
"""Validate the django_db marker.
It checks the signature and creates the ``transaction``,
``reset_sequences`` and ``databases`` attributes on the marker
which will have the correct values.
A sequence reset is only allowed when combined with a transaction.
"""

def apifun(
transaction: bool = False,
reset_sequences: bool = False,
databases: "_DjangoDbDatabases" = None,
) -> "_DjangoDb":
return transaction, reset_sequences, databases

return apifun(*marker.args, **marker.kwargs)


def validate_urls(marker) -> List[str]:
"""Validate the urls marker.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_db_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def test_run_second_transaction_test_case(self):
"*test_run_first_django_test_case*",
"*test_run_second_decorator*",
"*test_run_second_fixture*",
"*test_run_second_reset_sequences_fixture*",
"*test_run_second_reset_sequences_decorator*",
"*test_run_second_transaction_test_case*",
"*test_run_second_reset_sequences_fixture*",
"*test_run_last_test_case*",
"*test_run_last_simple_test_case*",
])
Expand Down

0 comments on commit d6b9563

Please sign in to comment.