diff --git a/starlette_exporter/middleware.py b/starlette_exporter/middleware.py index 7750c98..8dfbab3 100644 --- a/starlette_exporter/middleware.py +++ b/starlette_exporter/middleware.py @@ -1,8 +1,8 @@ """ Middleware for exporting Prometheus metrics using Starlette """ import logging import time -import warnings from collections import OrderedDict +from contextlib import suppress from inspect import iscoroutine from typing import ( Any, @@ -39,6 +39,7 @@ def get_matching_route_path( Credit to https://github.com/elastic/apm-agent-python """ + for route in routes: match, child_scope = route.matches(scope) if match == Match.FULL: @@ -54,7 +55,9 @@ def get_matching_route_path( if isinstance(route, BaseRoute) and getattr(route, "routes", None): child_scope = {**scope, **child_scope} child_route_name = get_matching_route_path( - child_scope, getattr(route, "routes"), route_name + child_scope, + getattr(route, "routes"), + route_name, ) if child_route_name is None: route_name = None @@ -256,6 +259,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: method = request.method path = request.url.path + base_path = request.base_url.path.rstrip("/") + + if base_path and path.startswith(base_path): + path = path[len(base_path) :] if path in self.skip_paths or method in self.skip_methods: await self.app(scope, receive, send) @@ -318,6 +325,7 @@ async def wrapped_send(message: Message) -> None: await send(message) exception: Optional[Exception] = None + original_scope = scope.copy() try: await self.app(scope, receive, wrapped_send) except Exception as e: @@ -330,7 +338,13 @@ async def wrapped_send(message: Message) -> None: ).dec() if self.filter_unhandled_paths or self.group_paths: - grouped_path = self._get_router_path(scope) + grouped_path: Optional[str] = None + + endpoint = scope.get("endpoint", None) + router = scope.get("router", None) + if endpoint and router: + with suppress(Exception): + grouped_path = get_matching_route_path(original_scope, router.routes) # filter_unhandled_paths removes any requests without mapped endpoint from the metrics. if self.filter_unhandled_paths and grouped_path is None: @@ -390,36 +404,3 @@ async def wrapped_send(message: Message) -> None: if exception: raise exception - @staticmethod - def _get_router_path(scope: Scope) -> Optional[str]: - """Returns the original router path (with url param names) for given request.""" - if not (scope.get("endpoint", None) and scope.get("router", None)): - return None - - root_path = scope.get("root_path", "") - app = scope.get("app", {}) - - if hasattr(app, "root_path"): - app_root_path = getattr(app, "root_path") - if app_root_path and root_path.startswith(app_root_path): - root_path = root_path[len(app_root_path) :] - - base_scope = { - "root_path": root_path, - "type": scope.get("type"), - "path": root_path + scope.get("path", ""), - "path_params": scope.get("path_params", {}), - "method": scope.get("method"), - "headers": scope.get("headers", {}), - } - - try: - path = get_matching_route_path( - base_scope, getattr(scope.get("router"), "routes") - ) - return path - except: - # unhandled path - pass - - return None diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 98415c5..86659b9 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -202,34 +202,6 @@ def test_ungrouped_paths(self, testapp): in metrics ) - def test_custom_root_path(self, testapp): - """test that an unhandled exception still gets logged in the requests counter""" - - client = TestClient(testapp(), root_path="/api") - - client.get("/200") - client.get("/500") - client.get("/404") - - with pytest.raises(KeyError, match="value_error"): - client.get("/unhandled") - - metrics = client.get("/metrics").content.decode() - - assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/200",status_code="200"} 1.0""" - in metrics - ) - assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/500",status_code="500"} 1.0""" - in metrics - ) - assert "/404" not in metrics - assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled",status_code="500"} 1.0""" - in metrics - ) - def test_histogram(self, client): """test that histogram buckets appear after making requests""" @@ -492,34 +464,6 @@ def test_unhandled(self, client): in metrics ) - def test_custom_root_path(self, testapp): - """test that custom root_path does not affect the path grouping""" - - client = TestClient(testapp(), root_path="/api") - - client.get("/200/111") - client.get("/500/1111") - client.get("/404/123") - - with pytest.raises(KeyError, match="value_error"): - client.get("/unhandled/123") - - metrics = client.get("/metrics").content.decode() - - assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/200/{test_param}",status_code="200"} 1.0""" - in metrics - ) - assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/500/{test_param}",status_code="500"} 1.0""" - in metrics - ) - assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/{test_param}",status_code="500"} 1.0""" - in metrics - ) - assert "/404" not in metrics - def test_mounted_path_404_unfiltered(self, testapp): """test an unhandled path that will be partially matched at the mounted base path (grouped paths)""" client = TestClient(testapp(group_paths=True, filter_unhandled_paths=False)) @@ -574,6 +518,63 @@ def test_histogram(self, client): in metrics ) + def test_custom_root_path(self, testapp): + """test that custom root_path does not affect the path grouping""" + + client = TestClient(testapp(skip_paths=["/health"]), root_path="/api") + + client.get("/200/111") + client.get("/500/1111") + client.get("/404/123") + + client.get("/api/200/111") + client.get("/api/500/1111") + client.get("/api/404/123") + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled/123") + + with pytest.raises(KeyError, match="value_error"): + client.get("/api/unhandled/123") + + client.get("/mounted/test/404") + client.get("/static/404") + + client.get("/api/mounted/test/123") + client.get("/api/static/test.txt") + + client.get("/health") + client.get("/api/health") + + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/200/{test_param}",status_code="200"} 2.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/500/{test_param}",status_code="500"} 2.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/{test_param}",status_code="500"} 2.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/mounted/test/{item}",status_code="200"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/static",status_code="200"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/static",status_code="404"} 1.0""" + in metrics + ) + assert "/404" not in metrics + assert "/health" not in metrics + class TestBackgroundTasks: """tests for ensuring the middleware handles requests involving background tasks"""