Skip to content

Commit

Permalink
Add support for WebSocket rules in the routing
Browse files Browse the repository at this point in the history
This allows for Rules to be marked as a WebSocket route and only
matched if the binding is websocket. It also ensures that when a
websocket rule is built with a scheme it defaults to the `ws` or `wss`
scheme.
  • Loading branch information
pgjones committed Feb 3, 2020
1 parent 49cf35b commit e932a1f
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ Unreleased
quality tags. Instead the initial order is preserved. :issue:`1686`
- Added ``Map.lock_class`` attribute for alternative
implementations. :pr:`1702`
- Support WebSocket rules (binding to WebSocket requests) in the
routing systems. :pr:`1709`


Version 0.16.1
Expand Down
29 changes: 29 additions & 0 deletions docs/routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,32 @@ Variable parts are of course also possible in the host section::
Rule('/', endpoint='www_index', host='www.example.com'),
Rule('/', endpoint='user_index', host='<user>.example.com')
], host_matching=True)


WebSockets
==========

.. versionadded:: 1.0

With Werkzeug 1.0 onwards it is possible to mark a Rule as a websocket
and only match it if the MapAdapter is created with a websocket
bind. This functionality can be used as so::

url_map = Map([
Rule("/", endpoint="index", websocket=True),
])
adapter = map.bind("example.org", "/", url_scheme="ws")
assert adapter.match("/") == ("index", {})

If the only match is a WebSocket rule and the bind is http (or the
only match is http and the bind is websocket) a
:class:`WebsocketMismatch` (derives from :class:`BadRequest`)
exception is raised.

As WebSocket urls have a different scheme, WebSocket Rules are always
built with a scheme and host i.e. as if ``force_external = True``.

.. note::

Werkzeug has no further WebSocket support (beyond routing). This
functionality is mostly of use to ASGI projects.
83 changes: 75 additions & 8 deletions src/werkzeug/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
from .datastructures import ImmutableDict
from .datastructures import MultiDict
from .exceptions import BadHost
from .exceptions import BadRequest
from .exceptions import HTTPException
from .exceptions import MethodNotAllowed
from .exceptions import NotFound
Expand Down Expand Up @@ -329,7 +330,15 @@ def __str__(self):
return u"".join(message)


class WebsocketMismatch(BadRequest):
"""The only matched rule is either a websocket and the request is http
or the rule is http and the request is a websocket."""

pass


class ValidationError(ValueError):

