Skip to content

Commit

Permalink
Merge pull request #74 from stephenhillier/response-labels
Browse files Browse the repository at this point in the history
from_response_label helper function
  • Loading branch information
stephenhillier committed May 29, 2024
2 parents 12d33ca + acfbf8e commit dc88ee9
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 38 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,30 @@ values are constrained (see [this writeup from Grafana on cardinality](https://g

### Label helpers

**`from_header(key: string, allowed_values: Optional[Iterable])`**: a convenience function for using a header value as a label.
**`from_header(key: string, allowed_values: Optional[Iterable] = None, default: str = "")`**: a convenience function for using a header value as a label.

`allowed_values` allows you to supply a list of allowed values. If supplied, header values not in the list will result in
an empty string being returned. This allows you to constrain the label values, reducing the risk of excessive cardinality.

`default`: the default value if the header does not exist.

Do not use headers that could contain unconstrained values (e.g. user id) or user-supplied values.


**`from_response_header(key: str, allowed_values: Optional[Iterable] = None, default: str = "")`**: a helper
function that extracts a value from a response header. This may be useful if you are using a middleware
or decorator that populates a header.

The same options (and warnings) apply as the `from_header` function.

```python
from starlette_exporter import PrometheusMiddleware, from_header

app.add_middleware(
PrometheusMiddleware,
labels={
"host": from_header("X-Internal-Org", allowed_values=("accounting", "marketing", "product"))
"cache": from_response_header("X-FastAPI-Cache", allowed_values=("hit", "miss"))
}
```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="starlette_exporter",
version="0.21.0",
version="0.22.0",
author="Stephen Hillier",
author_email="stephenhillier@gmail.com",
packages=["starlette_exporter"],
Expand Down
3 changes: 2 additions & 1 deletion starlette_exporter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
'PrometheusMiddleware',
'from_header',
'from_response_header',
'handle_metrics',
]

Expand All @@ -20,7 +21,7 @@
from starlette.responses import Response

from .middleware import PrometheusMiddleware
from .labels import from_header
from .labels import from_header, from_response_header


def handle_metrics(request: Request) -> Response:
Expand Down
31 changes: 30 additions & 1 deletion starlette_exporter/labels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
"""utilities for working with labels"""
from typing import Callable, Iterable, Optional
from typing import Any, Callable, Iterable, Optional, Dict

from starlette.requests import Request
from starlette.types import Message


class ResponseHeaderLabel:
"""ResponseHeaderLabel is a special class that allows populating a label
value based on response headers. starlette_exporter will recognize that this
should not be called until response headers are written."""

def __init__(
self, key: str, allowed_values: Optional[Iterable] = None, default: str = ""
) -> None:
self.key = key
self.default = default
self.allowed_values = allowed_values

def __call__(self, headers: Dict) -> Any:
v = headers.get(self.key, self.default)
if self.allowed_values is not None and v not in self.allowed_values:
return self.default
return v


def from_header(key: str, allowed_values: Optional[Iterable] = None) -> Callable:
Expand Down Expand Up @@ -36,3 +56,12 @@ def inner(r: Request):
return v

return inner


def from_response_header(
key: str, allowed_values: Optional[Iterable] = None, default: str = ""
):
"""returns a callable class that retrieves a header value from response headers.
starlette_exporter will automatically populate this label value when response headers
are written."""
return ResponseHeaderLabel(key, allowed_values, default)
116 changes: 85 additions & 31 deletions starlette_exporter/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from starlette.routing import BaseRoute, Match
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from starlette_exporter.labels import ResponseHeaderLabel

from . import optional_metrics

logger = logging.getLogger("starlette_exporter")
Expand Down Expand Up @@ -114,7 +116,6 @@ def __init__(
self.optional_metrics_list = optional_metrics
self.always_use_int_status = always_use_int_status

self.labels = OrderedDict(labels) if labels is not None else None
self.exemplars = exemplars
self._exemplars_req_kw = ""

Expand All @@ -132,7 +133,22 @@ def __init__(
if "request" in exemplar_sig.parameters:
self._exemplars_req_kw = "request"


# split labels into request and response labels.
# response labels will be evaluated while the response is
# written.
self.request_labels = OrderedDict({})
self.response_labels: OrderedDict[str, ResponseHeaderLabel] = OrderedDict({})

if labels is not None:
for k, v in labels.items():
if isinstance(v, ResponseHeaderLabel):
self.response_labels[k] = v
else:
self.request_labels[k] = v

# Default metrics
# Starlette initialises middleware multiple times, so store metrics on the class

@property
def request_count(self):
Expand All @@ -146,7 +162,8 @@ def request_count(self):
"path",
"status_code",
"app_name",
*self._default_label_keys(),
*self.request_labels.keys(),
*self.response_labels.keys(),
),
)
return PrometheusMiddleware._metrics[metric_name]
Expand Down Expand Up @@ -174,7 +191,8 @@ def response_body_size_count(self):
"path",
"status_code",
"app_name",
*self._default_label_keys(),
*self.request_labels.keys(),
*self.response_labels.keys(),
),
)
return PrometheusMiddleware._metrics[metric_name]
Expand All @@ -200,7 +218,8 @@ def request_body_size_count(self):
"path",
"status_code",
"app_name",
*self._default_label_keys(),
*self.request_labels.keys(),
*self.response_labels.keys(),
),
)
return PrometheusMiddleware._metrics[metric_name]
Expand All @@ -219,7 +238,8 @@ def request_time(self):
"path",
"status_code",
"app_name",
*self._default_label_keys(),
*self.request_labels.keys(),
*self.response_labels.keys(),
),
**self.kwargs,
)
Expand All @@ -232,23 +252,15 @@ def requests_in_progress(self):
PrometheusMiddleware._metrics[metric_name] = Gauge(
metric_name,
"Total HTTP requests currently in progress",
("method", "app_name", *self._default_label_keys()),
("method", "app_name", *self.request_labels.keys()),
multiprocess_mode="livesum",
)
return PrometheusMiddleware._metrics[metric_name]

