diff --git a/.envrc.example b/.envrc.example new file mode 100644 index 0000000..d77901f --- /dev/null +++ b/.envrc.example @@ -0,0 +1 @@ +export ENV_DB=postgres,mysql,sqlite3 diff --git a/.github/workflows/check_constraints.yml b/.github/workflows/check_constraints.yml index fcbc89f..5462cd1 100644 --- a/.github/workflows/check_constraints.yml +++ b/.github/workflows/check_constraints.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - os: [ubuntu-latest] +# os: ubuntu-latest # os: [ubuntu-latest, macos-latest, windows-latest] python-version: [3.5, 3.6, 3.7, 3.8, pypy3] @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/cache@v1 - id: Linux-cache + id: Linux-pip-cache if: startsWith(runner.os, 'Linux') with: path: ~/.cache/pip @@ -42,7 +42,7 @@ jobs: ${{ runner.os }}-pip- - uses: actions/cache@v1 - id: macOS-cache + id: macOS-pip-cache if: startsWith(runner.os, 'macOS') with: path: ~/Library/Caches/pip @@ -51,7 +51,7 @@ jobs: ${{ runner.os }}-pip- - uses: actions/cache@v1 - id: Windows-cache + id: Windows-pip-cache if: startsWith(runner.os, 'Windows') with: path: ~\AppData\Local\pip\Cache @@ -92,7 +92,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies - if: steps.${{ runner.os }}.outputs.cache-hit != 'true' +# if: steps.Linux-pip-cache.outputs.cache-hit != 'true' run: make install-test - name: Test with nox diff --git a/.gitignore b/.gitignore index b352f80..801f026 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ build/ dist/ __pycache__/ .nox/ +*.sqlite3 +status.json +.envrc diff --git a/Makefile b/Makefile index 21f66e5..bf82868 100644 --- a/Makefile +++ b/Makefile @@ -32,7 +32,7 @@ clean-build: ## Clean project build artifacts. test: @echo "Running `$(PYTHON_VERSION)` test..." - @$(MANAGE_PY) test + @$(MANAGE_PY) test -v 3 --noinput --failfast install: clean-build ## Install project dependencies. @echo "Installing project in dependencies..." diff --git a/README.md b/README.md index e7a874e..d217b5a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ ![Create New Release](https://github.com/jackton1/django-check-constraint/workflows/Create%20New%20Release/badge.svg) -Extends [Django's Check](https://docs.djangoproject.com/en/3.0/ref/models/options/#constraints) constraint with support for annotations and calling db functions. +Extends [Django's Check](https://docs.djangoproject.com/en/3.0/ref/models/options/#constraints) +constraint with support for UDF(User defined functions/db functions) and annotations. #### Installation @@ -58,7 +59,7 @@ non_null_count Defining a check constraint with this function -The equivalent of +The equivalent of (PostgresSQL) ```postgresql ALTER TABLE app_name_test_modoel ADD CONSTRAINT app_name_test_model_optional_field_provided @@ -132,4 +133,5 @@ TODO's ------ - [ ] Add support for schema based functions. -- [ ] Remove skipped sqlite3 test. +- [ ] Add warning about mysql lack of user defined check constraint support. +- [ ] Remove skipped sqlite3 test. \ No newline at end of file diff --git a/check_constraint/models.py b/check_constraint/models.py index c09f9ff..cc16911 100644 --- a/check_constraint/models.py +++ b/check_constraint/models.py @@ -5,8 +5,8 @@ class AnnotatedCheckConstraint(models.CheckConstraint): def __init__(self, *args, annotations=None, **kwargs): - super().__init__(*args, **kwargs) self.annotations = annotations or {} + super(AnnotatedCheckConstraint, self).__init__(*args, **kwargs) def _get_check_sql(self, model, schema_editor): query = Query(model=model) diff --git a/check_constraint/tests.py b/check_constraint/tests.py index a86fba7..65cdc97 100644 --- a/check_constraint/tests.py +++ b/check_constraint/tests.py @@ -1,16 +1,51 @@ -import os +from decimal import Decimal +from django.conf import settings +from django.contrib.auth import get_user_model +from django.db import IntegrityError, DatabaseError from django.test import TestCase -DATABASES = ["default"] +from demo.models import Book - -if "ENV_DB" in os.environ: - DATABASES += [os.environ["ENV_DB"]] +# TODO: Fix sqlite +User = get_user_model() class AnnotateCheckConstraintTestCase(TestCase): - databases = DATABASES + databases = settings.TEST_ENV_DB + + @classmethod + def setUpTestData(cls): + for db_name in cls._databases_names(include_mirrors=False): + cls.user = User.objects.db_manager(db_name).create_superuser( + username="Admin", email="admin@admin.com", password="test", + ) + + def test_create_passes_with_annotated_check_constraint(self): + for db_name in self._databases_names(include_mirrors=False): + book = Book.objects.using(db_name).create( + name="Business of the 21st Century", + created_by=self.user, + amount=Decimal("50"), + amount_off=Decimal("20.58"), + ) + + self.assertEqual(book.name, "Business of the 21st Century") + self.assertEqual(book.created_by, self.user) - def test_dummy_setup(self): - self.assertEqual(1, 1) + def test_create_is_invalid_with_annotated_check_constraint(self): + for db_name in self._databases_names(include_mirrors=False): + if db_name == "mysql": + with self.assertRaises(DatabaseError): + Book.objects.using(db_name).create( + name="Business of the 21st Century", + created_by=self.user, + amount=Decimal("50"), + ) + else: + with self.assertRaises(IntegrityError): + Book.objects.using(db_name).create( + name="Business of the 21st Century", + created_by=self.user, + amount=Decimal("50"), + ) diff --git a/demo/migrations/0001_initial.py b/demo/migrations/0001_initial.py new file mode 100644 index 0000000..26c42ef --- /dev/null +++ b/demo/migrations/0001_initial.py @@ -0,0 +1,101 @@ +# Generated by Django 2.2.10 on 2020-02-17 07:33 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="Book", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), + ("archived", models.BooleanField(default=False)), + ("amount", models.DecimalField(decimal_places=2, max_digits=9)), + ( + "amount_off", + models.DecimalField( + blank=True, decimal_places=2, max_digits=7, null=True + ), + ), + ( + "percentage", + models.DecimalField( + blank=True, decimal_places=0, max_digits=3, null=True + ), + ), + ( + "created_by", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + ), + migrations.CreateModel( + name="Library", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), + ], + ), + migrations.CreateModel( + name="LibraryBook", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "books", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, to="demo.Book" + ), + ), + ( + "library", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="library_books", + to="demo.Library", + ), + ), + ], + ), + migrations.AddField( + model_name="library", + name="books", + field=models.ManyToManyField(through="demo.LibraryBook", to="demo.Book"), + ), + ] diff --git a/demo/migrations/0002_auto_20200218_0733.py b/demo/migrations/0002_auto_20200218_0733.py new file mode 100644 index 0000000..0919570 --- /dev/null +++ b/demo/migrations/0002_auto_20200218_0733.py @@ -0,0 +1,117 @@ +# Generated by Django 2.2.10 on 2020-02-17 09:44 + +from django.db import migrations + + +def non_null_count(*values): + none_values = [i for i in values if i == None] + + return len(none_values) + + +DB_FUNCTIONS = { + "postgresql": { + "forward": lambda conn, cursor: cursor.execute( + """ + CREATE OR REPLACE FUNCTION public.non_null_count(VARIADIC arg_array ANYARRAY) + RETURNS BIGINT AS + $$ + SELECT COUNT(x) FROM UNNEST($1) AS x + $$ LANGUAGE SQL IMMUTABLE; + """ + ), + "reverse": lambda conn, cursor: cursor.execute( + """ + DROP FUNCTION IF EXISTS public.non_null_count(VARIADIC arg_array ANYARRAY); + """ + ), + }, + "sqlite": { + "forward": lambda conn, cursor: conn.create_function( + "non_null_count", -1, non_null_count + ), + "reverse": lambda conn, cursor: conn.create_function( + "non_null_count", -1, None + ), + }, + "mysql": { + "forward": lambda conn, cursor: cursor.execute( + """ + CREATE FUNCTION non_null_count (params JSON) + RETURNS INT + DETERMINISTIC + READS SQL DATA + BEGIN + DECLARE n INT DEFAULT JSON_LENGTH(params); + DECLARE i INT DEFAULT 0; + DECLARE current BOOLEAN DEFAULT false; + DECLARE val INT DEFAULT 0; + + WHILE i < n DO + SET current = if(JSON_TYPE(JSON_EXTRACT(params, concat('$[', i , ']'))) != 'NULL', true, false); + IF current THEN + SET val = val + 1; + END IF; + SET i = i + 1; + END WHILE; + RETURN val; + END; + CREATE TRIGGER demo_book_validate before INSERT ON demo_book + FOR each row + BEGIN + if non_null_count(JSON_ARRAY(new.amount_off, new.percentage)) = 0 + THEN + signal SQLSTATE '45000' SET message_text = 'Both amount_off and percentage cannot + be null'; + END if; + END; + + + CREATE TRIGGER demo_book_validate_2 before UPDATE ON demo_book + FOR each row + BEGIN + if non_null_count(JSON_ARRAY(new.amount_off, new.percentage)) = 0 + THEN + signal SQLSTATE '45000' SET message_text = 'Both amount_off and percentage cannot + be null'; + END if; + END; + """ + ), + "reverse": lambda conn, cursor: cursor.execute( + """ + DROP FUNCTION non_null_count; + DROP TRIGGER demo_book_validate; + DROP TRIGGER demo_book_validate_2; + """ + ), + }, +} + + +def forwards_func(apps, schema_editor): + conn = schema_editor.connection + vendor = conn.vendor + + with conn.cursor() as cursor: + func = DB_FUNCTIONS[vendor]["forward"] + + func(conn.connection, cursor) + + +def reverse_func(apps, schema_editor): + conn = schema_editor.connection + db_alias = conn.db_alias + + with conn.cursor() as cursor: + func = DB_FUNCTIONS[db_alias]["reverse"] + + func(conn, cursor) + + +class Migration(migrations.Migration): + dependencies = [ + ("demo", "0001_initial"), + ] + + operations = [migrations.RunPython(forwards_func, reverse_func)] diff --git a/demo/migrations/0003_auto_20200222_0146.py b/demo/migrations/0003_auto_20200222_0146.py new file mode 100644 index 0000000..1888ede --- /dev/null +++ b/demo/migrations/0003_auto_20200222_0146.py @@ -0,0 +1,46 @@ +# Generated by Django 3.0.3 on 2020-02-22 01:46 + +import check_constraint.models +import demo.models.function.non_null_count +from django.db import migrations, models, connection + + +class Migration(migrations.Migration): + + dependencies = [ + ("demo", "0002_auto_20200218_0733"), + ] + + operations = [ + migrations.AddConstraint( + model_name="book", + constraint=check_constraint.models.AnnotatedCheckConstraint( + annotations={ + "not_null_count": demo.models.function.non_null_count.NotNullCount( + "amount_off", "percentage" + ) + }, + check=models.Q(not_null_count=1), + name="demo_book_optional_field_provided", + ), + ), + ] + + def apply(self, project_state, schema_editor, collect_sql=False): + if schema_editor.connection.alias == "mysql": + for operation in self.operations: + # Save the state before the operation has run + operation.state_forwards(self.app_label, project_state) + return project_state + + return super(Migration, self).apply( + project_state, schema_editor, collect_sql=collect_sql + ) + + def unapply(self, project_state, schema_editor, collect_sql=False): + if schema_editor.connection.alias == "mysql": + return project_state + + return super(Migration, self).unapply( + project_state, schema_editor, collect_sql=collect_sql + ) diff --git a/demo/models.py b/demo/models.py deleted file mode 100644 index 507503e..0000000 --- a/demo/models.py +++ /dev/null @@ -1,20 +0,0 @@ -from django.contrib.auth import get_user_model -from django.db import models - - -class Books(models.Model): - name = models.CharField(max_length=255) - archived = models.BooleanField(default=False) - created_by = models.ForeignKey(get_user_model(), on_delete=models.CASCADE) - - -class Library(models.Model): - name = models.CharField(max_length=255) - books = models.ManyToManyField(Books, through="LibraryBooks") - - -class LibraryBooks(models.Model): - library = models.ForeignKey( - Library, on_delete=models.CASCADE, related_name="library_books" - ) - books = models.ForeignKey(Books, on_delete=models.PROTECT) diff --git a/demo/models/__init__.py b/demo/models/__init__.py new file mode 100644 index 0000000..d446143 --- /dev/null +++ b/demo/models/__init__.py @@ -0,0 +1,2 @@ +from demo.models.book import Book # noqa +from demo.models.library import Library # noqa diff --git a/demo/models/book.py b/demo/models/book.py new file mode 100644 index 0000000..f38cc96 --- /dev/null +++ b/demo/models/book.py @@ -0,0 +1,31 @@ +from django.contrib.auth import get_user_model +from django.db import models +from django.db.models import Q + +from check_constraint.models import AnnotatedCheckConstraint +from demo.models.function import NotNullCount + + +class Book(models.Model): + name = models.CharField(max_length=255) + archived = models.BooleanField(default=False) + created_by = models.ForeignKey(get_user_model(), on_delete=models.CASCADE) + + amount = models.DecimalField(max_digits=9, decimal_places=2,) + amount_off = models.DecimalField( + max_digits=7, decimal_places=2, null=True, blank=True, + ) + percentage = models.DecimalField( + max_digits=3, decimal_places=0, null=True, blank=True, + ) + + class Meta: + constraints = [ + AnnotatedCheckConstraint( + check=Q(not_null_count=1), + annotations={ + "not_null_count": (NotNullCount("amount_off", "percentage",)), + }, + name="%(app_label)s_%(class)s_optional_field_provided", + ), + ] diff --git a/demo/models/function/__init__.py b/demo/models/function/__init__.py new file mode 100644 index 0000000..7df83f0 --- /dev/null +++ b/demo/models/function/__init__.py @@ -0,0 +1 @@ +from demo.models.function.non_null_count import NotNullCount # noqa diff --git a/demo/models/function/non_null_count.py b/demo/models/function/non_null_count.py new file mode 100644 index 0000000..4797967 --- /dev/null +++ b/demo/models/function/non_null_count.py @@ -0,0 +1,35 @@ +from django.db.models import Func, SmallIntegerField, TextField +from django.db.models.functions import Cast + + +class NotNullCount(Func): + function = "non_null_count" + + def __init__(self, *expressions, **extra): + filter_exp = [ + Cast(exp, TextField()) for exp in expressions if isinstance(exp, str) + ] + if "output_field" not in extra: + extra["output_field"] = SmallIntegerField() + + if len(expressions) < 2: + raise ValueError("NotNullCount must take at least two expressions") + + super().__init__(*filter_exp, **extra) + + def as_sqlite(self, compiler, connection, **extra_context): + connection.ops.check_expression_support(self) + sql_parts = [] + params = [] + for arg in self.source_expressions: + arg_sql, arg_params = compiler.compile(arg) + sql_parts.append(arg_sql) + params.extend(arg_params) + data = {**self.extra, **extra_context} + data["template"] = "%(function)s(%(expressions)s)" + arg_joiner = self.arg_joiner + data["function"] = self.function + data["expressions"] = data["field"] = arg_joiner.join(sql_parts) + template = data["template"] + + return template % data, params diff --git a/demo/models/library.py b/demo/models/library.py new file mode 100644 index 0000000..2125c5d --- /dev/null +++ b/demo/models/library.py @@ -0,0 +1,13 @@ +from django.db import models + + +class Library(models.Model): + name = models.CharField(max_length=255) + books = models.ManyToManyField("Book", through="LibraryBook") + + +class LibraryBook(models.Model): + library = models.ForeignKey( + Library, on_delete=models.CASCADE, related_name="library_books" + ) + books = models.ForeignKey("Book", on_delete=models.PROTECT) diff --git a/django_check_constraint/settings.py b/django_check_constraint/settings.py index c5b15da..47fa991 100644 --- a/django_check_constraint/settings.py +++ b/django_check_constraint/settings.py @@ -37,6 +37,8 @@ "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", + "check_constraint", + "demo", ] MIDDLEWARE = [ @@ -72,18 +74,24 @@ # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases +TEST_ENV_DB = [] if "ENV_DB" not in os.environ else os.environ["ENV_DB"].split(",") DATABASES = { "default": { "ENGINE": "django.db.backends.sqlite3", "NAME": os.path.join(BASE_DIR, "db.sqlite3"), + "TEST": {"DEPENDENCIES": TEST_ENV_DB}, } } +if "sqlite3" in TEST_ENV_DB: + DATABASES["sqlite3"] = { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + "TEST": {"DEPENDENCIES": []}, + } -_ENV_DB = os.environ.get("ENV_DB") - -if _ENV_DB == "postgres": +if "postgres" in TEST_ENV_DB: DATABASES["postgres"] = { "ENGINE": "django.db.backends.postgresql", "NAME": os.environ.get("POSTGRES_DB", "test_postgres"), @@ -91,8 +99,10 @@ "PASSWORD": os.environ.get("POSTGRES_PASSWORD", ""), "HOST": "localhost", "PORT": os.environ.get("POSTGRES_PORT", "5432"), + "TEST": {"DEPENDENCIES": []}, } -elif _ENV_DB == "mysql": + +if "mysql" in TEST_ENV_DB: DATABASES["mysql"] = { "ENGINE": "django.db.backends.mysql", "HOST": "127.0.0.1", @@ -100,6 +110,7 @@ "USER": "root", "PASSWORD": os.environ.get("MYSQL_ROOT_PASSWORD", ""), "PORT": os.environ.get("MYSQL_PORT", "3306"), + "TEST": {"DEPENDENCIES": []}, } diff --git a/noxfile.py b/noxfile.py index e1976b3..b2d4b9d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -19,18 +19,23 @@ @nox.session(python=["3.5", "3.6", "3.7", "3.8"]) @nox.parametrize("django", ["2.2.10", "3.0", "3.0.1", "3.0.2", "3.0.3"]) -@nox.parametrize("database", ["postgres", "mysql"]) +@nox.parametrize("database", ["postgres", "mysql", "sqlite3"]) def tests(session, django, database): if django.split(".")[0] == "3" and session.python == "3.5": session.skip("Python: {} and django: {}".format(session.python, django)) - session.install( - *DB_PACKAGE[database][session.python], - env={ - "LDFLAGS": "-L/usr/local/opt/openssl@1.1/lib", - "CPPFLAGS": "-I/usr/local/opt/openssl@1.1/include", - } - ) + if database == "sqlite3" and session.python not in []: + # TODO: Fix me for all python versions. + session.skip("Python: {} and db: {}".format(session.python, database)) + + if database != "sqlite3": + session.install( + *DB_PACKAGE[database][session.python], + env={ + "LDFLAGS": "-L/usr/local/opt/openssl@1.1/lib", + "CPPFLAGS": "-I/usr/local/opt/openssl@1.1/include", + } + ) session.install("django=={}".format(django)) session.run("bash", "-c", "make test", external=True, env={"ENV_DB": database}) diff --git a/requirements.txt b/requirements.txt index 5f8ee80..4b2a7c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ # # pip-compile # -django==2.2.10 +asgiref==3.2.3 # via django +django==3.0.3 pytz==2019.3 # via django sqlparse==0.3.0 # via django diff --git a/setup.py b/setup.py index 4c897e7..0aacde8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages, setup -install_requires = ["Django>=2.2.10"] +install_requires = ["Django>=2.2.10,<4.0.0"] test_requires = [ "nox==2019.11.9", @@ -28,7 +28,12 @@ "pre-commit==2.0.1", ] -local_dev_requires = ["pip-tools==4.4.1", "check-manifest==0.37"] +local_dev_requires = [ + "pip-tools==4.4.1", + "check-manifest==0.37", + "psycopg2>=2.5.4", + "mysqlclient>=1.3.13", +] extras_require = { "development": [local_dev_requires, install_requires, test_requires, lint_requires],