diff --git a/pyproject.toml b/pyproject.toml index 8c73ee1..adc0450 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,11 +39,13 @@ urls.homepage = "https://github.com/unicef/django-celery-boost" dev = [ "celery[pytest]>=5.4", + "celery-types>=0.23", "django-webtest>=1.9.13", "factory-boy>=3.3.1", "flower>=2.0.1", "mypy>=1.11.2", "pdbpp>=0.10.3", + "pre-commit>=4.3", "psycopg>=2.9.9", "psycopg-binary>=3.2.2", "pytest>=8.3.3", diff --git a/src/django_celery_boost/models.py b/src/django_celery_boost/models.py index e345510..795aab1 100644 --- a/src/django_celery_boost/models.py +++ b/src/django_celery_boost/models.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generator import sentry_sdk -from celery import states +from celery import states, Signature from celery.app.base import Celery from concurrency.api import concurrency_disable_increment from concurrency.fields import AutoIncVersionField @@ -19,6 +19,7 @@ from django.utils.translation import gettext as _ from django_celery_boost.signals import task_queued, task_revoked, task_terminated +from django_celery_boost.task import TaskRunFromSignature if TYPE_CHECKING: import celery.app.control @@ -29,6 +30,17 @@ logger = logging.getLogger(__name__) +APP_LABEL = "app_label" +MODEL_NAME = "model_name" + + +class InvalidTaskBase(TypeError): + def __init__(self, task_handler_name: str): + super().__init__( + f"{task_handler_name} must be a TaskRunFromSignature instance. Use base argument with shared_task or app.task decorator." + ) + + class CeleryManager: pass @@ -335,6 +347,13 @@ def task_status(self) -> str: except Exception as e: # noqa return str(e) + def set_queued(self, result: AsyncResult) -> None: + with concurrency_disable_increment(self): + self.curr_async_result_id = result.id + self.datetime_queued = timezone.now() + self.save(update_fields=["curr_async_result_id", "datetime_queued"]) + task_queued.send(sender=self.__class__, task=self) + def queue(self, use_version: bool = True) -> str | None: """Queue the record processing. @@ -342,11 +361,7 @@ def queue(self, use_version: bool = True) -> str | None: """ if self.task_status not in self.ACTIVE_STATUSES: res = self.task_handler.delay(self.pk, self.version if use_version else None) - with concurrency_disable_increment(self): - self.curr_async_result_id = res.id - self.datetime_queued = timezone.now() - self.save(update_fields=["curr_async_result_id", "datetime_queued"]) - task_queued.send(sender=self.__class__, task=self) + self.set_queued(res) return self.curr_async_result_id return None @@ -388,6 +403,21 @@ def terminate(self, wait=False, timeout=None) -> str: task_terminated.send(sender=self.__class__, task=self) return st + def signature(self) -> Signature: + if not isinstance(self.task_handler, TaskRunFromSignature): + raise InvalidTaskBase(self.celery_task_name) + + return self.task_handler.signature( + (self.pk, self.version), + { + APP_LABEL: self._meta.app_label, + MODEL_NAME: self._meta.model_name, + }, + ) + + def s(self) -> Signature: + return self.signature() + @classmethod def discard_all(cls: "type[CeleryTaskModel]") -> None: cls.celery_app.control.discard_all() diff --git a/src/django_celery_boost/task.py b/src/django_celery_boost/task.py new file mode 100644 index 0000000..240487d --- /dev/null +++ b/src/django_celery_boost/task.py @@ -0,0 +1,42 @@ +from typing import Any, Protocol, cast + +from celery import Task +from celery.result import EagerResult, AsyncResult +from django.apps import apps + + +class ApplyCallable(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> AsyncResult: ... + + +def _apply(apply_method: ApplyCallable, *args: Any, **kwargs: Any) -> AsyncResult: + from django_celery_boost.models import APP_LABEL, MODEL_NAME, CeleryTaskModel + + task_args = args[0] + pk, version = task_args[-2], task_args[-1] + task_kwargs = args[1] + app_label, model_name = task_kwargs[APP_LABEL], task_kwargs[MODEL_NAME] + + model_class = cast(type[CeleryTaskModel], apps.get_model(app_label, model_name)) + model = model_class.objects.get(pk=pk, version=version) + + # we want pk and version to always come first + new_task_args = task_args[-2:] + task_args[:-2] + # We can possibly check whether the task is in an active state, if it's + # required, it can be done by passing some flag in kwargs and raising an + # exception in the overridden run method. Other options could require much + # more work + new_task_kwargs: dict[str, Any] = {} + + new_args: tuple[Any, ...] = (new_task_args, new_task_kwargs) + args[2:] + result = apply_method(*new_args, **kwargs) + model.set_queued(result) + return result + + +class TaskRunFromSignature(Task): + def apply(self, *args: Any, **kwargs: Any) -> EagerResult: + return cast(EagerResult, _apply(super().apply, *args, **kwargs)) + + def apply_async(self, *args: Any, **kwargs: Any) -> AsyncResult: + return _apply(super().apply_async, *args, **kwargs) diff --git a/tests/demoapp/demo/factories.py b/tests/demoapp/demo/factories.py index 3a60a34..1b96330 100644 --- a/tests/demoapp/demo/factories.py +++ b/tests/demoapp/demo/factories.py @@ -1,7 +1,7 @@ from typing import Any, Optional import factory -from demo.models import Job, MultipleJob +from demo.models import Job, MultipleJob, AddToJob, SumJob, ValueJob from django.contrib.auth.models import Group, Permission, User from factory.django import DjangoModelFactory from factory.faker import Faker @@ -82,3 +82,27 @@ def start(self): def stop(self): """Stop an active patch.""" return self.__exit__() + + +class ValueJobFactory(DjangoModelFactory): + curr_async_result_id = None + last_async_result_id = None + + class Meta: + model = ValueJob + + +class AddToJobFactory(DjangoModelFactory): + curr_async_result_id = None + last_async_result_id = None + + class Meta: + model = AddToJob + + +class SumJobFactory(DjangoModelFactory): + curr_async_result_id = None + last_async_result_id = None + + class Meta: + model = SumJob diff --git a/tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py b/tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py new file mode 100644 index 0000000..e835f58 --- /dev/null +++ b/tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py @@ -0,0 +1,207 @@ +# Generated by Django 5.2.8 on 2025-11-14 11:06 + +import concurrency.fields +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("demo", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="AddToJob", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), + ("description", models.CharField(blank=True, max_length=255, null=True)), + ( + "curr_async_result_id", + models.CharField( + blank=True, + editable=False, + help_text="Current (active) AsyncResult is", + max_length=36, + null=True, + ), + ), + ( + "last_async_result_id", + models.CharField( + blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True + ), + ), + ("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")), + ( + "datetime_queued", + models.DateTimeField( + blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At" + ), + ), + ( + "repeatable", + models.BooleanField( + blank=True, default=False, help_text="Indicate if the job can be repeated as-is" + ), + ), + ("celery_history", models.JSONField(blank=True, default=dict, editable=False)), + ("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)), + ( + "group_key", + models.CharField( + blank=True, + editable=False, + help_text="Tasks with the same group key will not run in parallel", + max_length=255, + null=True, + ), + ), + ("value", models.IntegerField(default=0)), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(app_label)s_%(class)s_jobs", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + "default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"), + }, + ), + migrations.CreateModel( + name="SumJob", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), + ("description", models.CharField(blank=True, max_length=255, null=True)), + ( + "curr_async_result_id", + models.CharField( + blank=True, + editable=False, + help_text="Current (active) AsyncResult is", + max_length=36, + null=True, + ), + ), + ( + "last_async_result_id", + models.CharField( + blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True + ), + ), + ("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")), + ( + "datetime_queued", + models.DateTimeField( + blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At" + ), + ), + ( + "repeatable", + models.BooleanField( + blank=True, default=False, help_text="Indicate if the job can be repeated as-is" + ), + ), + ("celery_history", models.JSONField(blank=True, default=dict, editable=False)), + ("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)), + ( + "group_key", + models.CharField( + blank=True, + editable=False, + help_text="Tasks with the same group key will not run in parallel", + max_length=255, + null=True, + ), + ), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(app_label)s_%(class)s_jobs", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + "default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"), + }, + ), + migrations.CreateModel( + name="ValueJob", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")), + ("description", models.CharField(blank=True, max_length=255, null=True)), + ( + "curr_async_result_id", + models.CharField( + blank=True, + editable=False, + help_text="Current (active) AsyncResult is", + max_length=36, + null=True, + ), + ), + ( + "last_async_result_id", + models.CharField( + blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True + ), + ), + ("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")), + ( + "datetime_queued", + models.DateTimeField( + blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At" + ), + ), + ( + "repeatable", + models.BooleanField( + blank=True, default=False, help_text="Indicate if the job can be repeated as-is" + ), + ), + ("celery_history", models.JSONField(blank=True, default=dict, editable=False)), + ("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)), + ( + "group_key", + models.CharField( + blank=True, + editable=False, + help_text="Tasks with the same group key will not run in parallel", + max_length=255, + null=True, + ), + ), + ("value", models.IntegerField(default=0)), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(app_label)s_%(class)s_jobs", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + "default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"), + }, + ), + ] diff --git a/tests/demoapp/demo/models.py b/tests/demoapp/demo/models.py index 050edb9..56ee476 100644 --- a/tests/demoapp/demo/models.py +++ b/tests/demoapp/demo/models.py @@ -44,3 +44,19 @@ class Meta(AsyncJobModel.Meta): permissions = (("test_multiplejob", "Can test MultipleJob"),) celery_task_name = "demo.tasks.process_job" + + +class ValueJob(CeleryTaskModel, models.Model): + value = models.IntegerField(default=0) + + celery_task_name = "demo.tasks.value" + + +class AddToJob(CeleryTaskModel, models.Model): + value = models.IntegerField(default=0) + + celery_task_name = "demo.tasks.add_to" + + +class SumJob(CeleryTaskModel, models.Model): + celery_task_name = "demo.tasks.sum_" diff --git a/tests/demoapp/demo/tasks.py b/tests/demoapp/demo/tasks.py index bf37278..f7cd43c 100644 --- a/tests/demoapp/demo/tasks.py +++ b/tests/demoapp/demo/tasks.py @@ -4,6 +4,8 @@ from concurrency.exceptions import RecordModifiedError from django.core.cache import cache +from django_celery_boost.task import TaskRunFromSignature + @shared_task(bind=True) def process_job(self, pk, version=None): @@ -52,3 +54,24 @@ def cache_store(key, value): @shared_task() def raise_task(param): raise Exception("Boom!") + + +@shared_task(base=TaskRunFromSignature) +def value(pk: int, version: int) -> int: + from .models import ValueJob + + job = ValueJob.objects.get(pk=pk, version=version) + return job.value + + +@shared_task(base=TaskRunFromSignature) +def add_to(pk: int, version: int, value: int) -> int: + from .models import AddToJob + + job = AddToJob.objects.get(pk=pk, version=version) + return job.value + value + + +@shared_task(base=TaskRunFromSignature) +def sum_(_: int, __: int, values: list[int]) -> int: + return sum(values) diff --git a/tests/test_canvas.py b/tests/test_canvas.py new file mode 100644 index 0000000..e50573f --- /dev/null +++ b/tests/test_canvas.py @@ -0,0 +1,47 @@ +from typing import cast + +import pytest +from _pytest.fixtures import SubRequest +from celery import chain, group, chord +from celery.worker import WorkController +from pytest_django.fixtures import SettingsWrapper + +from demo.factories import AddToJobFactory, SumJobFactory, ValueJobFactory +from demo.models import AddToJob, SumJob, ValueJob +from django_celery_boost.models import CeleryTaskModel + +pytest_plugins = ("celery.contrib.pytest",) + + +@pytest.fixture(params=[True, False], ids=["eager", "async"]) +def execution_mode(request: SubRequest, settings: SettingsWrapper) -> None: + if request.param: + settings.CELERY_TASK_ALWAYS_EAGER = True + settings.CELERY_TASK_STORE_EAGER_RESULT = True + + +def assert_job_fields_updated(*jobs: CeleryTaskModel) -> None: + for job in jobs: + job.refresh_from_db() + if job.curr_async_result_id is None or job.datetime_queued is None: + pytest.fail(f"Job {job} was not updated.") + + +def test_chain(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: + value_job = cast(ValueJob, ValueJobFactory(value=5)) + add_to_job = cast(AddToJob, AddToJobFactory(value=15)) + assert chain(value_job.s(), add_to_job.s())().get() == value_job.value + add_to_job.value + assert_job_fields_updated(value_job, add_to_job) + + +def test_group(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: + value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] + assert group([job.s() for job in value_jobs])().get() == [1, 2, 3] + assert_job_fields_updated(*value_jobs) + + +def test_chord(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: + value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] + sum_job = cast(SumJob, SumJobFactory()) + assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 6 + assert_job_fields_updated(*value_jobs, sum_job) diff --git a/tests/test_signature.py b/tests/test_signature.py new file mode 100644 index 0000000..6f7e61d --- /dev/null +++ b/tests/test_signature.py @@ -0,0 +1,28 @@ +from typing import cast + +import pytest + +from demo.factories import AddToJobFactory, JobFactory +from demo.models import AddToJob, Job +from django_celery_boost.models import InvalidTaskBase, APP_LABEL, MODEL_NAME + +pytestmark = [pytest.mark.django_db] + + +def test_can_create_signature_using_different_methods() -> None: + job = cast(AddToJob, AddToJobFactory()) + assert job.signature() == job.s() + + +def test_signature_parameters() -> None: + job = cast(AddToJob, AddToJobFactory()) + signature = job.signature() + assert signature.task == job.celery_task_name + assert signature.args == (job.id, job.version) + assert signature.kwargs == {APP_LABEL: AddToJob._meta.app_label, MODEL_NAME: AddToJob._meta.model_name} + + +def test_cannot_create_signature_without_correct_base_task() -> None: + with pytest.raises(InvalidTaskBase): + job = cast(Job, JobFactory()) + job.s()