Skip to content

Commit

Permalink
Adding allow route overwrite option in blueprint (#2716)
Browse files Browse the repository at this point in the history
* Adding allow route overwrite option

* Add test case for route overwriting after bp copy

* Fix test

* Fix

* Add test case `test_bp_allow_override`

* Remove conflicted future routes when overwriting is allowed

* Improved test test_bp_copy_with_route_overwriting

* Fix type

* Fix type 2

* Add `test_bp_copy_without_route_overwriting` case

* make `allow_route_overwrite` flag to be internal

* Remove unwanted test case

---------

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
  • Loading branch information
ChihweiLHBird and ahopkins committed Jul 7, 2023
1 parent 4068a0d commit e374409
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 5 deletions.
5 changes: 4 additions & 1 deletion sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,11 @@ def _apply_exception_handler(
def _apply_listener(self, listener: FutureListener):
return self.register_listener(listener.listener, listener.event)

def _apply_route(self, route: FutureRoute) -> List[Route]:
def _apply_route(
self, route: FutureRoute, overwrite: bool = False
) -> List[Route]:
params = route._asdict()
params["overwrite"] = overwrite
websocket = params.pop("websocket", False)
subprotocols = params.pop("subprotocols", None)

Expand Down
10 changes: 9 additions & 1 deletion sanic/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class Blueprint(BaseSanic):
"_future_listeners",
"_future_exceptions",
"_future_signals",
"_allow_route_overwrite",
"copied_from",
"ctx",
"exceptions",
Expand All @@ -119,6 +120,7 @@ def __init__(
):
super().__init__(name=name)
self.reset()
self._allow_route_overwrite = False
self.copied_from = ""
self.ctx = SimpleNamespace()
self.host = host
Expand Down Expand Up @@ -169,6 +171,7 @@ def registered(self) -> bool:

def reset(self):
self._apps: Set[Sanic] = set()
self._allow_route_overwrite = False
self.exceptions: List[RouteHandler] = []
self.listeners: Dict[str, List[ListenerType[Any]]] = {}
self.middlewares: List[MiddlewareType] = []
Expand All @@ -182,6 +185,7 @@ def copy(
url_prefix: Optional[Union[str, Default]] = _default,
version: Optional[Union[int, str, float, Default]] = _default,
version_prefix: Union[str, Default] = _default,
allow_route_overwrite: Union[bool, Default] = _default,
strict_slashes: Optional[Union[bool, Default]] = _default,
with_registration: bool = True,
with_ctx: bool = False,
Expand Down Expand Up @@ -225,6 +229,8 @@ def copy(
new_bp.strict_slashes = strict_slashes
if not isinstance(version_prefix, Default):
new_bp.version_prefix = version_prefix
if not isinstance(allow_route_overwrite, Default):
new_bp._allow_route_overwrite = allow_route_overwrite

for key, value in attrs_backup.items():
setattr(self, key, value)
Expand Down Expand Up @@ -360,7 +366,9 @@ def register(self, app, options):
continue

registered.add(apply_route)
route = app._apply_route(apply_route)
route = app._apply_route(
apply_route, overwrite=self._allow_route_overwrite
)

# If it is a copied BP, then make sure all of the names of routes
# matchup with the new BP name
Expand Down
8 changes: 6 additions & 2 deletions sanic/mixins/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def decorator(handler):
error_format,
route_context,
)

overwrite = getattr(self, "_allow_route_overwrite", False)
if overwrite:
self._future_routes = set(
filter(lambda x: x.uri != uri, self._future_routes)
)
self._future_routes.add(route)

args = list(signature(handler).parameters.keys())
Expand All @@ -182,7 +186,7 @@ def decorator(handler):
handler.is_stream = stream

if apply:
self._apply_route(route)
self._apply_route(route, overwrite=overwrite)

if static:
return route, handler
Expand Down
2 changes: 2 additions & 0 deletions sanic/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def add( # type: ignore
unquote: bool = False,
static: bool = False,
version_prefix: str = "/v",
overwrite: bool = False,
error_format: Optional[str] = None,
) -> Union[Route, List[Route]]:
"""
Expand Down Expand Up @@ -122,6 +123,7 @@ def add( # type: ignore
name=name,
strict=strict_slashes,
unquote=unquote,
overwrite=overwrite,
)

if isinstance(host, str):
Expand Down
79 changes: 78 additions & 1 deletion tests/test_blueprint_copy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from sanic import Blueprint, Sanic
import pytest

from sanic_routing.exceptions import RouteExists

from sanic import Blueprint, Request, Sanic
from sanic.response import text


Expand Down Expand Up @@ -74,3 +78,76 @@ def handle_request(request):
assert "test_bp_copy.test_bp4.handle_request" in route_names
assert "test_bp_copy.test_bp5.handle_request" in route_names
assert "test_bp_copy.test_bp6.handle_request" in route_names


def test_bp_copy_without_route_overwriting(app: Sanic):
bpv1 = Blueprint("bp_v1", version=1, url_prefix="my_api")

@bpv1.route("/")
async def handler(request: Request):
return text("v1")

app.blueprint(bpv1)

bpv2 = bpv1.copy("bp_v2", version=2, allow_route_overwrite=False)
bpv3 = bpv1.copy(
"bp_v3",
version=3,
allow_route_overwrite=False,
with_registration=False,
)

with pytest.raises(RouteExists, match="Route already registered*"):

@bpv2.route("/")
async def handler(request: Request):
return text("v2")

app.blueprint(bpv2)

with pytest.raises(RouteExists, match="Route already registered*"):

@bpv3.route("/")
async def handler(request: Request):
return text("v3")

app.blueprint(bpv3)


def test_bp_copy_with_route_overwriting(app: Sanic):
bpv1 = Blueprint("bp_v1", version=1, url_prefix="my_api")

@bpv1.route("/")
async def handler(request: Request):
return text("v1")

app.blueprint(bpv1)

bpv2 = bpv1.copy("bp_v2", version=2, allow_route_overwrite=True)
bpv3 = bpv1.copy(
"bp_v3", version=3, allow_route_overwrite=True, with_registration=False
)

@bpv2.route("/")
async def handler(request: Request):
return text("v2")

app.blueprint(bpv2)

@bpv3.route("/")
async def handler(request: Request):
return text("v3")

app.blueprint(bpv3)

_, response = app.test_client.get("/v1/my_api")
assert response.status == 200
assert response.text == "v1"

_, response = app.test_client.get("/v2/my_api")
assert response.status == 200
assert response.text == "v2"

_, response = app.test_client.get("/v3/my_api")
assert response.status == 200
assert response.text == "v3"

0 comments on commit e374409

Please sign in to comment.