Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ID generator for celery extension #51

Merged
merged 9 commits into from Sep 29, 2022
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Expand Up @@ -68,4 +68,4 @@ jobs:
with:
file: ./coverage.xml
fail_ci_if_error: true
if: matrix.python-version == '3.10'
if: matrix.python-version == '3.10.6'
9 changes: 9 additions & 0 deletions README.md
Expand Up @@ -484,6 +484,15 @@ load_correlation_ids()
+ load_celery_current_and_parent_ids()
```

If you wish to correlate celery task IDs through the IDs found in your broker (i.e., the celery `task_id`), use the `use_internal_celery_task_id` argument on `load_celery_current_and_parent_ids`
```diff
from asgi_correlation_id.extensions.celery import load_correlation_ids, load_celery_current_and_parent_ids

load_correlation_ids()
+ load_celery_current_and_parent_ids(use_internal_celery_task_id=True)
```
Note: `load_celery_current_and_parent_ids` will ignore the `generator` argument when `use_internal_celery_task_id` is set to `True`

To set up the additional log filters, update your log config like this:

```diff
Expand Down
20 changes: 13 additions & 7 deletions asgi_correlation_id/extensions/celery.py
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict
from uuid import uuid4

from celery.signals import before_task_publish, task_postrun, task_prerun
Expand All @@ -8,8 +8,10 @@
if TYPE_CHECKING:
from celery import Task

uuid_hex_generator: Callable[[], str] = lambda: uuid4().hex

def load_correlation_ids() -> None:

def load_correlation_ids(header_key: str = 'CORRELATION_ID', generator: Callable[[], str] = uuid_hex_generator) -> None:
"""
Transfer correlation IDs from a HTTP request to a Celery worker,
when spawned from a request.
Expand All @@ -18,7 +20,6 @@ def load_correlation_ids() -> None:
"""
from asgi_correlation_id.context import correlation_id

header_key = 'CORRELATION_ID'
sentry_extension = get_sentry_extension()

@before_task_publish.connect(weak=False)
Expand Down Expand Up @@ -46,7 +47,7 @@ def load_correlation_id(task: 'Task', **kwargs: Any) -> None:
correlation_id.set(id_value)
sentry_extension(id_value)
else:
generated_correlation_id = uuid4().hex
generated_correlation_id = generator()
correlation_id.set(generated_correlation_id)
sentry_extension(generated_correlation_id)

Expand All @@ -61,7 +62,11 @@ def cleanup(**kwargs: Any) -> None:
correlation_id.set(None)


def load_celery_current_and_parent_ids(header_key: str = 'CELERY_PARENT_ID') -> None:
def load_celery_current_and_parent_ids(
header_key: str = 'CELERY_PARENT_ID',
generator: Callable[[], str] = uuid_hex_generator,
use_internal_celery_task_id: bool = False,
) -> None:
"""
Configure Celery event hooks for generating tracing IDs with depth.

Expand All @@ -83,15 +88,16 @@ def publish_task_from_worker_or_request(headers: Dict[str, str], **kwargs: Any)
headers[header_key] = current

@task_prerun.connect(weak=False)
def worker_prerun(task: 'Task', **kwargs: Any) -> None:
def worker_prerun(task_id: str, task: 'Task', **kwargs: Any) -> None:
"""
Set current ID, and parent ID if it exists.
"""
parent_id = task.request.get(header_key)
if parent_id:
celery_parent_id.set(parent_id)

celery_current_id.set(uuid4().hex)
celery_id = task_id if use_internal_celery_task_id else generator()
celery_current_id.set(celery_id)

@task_postrun.connect(weak=False)
def clean_up(**kwargs: Any) -> None:
Expand Down
15 changes: 8 additions & 7 deletions asgi_correlation_id/log_filters.py
Expand Up @@ -7,6 +7,10 @@
from logging import LogRecord


def _trim_string(string: Optional[str], string_length: Optional[int]) -> Optional[str]:
return string[:string_length] if string_length is not None and string else string


# Middleware


Expand All @@ -27,18 +31,15 @@ def filter(self, record: 'LogRecord') -> bool:
metadata.
"""
cid = correlation_id.get()
if self.uuid_length is not None and cid:
record.correlation_id = cid[: self.uuid_length]
else:
record.correlation_id = cid
record.correlation_id = _trim_string(cid, self.uuid_length)
return True