"""Validation error. If a rule converter raises this exception the rule
does not match the current URL and the next URL is tried.
"""
Expand Down Expand Up @@ -631,8 +640,15 @@ def foo_with_slug(adapter, id):
used to provide a match rule for the whole host. This also means
that the subdomain feature is disabled.
`websocket`
If True (defaults to False) this represents a WebSocket, rather than
a http route.
.. versionadded:: 0.7
The `alias` and `host` parameters were added.
.. versionadded:: 1.0
The `websocket` parameter was added.
"""

def __init__(
Expand All @@ -648,6 +664,7 @@ def __init__(
redirect_to=None,
alias=False,
host=None,
websocket=False,
):
if not string.startswith("/"):
raise ValueError("urls must start with a leading slash")
Expand All @@ -662,14 +679,27 @@ def __init__(
self.defaults = defaults
self.build_only = build_only
self.alias = alias
self.websocket = websocket
if methods is not None:
if isinstance(methods, str):
raise TypeError("param `methods` should be `Iterable[str]`, not `str`")
methods = set([x.upper() for x in methods])
if "HEAD" not in methods and "GET" in methods:
methods.add("HEAD")

if (
websocket
and methods is not None
and len(methods - {"GET", "HEAD", "OPTIONS"}) > 0
):
raise ValueError(
"WebSocket Rules can only use 'GET', 'HEAD', or 'OPTIONS' methods"
)

if methods is None:
self.methods = None
else:
if isinstance(methods, str):
raise TypeError("param `methods` should be `Iterable[str]`, not `str`")
self.methods = set([x.upper() for x in methods])
if "HEAD" not in self.methods and "GET" in self.methods:
self.methods.add("HEAD")
self.methods = methods
self.endpoint = endpoint
self.redirect_to = redirect_to

Expand Down Expand Up @@ -1490,8 +1520,12 @@ def bind(
.. versionadded:: 0.8
`query_args` can now also be a string.
.. versionadded:: 1.0
`websocket` added
.. versionchanged:: 0.15
``path_info`` defaults to ``'/'`` if ``None``.
"""
server_name = server_name.lower()
if self.host_matching:
Expand Down Expand Up @@ -1663,6 +1697,7 @@ def __init__(
self.path_info = to_unicode(path_info)
self.default_method = to_unicode(default_method)
self.query_args = query_args
self.websocket = self.url_scheme in {"ws", "wss"}

def dispatch(
self, view_func, path_info=None, method=None, catch_http_exceptions=False
Expand Down Expand Up @@ -1720,7 +1755,14 @@ def application(environ, start_response):
return e
raise

def match(self, path_info=None, method=None, return_rule=False, query_args=None):
def match(
self,
path_info=None,
method=None,
return_rule=False,
query_args=None,
websocket=None,
):
"""The usage is simple: you just pass the match method the current
path info as well as the method (which defaults to `GET`). The
following things can then happen:
Expand All @@ -1741,6 +1783,10 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
You can use the `RequestRedirect` instance as response-like object
similar to all other subclasses of `HTTPException`.
- you receive a ``WebsocketMismatch`` exception if the only match is
a websocket rule and the bind is to a http request, or if the match
is a http rule and the bind is to a websocket request.
- you get a tuple in the form ``(endpoint, arguments)`` if there is
a match (unless `return_rule` is True, in which case you get a tuple
in the form ``(rule, arguments)``)
Expand Down Expand Up @@ -1805,6 +1851,8 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
if query_args is None:
query_args = self.query_args
method = (method or self.default_method).upper()
if websocket is None:
websocket = self.websocket

require_redirect = False

Expand All @@ -1814,6 +1862,7 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
)

have_match_for = set()
websocket_mismatch = False
for rule in self.map._rules:
try:
rv = rule.match(path, method)
Expand All @@ -1835,6 +1884,9 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
if rule.methods is not None and method not in rule.methods:
have_match_for.update(rule.methods)
continue
if rule.websocket != websocket:
websocket_mismatch = True
continue

if self.map.redirect_defaults:
redirect_url = self.get_default_redirect(rule, method, rv, query_args)
Expand Down Expand Up @@ -1880,6 +1932,8 @@ def _handle_match(match):

if have_match_for:
raise MethodNotAllowed(valid_methods=list(have_match_for))
if websocket_mismatch:
raise WebsocketMismatch()
raise NotFound()

def test(self, path_info=None, method=None):
Expand Down Expand Up @@ -2005,6 +2059,7 @@ def _partial_build(self, endpoint, values, method, append_unknown):
rv = rule.build(values, append_unknown)

if rv is not None:
rv = (rv[0], rv[1], rule.websocket)
if self.map.host_matching:
if rv[0] == self.server_name:
return rv
Expand Down Expand Up @@ -2114,10 +2169,22 @@ def build(
rv = self._partial_build(endpoint, values, method, append_unknown)
if rv is None:
raise BuildError(endpoint, values, method, self)
domain_part, path = rv
domain_part, path, websocket = rv

host = self.get_host(domain_part)

# Only build WebSocket routes with the scheme (as relative
# WebSocket paths aren't useful and are misleading). In
# addition if bound to a WebSocket ensure that http routes are
# built with a http scheme (if required).
url_scheme = self.url_scheme
secure = url_scheme in {"https", "wss"}
if websocket:
force_external = True
url_scheme = "wss" if secure else "ws"
elif url_scheme:
url_scheme = "https" if secure else "http"

# shortcut this.
if not force_external and (
(self.map.host_matching and host == self.server_name)
Expand All @@ -2127,7 +2194,7 @@ def build(
return str(
"%s//%s%s/%s"
% (
self.url_scheme + ":" if self.url_scheme else "",
url_scheme + ":" if url_scheme else "",
host,
self.script_name[:-1],
path.lstrip("/"),
Expand Down
31 changes: 31 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_basic_routing():
r.Rule("/", endpoint="index"),
r.Rule("/foo", endpoint="foo"),
r.Rule("/bar/", endpoint="bar"),
r.Rule("/ws", endpoint="ws", websocket=True),
r.Rule("/", endpoint="indexws", websocket=True),
]
)
adapter = map.bind("example.org", "/")
Expand All @@ -36,6 +38,9 @@ def test_basic_routing():
pytest.raises(r.RequestRedirect, lambda: adapter.match("/bar"))
pytest.raises(r.NotFound, lambda: adapter.match("/blub"))

adapter = map.bind("example.org", "/", url_scheme="ws")
assert adapter.match("/") == ("indexws", {})

adapter = map.bind("example.org", "/test")
with pytest.raises(r.RequestRedirect) as excinfo:
adapter.match("/bar")
Expand All @@ -61,6 +66,13 @@ def test_basic_routing():
adapter.match()
assert excinfo.value.new_url == "http://example.org/bar/?foo=bar"

adapter = map.bind("example.org", "/ws", url_scheme="wss")
assert adapter.match("/ws", websocket=True) == ("ws", {})
with pytest.raises(r.WebsocketMismatch):
adapter.match("/ws", websocket=False)
with pytest.raises(r.WebsocketMismatch):
adapter.match("/foo", websocket=True)


def test_merge_slashes_match():
url_map = r.Map(
Expand Down Expand Up @@ -192,6 +204,7 @@ def test_basic_building():
r.Rule("/bar/<float:bazf>", endpoint="barf"),
r.Rule("/bar/<path:bazp>", endpoint="barp"),
r.Rule("/hehe", endpoint="blah", subdomain="blah"),
r.Rule("/ws", endpoint="ws", websocket=True),
]
)
adapter = map.bind("example.org", "/", subdomain="blah")
Expand Down Expand Up @@ -223,6 +236,11 @@ def test_basic_building():
assert adapter.build("foo", {}) == "/foo"
assert adapter.build("foo", {}, force_external=True) == "//example.org/foo"

adapter = map.bind("example.org", url_scheme="ws")
assert adapter.build("ws", {}) == "ws://example.org/ws"
assert adapter.build("foo", {}, force_external=True) == "http://example.org/foo"
assert adapter.build("foo", {}) == "/foo"


def test_long_build():
long_args = dict(("v%d" % x, x) for x in range(10000))
Expand Down Expand Up @@ -1205,3 +1223,16 @@ def test_build_url_same_endpoint_multiple_hosts():

beta_case = m.bind("BeTa.ExAmPlE.CoM")
assert beta_case.build("index") == "/"


def test_rule_websocket_methods():
with pytest.raises(ValueError):
r.Rule("/ws", endpoint="ws", websocket=True, methods=["post"])
with pytest.raises(ValueError):
r.Rule(
"/ws",
endpoint="ws",
websocket=True,
methods=["get", "head", "options", "post"],
)
r.Rule("/ws", endpoint="ws", websocket=True, methods=["get", "head", "options"])

0 comments on commit e932a1f

Please sign in to comment.