Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ urls.homepage = "https://github.com/unicef/django-celery-boost"

dev = [
"celery[pytest]>=5.4",
"celery-types>=0.23",
"django-webtest>=1.9.13",
"factory-boy>=3.3.1",
"flower>=2.0.1",
"mypy>=1.11.2",
"pdbpp>=0.10.3",
"pre-commit>=4.3",
"psycopg>=2.9.9",
"psycopg-binary>=3.2.2",
"pytest>=8.3.3",
Expand Down
42 changes: 36 additions & 6 deletions src/django_celery_boost/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Callable, Generator

import sentry_sdk
from celery import states
from celery import states, Signature
from celery.app.base import Celery
from concurrency.api import concurrency_disable_increment
from concurrency.fields import AutoIncVersionField
Expand All @@ -19,6 +19,7 @@
from django.utils.translation import gettext as _

from django_celery_boost.signals import task_queued, task_revoked, task_terminated
from django_celery_boost.task import TaskRunFromSignature

if TYPE_CHECKING:
import celery.app.control
Expand All @@ -29,6 +30,17 @@
logger = logging.getLogger(__name__)


APP_LABEL = "app_label"
MODEL_NAME = "model_name"


class InvalidTaskBase(TypeError):
def __init__(self, task_handler_name: str):
super().__init__(
f"{task_handler_name} must be a TaskRunFromSignature instance. Use base argument with shared_task or app.task decorator."
)


class CeleryManager:
pass

Expand Down Expand Up @@ -335,18 +347,21 @@ def task_status(self) -> str:
except Exception as e: # noqa
return str(e)

def set_queued(self, result: AsyncResult) -> None:
with concurrency_disable_increment(self):
self.curr_async_result_id = result.id
self.datetime_queued = timezone.now()
self.save(update_fields=["curr_async_result_id", "datetime_queued"])
task_queued.send(sender=self.__class__, task=self)

def queue(self, use_version: bool = True) -> str | None:
"""Queue the record processing.

use_version: if True the task fails if the record is changed after it has been queued.
"""
if self.task_status not in self.ACTIVE_STATUSES:
res = self.task_handler.delay(self.pk, self.version if use_version else None)
with concurrency_disable_increment(self):
self.curr_async_result_id = res.id
self.datetime_queued = timezone.now()
self.save(update_fields=["curr_async_result_id", "datetime_queued"])
task_queued.send(sender=self.__class__, task=self)
self.set_queued(res)
return self.curr_async_result_id
return None

Expand Down Expand Up @@ -388,6 +403,21 @@ def terminate(self, wait=False, timeout=None) -> str:
task_terminated.send(sender=self.__class__, task=self)
return st

def signature(self) -> Signature:
if not isinstance(self.task_handler, TaskRunFromSignature):
raise InvalidTaskBase(self.celery_task_name)

return self.task_handler.signature(
(self.pk, self.version),
{
APP_LABEL: self._meta.app_label,
MODEL_NAME: self._meta.model_name,
},
)

def s(self) -> Signature:
return self.signature()

@classmethod
def discard_all(cls: "type[CeleryTaskModel]") -> None:
cls.celery_app.control.discard_all()
Expand Down
42 changes: 42 additions & 0 deletions src/django_celery_boost/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Protocol, cast

from celery import Task
from celery.result import EagerResult, AsyncResult
from django.apps import apps


class ApplyCallable(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> AsyncResult: ...


def _apply(apply_method: ApplyCallable, *args: Any, **kwargs: Any) -> AsyncResult:
from django_celery_boost.models import APP_LABEL, MODEL_NAME, CeleryTaskModel

task_args = args[0]
pk, version = task_args[-2], task_args[-1]
task_kwargs = args[1]
app_label, model_name = task_kwargs[APP_LABEL], task_kwargs[MODEL_NAME]

model_class = cast(type[CeleryTaskModel], apps.get_model(app_label, model_name))
model = model_class.objects.get(pk=pk, version=version)

# we want pk and version to always come first
new_task_args = task_args[-2:] + task_args[:-2]
# We can possibly check whether the task is in an active state, if it's
# required, it can be done by passing some flag in kwargs and raising an
# exception in the overridden run method. Other options could require much
# more work
new_task_kwargs: dict[str, Any] = {}

new_args: tuple[Any, ...] = (new_task_args, new_task_kwargs) + args[2:]
result = apply_method(*new_args, **kwargs)
model.set_queued(result)
return result


class TaskRunFromSignature(Task):
def apply(self, *args: Any, **kwargs: Any) -> EagerResult:
return cast(EagerResult, _apply(super().apply, *args, **kwargs))

def apply_async(self, *args: Any, **kwargs: Any) -> AsyncResult:
return _apply(super().apply_async, *args, **kwargs)
26 changes: 25 additions & 1 deletion tests/demoapp/demo/factories.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

import factory
from demo.models import Job, MultipleJob
from demo.models import Job, MultipleJob, AddToJob, SumJob, ValueJob
from django.contrib.auth.models import Group, Permission, User
from factory.django import DjangoModelFactory
from factory.faker import Faker
Expand Down Expand Up @@ -82,3 +82,27 @@ def start(self):
def stop(self):
"""Stop an active patch."""
return self.__exit__()


class ValueJobFactory(DjangoModelFactory):
curr_async_result_id = None
last_async_result_id = None

class Meta:
model = ValueJob


class AddToJobFactory(DjangoModelFactory):
curr_async_result_id = None
last_async_result_id = None

class Meta:
model = AddToJob


class SumJobFactory(DjangoModelFactory):
curr_async_result_id = None
last_async_result_id = None

class Meta:
model = SumJob
207 changes: 207 additions & 0 deletions tests/demoapp/demo/migrations/0002_addtojob_sumjob_valuejob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Generated by Django 5.2.8 on 2025-11-14 11:06

import concurrency.fields
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("demo", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]

operations = [
migrations.CreateModel(
name="AddToJob",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")),
("description", models.CharField(blank=True, max_length=255, null=True)),
(
"curr_async_result_id",
models.CharField(
blank=True,
editable=False,
help_text="Current (active) AsyncResult is",
max_length=36,
null=True,
),
),
(
"last_async_result_id",
models.CharField(
blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True
),
),
("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")),
(
"datetime_queued",
models.DateTimeField(
blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At"
),
),
(
"repeatable",
models.BooleanField(
blank=True, default=False, help_text="Indicate if the job can be repeated as-is"
),
),
("celery_history", models.JSONField(blank=True, default=dict, editable=False)),
("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)),
(
"group_key",
models.CharField(
blank=True,
editable=False,
help_text="Tasks with the same group key will not run in parallel",
max_length=255,
null=True,
),
),
("value", models.IntegerField(default=0)),
(
"owner",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="%(app_label)s_%(class)s_jobs",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
"default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"),
},
),
migrations.CreateModel(
name="SumJob",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")),
("description", models.CharField(blank=True, max_length=255, null=True)),
(
"curr_async_result_id",
models.CharField(
blank=True,
editable=False,
help_text="Current (active) AsyncResult is",
max_length=36,
null=True,
),
),
(
"last_async_result_id",
models.CharField(
blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True
),
),
("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")),
(
"datetime_queued",
models.DateTimeField(
blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At"
),
),
(
"repeatable",
models.BooleanField(
blank=True, default=False, help_text="Indicate if the job can be repeated as-is"
),
),
("celery_history", models.JSONField(blank=True, default=dict, editable=False)),
("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)),
(
"group_key",
models.CharField(
blank=True,
editable=False,
help_text="Tasks with the same group key will not run in parallel",
max_length=255,
null=True,
),
),
(
"owner",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="%(app_label)s_%(class)s_jobs",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
"default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"),
},
),
migrations.CreateModel(
name="ValueJob",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("version", concurrency.fields.AutoIncVersionField(default=0, help_text="record revision number")),
("description", models.CharField(blank=True, max_length=255, null=True)),
(
"curr_async_result_id",
models.CharField(
blank=True,
editable=False,
help_text="Current (active) AsyncResult is",
max_length=36,
null=True,
),
),
(
"last_async_result_id",
models.CharField(
blank=True, editable=False, help_text="Latest executed AsyncResult is", max_length=36, null=True
),
),
("datetime_created", models.DateTimeField(auto_now_add=True, help_text="Creation date and time")),
(
"datetime_queued",
models.DateTimeField(
blank=True, help_text="Queueing date and time", null=True, verbose_name="Queued At"
),
),
(
"repeatable",
models.BooleanField(
blank=True, default=False, help_text="Indicate if the job can be repeated as-is"
),
),
("celery_history", models.JSONField(blank=True, default=dict, editable=False)),
("local_status", models.CharField(blank=True, default="", editable=False, max_length=100, null=True)),
(
"group_key",
models.CharField(
blank=True,
editable=False,
help_text="Tasks with the same group key will not run in parallel",
max_length=255,
null=True,
),
),
("value", models.IntegerField(default=0)),
(
"owner",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="%(app_label)s_%(class)s_jobs",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
"default_permissions": ("add", "change", "delete", "view", "queue", "terminate", "inspect", "revoke"),
},
),
]
Loading
Loading