Skip to content

Commit

Permalink
Avoid modifying path in scope
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Jan 19, 2024
1 parent e3ac54d commit ecf05aa
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 92 deletions.
53 changes: 17 additions & 36 deletions starlette_exporter/middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""" Middleware for exporting Prometheus metrics using Starlette """
import logging
import time
import warnings
from collections import OrderedDict
from contextlib import suppress
from inspect import iscoroutine
from typing import (
Any,
Expand Down Expand Up @@ -39,6 +39,7 @@ def get_matching_route_path(
Credit to https://github.com/elastic/apm-agent-python
"""

for route in routes:
match, child_scope = route.matches(scope)
if match == Match.FULL:
Expand All @@ -54,7 +55,9 @@ def get_matching_route_path(
if isinstance(route, BaseRoute) and getattr(route, "routes", None):
child_scope = {**scope, **child_scope}
child_route_name = get_matching_route_path(
child_scope, getattr(route, "routes"), route_name
child_scope,
getattr(route, "routes"),
route_name,
)
if child_route_name is None:
route_name = None
Expand Down Expand Up @@ -256,6 +259,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

method = request.method
path = request.url.path
base_path = request.base_url.path.rstrip("/")

if base_path and path.startswith(base_path):
path = path[len(base_path) :]

if path in self.skip_paths or method in self.skip_methods:
await self.app(scope, receive, send)
Expand Down Expand Up @@ -318,6 +325,7 @@ async def wrapped_send(message: Message) -> None:
await send(message)

exception: Optional[Exception] = None
original_scope = scope.copy()
try:
await self.app(scope, receive, wrapped_send)
except Exception as e:
Expand All @@ -330,7 +338,13 @@ async def wrapped_send(message: Message) -> None:
).dec()

if self.filter_unhandled_paths or self.group_paths:
grouped_path = self._get_router_path(scope)
grouped_path: Optional[str] = None

endpoint = scope.get("endpoint", None)
router = scope.get("router", None)
if endpoint and router:
with suppress(Exception):
grouped_path = get_matching_route_path(original_scope, router.routes)

# filter_unhandled_paths removes any requests without mapped endpoint from the metrics.
if self.filter_unhandled_paths and grouped_path is None:
Expand Down Expand Up @@ -390,36 +404,3 @@ async def wrapped_send(message: Message) -> None:
if exception:
raise exception

@staticmethod
def _get_router_path(scope: Scope) -> Optional[str]:
"""Returns the original router path (with url param names) for given request."""
if not (scope.get("endpoint", None) and scope.get("router", None)):
return None

root_path = scope.get("root_path", "")
app = scope.get("app", {})

if hasattr(app, "root_path"):
app_root_path = getattr(app, "root_path")
if app_root_path and root_path.startswith(app_root_path):
root_path = root_path[len(app_root_path) :]

base_scope = {
"root_path": root_path,
"type": scope.get("type"),
"path": root_path + scope.get("path", ""),
"path_params": scope.get("path_params", {}),
"method": scope.get("method"),
"headers": scope.get("headers", {}),
}

try:
path = get_matching_route_path(
base_scope, getattr(scope.get("router"), "routes")
)
return path
except:
# unhandled path
pass

return None
113 changes: 57 additions & 56 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,34 +202,6 @@ def test_ungrouped_paths(self, testapp):
in metrics
)

def test_custom_root_path(self, testapp):
"""test that an unhandled exception still gets logged in the requests counter"""

client = TestClient(testapp(), root_path="/api")

client.get("/200")
client.get("/500")
client.get("/404")

with pytest.raises(KeyError, match="value_error"):
client.get("/unhandled")

metrics = client.get("/metrics").content.decode()

assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/200",status_code="200"} 1.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/500",status_code="500"} 1.0"""
in metrics
)
assert "/404" not in metrics
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/unhandled",status_code="500"} 1.0"""
in metrics
)

def test_histogram(self, client):
"""test that histogram buckets appear after making requests"""

Expand Down Expand Up @@ -492,34 +464,6 @@ def test_unhandled(self, client):
in metrics
)

def test_custom_root_path(self, testapp):
"""test that custom root_path does not affect the path grouping"""

client = TestClient(testapp(), root_path="/api")

client.get("/200/111")
client.get("/500/1111")
client.get("/404/123")

with pytest.raises(KeyError, match="value_error"):
client.get("/unhandled/123")

metrics = client.get("/metrics").content.decode()

assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/200/{test_param}",status_code="200"} 1.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/500/{test_param}",status_code="500"} 1.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/{test_param}",status_code="500"} 1.0"""
in metrics
)
assert "/404" not in metrics

def test_mounted_path_404_unfiltered(self, testapp):
"""test an unhandled path that will be partially matched at the mounted base path (grouped paths)"""
client = TestClient(testapp(group_paths=True, filter_unhandled_paths=False))
Expand Down Expand Up @@ -574,6 +518,63 @@ def test_histogram(self, client):
in metrics
)

def test_custom_root_path(self, testapp):
"""test that custom root_path does not affect the path grouping"""

client = TestClient(testapp(skip_paths=["/health"]), root_path="/api")

client.get("/200/111")
client.get("/500/1111")
client.get("/404/123")

client.get("/api/200/111")
client.get("/api/500/1111")
client.get("/api/404/123")

with pytest.raises(KeyError, match="value_error"):
client.get("/unhandled/123")

with pytest.raises(KeyError, match="value_error"):
client.get("/api/unhandled/123")

client.get("/mounted/test/404")
client.get("/static/404")

client.get("/api/mounted/test/123")
client.get("/api/static/test.txt")

client.get("/health")
client.get("/api/health")

metrics = client.get("/metrics").content.decode()

assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/200/{test_param}",status_code="200"} 2.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/500/{test_param}",status_code="500"} 2.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/{test_param}",status_code="500"} 2.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/mounted/test/{item}",status_code="200"} 1.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/static",status_code="200"} 1.0"""
in metrics
)
assert (
"""starlette_requests_total{app_name="starlette",method="GET",path="/static",status_code="404"} 1.0"""
in metrics
)
assert "/404" not in metrics
assert "/health" not in metrics


class TestBackgroundTasks:
"""tests for ensuring the middleware handles requests involving background tasks"""
Expand Down

0 comments on commit ecf05aa

Please sign in to comment.