Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 144 additions & 2 deletions taskiq/middlewares/opentelemetry_middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from collections.abc import Generator
from contextlib import AbstractContextManager
from datetime import datetime, timezone
from importlib.metadata import version
from typing import Any, TypeVar

import psutil
from packaging.version import Version, parse

try:
Expand All @@ -16,7 +19,7 @@

from opentelemetry import context as context_api
from opentelemetry import trace
from opentelemetry.metrics import Meter, MeterProvider, get_meter
from opentelemetry.metrics import Meter, MeterProvider, Observation, get_meter
from opentelemetry.propagate import extract, inject
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, Tracer, TracerProvider
Expand Down Expand Up @@ -59,6 +62,9 @@
_TASK_RETRY_REASON_KEY = "taskiq.retry.reason"
_TASK_NAME_KEY = "taskiq.task_name"

_TASK_QUEUE_TIME_KEY = "_taskiq_queue_time"
_TASK_RECEIVED_TIME_KEY = "_taskiq_broker_receive_time"


def set_attributes_from_context(span: Span, context: dict[str, Any]) -> None:
"""Helper to extract meta values from a Taskiq Context."""
Expand Down Expand Up @@ -170,6 +176,74 @@ def __init__(
if meter is None
else meter
)
# Create metrics
# 1- Number of tasks sent. Producer (Counter)
self.n_tasks_sent_counter = self._meter.create_counter(
name="tasks_sent",
unit="1",
description="Number of tasks sent from the producer side",
)
# 2- Number of errors by task name. consumer (Counter)
self.n_errors_counter = self._meter.create_counter(
name="task_errors",
unit="1",
description="Number of errors raised",
)
# 3- Number of task successes. consumer (Counter)
self.n_success_counter = self._meter.create_counter(
name="task_success",
unit="1",
description="Number of tasks completed successfully",
)
# 4- Task execution time. consumer (Histogram)
self.execution_time_hist = self._meter.create_histogram(
"task_execution_time",
unit="s",
description="Time to finish executing tasks",
)
# 5- Task wait time. both (Histogram)
self.task_wait_time = self._meter.create_histogram(
"task_wait_time",
unit="s",
description="Time the tasks waited before executing",
)
# current metrics to watch for in workers: CPU and memory utilization
self._process = psutil.Process()
# 6- CPU utilization
self.worker_cpu_utilization = self._meter.create_observable_gauge(
"worker_cpu_utilization",
callbacks=[self._observe_cpu],
unit="%",
description="Worker CPU utilization percentage. Only for worker processes",
)
# 7- Memory utilization
self.worker_memory_utilization = self._meter.create_observable_gauge(
"worker_memory_utilization",
callbacks=[self._observe_memory],
unit="By",
description="Worker memory utilization in bytes. Only for worker processes",
)

# 8- Number of tasks executing
self.number_of_broker_active_tasks = self._meter.create_up_down_counter(
"worker_active_tasks",
unit="1",
description="Number of tasks currently executing in the worker.",
)
# 9- Number of tasks executing
self.number_of_broker_prefetched_tasks = self._meter.create_up_down_counter(
"worker_prefetched_tasks",
unit="1",
description="Number of tasks currently prefetched in the worker.",
)

def _observe_memory(self, options: Any) -> Generator[Observation, None, None]:
if self.broker and self.broker.is_worker_process:
yield Observation(self._process.memory_info().rss)

def _observe_cpu(self, options: Any) -> Generator[Observation, None, None]:
if self.broker and self.broker.is_worker_process:
yield Observation(self._process.cpu_percent())

