Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions tests/unit/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def test_includeme(monkeypatch, settings, expected_level):
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
mock.ANY,
mock.ANY,
mock.ANY, # PositionalArgumentsFormatter
mock.ANY, # TimeStamper
mock.ANY, # StackInfoRenderer
structlog.processors.format_exc_info,
wlogging.RENDERER,
],
Expand All @@ -135,6 +136,10 @@ def test_includeme(monkeypatch, settings, expected_level):
)
assert isinstance(
configure.calls[0].kwargs["processors"][4],
structlog.processors.TimeStamper,
)
assert isinstance(
configure.calls[0].kwargs["processors"][5],
structlog.processors.StackInfoRenderer,
)
assert isinstance(
Expand All @@ -144,3 +149,29 @@ def test_includeme(monkeypatch, settings, expected_level):
pretend.call(wlogging._create_id, name="id", reify=True),
pretend.call(wlogging._create_logger, name="log", reify=True),
]


def test_configure_celery_logging(monkeypatch):
configure = pretend.call_recorder(lambda **kw: None)
monkeypatch.setattr(structlog, "configure", configure)

mock_handler = pretend.stub(setFormatter=pretend.call_recorder(lambda f: None))
mock_logger = pretend.stub(
handlers=pretend.stub(clear=pretend.call_recorder(lambda: None)),
setLevel=pretend.call_recorder(lambda level: None),
addHandler=pretend.call_recorder(lambda add_handler: None),
removeHandler=pretend.call_recorder(lambda remove_handler: None),
)
monkeypatch.setattr(logging, "getLogger", lambda: mock_logger)
monkeypatch.setattr(logging, "StreamHandler", lambda: mock_handler)

wlogging.configure_celery_logging()

# Verify handlers cleared and new one added
assert mock_logger.handlers.clear.calls == [pretend.call()]
assert len(mock_logger.addHandler.calls) == 1
assert mock_logger.setLevel.calls == [pretend.call(logging.INFO)]

# Verify processors
processors = configure.calls[0].kwargs["processors"]
assert structlog.contextvars.merge_contextvars in processors
23 changes: 23 additions & 0 deletions tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,26 @@ def test_includeme(env, ssl, broker_redis_url, expected_url, transport_options):
assert config.add_request_method.calls == [
pretend.call(tasks._get_task_from_request, name="task", reify=True)
]


def test_on_after_setup_logger(monkeypatch):
configure_celery_logging = pretend.call_recorder(lambda logfile, loglevel: None)
monkeypatch.setattr(
"warehouse.logging.configure_celery_logging", configure_celery_logging
)

tasks.on_after_setup_logger("logger", "loglevel", "logfile")

assert configure_celery_logging.calls == [pretend.call("logfile", "loglevel")]


def test_on_task_prerun(monkeypatch):
bind_contextvars = pretend.call_recorder(lambda **kw: None)
monkeypatch.setattr("structlog.contextvars.bind_contextvars", bind_contextvars)

task = pretend.stub(name="test.task")
tasks.on_task_prerun(None, "task-123", task)

assert bind_contextvars.calls == [
pretend.call(task_id="task-123", task_name="test.task")
]
34 changes: 34 additions & 0 deletions warehouse/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,39 @@ def _create_id(request):
return str(uuid.uuid4())


def configure_celery_logging(logfile: str | None = None, loglevel: int = logging.INFO):
"""Configure unified structlog logging for Celery that handles all log types."""
processors = [
structlog.contextvars.merge_contextvars,
structlog.processors.TimeStamper(fmt="iso"),
structlog.stdlib.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
]
formatter = structlog.stdlib.ProcessorFormatter(
processor=RENDERER,
foreign_pre_chain=processors, # type: ignore[arg-type]
)

handler = logging.FileHandler(logfile) if logfile else logging.StreamHandler()
handler.setFormatter(formatter)

root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(loglevel)

structlog.configure(
processors=processors # type: ignore[arg-type]
+ [
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
cache_logger_on_first_use=True,
)


def _create_logger(request):
# This has to use **{} instead of just a kwarg because request.id is not
# an allowed kwarg name.
Expand Down Expand Up @@ -88,6 +121,7 @@ def includeme(config):
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
RENDERER,
Expand Down
26 changes: 26 additions & 0 deletions warehouse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
import time
import typing
import urllib.parse
import uuid

import celery
import celery.app.backends
import celery.backends.redis
import pyramid.scripting
import pyramid_retry
import structlog
import transaction
import venusian

from celery import signals
from kombu import Queue
from pyramid.threadlocal import get_current_request

from warehouse.config import Environment
from warehouse.logging import configure_celery_logging
from warehouse.metrics import IMetricsService

if typing.TYPE_CHECKING:
Expand All @@ -37,6 +41,20 @@
logger = logging.getLogger(__name__)


# Celery signal handlers for unified structlog configuration
@signals.after_setup_logger.connect
def on_after_setup_logger(logger, loglevel, logfile, *args, **kwargs):
"""Override Celery's default logging behavior
with unified structlog configuration."""
configure_celery_logging(logfile, loglevel)


@signals.task_prerun.connect
def on_task_prerun(sender, task_id, task, **_):
"""Bind task metadata to contextvars for all logs within the task."""
structlog.contextvars.bind_contextvars(task_id=task_id, task_name=task.name)


class TLSRedisBackend(celery.backends.redis.RedisBackend):
def _params_from_url(self, url, defaults):
params = super()._params_from_url(url, defaults)
Expand Down Expand Up @@ -122,6 +140,10 @@ def get_request(self) -> Request:
env["request"].remote_addr_hashed = hashlib.sha256(
("127.0.0.1" + registry.settings["warehouse.ip_salt"]).encode("utf8")
).hexdigest()
request_id = str(uuid.uuid4())
env["request"].id = request_id
structlog.contextvars.bind_contextvars(**{"request.id": request_id})
env["request"].log = structlog.get_logger("warehouse.request")
self.request.update(pyramid_env=env)

return self.request.pyramid_env["request"] # type: ignore[attr-defined]
Expand Down Expand Up @@ -302,6 +324,10 @@ def includeme(config: Configurator) -> None:
REDBEAT_REDIS_URL=s["celery.scheduler_url"],
# Silence deprecation warning on startup
broker_connection_retry_on_startup=False,
# Disable Celery's logger hijacking for unified structlog control
worker_hijack_root_logger=False,
worker_log_format="%(message)s",
worker_task_log_format="%(message)s",
)
config.registry["celery.app"].Task = WarehouseTask
config.registry["celery.app"].pyramid_config = config
Expand Down
Loading