Skip to content

Commit

Permalink
websocket: Improve subprotocol support
Browse files Browse the repository at this point in the history
- Add client-side subprotocol option
- Add selected_subprotocol attribute to client and server objects
- Call select_subprotocol exactly once instead of only on non-empty
- Fix bug in previous select_subprotocol change when multiple
  subprotocols are offered
- Add tests

Updates #2281
  • Loading branch information
bdarnell committed May 12, 2018
1 parent 8afac1f commit fac04e0
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 17 deletions.
9 changes: 6 additions & 3 deletions docs/releases/v5.1.0.rst
Expand Up @@ -135,9 +135,12 @@ Deprecation notice
`tornado.websocket`
~~~~~~~~~~~~~~~~~~~

- The `.WebSocketHandler.select_subprotocol` method is now called only
when a subprotocol header is provided (previously it would be called
with a list containing an empty string).
- `.websocket_connect` now supports subprotocols.
- `.WebSocketHandler` and `.WebSocketClientConnection` now have
``selected_subprotocol`` attributes to see the subprotocol in use.
- The `.WebSocketHandler.select_subprotocol` method is now called with
an empty list instead of a list containing an empty string if no
subprotocols were requested by the client.
- The ``data`` argument to `.WebSocketHandler.ping` is now optional.
- Client-side websocket connections no longer buffer more than one
message in memory at a time.
Expand Down
1 change: 1 addition & 0 deletions docs/websocket.rst
Expand Up @@ -16,6 +16,7 @@
.. automethod:: WebSocketHandler.on_message
.. automethod:: WebSocketHandler.on_close
.. automethod:: WebSocketHandler.select_subprotocol
.. autoattribute:: WebSocketHandler.selected_subprotocol
.. automethod:: WebSocketHandler.on_ping

Output
Expand Down
37 changes: 37 additions & 0 deletions tornado/test/websocket_test.py
Expand Up @@ -143,6 +143,25 @@ def on_message(self, message):
self.write_message(self.render_string('message.html', message=message))


class SubprotocolHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super(SubprotocolHandler, self).initialize(**kwargs)
self.select_subprotocol_called = False

def select_subprotocol(self, subprotocols):
if self.select_subprotocol_called:
raise Exception("select_subprotocol called twice")
self.select_subprotocol_called = True
if 'goodproto' in subprotocols:
return 'goodproto'
return None

def open(self):
if not self.select_subprotocol_called:
raise Exception("select_subprotocol not called")
self.write_message("subprotocol=%s" % self.selected_subprotocol)


class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
Expand Down Expand Up @@ -183,6 +202,8 @@ def get_app(self):
dict(close_future=self.close_future)),
('/render', RenderMessageHandler,
dict(close_future=self.close_future)),
('/subprotocol', SubprotocolHandler,
dict(close_future=self.close_future)),
], template_loader=DictLoader({
'message.html': '<b>{{ message }}</b>',
}))
Expand Down Expand Up @@ -443,6 +464,22 @@ def test_check_origin_invalid_subdomains(self):

self.assertEqual(cm.exception.code, 403)

@gen_test
def test_subprotocols(self):
ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
self.assertEqual(ws.selected_subprotocol, 'goodproto')
res = yield ws.read_message()
self.assertEqual(res, 'subprotocol=goodproto')
yield self.close(ws)

@gen_test
def test_subprotocols_not_offered(self):
ws = yield self.ws_connect('/subprotocol')
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
self.assertEqual(res, 'subprotocol=None')
yield self.close(ws)


if sys.version_info >= (3, 5):
NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
Expand Down
72 changes: 58 additions & 14 deletions tornado/websocket.py
Expand Up @@ -256,18 +256,38 @@ def write_message(self, message, binary=False):
return self.ws_connection.write_message(message, binary=binary)

