Skip to content

Commit e580597

Browse files
committed
Pytest plugin
1 parent 17d33e3 commit e580597

File tree

9 files changed

+128
-30
lines changed

9 files changed

+128
-30
lines changed

dbdiff/fixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def dump(self, out):
119119
traceback=True,
120120
indent=self.indent,
121121
stdout=out,
122-
use_natural_foreign_keys=True,
122+
use_natural_foreign_keys=True
123123
)
124124

125125
def assertNoDiff(self, exclude=None): # noqa

dbdiff/plugin.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Pytest plugin for django-dbdiff.
2+
3+
The marker enables the smarter sequence reset feature previously available in
4+
the DbdiffTestMixin in pytest, example usage::
5+
6+
@dbdiff(models=[YourModel])
7+
def your_test():
8+
assert YourModel.objects.create().pk == 1
9+
"""
10+
import pytest
11+
12+
from pytest_django.pytest_compat import getfixturevalue
13+
14+
from .sequence import sequence_reset
15+
16+
17+
@pytest.fixture(autouse=True)
18+
def _dbdiff_marker(request):
19+
marker = request.keywords.get('dbdiff', None)
20+
if not marker:
21+
return
22+
23+
# Enable transactional db
24+
getfixturevalue(request, 'transactional_db')
25+
26+
for model in marker.kwargs['models']:
27+
sequence_reset(model)
28+
29+
30+
def pytest_load_initial_conftests(early_config, parser, args):
31+
"""Register the dbdiff mark."""
32+
early_config.addinivalue_line(
33+
'markers',
34+
'dbdiff(models, reset_sequences=True): Mark the test as using '
35+
'the django test database. The *transaction* argument marks will '
36+
"allow you to use real transactions in the test like Django's "
37+
'TransactionTestCase.')

dbdiff/sequence.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Smarter model pk sequence reset."""
2+
from django.db import connection, models
3+
4+
5+
def pk_sequence_get(model):
6+
"""Return a list of table, column tuples which should have sequences."""
7+
for field in model._meta.get_fields():
8+
if not getattr(field, 'primary_key', False):
9+
continue
10+
if not isinstance(field, models.AutoField):
11+
continue
12+
return field.db_column or field.column
13+
14+
15+
def sequence_reset(model):
16+
"""
17+
Better sequence reset than TransactionTestCase.
18+
19+
The difference with using TransactionTestCase with reset_sequences=True is
20+
that this will reset sequences for the given models to their higher value,
21+
supporting pre-existing models which could have been created by a
22+
migration.
23+
"""
24+
pk_field = pk_sequence_get(model)
25+
if not pk_field:
26+
return
27+
28+
if connection.vendor == 'postgresql':
29+
reset = """
30+
SELECT
31+
setval(
32+
pg_get_serial_sequence('{table}', '{column}'),
33+
coalesce(max({column}),0) + 1,
34+
false
35+
)
36+
FROM {table}
37+
"""
38+
elif connection.vendor == 'sqlite':
39+
reset = """
40+
UPDATE sqlite_sequence
41+
SET seq=(SELECT max({column}) from {table})
42+
WHERE name='{table}'
43+
"""
44+
elif connection.vendor == 'mysql':
45+
cursor = connection.cursor()
46+
cursor.execute(
47+
'SELECT MAX({column}) + 1 FROM {table}'.format(
48+
column=pk_field, table=model._meta.db_table
49+
)
50+
)
51+
result = cursor.fetchone()[0] or 0
52+
reset = 'ALTER TABLE {table} AUTO_INCREMENT = %s' % result
53+
54+
connection.cursor().execute(
55+
reset.format(column=pk_field, table=model._meta.db_table)
56+
)

dbdiff/test.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Convenience test mixin."""
22
from django.core.management import call_command
3-
from django.db import connection, models
43

54
from .fixture import Fixture
5+
from .sequence import sequence_reset
66

77

88
class DbdiffTestMixin(object):
@@ -45,34 +45,7 @@ def test_db_import(self):
4545
call_command('loaddata', fixture)
4646

4747
for model in self.dbdiff_models:
48-
if connection.vendor == 'postgresql':
49-
pk_field = None
50-
for field in model._meta.get_fields():
51-
if getattr(field, 'primary_key', False):
52-
pk_field = field
53-
break
54-
55-
if not isinstance(pk_field, models.AutoField):
56-
continue
57-
58-
if not pk_field.db_column:
59-
continue
60-
61-
reset = """
62-
SELECT
63-
setval(
64-
pg_get_serial_sequence('%(table)s', '%(column)s'),
65-
coalesce(max(%(column)s),0) + 1,
66-
false
67-
)
68-
FROM %(table)s
69-
""" % dict(
70-
table=model._meta.db_table,
71-
column=pk_field.db_column,
72-
)
73-
else:
74-
raise NotImplemented()
75-
connection.cursor().execute(reset)
48+
sequence_reset(model)
7649

7750
self.dbdiff_test()
7851

dbdiff/tests/nonintpk/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Test that we don't crash with non sequence pks."""

dbdiff/tests/nonintpk/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import uuid
2+
3+
from django.db import models
4+
5+
6+
class Nonintpk(models.Model):
7+
# dbdiff should not try to reset this sequence
8+
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)

dbdiff/tests/project/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
'dbdiff',
4141

4242
'dbdiff.tests.decimal_test',
43+
'dbdiff.tests.nonintpk',
4344
)
4445

4546
MIDDLEWARE_CLASSES = (

dbdiff/tests/test_plugin.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dbdiff.tests.decimal_test.models import TestModel as DecimalModel
2+
from dbdiff.tests.nonintpk.models import Nonintpk
3+
4+
import pytest
5+
6+
7+
@pytest.mark.dbdiff(models=[DecimalModel])
8+
def test_insert_first():
9+
assert DecimalModel.objects.count() == 0
10+
assert DecimalModel.objects.create(test_field=1).pk == 1
11+
12+
13+
@pytest.mark.dbdiff(models=[DecimalModel])
14+
def test_still_first_pk():
15+
assert DecimalModel.objects.count() == 0
16+
assert DecimalModel.objects.create(test_field=1).pk == 1
17+
18+
19+
@pytest.mark.dbdiff(models=[DecimalModel, Nonintpk])
20+
def test_doesnt_reset_nonintpk_which_would_fail():
21+
assert DecimalModel.objects.count() == 0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def read(fname):
2323
license='MIT',
2424
keywords='django test database fixture diff',
2525
install_requires=['ijson', 'json_delta', 'six'],
26+
entry_points={'pytest11': ['dbdiff = dbdiff.plugin']},
2627
classifiers=[
2728
'Development Status :: 4 - Beta',
2829
'Environment :: Web Environment',

0 commit comments

Comments
 (0)