Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import taskbadger as badger
from taskbadger.systems.celery import CelerySystemIntegration


def _load_config():
Expand All @@ -30,5 +31,6 @@ def _load_config():
os.environ.get("TASKBADGER_ORG", ""),
os.environ.get("TASKBADGER_PROJECT", ""),
os.environ.get("TASKBADGER_API_KEY", ""),
systems=[CelerySystemIntegration()],
)
print(f"\nIntegration tests configuration:\n {badger.mug.Badger.current.settings}\n")
6 changes: 6 additions & 0 deletions integration_tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ def add(self, x, y):
assert self.taskbadger_task is not None, "missing task on self"
self.taskbadger_task.update(value=100, data={"result": x + y})
return x + y


@shared_task(bind=True)
def add_auto_track(self, x, y):
assert self.request.taskbadger_task_id is not None, "missing task ID on self.request"
return x + y
18 changes: 17 additions & 1 deletion integration_tests/test_celery.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import logging
import random

import pytest

from taskbadger import StatusEnum

from .tasks import add
from .tasks import add, add_auto_track


@pytest.fixture(autouse=True)
def check_log_errors(caplog):
yield
for when in ("call", "setup", "teardown"):
errors = [r.getMessage() for r in caplog.get_records(when) if r.levelno == logging.ERROR]
if errors:
pytest.fail(f"log errors during '{when}': {errors}")


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -24,3 +34,9 @@ def test_celery(celery_worker):
assert tb_task.status == StatusEnum.SUCCESS
assert tb_task.value == 100
assert tb_task.data == {"result": a + b}


def test_celery_auto_track(celery_worker):
a, b = random.randint(1, 1000), random.randint(1, 1000)
result = add_auto_track.delay(a, b)
assert result.get(timeout=10, propagate=True) == a + b
96 changes: 84 additions & 12 deletions taskbadger/celery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import functools
import logging

Expand All @@ -18,6 +19,46 @@
log = logging.getLogger("taskbadger")


class Cache:
def __init__(self, maxsize=128):
self.cache = collections.OrderedDict()
self.maxsize = maxsize

def set(self, key, value):
self.cache[key] = value

def unset(self, key):
self.cache.pop(key, None)

def get(self, key):
return self.cache.get(key)

def prune(self):
if len(self.cache) > self.maxsize:
self.cache.popitem(last=False)


def cached(cache_none=True, maxsize=128):
cache = Cache(maxsize=maxsize)

def _wrapper(func):
@functools.wraps(func)
def _inner(*args, **kwargs):
key = args + tuple(sorted(kwargs.items()))
if key in cache.cache:
return cache.get(key)

result = func(*args, **kwargs)
if result is not None or cache_none:
cache.set(key, result)
return result

_inner.cache = cache
return _inner

return _wrapper


class Task(celery.Task):
"""A Celery Task that tracks itself with TaskBadger.

Expand Down Expand Up @@ -89,18 +130,21 @@ def taskbadger_task(self):
task = self.request.get("taskbadger_task")
if not task:
log.debug("Fetching task '%s'", self.taskbadger_task_id)
try:
task = get_task(self.taskbadger_task_id)
task = safe_get_task(self.taskbadger_task_id)
if task:
self.request.update({"taskbadger_task": task})
except Exception:
log.exception("Error fetching task '%s'", self.taskbadger_task_id)
task = None
return task


@before_task_publish.connect
def task_publish_handler(sender=None, headers=None, **kwargs):
if not headers.get("taskbadger_track") or not Badger.is_configured():
if sender.startswith("celery.") or not headers or not Badger.is_configured():
return

celery_system = Badger.current.settings.get_system_by_id("celery")
auto_track = celery_system and celery_system.auto_track_tasks
manual_track = headers.get("taskbadger_track")
if not manual_track and not auto_track:
return

ctask = celery.current_app.tasks.get(sender)
Expand All @@ -112,7 +156,7 @@ def task_publish_handler(sender=None, headers=None, **kwargs):
kwargs[attr.removeprefix(KWARG_PREFIX)] = getattr(ctask, attr)

# get kwargs from the task headers (set via apply_async)
kwargs.update(headers[TB_KWARGS_ARG])
kwargs.update(headers.get(TB_KWARGS_ARG, {}))
kwargs["status"] = StatusEnum.PENDING
name = kwargs.pop("name", headers["task"])

Expand Down Expand Up @@ -147,11 +191,20 @@ def task_retry_handler(sender=None, einfo=None, **kwargs):


def _update_task(signal_sender, status, einfo=None):
log.debug("celery_task_update %s %s", signal_sender, status)
if not hasattr(signal_sender, "taskbadger_task"):
headers = signal_sender.request.headers
if not headers:
return

task_id = headers.get("taskbadger_task_id")
if not task_id:
return

task = signal_sender.taskbadger_task
log.debug("celery_task_update %s %s", signal_sender, status)
if hasattr(signal_sender, "taskbadger_task"):
task = signal_sender.taskbadger_task
else:
task = safe_get_task(task_id)

if task is None:
return

Expand All @@ -164,7 +217,9 @@ def _update_task(signal_sender, status, einfo=None):
data = None
if einfo:
data = DefaultMergeStrategy().merge(task.data, {"exception": str(einfo)})
update_task_safe(task.id, status=status, data=data)
task = update_task_safe(task.id, status=status, data=data)
if task:
safe_get_task.cache.set((task_id,), task)


def enter_session():
Expand All @@ -176,8 +231,25 @@ def enter_session():


def exit_session(signal_sender):
if not hasattr(signal_sender, "taskbadger_task") or not Badger.is_configured():
headers = signal_sender.request.headers
if not headers:
return

task_id = headers.get("taskbadger_task_id")
if not task_id or not Badger.is_configured():
return

safe_get_task.cache.unset((task_id,))
safe_get_task.cache.prune()

session = Badger.current.session()
if session.client:
session.__exit__()


@cached(cache_none=False)
def safe_get_task(task_id: str):
try:
return get_task(task_id)
except Exception:
log.exception("Error fetching task '%s'", task_id)
7 changes: 6 additions & 1 deletion taskbadger/mug.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import dataclasses
from contextlib import ContextDecorator
from contextvars import ContextVar
from typing import Union
from typing import Dict, Union

from taskbadger.internal import AuthenticatedClient
from taskbadger.systems import System

_local = ContextVar("taskbadger_client")

Expand All @@ -14,6 +15,7 @@ class Settings:
token: str
organization_slug: str
project_slug: str
systems: Dict[str, System] = dataclasses.field(default_factory=dict)

def get_client(self):
return AuthenticatedClient(self.base_url, self.token)
Expand All @@ -24,6 +26,9 @@ def as_kwargs(self):
"project_slug": self.project_slug,
}

def get_system_by_id(self, identifier: str) -> System:
return self.systems.get(identifier)

def __str__(self):
return (
f"Settings(base_url='{self.base_url}',"
Expand Down
18 changes: 14 additions & 4 deletions taskbadger/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,36 @@
)
from taskbadger.internal.types import UNSET
from taskbadger.mug import Badger, Session, Settings
from taskbadger.systems import System

_TB_HOST = "https://taskbadger.net"


def init(organization_slug: str = None, project_slug: str = None, token: str = None):
def init(organization_slug: str = None, project_slug: str = None, token: str = None, systems: List[System] = None):
"""Initialize Task Badger client