def _default_label_keys(self) -> List[str]:
if self.labels is None:
return []
return list(self.labels.keys())

async def _default_label_values(self, request: Request):
if self.labels is None:
return []

async def _request_label_values(self, request: Request) -> List[str]:
values: List[str] = []

for k, v in self.labels.items():
for k, v in self.request_labels.items():
if callable(v):
parsed_value = ""
# if provided a callable, try to use it on the request.
Expand All @@ -267,6 +279,34 @@ async def _default_label_values(self, request: Request):

return values

def _response_label_values(self, message: Message) -> List[str]:
values: List[str] = []

# bail if no response labels were defined by the user
if not self.response_labels:
return values

# create a dict of headers to make it easy to find keys
headers = {
k.decode("utf-8"): v.decode("utf-8")
for (k, v) in message.get("headers", ())
}

for k, v in self.response_labels.items():
# currently only ResponseHeaderLabel supported
if isinstance(v, ResponseHeaderLabel):
parsed_value = ""
try:
result = v(headers)
except Exception:
logger.warn(f"label function for {k} failed", exc_info=True)
else:
parsed_value = str(result)
values.append(parsed_value)


return values

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http"]:
await self.app(scope, receive, send)
Expand All @@ -288,13 +328,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
begin = time.perf_counter()
end = None

default_labels = await self._default_label_values(request)
request_labels = await self._request_label_values(request)

# Increment requests_in_progress gauge when request comes in
self.requests_in_progress.labels(method, self.app_name, *default_labels).inc()
self.requests_in_progress.labels(method, self.app_name, *request_labels).inc()

status_code = None

# custom response label values, to be populated when response is written.
response_labels = []

# optional request and response body size metrics
response_body_size: int = 0

Expand All @@ -311,10 +354,13 @@ async def wrapped_send(message: Message) -> None:
nonlocal status_code
status_code = message["status"]

nonlocal response_labels
response_labels = self._response_label_values(message)

if self.always_use_int_status:
try:
status_code = int(message["status"])
except ValueError as e:
except ValueError:
logger.warning(
f"always_use_int_status flag selected but failed to convert status_code to int for value: {status_code}"
)
Expand Down Expand Up @@ -348,11 +394,18 @@ async def wrapped_send(message: Message) -> None:
except Exception as e:
status_code = 500
exception = e
finally:
# Decrement 'requests_in_progress' gauge after response sent
self.requests_in_progress.labels(
method, self.app_name, *request_labels
).dec()

# Decrement 'requests_in_progress' gauge after response sent
self.requests_in_progress.labels(
method, self.app_name, *default_labels
).dec()
if status_code is None:
if await request.is_disconnected():
# In case no response was returned and the client is disconnected, 499 is reported as status code.
status_code = 499
else:
status_code = 500

if self.filter_unhandled_paths or self.group_paths:
grouped_path: Optional[str] = None
Expand All @@ -369,20 +422,21 @@ async def wrapped_send(message: Message) -> None:
raise exception
return


# group_paths enables returning the original router path (with url param names)
# for example, when using this option, requests to /api/product/1 and /api/product/3
# will both be grouped under /api/product/{product_id}. See the README for more info.
if self.group_paths and grouped_path is not None:
path = grouped_path

if status_code is None:
if await request.is_disconnected():
# In case no response was returned and the client is disconnected, 499 is reported as status code.
status_code = 499
else:
status_code = 500

labels = [method, path, status_code, self.app_name, *default_labels]
labels = [
method,
path,
status_code,
self.app_name,
*request_labels,
*response_labels,
]

# optional extra arguments to be passed as kwargs to observations
# note: only used for histogram observations and counters to support exemplars
Expand Down
Loading

0 comments on commit dc88ee9

Please sign in to comment.