Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
230 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from functools import partial | ||
from typing import Callable, Dict, List, Optional, Union | ||
|
||
from django.core.management.color import Style, no_style | ||
from django.db import DefaultConnectionProxy, connections, transaction | ||
from django.db.backends.base.base import BaseDatabaseWrapper | ||
|
||
_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper] | ||
|
||
DJANGO_MIGRATIONS_TABLE_NAME = 'django_migrations' | ||
|
||
|
||
def drop_models_tables( | ||
database_name: str, | ||
style: Optional[Style] = None, | ||
) -> None: | ||
"""Drop all installed Django's models tables.""" | ||
style = style or no_style() | ||
connection = connections[database_name] | ||
tables = connection.introspection.django_table_names( | ||
only_existing=True, | ||
include_views=False, | ||
) | ||
sql_drop_tables = [ | ||
connection.SchemaEditorClass.sql_delete_table % { | ||
'table': style.SQL_FIELD(connection.ops.quote_name(table)), | ||
} | ||
for table in tables | ||
] | ||
if sql_drop_tables: | ||
get_execute_sql_flush_for(connection)(database_name, sql_drop_tables) | ||
|
||
|
||
def flush_django_migrations_table( | ||
database_name: str, | ||
style: Optional[Style] = None, | ||
) -> None: | ||
"""Flush `django_migrations` table.""" | ||
style = style or no_style() | ||
connection = connections[database_name] | ||
django_migrations_sequences = get_django_migrations_table_sequences( | ||
connection, | ||
) | ||
execute_sql_flush = get_execute_sql_flush_for(connection) | ||
execute_sql_flush( | ||
database_name, | ||
connection.ops.sql_flush( | ||
style, | ||
[DJANGO_MIGRATIONS_TABLE_NAME], | ||
django_migrations_sequences, | ||
allow_cascade=False, | ||
), | ||
) | ||
|
||
|
||
def get_django_migrations_table_sequences( | ||
connection: _Connection, | ||
) -> List[Dict[str, str]]: | ||
"""`django_migrations` table introspected sequences. | ||
Returns properly inspected sequences when using `Django>1.11` | ||
and static sequence for `id` column otherwise. | ||
""" | ||
if hasattr(connection.introspection, 'get_sequences'): # noqa: WPS421 | ||
with connection.cursor() as cursor: | ||
return connection.introspection.get_sequences( | ||
cursor, | ||
DJANGO_MIGRATIONS_TABLE_NAME, | ||
) | ||
# for `Django==1.11` only primary key sequence is returned | ||
return [{'table': DJANGO_MIGRATIONS_TABLE_NAME, 'column': 'id'}] | ||
|
||
|
||
def get_execute_sql_flush_for( | ||
connection: _Connection, | ||
) -> Callable[[str, List[str]], None]: | ||
"""Return ``execute_sql_flush`` callable for given connection.""" | ||
return getattr( | ||
connection.ops, | ||
'execute_sql_flush', | ||
partial(execute_sql_flush, connection), | ||
) | ||
|
||
|
||
def execute_sql_flush( | ||
connection: _Connection, | ||
using: str, | ||
sql_list: List[str], | ||
) -> None: # pragma: no cover | ||
"""Execute a list of SQL statements to flush the database. | ||
This function is copy of ``connection.ops.execute_sql_flush`` | ||
method from Django's source code: | ||
https://github.com/django/django/blob/227d0c7365cfd0a64d021cb9bdcf77bed2d3f170/django/db/backends/base/operations.py#L401 | ||
to make `django-test-migrations` compatible with `Django==1.11`. | ||
``connection.ops.execute_sql_flush()`` was introduced in `Django==2.0`. | ||
""" | ||
with transaction.atomic( | ||
using=using, | ||
savepoint=connection.features.can_rollback_ddl, | ||
): | ||
with connection.cursor() as cursor: | ||
for sql in sql_list: | ||
cursor.execute(sql) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from functools import partial | ||
|
||
from django_test_migrations import sql | ||
|
||
TESTING_DATABASE_NAME = 'test' | ||
|
||
|
||
def test_drop_models_table_no_tables_detected(mocker): | ||
"""Ensure any `DROP TABLE` statement executed when no tables detected.""" | ||
testing_connection_mock = mocker.MagicMock() | ||
testing_connection_mock.introspection.django_table_names.return_value = [] | ||
connections_mock = mocker.patch('django.db.connections._connections') | ||
connections_mock.test = testing_connection_mock | ||
sql.drop_models_tables(TESTING_DATABASE_NAME) | ||
testing_connection_mock.ops.execute_sql_flush.assert_not_called() | ||
|
||
|
||
def test_drop_models_table_table_detected(mocker): | ||
"""Ensure `DROP TABLE` statements are executed when any table detected.""" | ||
testing_connection_mock = mocker.MagicMock() | ||
testing_connection_mock.introspection.django_table_names.return_value = [ | ||
'foo_bar', | ||
'foo_baz', | ||
] | ||
connections_mock = mocker.patch('django.db.connections._connections') | ||
connections_mock.test = testing_connection_mock | ||
sql.drop_models_tables(TESTING_DATABASE_NAME) | ||
testing_connection_mock.ops.execute_sql_flush.assert_called_once() | ||
|
||
|
||
def test_get_django_migrations_table_sequences0(mocker): | ||
"""Ensure valid sequences are returned when using `Django>1.11`.""" | ||
connection_mock = mocker.MagicMock() | ||
sql.get_django_migrations_table_sequences(connection_mock) | ||
connection_mock.introspection.get_sequences.assert_called_once_with( | ||
connection_mock.cursor().__enter__.return_value, # noqa: WPS609 | ||
sql.DJANGO_MIGRATIONS_TABLE_NAME, | ||
) | ||
|
||
|
||
def test_get_django_migrations_table_sequences1(mocker): | ||
"""Ensure valid sequences are returned when using `Django==1.11`.""" | ||
connection_mock = mocker.Mock() | ||
del connection_mock.introspection.get_sequences # noqa: WPS420 | ||
assert ( | ||
sql.get_django_migrations_table_sequences(connection_mock) == | ||
[{'table': sql.DJANGO_MIGRATIONS_TABLE_NAME, 'column': 'id'}] | ||
) | ||
|
||
|
||
def test_get_execute_sql_flush_for_method_present(mocker): | ||
"""Ensure connections.ops method returned when it is already present.""" | ||
connection_mock = mocker.Mock() | ||
connection_mock.ops.execute_sql_flush = _fake_execute_sql_flush | ||
execute_sql_flush = sql.get_execute_sql_flush_for(connection_mock) | ||
assert execute_sql_flush == _fake_execute_sql_flush | ||
|
||
|
||
def test_get_execute_sql_flush_for_method_missing(mocker): | ||
"""Ensure custom function is returned when connection.ops miss methods.""" | ||
connection_mock = mocker.Mock() | ||
del connection_mock.ops.execute_sql_flush # noqa: WPS420 | ||
execute_sql_flush = sql.get_execute_sql_flush_for(connection_mock) | ||
assert isinstance(execute_sql_flush, partial) | ||
assert execute_sql_flush.func == sql.execute_sql_flush | ||
assert execute_sql_flush.args[0] == connection_mock | ||
|
||
|
||
def _fake_execute_sql_flush(using, sql_list): | ||
return None |