From b5885ffa1647e55491e16022d2f6f3ac1175a3c2 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 1 Nov 2023 14:39:29 +0200 Subject: [PATCH 1/2] auto track celery tasks --- integration_tests/__init__.py | 2 + integration_tests/tasks.py | 6 ++ integration_tests/test_celery.py | 18 ++++- taskbadger/celery.py | 96 +++++++++++++++++++++---- taskbadger/mug.py | 7 +- taskbadger/sdk.py | 18 +++-- taskbadger/systems/__init__.py | 6 ++ taskbadger/systems/celery.py | 13 ++++ tests/test_celery_error.py | 4 +- tests/test_celery_system_integration.py | 62 ++++++++++++++++ 10 files changed, 213 insertions(+), 19 deletions(-) create mode 100644 taskbadger/systems/__init__.py create mode 100644 taskbadger/systems/celery.py create mode 100644 tests/test_celery_system_integration.py diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py index a8ff2fb..a4863a7 100644 --- a/integration_tests/__init__.py +++ b/integration_tests/__init__.py @@ -4,6 +4,7 @@ import pytest import taskbadger as badger +from taskbadger.systems.celery import CelerySystemIntegration def _load_config(): @@ -30,5 +31,6 @@ def _load_config(): os.environ.get("TASKBADGER_ORG", ""), os.environ.get("TASKBADGER_PROJECT", ""), os.environ.get("TASKBADGER_API_KEY", ""), + systems=[CelerySystemIntegration()], ) print(f"\nIntegration tests configuration:\n {badger.mug.Badger.current.settings}\n") diff --git a/integration_tests/tasks.py b/integration_tests/tasks.py index 9ae69b1..c1cd9e0 100644 --- a/integration_tests/tasks.py +++ b/integration_tests/tasks.py @@ -8,3 +8,9 @@ def add(self, x, y): assert self.taskbadger_task is not None, "missing task on self" self.taskbadger_task.update(value=100, data={"result": x + y}) return x + y + + +@shared_task(bind=True) +def add_auto_track(self, x, y): + assert self.request.taskbadger_task_id is not None, "missing task ID on self.request" + return x + y diff --git a/integration_tests/test_celery.py b/integration_tests/test_celery.py index 628fd81..23776ab 100644 --- a/integration_tests/test_celery.py +++ b/integration_tests/test_celery.py @@ -1,10 +1,20 @@ +import logging import random import pytest from taskbadger import StatusEnum -from .tasks import add +from .tasks import add, add_auto_track + + +@pytest.fixture(autouse=True) +def check_log_errors(caplog): + yield + for when in ("call", "setup", "teardown"): + errors = [r.getMessage() for r in caplog.get_records(when) if r.levelno == logging.ERROR] + if errors: + pytest.fail(f"log errors during '{when}': {errors}") @pytest.fixture(scope="session", autouse=True) @@ -24,3 +34,9 @@ def test_celery(celery_worker): assert tb_task.status == StatusEnum.SUCCESS assert tb_task.value == 100 assert tb_task.data == {"result": a + b} + + +def test_celery_auto_track(celery_worker): + a, b = random.randint(1, 1000), random.randint(1, 1000) + result = add_auto_track.delay(a, b) + assert result.get(timeout=10, propagate=True) == a + b diff --git a/taskbadger/celery.py b/taskbadger/celery.py index cf5ddf4..d4dbff7 100644 --- a/taskbadger/celery.py +++ b/taskbadger/celery.py @@ -1,3 +1,4 @@ +import collections import functools import logging @@ -18,6 +19,46 @@ log = logging.getLogger("taskbadger") +class Cache: + def __init__(self, maxsize=128): + self.cache = collections.OrderedDict() + self.maxsize = maxsize + + def set(self, key, value): + self.cache[key] = value + + def unset(self, key): + self.cache.pop(key, None) + + def get(self, key): + return self.cache.get(key) + + def prune(self): + if len(self.cache) > self.maxsize: + self.cache.popitem(last=False) + + +def cached(cache_none=True, maxsize=128): + cache = Cache(maxsize=maxsize) + + def _wrapper(func): + @functools.wraps(func) + def _inner(*args, **kwargs): + key = args + tuple(sorted(kwargs.items())) + if key in cache.cache: + return cache.get(key) + + result = func(*args, **kwargs) + if result is not None or cache_none: + cache.set(key, result) + return result + + _inner.cache = cache + return _inner + + return _wrapper + + class Task(celery.Task): """A Celery Task that tracks itself with TaskBadger. @@ -89,18 +130,21 @@ def taskbadger_task(self): task = self.request.get("taskbadger_task") if not task: log.debug("Fetching task '%s'", self.taskbadger_task_id) - try: - task = get_task(self.taskbadger_task_id) + task = safe_get_task(self.taskbadger_task_id) + if task: self.request.update({"taskbadger_task": task}) - except Exception: - log.exception("Error fetching task '%s'", self.taskbadger_task_id) - task = None return task @before_task_publish.connect def task_publish_handler(sender=None, headers=None, **kwargs): - if not headers.get("taskbadger_track") or not Badger.is_configured(): + if sender.startswith("celery.") or not Badger.is_configured(): + return + + celery_system = Badger.current.settings.get_system_by_id("celery") + auto_track = celery_system and celery_system.auto_track_tasks + manual_track = headers.get("taskbadger_track") + if not manual_track and not auto_track: return ctask = celery.current_app.tasks.get(sender) @@ -112,7 +156,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs): kwargs[attr.removeprefix(KWARG_PREFIX)] = getattr(ctask, attr) # get kwargs from the task headers (set via apply_async) - kwargs.update(headers[TB_KWARGS_ARG]) + kwargs.update(headers.get(TB_KWARGS_ARG, {})) kwargs["status"] = StatusEnum.PENDING name = kwargs.pop("name", headers["task"]) @@ -147,11 +191,20 @@ def task_retry_handler(sender=None, einfo=None, **kwargs): def _update_task(signal_sender, status, einfo=None): - log.debug("celery_task_update %s %s", signal_sender, status) - if not hasattr(signal_sender, "taskbadger_task"): + headers = signal_sender.request.headers + if not headers: + return + + task_id = headers.get("taskbadger_task_id") + if not task_id: return - task = signal_sender.taskbadger_task + log.debug("celery_task_update %s %s", signal_sender, status) + if hasattr(signal_sender, "taskbadger_task"): + task = signal_sender.taskbadger_task + else: + task = safe_get_task(task_id) + if task is None: return @@ -164,7 +217,9 @@ def _update_task(signal_sender, status, einfo=None): data = None if einfo: data = DefaultMergeStrategy().merge(task.data, {"exception": str(einfo)}) - update_task_safe(task.id, status=status, data=data) + task = update_task_safe(task.id, status=status, data=data) + if task: + safe_get_task.cache.set((task_id,), task) def enter_session(): @@ -176,8 +231,25 @@ def enter_session(): def exit_session(signal_sender): - if not hasattr(signal_sender, "taskbadger_task") or not Badger.is_configured(): + headers = signal_sender.request.headers + if not headers: return + + task_id = headers.get("taskbadger_task_id") + if not task_id or not Badger.is_configured(): + return + + safe_get_task.cache.unset((task_id,)) + safe_get_task.cache.prune() + session = Badger.current.session() if session.client: session.__exit__() + + +@cached(cache_none=False) +def safe_get_task(task_id: str): + try: + return get_task(task_id) + except Exception: + log.exception("Error fetching task '%s'", task_id) diff --git a/taskbadger/mug.py b/taskbadger/mug.py index 4ab22a2..d965016 100644 --- a/taskbadger/mug.py +++ b/taskbadger/mug.py @@ -1,9 +1,10 @@ import dataclasses from contextlib import ContextDecorator from contextvars import ContextVar -from typing import Union +from typing import Dict, Union from taskbadger.internal import AuthenticatedClient +from taskbadger.systems import System _local = ContextVar("taskbadger_client") @@ -14,6 +15,7 @@ class Settings: token: str organization_slug: str project_slug: str + systems: Dict[str, System] = dataclasses.field(default_factory=dict) def get_client(self): return AuthenticatedClient(self.base_url, self.token) @@ -24,6 +26,9 @@ def as_kwargs(self): "project_slug": self.project_slug, } + def get_system_by_id(self, identifier: str) -> System: + return self.systems.get(identifier) + def __str__(self): return ( f"Settings(base_url='{self.base_url}'," diff --git a/taskbadger/sdk.py b/taskbadger/sdk.py index 5cd5754..e264651 100644 --- a/taskbadger/sdk.py +++ b/taskbadger/sdk.py @@ -13,26 +13,36 @@ ) from taskbadger.internal.types import UNSET from taskbadger.mug import Badger, Session, Settings +from taskbadger.systems import System _TB_HOST = "https://taskbadger.net" -def init(organization_slug: str = None, project_slug: str = None, token: str = None): +def init(organization_slug: str = None, project_slug: str = None, token: str = None, systems: List[System] = None): """Initialize Task Badger client Call this function once per thread """ - _init(_TB_HOST, organization_slug, project_slug, token) + _init(_TB_HOST, organization_slug, project_slug, token, systems) -def _init(host: str = None, organization_slug: str = None, project_slug: str = None, token: str = None): +def _init( + host: str = None, + organization_slug: str = None, + project_slug: str = None, + token: str = None, + systems: List[System] = None, +): host = host or os.environ.get("TASKBADGER_HOST", "https://taskbadger.net") organization_slug = organization_slug or os.environ.get("TASKBADGER_ORG") project_slug = project_slug or os.environ.get("TASKBADGER_PROJECT") token = token or os.environ.get("TASKBADGER_API_KEY") if host and organization_slug and project_slug and token: - settings = Settings(host, token, organization_slug, project_slug) + systems = systems or [] + settings = Settings( + host, token, organization_slug, project_slug, systems={system.identifier: system for system in systems} + ) Badger.current.bind(settings) else: raise ConfigurationError( diff --git a/taskbadger/systems/__init__.py b/taskbadger/systems/__init__.py new file mode 100644 index 0000000..3cb3b54 --- /dev/null +++ b/taskbadger/systems/__init__.py @@ -0,0 +1,6 @@ +class System(object): + """ + Baseclass for all systems. + """ + + identifier: str = None diff --git a/taskbadger/systems/celery.py b/taskbadger/systems/celery.py new file mode 100644 index 0000000..1dd6807 --- /dev/null +++ b/taskbadger/systems/celery.py @@ -0,0 +1,13 @@ +from taskbadger.systems import System + + +class CelerySystemIntegration(System): + identifier = "celery" + + def __init__(self, auto_track_tasks=True): + """ + Args: + auto_track_tasks: Automatically track all Celery tasks regardless of whether they are using the + `taskbadger.celery.Task` base class. + """ + self.auto_track_tasks = auto_track_tasks diff --git a/tests/test_celery_error.py b/tests/test_celery_error.py index ce72de7..906d136 100644 --- a/tests/test_celery_error.py +++ b/tests/test_celery_error.py @@ -20,7 +20,9 @@ def add_error(self, a, b): with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch( "taskbadger.celery.update_task_safe" ) as update, mock.patch("taskbadger.celery.get_task") as get_task: - get_task.return_value = task_for_test() + task = task_for_test() + get_task.return_value = task + update.return_value = task result = add_error.delay(2, 2) with pytest.raises(Exception): result.get(timeout=10, propagate=True) diff --git a/tests/test_celery_system_integration.py b/tests/test_celery_system_integration.py new file mode 100644 index 0000000..8e43df5 --- /dev/null +++ b/tests/test_celery_system_integration.py @@ -0,0 +1,62 @@ +""" +Note +==== + +As part of the Celery fixture setup a 'ping' task is run which executes +before the `bind_settings` fixture is executed. This means that if any code +calls `Badger.is_configured()` (or similar), the `_local` ContextVar in the +Celery runner thread will not have the configuration set. +""" +import logging +from unittest import mock + +import pytest + +from taskbadger.mug import Badger, Settings +from taskbadger.systems.celery import CelerySystemIntegration +from tests.utils import task_for_test + + +@pytest.fixture +def bind_settings_with_system(): + systems = [CelerySystemIntegration()] + Badger.current.bind( + Settings( + "https://taskbadger.net", "token", "org", "proj", systems={system.identifier: system for system in systems} + ) + ) + yield + Badger.current.bind(None) + + +@pytest.fixture(autouse=True) +def check_log_errors(caplog): + yield + errors = [r.getMessage() for r in caplog.get_records("call") if r.levelno == logging.ERROR] + if errors: + pytest.fail(f"log errors during tests: {errors}") + + +def test_celery_auto_track_task(celery_session_app, celery_session_worker, bind_settings_with_system): + @celery_session_app.task(bind=True) + def add_normal(self, a, b): + assert self.request.get("taskbadger_task_id") is not None, "missing task in request" + assert not hasattr(self, "taskbadger_task") + assert Badger.current.session().client is not None, "missing client" + return a + b + + celery_session_worker.reload() + + with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch( + "taskbadger.celery.update_task_safe" + ) as update, mock.patch("taskbadger.celery.get_task") as get_task: + tb_task = task_for_test() + create.return_value = tb_task + result = add_normal.delay(2, 2) + assert result.info.get("taskbadger_task_id") == tb_task.id + assert result.get(timeout=10, propagate=True) == 4 + + create.assert_called_once() + assert get_task.call_count == 1 + assert update.call_count == 2 + assert Badger.current.session().client is None From 8303ca87de78de8779e7ea45e28bfd2502268cc1 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 1 Nov 2023 14:42:40 +0200 Subject: [PATCH 2/2] check headers --- taskbadger/celery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taskbadger/celery.py b/taskbadger/celery.py index d4dbff7..c3b4fda 100644 --- a/taskbadger/celery.py +++ b/taskbadger/celery.py @@ -138,7 +138,7 @@ def taskbadger_task(self): @before_task_publish.connect def task_publish_handler(sender=None, headers=None, **kwargs): - if sender.startswith("celery.") or not Badger.is_configured(): + if sender.startswith("celery.") or not headers or not Badger.is_configured(): return celery_system = Badger.current.settings.get_system_by_id("celery")