diff --git a/starlette_exporter/labels.py b/starlette_exporter/labels.py index 1701638..4bdcda3 100644 --- a/starlette_exporter/labels.py +++ b/starlette_exporter/labels.py @@ -2,7 +2,6 @@ from typing import Any, Callable, Iterable, Optional, Dict from starlette.requests import Request -from starlette.types import Message class ResponseHeaderLabel: @@ -13,7 +12,7 @@ class ResponseHeaderLabel: def __init__( self, key: str, allowed_values: Optional[Iterable] = None, default: str = "" ) -> None: - self.key = key + self.key = key.lower() self.default = default self.allowed_values = allowed_values diff --git a/starlette_exporter/middleware.py b/starlette_exporter/middleware.py index 5a6ad01..5a68761 100644 --- a/starlette_exporter/middleware.py +++ b/starlette_exporter/middleware.py @@ -300,7 +300,7 @@ def _response_label_values(self, message: Message) -> List[str]: # create a dict of headers to make it easy to find keys headers = { - k.decode("utf-8"): v.decode("utf-8") + k.decode("utf-8").lower(): v.decode("utf-8") for (k, v) in message.get("headers", ()) } @@ -405,6 +405,10 @@ async def wrapped_send(message: Message) -> None: await self.app(scope, receive, wrapped_send) except Exception as e: status_code = 500 + + # during an unhandled exception, populate response labels with empty strings. + response_labels = self._response_label_values({}) + exception = e finally: # Decrement 'requests_in_progress' gauge after response sent diff --git a/tests/test_middleware.py b/tests/test_middleware.py index bc25401..8364c6e 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -67,7 +67,7 @@ def httpstatus_response(request): ) async def error(request): - raise HTTPException(status_code=500, detail="this is a test error") + raise HTTPException(status_code=500, detail="this is a test error", headers={"foo":"baz"}) app.add_route("/500", error) app.add_route("/500/{test_param}", error) @@ -804,6 +804,46 @@ def test_from_response_header(self, testapp): in metrics ), metrics + def test_from_response_header_case_insensitive(self, testapp): + """test with the library-provided from_response_header function with a capitalized header key.""" + labels = {"foo": from_response_header("Foo"), "hello": "world"} + client = TestClient(testapp(labels=labels)) + client.get("/200") + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",foo="baz",hello="world",method="GET",path="/200",status_code="200"} 1.0""" + in metrics + ), metrics + + def test_from_response_header_http_exception(self, testapp): + """test from_response_header against an endpoint that raises an HTTPException""" + labels = {"foo": from_response_header("foo"), "hello": "world"} + client = TestClient(testapp(labels=labels)) + client.get("/500") + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",foo="baz",hello="world",method="GET",path="/500",status_code="500"} 1.0""" + in metrics + ), metrics + + def test_from_response_header_unhandled_exception(self, testapp): + """test from_response_header function against an endpoint that raises an unhandled exception""" + labels = {"foo": from_response_header("foo"), "hello": "world"} + client = TestClient(testapp(labels=labels)) + + # make the test call. This raises an error but will still populate metrics. + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled") + + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",foo="",hello="world",method="GET",path="/unhandled",status_code="500"} 1.0""" + in metrics + ), metrics + def test_from_response_header_default(self, testapp): """test with the library-provided from_response_header function, with a missing header (testing the default value)"""