Skip to content

Commit

Permalink
Merge bc86221 into 351b619
Browse files Browse the repository at this point in the history
  • Loading branch information
skarzi committed Apr 6, 2020
2 parents 351b619 + bc86221 commit 3550492
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 16 deletions.
8 changes: 4 additions & 4 deletions django_test_migrations/contrib/pytest_plugin.py
Expand Up @@ -7,7 +7,7 @@


@pytest.fixture()
def migrator_factory(transactional_db, django_db_use_migrations):
def migrator_factory(request, transactional_db, django_db_use_migrations):
"""
Pytest fixture to create migrators inside the pytest tests.
Expand Down Expand Up @@ -38,9 +38,9 @@ def test_migration(migrator_factory):
pytest.skip('--nomigrations was specified')

def factory(database_name: Optional[str] = None) -> Migrator:
# ``Migrator.reset`` is not registered as finalizer here, because
# database is flushed by ``transactional_db`` fixture's finalizers
return Migrator(database_name)
migrator = Migrator(database_name)
request.addfinalizer(migrator.reset)
return migrator
return factory


Expand Down
5 changes: 5 additions & 0 deletions django_test_migrations/contrib/unittest_case.py
Expand Up @@ -49,3 +49,8 @@ def prepare(self) -> None:
Used to prepare some data before the migration process.
"""

def tearDown(self) -> None:
"""Used to clean mess up after each test."""
self._migrator.reset()
super().tearDown()
32 changes: 21 additions & 11 deletions django_test_migrations/migrator.py
Expand Up @@ -4,11 +4,14 @@
from typing import List, Optional, Tuple, Union

from django.core.management import call_command
from django.core.management.color import no_style
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.state import ProjectState
from django.db.models.signals import post_migrate, pre_migrate

from django_test_migrations import sql

# Regular or rollback migration: 0001 -> 0002, or 0002 -> 0001
# Rollback migration to initial state: 0001 -> None
_Migration = Tuple[str, Optional[str]]
Expand Down Expand Up @@ -65,21 +68,28 @@ def __init__(

def before(self, migrate_from: _MigrationSpec) -> ProjectState:
"""Reverse back to the original migration."""
if not isinstance(migrate_from, list):
migrate_from = [migrate_from]
with _mute_migrate_signals():
return self._executor.migrate(migrate_from)
style = no_style()
# start from clean database state
sql.drop_models_tables(self._database, style)
sql.flush_django_migrations_table(self._database, style)
# apply all necessary migrations on clean database
# (only forward, so any unexpected migration won't be applied)
# to restore database state before tested migration
self._executor.loader.build_graph() # reload.
return self._migrate(migrate_from)

def after(self, migrate_to: _MigrationSpec) -> ProjectState:
"""Apply the next migration."""
self._executor.loader.build_graph() # reload.
return self.before(migrate_to)
return self._migrate(migrate_to)

def reset(self) -> None:
"""Reset the state to the most recent one."""
call_command(
'flush',
database=self._database,
interactive=False,
verbosity=0,
)
call_command('migrate', verbosity=0, database=self._database)

def _migrate(self, migration: _MigrationSpec) -> ProjectState:
"""Migrate to given ``migration``."""
if not isinstance(migration, list):
migration = [migration]
with _mute_migrate_signals():
return self._executor.migrate(migration)
108 changes: 108 additions & 0 deletions django_test_migrations/sql.py
@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-

from functools import partial
from typing import Callable, Dict, List, Optional, Union

from django.core.management.color import Style, no_style
from django.db import DefaultConnectionProxy, connections, transaction
from django.db.backends.base.base import BaseDatabaseWrapper

_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper]

DJANGO_MIGRATIONS_TABLE_NAME = 'django_migrations'


def drop_models_tables(
database_name: str,
style: Optional[Style] = None,
) -> None:
"""Drop all installed Django's models tables."""
style = style or no_style()
connection = connections[database_name]
tables = connection.introspection.django_table_names(
only_existing=True,
include_views=False,
)
sql_drop_tables = [
connection.SchemaEditorClass.sql_delete_table % {
'table': style.SQL_FIELD(connection.ops.quote_name(table)),
}
for table in tables
]
if sql_drop_tables:
get_execute_sql_flush_for(connection)(database_name, sql_drop_tables)


def flush_django_migrations_table(
database_name: str,
style: Optional[Style] = None,
) -> None:
"""Flush `django_migrations` table."""
style = style or no_style()
connection = connections[database_name]
django_migrations_sequences = get_django_migrations_table_sequences(
connection,
)
execute_sql_flush = get_execute_sql_flush_for(connection)
execute_sql_flush(
database_name,
connection.ops.sql_flush(
style,
[DJANGO_MIGRATIONS_TABLE_NAME],
django_migrations_sequences,
allow_cascade=False,
),
)