Call this function once per thread
"""
_init(_TB_HOST, organization_slug, project_slug, token)
_init(_TB_HOST, organization_slug, project_slug, token, systems)


def _init(host: str = None, organization_slug: str = None, project_slug: str = None, token: str = None):
def _init(
host: str = None,
organization_slug: str = None,
project_slug: str = None,
token: str = None,
systems: List[System] = None,
):
host = host or os.environ.get("TASKBADGER_HOST", "https://taskbadger.net")
organization_slug = organization_slug or os.environ.get("TASKBADGER_ORG")
project_slug = project_slug or os.environ.get("TASKBADGER_PROJECT")
token = token or os.environ.get("TASKBADGER_API_KEY")

if host and organization_slug and project_slug and token:
settings = Settings(host, token, organization_slug, project_slug)
systems = systems or []
settings = Settings(
host, token, organization_slug, project_slug, systems={system.identifier: system for system in systems}
)
Badger.current.bind(settings)
else:
raise ConfigurationError(
Expand Down
6 changes: 6 additions & 0 deletions taskbadger/systems/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class System(object):
"""
Baseclass for all systems.
"""

identifier: str = None
13 changes: 13 additions & 0 deletions taskbadger/systems/celery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from taskbadger.systems import System


class CelerySystemIntegration(System):
identifier = "celery"

def __init__(self, auto_track_tasks=True):
"""
Args:
auto_track_tasks: Automatically track all Celery tasks regardless of whether they are using the
`taskbadger.celery.Task` base class.
"""
self.auto_track_tasks = auto_track_tasks
4 changes: 3 additions & 1 deletion tests/test_celery_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def add_error(self, a, b):
with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
"taskbadger.celery.update_task_safe"
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
get_task.return_value = task_for_test()
task = task_for_test()
get_task.return_value = task
update.return_value = task
result = add_error.delay(2, 2)
with pytest.raises(Exception):
result.get(timeout=10, propagate=True)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_celery_system_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Note
====

As part of the Celery fixture setup a 'ping' task is run which executes
before the `bind_settings` fixture is executed. This means that if any code
calls `Badger.is_configured()` (or similar), the `_local` ContextVar in the
Celery runner thread will not have the configuration set.
"""
import logging
from unittest import mock

import pytest

from taskbadger.mug import Badger, Settings
from taskbadger.systems.celery import CelerySystemIntegration
from tests.utils import task_for_test


@pytest.fixture
def bind_settings_with_system():
systems = [CelerySystemIntegration()]
Badger.current.bind(
Settings(
"https://taskbadger.net", "token", "org", "proj", systems={system.identifier: system for system in systems}
)
)
yield
Badger.current.bind(None)


@pytest.fixture(autouse=True)
def check_log_errors(caplog):
yield
errors = [r.getMessage() for r in caplog.get_records("call") if r.levelno == logging.ERROR]
if errors:
pytest.fail(f"log errors during tests: {errors}")


def test_celery_auto_track_task(celery_session_app, celery_session_worker, bind_settings_with_system):
@celery_session_app.task(bind=True)
def add_normal(self, a, b):
assert self.request.get("taskbadger_task_id") is not None, "missing task in request"
assert not hasattr(self, "taskbadger_task")
assert Badger.current.session().client is not None, "missing client"
return a + b

celery_session_worker.reload()

with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
"taskbadger.celery.update_task_safe"
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
tb_task = task_for_test()
create.return_value = tb_task
result = add_normal.delay(2, 2)
assert result.info.get("taskbadger_task_id") == tb_task.id
assert result.get(timeout=10, propagate=True) == 4

create.assert_called_once()
assert get_task.call_count == 1
assert update.call_count == 2
assert Badger.current.session().client is None