diff --git a/taskbadger/celery.py b/taskbadger/celery.py index c3b4fda..565b1b7 100644 --- a/taskbadger/celery.py +++ b/taskbadger/celery.py @@ -13,6 +13,7 @@ KWARG_PREFIX = "taskbadger_" TB_KWARGS_ARG = f"{KWARG_PREFIX}kwargs" IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id"} +TB_TASK_ID = f"{KWARG_PREFIX}task_id" TERMINAL_STATES = {StatusEnum.SUCCESS, StatusEnum.ERROR, StatusEnum.CANCELLED, StatusEnum.STALE} @@ -110,8 +111,8 @@ def apply_async(self, *args, **kwargs): headers[TB_KWARGS_ARG] = tb_kwargs result = super().apply_async(*args, **kwargs) - tb_task_id = result.info.get("taskbadger_task_id") if result.info else None - setattr(result, "taskbadger_task_id", tb_task_id) + tb_task_id = result.info.get(TB_TASK_ID) if result.info else None + setattr(result, TB_TASK_ID, tb_task_id) _get_task = functools.partial(get_task, tb_task_id) if tb_task_id else lambda: None setattr(result, "get_taskbadger_task", _get_task) @@ -120,7 +121,7 @@ def apply_async(self, *args, **kwargs): @property def taskbadger_task_id(self): - return self.request and self.request.headers and self.request.headers.get("taskbadger_task_id") + return _get_taskbadger_task_id(self.request) @property def taskbadger_task(self): @@ -137,8 +138,9 @@ def taskbadger_task(self): @before_task_publish.connect -def task_publish_handler(sender=None, headers=None, **kwargs): - if sender.startswith("celery.") or not headers or not Badger.is_configured(): +def task_publish_handler(sender=None, headers=None, body=None, **kwargs): + headers = headers if "task" in headers else body + if sender.startswith("celery.") or not Badger.is_configured(): return celery_system = Badger.current.settings.get_system_by_id("celery") @@ -162,7 +164,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs): task = create_task_safe(name, **kwargs) if task: - meta = {"taskbadger_task_id": task.id} + meta = {TB_TASK_ID: task.id} headers.update(meta) ctask.update_state(task_id=headers["id"], state="PENDING", meta=meta) @@ -191,11 +193,7 @@ def task_retry_handler(sender=None, einfo=None, **kwargs): def _update_task(signal_sender, status, einfo=None): - headers = signal_sender.request.headers - if not headers: - return - - task_id = headers.get("taskbadger_task_id") + task_id = _get_taskbadger_task_id(signal_sender.request) if not task_id: return @@ -235,7 +233,7 @@ def exit_session(signal_sender): if not headers: return - task_id = headers.get("taskbadger_task_id") + task_id = headers.get(TB_TASK_ID) if not task_id or not Badger.is_configured(): return @@ -253,3 +251,15 @@ def safe_get_task(task_id: str): return get_task(task_id) except Exception: log.exception("Error fetching task '%s'", task_id) + + +def _get_taskbadger_task_id(request): + if not request: + return + + task_id = request.get(TB_TASK_ID) + if task_id: + return task_id + + if request.headers: + return request.headers.get(TB_TASK_ID) diff --git a/taskbadger/systems/celery.py b/taskbadger/systems/celery.py index 1dd6807..966fdec 100644 --- a/taskbadger/systems/celery.py +++ b/taskbadger/systems/celery.py @@ -11,3 +11,6 @@ def __init__(self, auto_track_tasks=True): `taskbadger.celery.Task` base class. """ self.auto_track_tasks = auto_track_tasks + if auto_track_tasks: + # Importing this here ensures that the Celery signal handlers are registered + import taskbadger.celery # noqa