def get_django_migrations_table_sequences(
connection: _Connection,
) -> List[Dict[str, str]]:
"""`django_migrations` table introspected sequences.
Returns properly inspected sequences when using `Django>1.11`
and static sequence for `id` column otherwise.
"""
if hasattr(connection.introspection, 'get_sequences'): # noqa: WPS421
with connection.cursor() as cursor:
return connection.introspection.get_sequences(
cursor,
DJANGO_MIGRATIONS_TABLE_NAME,
)
# for `Django==1.11` only primary key sequence is returned
return [{'table': DJANGO_MIGRATIONS_TABLE_NAME, 'column': 'id'}]


def get_execute_sql_flush_for(
connection: _Connection,
) -> Callable[[str, List[str]], None]:
"""Return ``execute_sql_flush`` callable for given connection."""
return getattr(
connection.ops,
'execute_sql_flush',
partial(execute_sql_flush, connection),
)


def execute_sql_flush(
connection: _Connection,
using: str,
sql_list: List[str],
) -> None: # pragma: no cover
"""Execute a list of SQL statements to flush the database.
This function is copy of ``connection.ops.execute_sql_flush``
method from Django's source code:
https://github.com/django/django/blob/227d0c7365cfd0a64d021cb9bdcf77bed2d3f170/django/db/backends/base/operations.py#L401
to make `django-test-migrations` compatible with `Django==1.11`.
``connection.ops.execute_sql_flush()`` was introduced in `Django==2.0`.
"""
with transaction.atomic(
using=using,
savepoint=connection.features.can_rollback_ddl,
):
with connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -71,3 +71,4 @@ pytest-cov = "^2.7"
pytest-randomly = "^3.2"
pytest-django = "^3.9"
pytest-pythonpath = "^0.7.3"
pytest-mock = "^2.0.0"
72 changes: 72 additions & 0 deletions tests/test_sql/test_sql.py
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-

from functools import partial

from django_test_migrations import sql

TESTING_DATABASE_NAME = 'test'


def test_drop_models_table_no_tables_detected(mocker):
"""Ensure any `DROP TABLE` statement executed when no tables detected."""
testing_connection_mock = mocker.MagicMock()
testing_connection_mock.introspection.django_table_names.return_value = []
connections_mock = mocker.patch('django.db.connections._connections')
connections_mock.test = testing_connection_mock
sql.drop_models_tables(TESTING_DATABASE_NAME)
testing_connection_mock.ops.execute_sql_flush.assert_not_called()


def test_drop_models_table_table_detected(mocker):
"""Ensure `DROP TABLE` statements are executed when any table detected."""
testing_connection_mock = mocker.MagicMock()
testing_connection_mock.introspection.django_table_names.return_value = [
'foo_bar',
'foo_baz',
]
connections_mock = mocker.patch('django.db.connections._connections')
connections_mock.test = testing_connection_mock
sql.drop_models_tables(TESTING_DATABASE_NAME)
testing_connection_mock.ops.execute_sql_flush.assert_called_once()


def test_get_django_migrations_table_sequences0(mocker):
"""Ensure valid sequences are returned when using `Django>1.11`."""
connection_mock = mocker.MagicMock()
sql.get_django_migrations_table_sequences(connection_mock)
connection_mock.introspection.get_sequences.assert_called_once_with(
connection_mock.cursor().__enter__.return_value, # noqa: WPS609
sql.DJANGO_MIGRATIONS_TABLE_NAME,
)


def test_get_django_migrations_table_sequences1(mocker):
"""Ensure valid sequences are returned when using `Django==1.11`."""
connection_mock = mocker.Mock()
del connection_mock.introspection.get_sequences # noqa: WPS420
assert (
sql.get_django_migrations_table_sequences(connection_mock) ==
[{'table': sql.DJANGO_MIGRATIONS_TABLE_NAME, 'column': 'id'}]
)


def test_get_execute_sql_flush_for_method_present(mocker):
"""Ensure connections.ops method returned when it is already present."""
connection_mock = mocker.Mock()
connection_mock.ops.execute_sql_flush = _fake_execute_sql_flush
execute_sql_flush = sql.get_execute_sql_flush_for(connection_mock)
assert execute_sql_flush == _fake_execute_sql_flush


def test_get_execute_sql_flush_for_method_missing(mocker):
"""Ensure custom function is returned when connection.ops miss methods."""
connection_mock = mocker.Mock()
del connection_mock.ops.execute_sql_flush # noqa: WPS420
execute_sql_flush = sql.get_execute_sql_flush_for(connection_mock)
assert isinstance(execute_sql_flush, partial)
assert execute_sql_flush.func == sql.execute_sql_flush
assert execute_sql_flush.args[0] == connection_mock


def _fake_execute_sql_flush(using, sql_list):
return None

0 comments on commit 3550492

Please sign in to comment.