diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a8403d4..f2fcdb9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,16 +48,16 @@ jobs: id: poetry-cache with: path: ~/.local - key: key-0 + key: ${{ matrix.python-version }}-1 - uses: snok/install-poetry@v1 with: virtualenvs-create: false - version: 1.1.11 + version: 1.1.12 - uses: actions/cache@v2 id: cache-venv with: path: .venv - key: ${{ hashFiles('**/poetry.lock') }}-${{ matrix.python-version }}-0 + key: ${{ hashFiles('**/poetry.lock') }}-${{ matrix.python-version }}-1 - run: | python -m venv .venv source .venv/bin/activate diff --git a/asgi_correlation_id/log_filters.py b/asgi_correlation_id/log_filters.py index 717d105..c46e036 100644 --- a/asgi_correlation_id/log_filters.py +++ b/asgi_correlation_id/log_filters.py @@ -1,12 +1,12 @@ from logging import Filter, LogRecord -from typing import Type +from typing import Optional, Type from asgi_correlation_id.context import celery_current_id, celery_parent_id, correlation_id # Middleware -def correlation_id_filter(uuid_length: int = 32) -> Type[Filter]: +def correlation_id_filter(uuid_length: Optional[int] = None) -> Type[Filter]: class CorrelationId(Filter): def filter(self, record: LogRecord) -> bool: """ @@ -18,7 +18,10 @@ def filter(self, record: LogRecord) -> bool: metadata. """ cid = correlation_id.get() - record.correlation_id = cid[:uuid_length] if cid else cid # type: ignore[attr-defined] + if uuid_length is not None and cid: + record.correlation_id = cid[:uuid_length] # type: ignore[attr-defined] + else: + record.correlation_id = cid # type: ignore[attr-defined] return True return CorrelationId diff --git a/tests/test_log_filter.py b/tests/test_log_filter.py new file mode 100644 index 0000000..381a549 --- /dev/null +++ b/tests/test_log_filter.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock +from uuid import uuid4 + +from asgi_correlation_id import correlation_id_filter +from asgi_correlation_id.context import correlation_id + + +def test_correlation_id_filter(): + mock_record = Mock() + + cid = uuid4().hex + correlation_id.set(cid) + + # Call with no uuid length + correlation_id_filter(None)().filter(mock_record) + assert mock_record.correlation_id == cid + + # Call with uuid length + for length in [0, 14, 30, 100]: + correlation_id_filter(uuid_length=length)().filter(mock_record) + assert mock_record.correlation_id == cid[:length] diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 7a97be2..0136a77 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -82,7 +82,7 @@ async def test_websocket_request(caplog): @app.get('/access-control-expose-headers') -async def access_control_view() -> dict: +async def access_control_view() -> Response: return Response(status_code=204, headers={'Access-Control-Expose-Headers': 'test1, test2'})