diff --git a/sanic_routing/route.py b/sanic_routing/route.py index 46e86b4..a5088d2 100644 --- a/sanic_routing/route.py +++ b/sanic_routing/route.py @@ -74,7 +74,7 @@ def __init__( name: str, handler: t.Callable[..., t.Any], methods: t.Union[t.Sequence[str], t.FrozenSet[str]], - requirements: t.Dict[str, t.Any] = None, + requirements: t.Optional[t.Dict[str, t.Any]] = None, strict: bool = False, unquote: bool = False, static: bool = False, diff --git a/sanic_routing/tree.py b/sanic_routing/tree.py index 4825152..311e3c1 100644 --- a/sanic_routing/tree.py +++ b/sanic_routing/tree.py @@ -3,7 +3,7 @@ from .group import RouteGroup from .line import Line -from .patterns import REGEX_PARAM_NAME, REGEX_PARAM_NAME_EXT +from .patterns import REGEX_PARAM_NAME, REGEX_PARAM_NAME_EXT, alpha, ext, slug logger = getLogger("sanic.root") @@ -16,6 +16,7 @@ def __init__( parent=None, router=None, param=None, + unquote=False, ) -> None: self.root = root self.part = part @@ -34,7 +35,7 @@ def __init__( self.children_param_injected = False self.has_deferred = False self.equality_check = False - self.unquote = False + self.unquote = unquote self.router = router def __str__(self) -> str: @@ -268,7 +269,7 @@ def _inject_param_check(self, location, indent, idx): Line("pass", indent + 1), Line("else:", indent), ] - if self.unquote: + if self.unquote and self._cast_as_str(self.param.cast): lines.append( Line( f"basket['__matches__'][{idx}] = " @@ -280,6 +281,11 @@ def _inject_param_check(self, location, indent, idx): location.extend(lines) + @staticmethod + def _cast_as_str(cast) -> bool: + return_type_hint = t.get_type_hints(cast).get("return") + return cast in (str, ext, slug, alpha) or return_type_hint is str + @staticmethod def _inject_method_check(location, indent, group): """ @@ -436,6 +442,7 @@ def generate(self, groups: t.Iterable[RouteGroup]) -> None: """ for group in groups: current = self.root + current.unquote = current.unquote or group.unquote for level, part in enumerate(group.parts): param = None dynamic = part.startswith("<") @@ -452,6 +459,7 @@ def generate(self, groups: t.Iterable[RouteGroup]) -> None: parent=current, router=self.router, param=param, + unquote=current.unquote, ) child.dynamic = dynamic current.add_child(child) @@ -459,7 +467,6 @@ def generate(self, groups: t.Iterable[RouteGroup]) -> None: current.level = level + 1 current.groups.append(group) - current.unquote = current.unquote or group.unquote def display(self) -> None: """ diff --git a/tests/test_builtin_param_types.py b/tests/test_builtin_param_types.py index 43c5218..c1507e7 100644 --- a/tests/test_builtin_param_types.py +++ b/tests/test_builtin_param_types.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest + from sanic_routing import BaseRouter from sanic_routing.exceptions import InvalidUsage, NotFound diff --git a/tests/test_unquote.py b/tests/test_unquote.py new file mode 100644 index 0000000..2374e09 --- /dev/null +++ b/tests/test_unquote.py @@ -0,0 +1,52 @@ +from unittest.mock import Mock + +from sanic_routing import BaseRouter + + +class Router(BaseRouter): + def get(self, path, method, extra=None): + return self.resolve(path=path, method=method, extra=extra) + + +def test_no_unquote(): + handler = Mock(return_value=123) + + router = Router() + router.add("//", methods=["GET"], handler=handler, unquote=False) + router.finalize() + + _, handler, params = router.get("/%F0%9F%98%8E/sunglasses", "GET") + assert params == {"bar": "sunglasses", "foo": "%F0%9F%98%8E"} + + _, handler, params = router.get("/😎/sunglasses", "GET") + assert params == {"bar": "sunglasses", "foo": "😎"} + + +def test_unquote(): + handler = Mock(return_value=123) + + router = Router() + router.add("//", methods=["GET"], handler=handler, unquote=True) + router.finalize() + + _, handler, params = router.get("/%F0%9F%98%8E/sunglasses", "GET") + assert params == {"bar": "sunglasses", "foo": "😎"} + + _, handler, params = router.get("/😎/sunglasses", "GET") + assert params == {"bar": "sunglasses", "foo": "😎"} + + +def test_unquote_non_string(): + handler = Mock(return_value=123) + + router = Router() + router.add( + "//", methods=["GET"], handler=handler, unquote=True + ) + router.finalize() + + _, handler, params = router.get("/%F0%9F%98%8E/123", "GET") + assert params == {"bar": 123, "foo": "😎"} + + _, handler, params = router.get("/😎/123", "GET") + assert params == {"bar": 123, "foo": "😎"}