Skip to content

Commit

Permalink
Provide a better way to override request handling.
Browse files Browse the repository at this point in the history
This replaces the get_response_status() API which never made it into a
release (so there's no backwards incompatibility).

Remove a test that depends on get_response_status() being called after
check_request(). The extension point must be before check_request() so
it can handle regular HTTP requests.

Fix #116.

Supersedes #202 #154, #137.
  • Loading branch information
aaugustin committed Aug 20, 2017
1 parent 3242b20 commit 50fd62e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 65 deletions.
2 changes: 1 addition & 1 deletion docs/api.rst
Expand Up @@ -42,8 +42,8 @@ Server
.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None)

.. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None)
.. automethod:: process_request(path, request_headers)
.. automethod:: select_subprotocol(client_protos, server_protos)
.. automethod:: get_response_status()

Client
......
Expand Down
4 changes: 2 additions & 2 deletions docs/changelog.rst
Expand Up @@ -14,8 +14,8 @@ Changelog
* :func:`~websockets.server.serve` can be used as an asynchronous context
manager on Python ≥ 3.5.

* Added support for rejecting incoming connections by customizing
:meth:`~websockets.server.WebSocketServerProtocol.get_response_status()`.
* Added support for customizing handling of incoming connections with
:meth:`~websockets.server.WebSocketServerProtocol.process_request()`.

* Made read and write buffer sizes configurable.

Expand Down
61 changes: 33 additions & 28 deletions websockets/server.py
Expand Up @@ -81,8 +81,8 @@ def handler(self):
self.writer.write(response.encode())
raise

# Subclasses can customize get_response_status() or handshake() to
# reject the handshake, typically after checking authentication.
# Subclasses can customize process_request() to reject the
# handshake, typically after checking authentication.
if path is None:
return

Expand Down Expand Up @@ -211,13 +211,22 @@ def select_subprotocol(client_protos, server_protos):
return sorted(common_protos, key=priority)[0]

@asyncio.coroutine
def get_response_status(self, set_header):
def process_request(self, path, request_headers):
"""
Return a :class:`~http.HTTPStatus` for the HTTP response.
Intercept the HTTP request and return a HTTP response if needed.
(:class:`~http.HTTPStatus` was added in Python 3.5. On earlier
versions, a compatible object must be returned. Check the definition
of ``SWITCHING_PROTOCOLS`` for an example.)
``request_headers`` are a :class:`~http.client.HTTPMessage`.
If this coroutine returns ``None``, the WebSocket handshake continues.
If it returns a HTTP status code and HTTP headers, that HTTP response
is sent and the connection is closed immediately.
The HTTP status must be a :class:`~http.HTTPStatus` and HTTP headers
must be an iterable of ``(name, value)`` pairs.
(:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible
object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in
``websockets.compatibility`` for an example.)
This method may be overridden to check the request headers and set a
different status, for example to authenticate the request and return
Expand All @@ -226,13 +235,7 @@ def get_response_status(self, set_header):
It is declared as a coroutine because such authentication checks are
likely to require network requests.
The connection is closed immediately after sending the response when
the status is not ``HTTPStatus.SWITCHING_PROTOCOLS``.
Call ``set_header(key, value)`` to set additional response headers.
"""
return SWITCHING_PROTOCOLS

@asyncio.coroutine
def handshake(self, origins=None, subprotocols=None, extra_headers=None):
Expand All @@ -255,29 +258,30 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None):
Return the URI of the request.
"""
path, headers = yield from self.read_http_request()
get_header = lambda k: headers.get(k, '')
path, request_headers = yield from self.read_http_request()

# Hook for customizing request handling, for example checking
# authentication or treating some paths as plain HTTP endpoints.

early_response = yield from self.process_request(path, request_headers)
if early_response is not None:
yield from self.write_http_response(*early_response)
self.opening_handshake.set_result(False)
yield from self.close_connection(force=True)
return

get_header = lambda k: request_headers.get(k, '')

key = check_request(get_header)

self.origin = self.process_origin(get_header, origins)
self.subprotocol = self.process_subprotocol(get_header, subprotocols)

headers = []
set_header = lambda k, v: headers.append((k, v))
response_headers = []
set_header = lambda k, v: response_headers.append((k, v))

set_header('Server', USER_AGENT)

status = yield from self.get_response_status(set_header)

# Abort the connection if the status code isn't 101.
if status.value != SWITCHING_PROTOCOLS.value:
yield from self.write_http_response(status, headers)
self.opening_handshake.set_result(False)
yield from self.close_connection(force=True)
return

# Status code is 101, establish the connection.
if self.subprotocol:
set_header('Sec-WebSocket-Protocol', self.subprotocol)
if extra_headers is not None:
Expand All @@ -289,7 +293,8 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None):
set_header(name, value)
build_response(set_header, key)

yield from self.write_http_response(status, headers)
yield from self.write_http_response(
SWITCHING_PROTOCOLS, response_headers)

assert self.state == CONNECTING
self.state = OPEN
Expand Down
38 changes: 4 additions & 34 deletions websockets/test_client_server.py
@@ -1,7 +1,5 @@
import asyncio
import functools
import http
import http.client
import logging
import os
import ssl
Expand Down Expand Up @@ -90,15 +88,15 @@ def with_client(*args, **kwds):
class UnauthorizedServerProtocol(WebSocketServerProtocol):

@asyncio.coroutine
def get_response_status(self, set_header):
return UNAUTHORIZED
def process_request(self, path, request_headers):
return UNAUTHORIZED, []


class ForbiddenServerProtocol(WebSocketServerProtocol):

@asyncio.coroutine
def get_response_status(self, set_header):
return FORBIDDEN
def process_request(self, path, request_headers):
return FORBIDDEN, []


class FooClientProtocol(WebSocketClientProtocol):
Expand Down Expand Up @@ -260,34 +258,6 @@ def test_protocol_custom_response_headers_list(self):
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)

def test_get_response_status_attributes_available(self):
# Save the attribute values to a dict instead of asserting inside
# get_response_status() because assertion errors there do not
# currently bubble up for easy viewing.
attrs = {}

class SaveAttributesProtocol(WebSocketServerProtocol):
@asyncio.coroutine
def get_response_status(self, set_header):
attrs['origin'] = self.origin
attrs['path'] = self.path
attrs['raw_request_headers'] = self.raw_request_headers.copy()
attrs['request_headers'] = self.request_headers
status = yield from super().get_response_status(set_header)
return status

with self.temp_server(create_protocol=SaveAttributesProtocol):
self.start_client(path='foo/bar', origin='http://otherhost')
self.assertEqual(attrs['origin'], 'http://otherhost')
self.assertEqual(attrs['path'], '/foo/bar')
# To reduce test brittleness, only check one nontrivial aspect
# of the request headers.
self.assertIn(('Origin', 'http://otherhost'),
attrs['raw_request_headers'])
request_headers = attrs['request_headers']
self.assertIsInstance(request_headers, http.client.HTTPMessage)
self.assertEqual(request_headers.get('origin'), 'http://otherhost')

def assert_client_raises_code(self, status_code):
with self.assertRaises(InvalidStatusCode) as raised:
self.start_client()
Expand Down

0 comments on commit 50fd62e

Please sign in to comment.