def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
"""
Expand All @@ -193,7 +267,7 @@ def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
activation.__enter__()
attach_context(message, span, activation, None, is_publish=True)
inject(message.labels)

message.labels[_TASK_QUEUE_TIME_KEY] = datetime.now(timezone.utc).timestamp()
return message

def post_send(self, message: TaskiqMessage) -> None:
Expand All @@ -214,6 +288,7 @@ def post_send(self, message: TaskiqMessage) -> None:

activation.__exit__(None, None, None)
detach_context(message, is_publish=True)
self.n_tasks_sent_counter.add(1, attributes={"task_name": message.task_name})

def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
"""
Expand All @@ -236,6 +311,11 @@ def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101
attach_context(message, span, activation, token)
message.labels[_TASK_RECEIVED_TIME_KEY] = datetime.now(timezone.utc).timestamp()
self.number_of_broker_active_tasks.add(
1,
attributes={"task_name": message.task_name},
)
return message

def post_save( # pylint: disable=R6301
Expand Down Expand Up @@ -313,3 +393,65 @@ def on_error(
}
span.record_exception(exception)
span.set_status(Status(**status_kwargs)) # type: ignore[arg-type]

def post_execute(
self,
message: "TaskiqMessage",
result: "TaskiqResult[Any]",
) -> None:
"""
This function tracks number of errors and success executions.

:param message: received message.
:param result: result of the execution.
"""
if result.is_err:
retry_on_error = message.labels.get("retry_on_error")
if isinstance(retry_on_error, str):
retry_on_error = retry_on_error.lower() == "true"

if retry_on_error is None:
retry_on_error = False

if retry_on_error:
# Add retry reason metadata to span
self.n_errors_counter.add(
1,
attributes={"retry_error": True, "task_name": message.task_name},
)
else:
self.n_errors_counter.add(
1,
attributes={"retry_error": False, "task_name": message.task_name},
)
else:
self.n_success_counter.add(
1,
attributes={"task_name": message.task_name},
)
self.execution_time_hist.record(
result.execution_time,
attributes={
"task_name": message.task_name,
},
)
task_receive_time = message.labels.get(_TASK_RECEIVED_TIME_KEY)
task_send_time = message.labels.get(_TASK_QUEUE_TIME_KEY)
if task_receive_time is not None and task_send_time is not None:
self.task_wait_time.record(
amount=task_receive_time - task_send_time,
attributes={"task_name": message.task_name},
)

self.number_of_broker_active_tasks.add(
-1,
attributes={"task_name": message.task_name},
)

def on_prefetch_queue_add(self) -> None:
"""This hook is called after task is added to the worker prefetch queue."""
self.number_of_broker_prefetched_tasks.add(1)

def on_prefetch_queue_remove(self) -> None:
"""This hook is called after task is removed from the worker prefetch queue."""
self.number_of_broker_prefetched_tasks.add(-1)
13 changes: 13 additions & 0 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ async def prefetcher(
current_message = asyncio.create_task(iterator.__anext__()) # type: ignore
fetched_tasks += 1
await queue.put(message)
# Custom hooks for OTel and any future instrumentations
for middleware in reversed(self.broker.middlewares):
if hasattr(middleware, "on_prefetch_queue_add"):
await maybe_awaitable(
middleware.on_prefetch_queue_add(), # type: ignore
)
except (asyncio.CancelledError, StopAsyncIteration):
break
# We don't want to fetch new messages if we are shutting down.
Expand Down Expand Up @@ -434,6 +440,13 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
logger.info("No more tasks to wait for. Shutting down.")
break

# Custom hooks for OTel and any future instrumentations
for middleware in reversed(self.broker.middlewares):
if hasattr(middleware, "on_prefetch_queue_remove"):
await maybe_awaitable(
middleware.on_prefetch_queue_remove(), # type: ignore
)

task = asyncio.create_task(
self.callback(message=message, raise_err=False),
)
Expand Down
6 changes: 6 additions & 0 deletions tests/opentelemetry/taskiq_test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any

from opentelemetry import baggage
Expand Down Expand Up @@ -26,3 +27,8 @@ async def task_raises() -> None:
@broker.task
async def task_returns_baggage() -> Any:
return dict(baggage.get_all())


@broker.task
async def task_does_processing(wait_time: float) -> None:
await asyncio.sleep(wait_time)
Loading