# Celery extension


class CeleryTracingIdsFilter(Filter):
def __init__(self, name: str = '', uuid_length: int = 32):
def __init__(self, name: str = '', uuid_length: Optional[int] = None):
super().__init__(name=name)
self.uuid_length = uuid_length

Expand All @@ -52,7 +53,7 @@ def filter(self, record: 'LogRecord') -> bool:
or from an endpoint, the parent ID will be None.
"""
pid = celery_parent_id.get()
record.celery_parent_id = pid[: self.uuid_length] if pid else pid
record.celery_parent_id = _trim_string(pid, self.uuid_length)
cid = celery_current_id.get()
record.celery_current_id = cid[: self.uuid_length] if cid else cid
record.celery_current_id = _trim_string(cid, self.uuid_length)
return True
61 changes: 55 additions & 6 deletions tests/test_log_filter.py
Expand Up @@ -18,24 +18,23 @@ def cid():
@pytest.fixture()
def log_record():
"""Create and return an INFO-level log record"""
record = LogRecord(name='', level=INFO, pathname='', lineno=0, msg='Hello, world!', args=(), exc_info=None)
return record
return LogRecord(name='', level=INFO, pathname='', lineno=0, msg='Hello, world!', args=(), exc_info=None)


def test_filter_has_uuid_length_attributes():
filter_ = CorrelationIdFilter(uuid_length=8)
assert filter_.uuid_length == 8


def test_filter_adds_correlation_id(cid, log_record):
def test_filter_adds_correlation_id(cid: str, log_record: LogRecord):
filter_ = CorrelationIdFilter()

assert not hasattr(log_record, 'correlation_id')
filter_.filter(log_record)
assert log_record.correlation_id == cid


def test_filter_truncates_correlation_id(cid, log_record):
def test_filter_truncates_correlation_id(cid: str, log_record: LogRecord):
filter_ = CorrelationIdFilter(uuid_length=8)

assert not hasattr(log_record, 'correlation_id')
Expand All @@ -49,7 +48,7 @@ def test_celery_filter_has_uuid_length_attributes():
assert filter_.uuid_length == 8


def test_celery_filter_adds_parent_id(cid, log_record):
def test_celery_filter_adds_parent_id(cid: str, log_record: LogRecord):
filter_ = CeleryTracingIdsFilter()
celery_parent_id.set('a')

Expand All @@ -58,10 +57,60 @@ def test_celery_filter_adds_parent_id(cid, log_record):
assert log_record.celery_parent_id == 'a'


def test_celery_filter_adds_current_id(cid, log_record):
def test_celery_filter_adds_current_id(cid: str, log_record: LogRecord):
filter_ = CeleryTracingIdsFilter()
celery_current_id.set('b')

assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == 'b'


@pytest.mark.parametrize(
('uuid_length', 'expected'),
[
(6, 6),
(16, 16),
(None, 36),
(38, 36),
],
)
def test_celery_filter_truncates_current_id_correctly(cid: str, log_record: LogRecord, uuid_length, expected):
"""
If uuid is unspecified, the default should be 36.

Otherwise, the id should be truncated to the specified length.
"""
filter_ = CeleryTracingIdsFilter(uuid_length=uuid_length)
celery_id = str(uuid4())
celery_current_id.set(celery_id)

assert not hasattr(log_record, 'celery_current_id')
filter_.filter(log_record)
assert log_record.celery_current_id == celery_id[:expected]


def test_celery_filter_maintains_current_behavior(cid: str, log_record: LogRecord):
"""Maintain default behavior with signature change

Since the default values of CeleryTracingIdsFilter are being changed,
the new default values should also not trim a hex uuid.
"""
celery_id = uuid4().hex
celery_current_id.set(celery_id)
new_filter = CeleryTracingIdsFilter()

assert not hasattr(log_record, 'celery_current_id')
new_filter.filter(log_record)
assert log_record.celery_current_id == celery_id
new_filter_record_id = log_record.celery_current_id

del log_record.celery_current_id

original_filter = CeleryTracingIdsFilter(uuid_length=32)
assert not hasattr(log_record, 'celery_current_id')
original_filter.filter(log_record)
assert log_record.celery_current_id == celery_id
original_filter_record_id = log_record.celery_current_id

assert original_filter_record_id == new_filter_record_id