From 940a66a5ed7b62da329c7fd525fe823df9d2f84c Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Mon, 17 Nov 2025 08:54:36 +0400 Subject: [PATCH 1/7] Add possibility to create signatures from CeleryTaskModel instance --- pyproject.toml | 2 + src/django_celery_boost/models.py | 42 +++- src/django_celery_boost/task.py | 48 ++++ tests/demoapp/demo/factories.py | 26 ++- .../0002_addtojob_sumandaddtojob_valuejob.py | 208 ++++++++++++++++++ tests/demoapp/demo/models.py | 18 ++ tests/demoapp/demo/tasks.py | 26 +++ tests/test_canvas.py | 26 +++ tests/test_signature.py | 28 +++ 9 files changed, 417 insertions(+), 7 deletions(-) create mode 100644 src/django_celery_boost/task.py create mode 100644 tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py create mode 100644 tests/test_canvas.py create mode 100644 tests/test_signature.py diff --git a/pyproject.toml b/pyproject.toml index 8c73ee1..aa9ec18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,10 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ "celery>=5.4", + "celery-types>=0.23", "django-admin-extra-buttons", "django-concurrency", + "pre-commit>=4.3", "sentry-sdk", ] diff --git a/src/django_celery_boost/models.py b/src/django_celery_boost/models.py index e345510..795aab1 100644 --- a/src/django_celery_boost/models.py +++ b/src/django_celery_boost/models.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generator import sentry_sdk -from celery import states +from celery import states, Signature from celery.app.base import Celery from concurrency.api import concurrency_disable_increment from concurrency.fields import AutoIncVersionField @@ -19,6 +19,7 @@ from django.utils.translation import gettext as _ from django_celery_boost.signals import task_queued, task_revoked, task_terminated +from django_celery_boost.task import TaskRunFromSignature if TYPE_CHECKING: import celery.app.control @@ -29,6 +30,17 @@ logger = logging.getLogger(__name__) +APP_LABEL = "app_label" +MODEL_NAME = "model_name" + + +class InvalidTaskBase(TypeError): + def __init__(self, task_handler_name: str): + super().__init__( + f"{task_handler_name} must be a TaskRunFromSignature instance. Use base argument with shared_task or app.task decorator." + ) + + class CeleryManager: pass @@ -335,6 +347,13 @@ def task_status(self) -> str: except Exception as e: # noqa return str(e) + def set_queued(self, result: AsyncResult) -> None: + with concurrency_disable_increment(self): + self.curr_async_result_id = result.id + self.datetime_queued = timezone.now() + self.save(update_fields=["curr_async_result_id", "datetime_queued"]) + task_queued.send(sender=self.__class__, task=self) + def queue(self, use_version: bool = True) -> str | None: """Queue the record processing. @@ -342,11 +361,7 @@ def queue(self, use_version: bool = True) -> str | None: """ if self.task_status not in self.ACTIVE_STATUSES: res = self.task_handler.delay(self.pk, self.version if use_version else None) - with concurrency_disable_increment(self): - self.curr_async_result_id = res.id - self.datetime_queued = timezone.now() - self.save(update_fields=["curr_async_result_id", "datetime_queued"]) - task_queued.send(sender=self.__class__, task=self) + self.set_queued(res) return self.curr_async_result_id return None @@ -388,6 +403,21 @@ def terminate(self, wait=False, timeout=None) -> str: task_terminated.send(sender=self.__class__, task=self) return st + def signature(self) -> Signature: + if not isinstance(self.task_handler, TaskRunFromSignature): + raise InvalidTaskBase(self.celery_task_name) + + return self.task_handler.signature( + (self.pk, self.version), + { + APP_LABEL: self._meta.app_label, + MODEL_NAME: self._meta.model_name, + }, + ) + + def s(self) -> Signature: + return self.signature() + @classmethod def discard_all(cls: "type[CeleryTaskModel]") -> None: cls.celery_app.control.discard_all() diff --git a/src/django_celery_boost/task.py b/src/django_celery_boost/task.py new file mode 100644 index 0000000..e4480b7 --- /dev/null +++ b/src/django_celery_boost/task.py @@ -0,0 +1,48 @@ +from typing import Any, Protocol, cast + +from celery import Task +from celery.result import EagerResult, AsyncResult +from django.apps import apps + + +class ApplyCallable[T: AsyncResult](Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> T: ... + + +def _apply[T: AsyncResult](apply_method: ApplyCallable[T], *args: Any, **kwargs: Any) -> T: + from django_celery_boost.models import APP_LABEL, MODEL_NAME, CeleryTaskModel + + task_args = args[0] + pk, version = task_args[-2], task_args[-1] + task_kwargs = args[1] + app_label, model_name = task_kwargs[APP_LABEL], task_kwargs[MODEL_NAME] + + model_class = cast(type[CeleryTaskModel], apps.get_model(app_label, model_name)) + model = model_class.objects.get(pk=pk, version=version) + + # we want pk and version to always come first + new_task_args = task_args[-2:] + task_args[:-2] + # We can possibly check whether the task is in an active state, if it's + # required, it can be done by passing some flag in kwargs and raising an + # exception in the overridden run method. Other options could require much + # more work + new_task_kwargs: dict[str, Any] = {} + + new_args: tuple[Any, ...] = (new_task_args, new_task_kwargs) + args[2:] + result = apply_method(*new_args, **kwargs) + model.set_queued(result) + return result + + +class TaskRunFromSignature(Task): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return super().__call__(*args, **kwargs) + + def apply(self, *args: Any, **kwargs: Any) -> EagerResult: + return _apply(super().apply, *args, **kwargs) + + def apply_async(self, *args: Any, **kwargs: Any) -> AsyncResult: + return _apply(super().apply_async, *args, **kwargs) + + def run(self, *args: Any, **kwargs: Any) -> Any: + return super().run(*args, **kwargs) diff --git a/tests/demoapp/demo/factories.py b/tests/demoapp/demo/factories.py index 3a60a34..d67d0f9 100644 --- a/tests/demoapp/demo/factories.py +++ b/tests/demoapp/demo/factories.py @@ -1,7 +1,7 @@ from typing import Any, Optional import factory -from demo.models import Job, MultipleJob +from demo.models import Job, MultipleJob, AddToJob, SumAndAddToJob, ValueJob from django.contrib.auth.models import Group, Permission, User from factory.django import DjangoModelFactory from factory.faker import Faker @@ -82,3 +82,27 @@ def start(self): def stop(self): """Stop an active patch.""" return self.__exit__() + + +class ValueJobFactory(DjangoModelFactory): + curr_async_result_id = None + last_async_result_id = None + + class Meta: + model = ValueJob + + +class AddToJobFactory(DjangoModelFactory): + curr_async_result_id = None + last_async_result_id = None + + class Meta: + model = AddToJob + + +class SumAndAddToJobFactory(DjangoModelFactory): + curr_async_result_id = None + last_async_result_id = None + + class Meta: + model = SumAndAddToJob diff --git a/tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py b/tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py new file mode 100644 index 0000000..afa9872 --- /dev/null +++ b/tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py @@ -0,0 +1,208 @@ +# Generated by Django 5.2.8 on 2025-11-14 11:06 + +import concurrency.fields +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("demo", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="AddToJob", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), + ("description", models.CharField(blank=True, max_length=255, null=True)), + ( + "curr_async_result_id", + models.CharField( + blank=True, + editable=False, + help_text="Current (active) AsyncResult is", + max_length=36, + null=True, + ), + ), + ( + "last_async_result_id", + models.CharField( + blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True + ), + ), + ("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")), + ( + "datetime_queued", + models.DateTimeField( + blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At" + ), + ), + ( + "repeatable", + models.BooleanField( + blank=True, default=False, help_text="Indicate if the job can be repeated as-is" + ), + ), + ("celery_history", models.JSONField(blank=True, default=dict, editable=False)), + ("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)), + ( + "group_key", + models.CharField( + blank=True, + editable=False, + help_text="Tasks with the same group key will not run in parallel", + max_length=255, + null=True, + ), + ), + ("value", models.IntegerField(default=0)), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(app_label)s_%(class)s_jobs", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + "default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"), + }, + ), + migrations.CreateModel( + name="SumAndAddToJob", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), + ("description", models.CharField(blank=True, max_length=255, null=True)), + ( + "curr_async_result_id", + models.CharField( + blank=True, + editable=False, + help_text="Current (active) AsyncResult is", + max_length=36, + null=True, + ), + ), + ( + "last_async_result_id", + models.CharField( + blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True + ), + ), + ("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")), + ( + "datetime_queued", + models.DateTimeField( + blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At" + ), + ), + ( + "repeatable", + models.BooleanField( + blank=True, default=False, help_text="Indicate if the job can be repeated as-is" + ), + ), + ("celery_history", models.JSONField(blank=True, default=dict, editable=False)), + ("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)), + ( + "group_key", + models.CharField( + blank=True, + editable=False, + help_text="Tasks with the same group key will not run in parallel", + max_length=255, + null=True, + ), + ), + ("value", models.IntegerField(default=0)), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(app_label)s_%(class)s_jobs", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + "default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"), + }, + ), + migrations.CreateModel( + name="ValueJob", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), + ("description", models.CharField(blank=True, max_length=255, null=True)), + ( + "curr_async_result_id", + models.CharField( + blank=True, + editable=False, + help_text="Current (active) AsyncResult is", + max_length=36, + null=True, + ), + ), + ( + "last_async_result_id", + models.CharField( + blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True + ), + ), + ("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")), + ( + "datetime_queued", + models.DateTimeField( + blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At" + ), + ), + ( + "repeatable", + models.BooleanField( + blank=True, default=False, help_text="Indicate if the job can be repeated as-is" + ), + ), + ("celery_history", models.JSONField(blank=True, default=dict, editable=False)), + ("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)), + ( + "group_key", + models.CharField( + blank=True, + editable=False, + help_text="Tasks with the same group key will not run in parallel", + max_length=255, + null=True, + ), + ), + ("value", models.IntegerField(default=0)), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(app_label)s_%(class)s_jobs", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + "default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"), + }, + ), + ] diff --git a/tests/demoapp/demo/models.py b/tests/demoapp/demo/models.py index 050edb9..f3a602b 100644 --- a/tests/demoapp/demo/models.py +++ b/tests/demoapp/demo/models.py @@ -44,3 +44,21 @@ class Meta(AsyncJobModel.Meta): permissions = (("test_multiplejob", "Can test MultipleJob"),) celery_task_name = "demo.tasks.process_job" + + +class ValueJob(CeleryTaskModel, models.Model): + value = models.IntegerField(default=0) + + celery_task_name = "demo.tasks.value" + + +class AddToJob(CeleryTaskModel, models.Model): + value = models.IntegerField(default=0) + + celery_task_name = "demo.tasks.add_to" + + +class SumAndAddToJob(CeleryTaskModel, models.Model): + value = models.IntegerField(default=0) + + celery_task_name = "demo.tasks.sum_and_add_to" diff --git a/tests/demoapp/demo/tasks.py b/tests/demoapp/demo/tasks.py index bf37278..3d1a869 100644 --- a/tests/demoapp/demo/tasks.py +++ b/tests/demoapp/demo/tasks.py @@ -4,6 +4,8 @@ from concurrency.exceptions import RecordModifiedError from django.core.cache import cache +from django_celery_boost.task import TaskRunFromSignature + @shared_task(bind=True) def process_job(self, pk, version=None): @@ -52,3 +54,27 @@ def cache_store(key, value): @shared_task() def raise_task(param): raise Exception("Boom!") + + +@shared_task(base=TaskRunFromSignature) +def value(pk: int, version: int) -> int: + from .models import ValueJob + + job = ValueJob.objects.get(pk=pk, version=version) + return job.value + + +@shared_task(base=TaskRunFromSignature) +def add_to(pk: int, version: int, value: int) -> int: + from .models import AddToJob + + job = AddToJob.objects.get(pk=pk, version=version) + return job.value + value + + +@shared_task(base=TaskRunFromSignature) +def sum_and_add_to(pk: int, version: int, values: list[int]) -> int: + from .models import SumAndAddToJob + + job = SumAndAddToJob.objects.get(pk=pk, version=version) + return job.value + sum(values) diff --git a/tests/test_canvas.py b/tests/test_canvas.py new file mode 100644 index 0000000..34d6db4 --- /dev/null +++ b/tests/test_canvas.py @@ -0,0 +1,26 @@ +from typing import cast + +from celery import chain, group, chord +from celery.worker import WorkController + +from demo.factories import AddToJobFactory, SumAndAddToJobFactory, ValueJobFactory +from demo.models import AddToJob, SumAndAddToJob, ValueJob + +pytest_plugins = ("celery.contrib.pytest",) + + +def test_chain(transactional_db: None, celery_worker: WorkController) -> None: + value_job = cast(ValueJob, ValueJobFactory(value=5)) + add_to_job = cast(AddToJob, AddToJobFactory(value=15)) + assert chain(value_job.s(), add_to_job.s())().get() == value_job.value + add_to_job.value + + +def test_group(transactional_db: None, celery_worker: WorkController) -> None: + value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] + assert group([job.s() for job in value_jobs])().get() == [1, 2, 3] + + +def test_chord(transactional_db: None, celery_worker: WorkController) -> None: + value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] + sum_job = cast(SumAndAddToJob, SumAndAddToJobFactory(value=5)) + assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 11 diff --git a/tests/test_signature.py b/tests/test_signature.py new file mode 100644 index 0000000..6f7e61d --- /dev/null +++ b/tests/test_signature.py @@ -0,0 +1,28 @@ +from typing import cast + +import pytest + +from demo.factories import AddToJobFactory, JobFactory +from demo.models import AddToJob, Job +from django_celery_boost.models import InvalidTaskBase, APP_LABEL, MODEL_NAME + +pytestmark = [pytest.mark.django_db] + + +def test_can_create_signature_using_different_methods() -> None: + job = cast(AddToJob, AddToJobFactory()) + assert job.signature() == job.s() + + +def test_signature_parameters() -> None: + job = cast(AddToJob, AddToJobFactory()) + signature = job.signature() + assert signature.task == job.celery_task_name + assert signature.args == (job.id, job.version) + assert signature.kwargs == {APP_LABEL: AddToJob._meta.app_label, MODEL_NAME: AddToJob._meta.model_name} + + +def test_cannot_create_signature_without_correct_base_task() -> None: + with pytest.raises(InvalidTaskBase): + job = cast(Job, JobFactory()) + job.s() From 5031b15431255e8ec634f57069dcaf7aa5daed58 Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Mon, 17 Nov 2025 09:03:27 +0400 Subject: [PATCH 2/7] Cleanup code --- pyproject.toml | 4 ++-- src/django_celery_boost/task.py | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa9ec18..adc0450 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,10 +27,8 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ "celery>=5.4", - "celery-types>=0.23", "django-admin-extra-buttons", "django-concurrency", - "pre-commit>=4.3", "sentry-sdk", ] @@ -41,11 +39,13 @@ urls.homepage = "https://github.com/unicef/django-celery-boost" dev = [ "celery[pytest]>=5.4", + "celery-types>=0.23", "django-webtest>=1.9.13", "factory-boy>=3.3.1", "flower>=2.0.1", "mypy>=1.11.2", "pdbpp>=0.10.3", + "pre-commit>=4.3", "psycopg>=2.9.9", "psycopg-binary>=3.2.2", "pytest>=8.3.3", diff --git a/src/django_celery_boost/task.py b/src/django_celery_boost/task.py index e4480b7..5c00719 100644 --- a/src/django_celery_boost/task.py +++ b/src/django_celery_boost/task.py @@ -35,14 +35,8 @@ def _apply[T: AsyncResult](apply_method: ApplyCallable[T], *args: Any, **kwargs: class TaskRunFromSignature(Task): - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return super().__call__(*args, **kwargs) - def apply(self, *args: Any, **kwargs: Any) -> EagerResult: return _apply(super().apply, *args, **kwargs) def apply_async(self, *args: Any, **kwargs: Any) -> AsyncResult: return _apply(super().apply_async, *args, **kwargs) - - def run(self, *args: Any, **kwargs: Any) -> Any: - return super().run(*args, **kwargs) From 4dad5eaa894a835fd5317515e268bc20f7d7f154 Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Mon, 17 Nov 2025 09:11:40 +0400 Subject: [PATCH 3/7] Use simpler syntax --- src/django_celery_boost/task.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/django_celery_boost/task.py b/src/django_celery_boost/task.py index 5c00719..287ca94 100644 --- a/src/django_celery_boost/task.py +++ b/src/django_celery_boost/task.py @@ -5,11 +5,11 @@ from django.apps import apps -class ApplyCallable[T: AsyncResult](Protocol): - def __call__(self, *args: Any, **kwargs: Any) -> T: ... +class ApplyCallable(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> AsyncResult: ... -def _apply[T: AsyncResult](apply_method: ApplyCallable[T], *args: Any, **kwargs: Any) -> T: +def _apply[T: AsyncResult](apply_method: ApplyCallable, *args: Any, **kwargs: Any) -> AsyncResult: from django_celery_boost.models import APP_LABEL, MODEL_NAME, CeleryTaskModel task_args = args[0] @@ -36,7 +36,7 @@ def _apply[T: AsyncResult](apply_method: ApplyCallable[T], *args: Any, **kwargs: class TaskRunFromSignature(Task): def apply(self, *args: Any, **kwargs: Any) -> EagerResult: - return _apply(super().apply, *args, **kwargs) + return cast(EagerResult, _apply(super().apply, *args, **kwargs)) def apply_async(self, *args: Any, **kwargs: Any) -> AsyncResult: return _apply(super().apply_async, *args, **kwargs) From c8a031bfc79bb1673e9a533a67cd5f1e885081b3 Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Mon, 17 Nov 2025 09:13:42 +0400 Subject: [PATCH 4/7] Fix syntax error --- src/django_celery_boost/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_celery_boost/task.py b/src/django_celery_boost/task.py index 287ca94..240487d 100644 --- a/src/django_celery_boost/task.py +++ b/src/django_celery_boost/task.py @@ -9,7 +9,7 @@ class ApplyCallable(Protocol): def __call__(self, *args: Any, **kwargs: Any) -> AsyncResult: ... -def _apply[T: AsyncResult](apply_method: ApplyCallable, *args: Any, **kwargs: Any) -> AsyncResult: +def _apply(apply_method: ApplyCallable, *args: Any, **kwargs: Any) -> AsyncResult: from django_celery_boost.models import APP_LABEL, MODEL_NAME, CeleryTaskModel task_args = args[0] From c1313b55fcec8f11f1fd5d374b387af0cfc56862 Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Tue, 18 Nov 2025 12:13:51 +0400 Subject: [PATCH 5/7] Simplify tests --- tests/demoapp/demo/factories.py | 6 +++--- ...tojob_valuejob.py => 0002_addtojob_sumjob_valuejob.py} | 3 +-- tests/demoapp/demo/models.py | 6 ++---- tests/demoapp/demo/tasks.py | 7 ++----- tests/test_canvas.py | 8 ++++---- 5 files changed, 12 insertions(+), 18 deletions(-) rename tests/demoapp/demo/migrations/{0002_addtojob_sumandaddtojob_valuejob.py => 0002_addtojob_sumjob_valuejob.py} (98%) diff --git a/tests/demoapp/demo/factories.py b/tests/demoapp/demo/factories.py index d67d0f9..1b96330 100644 --- a/tests/demoapp/demo/factories.py +++ b/tests/demoapp/demo/factories.py @@ -1,7 +1,7 @@ from typing import Any, Optional import factory -from demo.models import Job, MultipleJob, AddToJob, SumAndAddToJob, ValueJob +from demo.models import Job, MultipleJob, AddToJob, SumJob, ValueJob from django.contrib.auth.models import Group, Permission, User from factory.django import DjangoModelFactory from factory.faker import Faker @@ -100,9 +100,9 @@ class Meta: model = AddToJob -class SumAndAddToJobFactory(DjangoModelFactory): +class SumJobFactory(DjangoModelFactory): curr_async_result_id = None last_async_result_id = None class Meta: - model = SumAndAddToJob + model = SumJob diff --git a/tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py b/tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py similarity index 98% rename from tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py rename to tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py index afa9872..e835f58 100644 --- a/tests/demoapp/demo/migrations/0002_addtojob_sumandaddtojob_valuejob.py +++ b/tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py @@ -78,7 +78,7 @@ class Migration(migrations.Migration): }, ), migrations.CreateModel( - name="SumAndAddToJob", + name="SumJob", fields=[ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), @@ -124,7 +124,6 @@ class Migration(migrations.Migration): null=True, ), ), - ("value", models.IntegerField(default=0)), ( "owner", models.ForeignKey( diff --git a/tests/demoapp/demo/models.py b/tests/demoapp/demo/models.py index f3a602b..56ee476 100644 --- a/tests/demoapp/demo/models.py +++ b/tests/demoapp/demo/models.py @@ -58,7 +58,5 @@ class AddToJob(CeleryTaskModel, models.Model): celery_task_name = "demo.tasks.add_to" -class SumAndAddToJob(CeleryTaskModel, models.Model): - value = models.IntegerField(default=0) - - celery_task_name = "demo.tasks.sum_and_add_to" +class SumJob(CeleryTaskModel, models.Model): + celery_task_name = "demo.tasks.sum_" diff --git a/tests/demoapp/demo/tasks.py b/tests/demoapp/demo/tasks.py index 3d1a869..f7cd43c 100644 --- a/tests/demoapp/demo/tasks.py +++ b/tests/demoapp/demo/tasks.py @@ -73,8 +73,5 @@ def add_to(pk: int, version: int, value: int) -> int: @shared_task(base=TaskRunFromSignature) -def sum_and_add_to(pk: int, version: int, values: list[int]) -> int: - from .models import SumAndAddToJob - - job = SumAndAddToJob.objects.get(pk=pk, version=version) - return job.value + sum(values) +def sum_(_: int, __: int, values: list[int]) -> int: + return sum(values) diff --git a/tests/test_canvas.py b/tests/test_canvas.py index 34d6db4..acb8843 100644 --- a/tests/test_canvas.py +++ b/tests/test_canvas.py @@ -3,8 +3,8 @@ from celery import chain, group, chord from celery.worker import WorkController -from demo.factories import AddToJobFactory, SumAndAddToJobFactory, ValueJobFactory -from demo.models import AddToJob, SumAndAddToJob, ValueJob +from demo.factories import AddToJobFactory, SumJobFactory, ValueJobFactory +from demo.models import AddToJob, SumJob, ValueJob pytest_plugins = ("celery.contrib.pytest",) @@ -22,5 +22,5 @@ def test_group(transactional_db: None, celery_worker: WorkController) -> None: def test_chord(transactional_db: None, celery_worker: WorkController) -> None: value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] - sum_job = cast(SumAndAddToJob, SumAndAddToJobFactory(value=5)) - assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 11 + sum_job = cast(SumJob, SumJobFactory()) + assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 6 From 001480ce4f8a5bf46557bccbdcef962035e49c48 Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Tue, 18 Nov 2025 16:35:37 +0400 Subject: [PATCH 6/7] Test canvas related code in both eager and async mode --- tests/test_canvas.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/test_canvas.py b/tests/test_canvas.py index acb8843..3a51bd2 100644 --- a/tests/test_canvas.py +++ b/tests/test_canvas.py @@ -3,24 +3,38 @@ from celery import chain, group, chord from celery.worker import WorkController +from _pytest.fixtures import SubRequest +import pytest +from pytest_django.fixtures import SettingsWrapper + from demo.factories import AddToJobFactory, SumJobFactory, ValueJobFactory from demo.models import AddToJob, SumJob, ValueJob pytest_plugins = ("celery.contrib.pytest",) -def test_chain(transactional_db: None, celery_worker: WorkController) -> None: +@pytest.fixture(params=[True, False], ids=["eager", "async"]) +def execution_mode(request: SubRequest, settings: SettingsWrapper) -> None: + if request.param: + settings.CELERY_TASK_ALWAYS_EAGER = True + settings.CELERY_TASK_STORE_EAGER_RESULT = True + else: + settings.CELERY_TASK_ALWAYS_EAGER = False + settings.CELERY_TASK_STORE_EAGER_RESULT = False + + +def test_chain(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: value_job = cast(ValueJob, ValueJobFactory(value=5)) add_to_job = cast(AddToJob, AddToJobFactory(value=15)) assert chain(value_job.s(), add_to_job.s())().get() == value_job.value + add_to_job.value -def test_group(transactional_db: None, celery_worker: WorkController) -> None: +def test_group(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] assert group([job.s() for job in value_jobs])().get() == [1, 2, 3] -def test_chord(transactional_db: None, celery_worker: WorkController) -> None: +def test_chord(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] sum_job = cast(SumJob, SumJobFactory()) assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 6 From a8d55918e9af81df6ef39abb73a87a6bd195cd99 Mon Sep 17 00:00:00 2001 From: Sergey Misuk Date: Tue, 18 Nov 2025 20:59:10 +0400 Subject: [PATCH 7/7] Assert curr_async_result_id and datetime_queued updated --- tests/test_canvas.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_canvas.py b/tests/test_canvas.py index 3a51bd2..e50573f 100644 --- a/tests/test_canvas.py +++ b/tests/test_canvas.py @@ -1,14 +1,14 @@ from typing import cast +import pytest +from _pytest.fixtures import SubRequest from celery import chain, group, chord from celery.worker import WorkController - -from _pytest.fixtures import SubRequest -import pytest from pytest_django.fixtures import SettingsWrapper from demo.factories import AddToJobFactory, SumJobFactory, ValueJobFactory from demo.models import AddToJob, SumJob, ValueJob +from django_celery_boost.models import CeleryTaskModel pytest_plugins = ("celery.contrib.pytest",) @@ -18,23 +18,30 @@ def execution_mode(request: SubRequest, settings: SettingsWrapper) -> None: if request.param: settings.CELERY_TASK_ALWAYS_EAGER = True settings.CELERY_TASK_STORE_EAGER_RESULT = True - else: - settings.CELERY_TASK_ALWAYS_EAGER = False - settings.CELERY_TASK_STORE_EAGER_RESULT = False + + +def assert_job_fields_updated(*jobs: CeleryTaskModel) -> None: + for job in jobs: + job.refresh_from_db() + if job.curr_async_result_id is None or job.datetime_queued is None: + pytest.fail(f"Job {job} was not updated.") def test_chain(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: value_job = cast(ValueJob, ValueJobFactory(value=5)) add_to_job = cast(AddToJob, AddToJobFactory(value=15)) assert chain(value_job.s(), add_to_job.s())().get() == value_job.value + add_to_job.value + assert_job_fields_updated(value_job, add_to_job) def test_group(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] assert group([job.s() for job in value_jobs])().get() == [1, 2, 3] + assert_job_fields_updated(*value_jobs) def test_chord(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] sum_job = cast(SumJob, SumJobFactory()) assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 6 + assert_job_fields_updated(*value_jobs, sum_job)