Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions taskbadger/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
KWARG_PREFIX = "taskbadger_"
TB_KWARGS_ARG = f"{KWARG_PREFIX}kwargs"
IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id"}
TB_TASK_ID = f"{KWARG_PREFIX}task_id"

TERMINAL_STATES = {StatusEnum.SUCCESS, StatusEnum.ERROR, StatusEnum.CANCELLED, StatusEnum.STALE}

Expand Down Expand Up @@ -110,8 +111,8 @@ def apply_async(self, *args, **kwargs):
headers[TB_KWARGS_ARG] = tb_kwargs
result = super().apply_async(*args, **kwargs)

tb_task_id = result.info.get("taskbadger_task_id") if result.info else None
setattr(result, "taskbadger_task_id", tb_task_id)
tb_task_id = result.info.get(TB_TASK_ID) if result.info else None
setattr(result, TB_TASK_ID, tb_task_id)

_get_task = functools.partial(get_task, tb_task_id) if tb_task_id else lambda: None
setattr(result, "get_taskbadger_task", _get_task)
Expand All @@ -120,7 +121,7 @@ def apply_async(self, *args, **kwargs):

@property
def taskbadger_task_id(self):
return self.request and self.request.headers and self.request.headers.get("taskbadger_task_id")
return _get_taskbadger_task_id(self.request)

@property
def taskbadger_task(self):
Expand All @@ -137,8 +138,9 @@ def taskbadger_task(self):


@before_task_publish.connect
def task_publish_handler(sender=None, headers=None, **kwargs):
if sender.startswith("celery.") or not headers or not Badger.is_configured():
def task_publish_handler(sender=None, headers=None, body=None, **kwargs):
headers = headers if "task" in headers else body
if sender.startswith("celery.") or not Badger.is_configured():
return

celery_system = Badger.current.settings.get_system_by_id("celery")
Expand All @@ -162,7 +164,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs):

task = create_task_safe(name, **kwargs)
if task:
meta = {"taskbadger_task_id": task.id}
meta = {TB_TASK_ID: task.id}
headers.update(meta)
ctask.update_state(task_id=headers["id"], state="PENDING", meta=meta)

Expand Down Expand Up @@ -191,11 +193,7 @@ def task_retry_handler(sender=None, einfo=None, **kwargs):


def _update_task(signal_sender, status, einfo=None):
headers = signal_sender.request.headers
if not headers:
return

task_id = headers.get("taskbadger_task_id")
task_id = _get_taskbadger_task_id(signal_sender.request)
if not task_id:
return

Expand Down Expand Up @@ -235,7 +233,7 @@ def exit_session(signal_sender):
if not headers:
return

task_id = headers.get("taskbadger_task_id")
task_id = headers.get(TB_TASK_ID)
if not task_id or not Badger.is_configured():
return

Expand All @@ -253,3 +251,15 @@ def safe_get_task(task_id: str):
return get_task(task_id)
except Exception:
log.exception("Error fetching task '%s'", task_id)


def _get_taskbadger_task_id(request):
if not request:
return

task_id = request.get(TB_TASK_ID)
if task_id:
return task_id

if request.headers:
return request.headers.get(TB_TASK_ID)
3 changes: 3 additions & 0 deletions taskbadger/systems/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ def __init__(self, auto_track_tasks=True):
`taskbadger.celery.Task` base class.
"""
self.auto_track_tasks = auto_track_tasks
if auto_track_tasks:
# Importing this here ensures that the Celery signal handlers are registered
import taskbadger.celery # noqa