def select_subprotocol(self, subprotocols):
"""Invoked when a new WebSocket requests specific subprotocols.
"""Override to implement subprotocol negotiation.
``subprotocols`` is a list of strings identifying the
subprotocols proposed by the client. This method may be
overridden to return one of those strings to select it, or
``None`` to not select a subprotocol. Failure to select a
subprotocol does not automatically abort the connection,
although clients may close the connection if none of their
proposed subprotocols was selected.
``None`` to not select a subprotocol.
Failure to select a subprotocol does not automatically abort
the connection, although clients may close the connection if
none of their proposed subprotocols was selected.
The list may be empty, in which case this method must return
None. This method is always called exactly once even if no
subprotocols were proposed so that the handler can be advised
of this fact.
.. versionchanged:: 5.1
Previously, this method was called with a list containing
an empty string instead of an empty list if no subprotocols
were proposed by the client.
"""
return None

@property
def selected_subprotocol(self):
"""The subprotocol returned by `select_subprotocol`.
.. versionadded:: 5.1
"""
return self.ws_connection.selected_subprotocol

def get_compression_options(self):
"""Override to return compression options for the connection.
Expand Down Expand Up @@ -675,12 +695,15 @@ def _challenge_response(self):
self.request.headers.get("Sec-Websocket-Key"))

def _accept_connection(self):
subprotocols = [s.strip() for s in self.request.headers.get_list("Sec-WebSocket-Protocol")]
if subprotocols:
selected = self.handler.select_subprotocol(subprotocols)
if selected:
assert selected in subprotocols
self.handler.set_header("Sec-WebSocket-Protocol", selected)
subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
subprotocols = [s.strip() for s in subprotocol_header.split(',')]
else:
subprotocols = []
self.selected_subprotocol = self.handler.select_subprotocol(subprotocols)
if self.selected_subprotocol:
assert self.selected_subprotocol in subprotocols
self.handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)

extensions = self._parse_extensions_header(self.request.headers)
for ext in extensions:
Expand Down Expand Up @@ -739,6 +762,8 @@ def _process_server_headers(self, key, headers):
else:
raise ValueError("unsupported extension %r", ext)

self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None)

def _get_compressor_options(self, side, agreed_parameters, compression_options=None):
"""Converts a websocket agreed_parameters set to keyword arguments
for our compressor objects.
Expand Down Expand Up @@ -1056,7 +1081,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
"""
def __init__(self, request, on_message_callback=None,
compression_options=None, ping_interval=None, ping_timeout=None,
max_message_size=None):
max_message_size=None, subprotocols=[]):
self.compression_options = compression_options
self.connect_future = Future()
self.protocol = None
Expand All @@ -1077,6 +1102,8 @@ def __init__(self, request, on_message_callback=None,
'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13',
})
if subprotocols is not None:
request.headers['Sec-WebSocket-Protocol'] = ','.join(subprotocols)
if self.compression_options is not None:
# Always offer to let the server set our max_wbits (and even though
# we don't offer it, we will accept a client_no_context_takeover
Expand Down Expand Up @@ -1211,11 +1238,19 @@ def get_websocket_protocol(self):
return WebSocketProtocol13(self, mask_outgoing=True,
compression_options=self.compression_options)

@property
def selected_subprotocol(self):
"""The subprotocol selected by the server.
.. versionadded:: 5.1
"""
return self.protocol.selected_subprotocol


def websocket_connect(url, callback=None, connect_timeout=None,
on_message_callback=None, compression_options=None,
ping_interval=None, ping_timeout=None,
max_message_size=None):
max_message_size=None, subprotocols=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
Expand All @@ -1238,6 +1273,11 @@ def websocket_connect(url, callback=None, connect_timeout=None,
``websocket_connect``. In both styles, a message of ``None``
indicates that the connection has been closed.
``subprotocols`` may be a list of strings specifying proposed
subprotocols. The selected protocol may be found on the
``selected_subprotocol`` attribute of the connection object
when the connection is complete.
.. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls.
Expand All @@ -1250,6 +1290,9 @@ def websocket_connect(url, callback=None, connect_timeout=None,
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. versionchanged:: 5.1
Added the ``subprotocols`` argument.
"""
if isinstance(url, httpclient.HTTPRequest):
assert connect_timeout is None
Expand All @@ -1266,7 +1309,8 @@ def websocket_connect(url, callback=None, connect_timeout=None,
compression_options=compression_options,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
max_message_size=max_message_size)
max_message_size=max_message_size,
subprotocols=subprotocols)
if callback is not None:
IOLoop.current().add_future(conn.connect_future, callback)
return conn.connect_future

0 comments on commit fac04e0

Please sign in to comment.