diff --git a/docs/changelog.md b/docs/changelog.md
index f8751521..b16781d0 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -10,13 +10,22 @@ A changelog:
## dev version
-* Changed `Roll.hook` signature to also accept kwargs ([#5](https://github.com/pyrates/roll/pull/5))
-* `json` shorcut sets `utf-8` charset in `Content-Type` header ([#13](https://github.com/pyrates/roll/pull/13))
-* Added `static` extension to serve static files for development ([#16](https://github.com/pyrates/roll/pull/16))
-* `cors` accepts `headers` parameter to control `Access-Control-Allow-Headers` ([#12](https://github.com/pyrates/roll/pull/12))
+* Changed `Roll.hook` signature to also accept kwargs
+ ([#5](https://github.com/pyrates/roll/pull/5))
+* `json` shorcut sets `utf-8` charset in `Content-Type` header
+ ([#13](https://github.com/pyrates/roll/pull/13))
+* Added `static` extension to serve static files for development
+ ([#16](https://github.com/pyrates/roll/pull/16))
+* `cors` accepts `headers` parameter to control `Access-Control-Allow-Headers`
+ ([#12](https://github.com/pyrates/roll/pull/12))
+* Added `content_negociation` extension to reject unacceptable client requests
+ based on the `Accept` header
+ ([#21](https://github.com/pyrates/roll/pull/21))
* **Breaking changes**:
- * `options` extension is no more applied by default ([#16](https://github.com/pyrates/roll/pull/16))
- * deprecated `req` pytest fixture is now removed ([#9](https://github.com/pyrates/roll/pull/9))
+ * `options` extension is no more applied by default
+ ([#16](https://github.com/pyrates/roll/pull/16))
+ * deprecated `req` pytest fixture is now removed
+ ([#9](https://github.com/pyrates/roll/pull/9))
## 0.5.0 — 2017-09-21
diff --git a/docs/how-to-guides.md b/docs/how-to-guides.md
index 35994c3e..9fd5613b 100644
--- a/docs/how-to-guides.md
+++ b/docs/how-to-guides.md
@@ -46,6 +46,27 @@ always a bonus. The `response` object is modified in place.
Make sure to check these out!*
+## How to deal with content negociation
+
+The [`content_negociation` extension](reference.md#content_negociation)
+is made for this purpose, you can use it that way:
+
+```python
+extensions.content_negociation(app)
+
+@app.route('/test', accepts=['text/html', 'application/json'])
+async def get(req, resp):
+ if req.headers['Accept'] == 'text/html':
+ resp.headers['Content-Type'] = 'text/html'
+ resp.body = '
accepted
'
+ elif req.headers['Accept'] == 'application/json':
+ resp.json = {'status': 'accepted'}
+```
+
+Requests with `Accept` header not matching `text/html` or
+`application/json` will be honored with a `406 Not Acceptable` response.
+
+
## How to return an HTTP error
There are many reasons to return an HTTP error, with Roll you have to
diff --git a/docs/index.md b/docs/index.md
index ddc1619a..a3f64c7e 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -56,6 +56,7 @@ A how-to guide:
* [How to install Roll](how-to-guides.md#how-to-install-roll)
* [How to create an extension](how-to-guides.md#how-to-create-an-extension)
+* [How to deal with content negociation](how-to-guides.md#how-to-deal-with-content-negociation)
* [How to return an HTTP error](how-to-guides.md#how-to-return-an-http-error)
* [How to return JSON content](how-to-guides.md#how-to-return-json-content)
* [How to subclass Roll itself](how-to-guides.md#how-to-subclass-roll-itself)
diff --git a/docs/reference.md b/docs/reference.md
index ec823a0f..6288b469 100644
--- a/docs/reference.md
+++ b/docs/reference.md
@@ -26,8 +26,12 @@ The `status` can be either a `http.HTTPStatus` instance or an integer.
### Request
-A container for the result of the parsing on each request made by
-`httptools.HttpRequestParser`.
+A container for the result of the parsing on each request.
+The default parsing is made by `httptools.HttpRequestParser`.
+
+You can use the empty `kwargs` dict to attach whatever you want,
+especially useful for extensions.
+
#### Properties
@@ -38,6 +42,7 @@ A container for the result of the parsing on each request made by
- **method** (`str`): HTTP verb
- **body** (`bytes`): raw body as received by Roll
- **headers** (`dict`): HTTP headers
+- **route** (`Route`): a [Route instance](#Route) storing results from URL matching
- **kwargs** (`dict`): store here any extra data needed in the Request lifetime
@@ -105,6 +110,15 @@ parser. Default routes use [autoroutes](https://github.com/pyrates/autoroutes),
please refers to that documentation for available patterns.
+### Route
+
+A namedtuple to collect matched route data with attributes:
+
+* **payload** (`dict`): the data received by the `@app.route` decorator,
+ contains all handlers plus optionnal custom data.
+* **vars** (`dict`): URL placeholders resolved for the current route.
+
+
## Extensions
Please read
@@ -152,10 +166,24 @@ Combine it with the `cors` extension to handle the preflight request.
- **app**: Roll app to register the extension against
+### content_negociation
+
+Deal with content negociation declared during routes definition.
+Will return a `406 Not Acceptable` response in case of mismatch between
+the `Accept` header from the client and the `accepts` parameter set in
+routes. Useful to reject requests which are not expecting the available
+response.
+
#### Parameters
- **app**: Roll app to register the extension against
+
+#### Requirements
+
+- mimetype-match>=1.0.4
+
+
### traceback
Print the traceback on the server side if any. Handy for debugging.
@@ -164,6 +192,7 @@ Print the traceback on the server side if any. Handy for debugging.
- **app**: Roll app to register the extension against
+
### igniter
Display a BIG message when running the server.
@@ -173,6 +202,7 @@ Quite useless, hence so essential!
- **app**: Roll app to register the extension against
+
### static
Serve static files. Should not be used in production.
diff --git a/requirements-dev.txt b/requirements-dev.txt
index be628831..a882ad69 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,3 +1,4 @@
+mimetype-match==1.0.4
mkdocs==0.16.3
pytest==3.1.2
pytest-asyncio==0.6.0
diff --git a/roll/__init__.py b/roll/__init__.py
index 415141aa..28eeed11 100644
--- a/roll/__init__.py
+++ b/roll/__init__.py
@@ -9,6 +9,7 @@
a test failing): https://github.com/pyrates/roll/issues/new
"""
import asyncio
+from collections import namedtuple
from http import HTTPStatus
from typing import TypeVar
from urllib.parse import parse_qs, unquote
@@ -99,10 +100,10 @@ def float(self, key: str, default=...):
class Request:
"""A container for the result of the parsing on each request.
- The parsing is made by `httptools.HttpRequestParser`.
+ The default parsing is made by `httptools.HttpRequestParser`.
"""
- __slots__ = ('url', 'path', 'query_string', 'query', 'method', 'kwargs',
- 'body', 'headers')
+ __slots__ = ('url', 'path', 'query_string', 'query', 'method', 'body',
+ 'headers', 'route', 'kwargs')
def __init__(self):
self.kwargs = {}
@@ -199,7 +200,7 @@ def write(self, *args):
payload = b'HTTP/1.1 %a %b\r\n' % (
self.response.status.value, self.response.status.phrase.encode())
if not isinstance(self.response.body, bytes):
- self.response.body = self.response.body.encode()
+ self.response.body = str(self.response.body).encode()
if 'Content-Length' not in self.response.headers:
length = len(self.response.body)
self.response.headers['Content-Length'] = length
@@ -221,6 +222,9 @@ def match(self, url: str):
return payload, params
+Route = namedtuple('Route', ['payload', 'vars'])
+
+
class Roll:
"""Deal with routes dispatching and events listening.
@@ -241,9 +245,13 @@ async def shutdown(self):
async def __call__(self, request: Request, response: Response):
try:
+ request.route = Route(*self.routes.match(request.path))
if not await self.hook('request', request, response):
- params, handler = self.dispatch(request)
- await handler(request, response, **params)
+ # Uppercased in order to only consider HTTP verbs.
+ if request.method.upper() not in request.route.payload:
+ raise HttpError(HTTPStatus.METHOD_NOT_ALLOWED)
+ handler = request.route.payload[request.method]
+ await handler(request, response, **request.route.vars)
except Exception as error:
await self.on_error(request, response, error)
try:
@@ -268,23 +276,18 @@ async def on_error(self, request: Request, response: Response, error):
def factory(self):
return self.Protocol(self)
- def route(self, path: str, methods: list=None):
+ def route(self, path: str, methods: list=None, **extras: dict):
if methods is None:
methods = ['GET']
def wrapper(func):
- self.routes.add(path, **{m: func for m in methods})
+ payload = {method: func for method in methods}
+ payload.update(extras)
+ self.routes.add(path, **payload)
return func
return wrapper
- def dispatch(self, request: Request):
- handlers, params = self.routes.match(request.path)
- if request.method not in handlers:
- raise HttpError(HTTPStatus.METHOD_NOT_ALLOWED)
- request.kwargs.update(params)
- return params, handlers[request.method]
-
def listen(self, name: str):
def wrapper(func):
self.hooks.setdefault(name, [])
diff --git a/roll/extensions.py b/roll/extensions.py
index e65369ad..67c4bef9 100644
--- a/roll/extensions.py
+++ b/roll/extensions.py
@@ -1,6 +1,7 @@
import asyncio
import logging
import mimetypes
+import sys
from http import HTTPStatus
from pathlib import Path
from traceback import print_exc
@@ -48,6 +49,22 @@ async def handle_options(request, response):
return request.method == 'OPTIONS'
+def content_negociation(app):
+
+ try:
+ from mimetype_match import get_best_match
+ except ImportError:
+ sys.exit('Please install mimetype-match>=1.0.4 to be able to use the '
+ 'content_negociation extension.')
+
+ @app.listen('request')
+ async def reject_unacceptable_requests(request, response):
+ accept = request.headers.get('Accept')
+ accepts = request.route.payload['accepts']
+ if accept is None or get_best_match(accept, accepts) is None:
+ raise HttpError(HTTPStatus.NOT_ACCEPTABLE)
+
+
def traceback(app):
@app.listen('error')
@@ -108,8 +125,8 @@ async def serve(request, response, path):
content_type, encoding = mimetypes.guess_type(str(abspath))
with abspath.open('rb') as source:
response.body = source.read()
- response.headers['Content-Type'] = (content_type
- or 'application/octet-stream')
+ response.headers['Content-Type'] = (content_type or
+ 'application/octet-stream')
@app.listen('startup')
async def register_route():
diff --git a/roll/testing.py b/roll/testing.py
index 25a356ba..218e1b9d 100644
--- a/roll/testing.py
+++ b/roll/testing.py
@@ -4,6 +4,15 @@
import pytest
+class Transport:
+
+ def write(self, data):
+ ...
+
+ def close(self):
+ ...
+
+
class Client:
# Default content type for request body encoding, change it to your own
@@ -36,12 +45,15 @@ async def request(self, path, method='GET', body=b'', headers=None,
headers['Content-Type'] = content_type
body, headers = self.encode_body(body, headers)
protocol = self.app.factory()
+ protocol.connection_made(Transport())
protocol.on_message_begin()
protocol.on_url(path.encode())
protocol.request.body = body
protocol.request.method = method
protocol.request.headers = headers
- return await self.app(protocol.request, protocol.response)
+ await self.app(protocol.request, protocol.response)
+ protocol.write()
+ return protocol.response
async def get(self, path, **kwargs):
return await self.request(path, method='GET', **kwargs)
diff --git a/tests/test_errors.py b/tests/test_errors.py
index 8507be9c..f14dd279 100644
--- a/tests/test_errors.py
+++ b/tests/test_errors.py
@@ -14,7 +14,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
- assert resp.body == 'Oops.'
+ assert resp.body == b'Oops.'
async def test_httpstatus_error(client, app):
@@ -25,7 +25,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.BAD_REQUEST
- assert resp.body == 'Really bad.'
+ assert resp.body == b'Really bad.'
async def test_error_only_with_status(client, app):
@@ -36,7 +36,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
- assert resp.body == 'Internal Server Error'
+ assert resp.body == b'Internal Server Error'
async def test_error_only_with_httpstatus(client, app):
@@ -47,7 +47,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
- assert resp.body == 'Internal Server Error'
+ assert resp.body == b'Internal Server Error'
async def test_error_subclasses_with_super(client, app):
@@ -63,7 +63,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
- assert resp.body == 'Oops.
'
+ assert resp.body == b'Oops.
'
async def test_error_subclasses_without_super(client, app):
@@ -79,4 +79,4 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
- assert resp.body == 'Oops.
'
+ assert resp.body == b'Oops.
'
diff --git a/tests/test_extensions.py b/tests/test_extensions.py
index d1f3c2e6..204a9a3f 100644
--- a/tests/test_extensions.py
+++ b/tests/test_extensions.py
@@ -18,7 +18,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.OK
- assert resp.body == 'test response'
+ assert resp.body == b'test response'
assert resp.headers['Access-Control-Allow-Origin'] == '*'
@@ -85,7 +85,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.headers['Content-Type'] == 'application/json; charset=utf-8'
- assert json.loads(resp.body) == {'key': 'value'}
+ assert json.loads(resp.body.decode()) == {'key': 'value'}
assert resp.status == HTTPStatus.OK
@@ -98,7 +98,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.headers['Content-Type'] == 'application/json; charset=utf-8'
- assert json.loads(resp.body) == {'key': 'value'}
+ assert json.loads(resp.body.decode()) == {'key': 'value'}
assert resp.status == HTTPStatus.BAD_REQUEST
@@ -167,3 +167,106 @@ async def test_can_change_static_prefix(client, app):
resp = await client.get('/foo/index.html')
assert resp.status == HTTPStatus.OK
assert b'Test' in resp.body
+
+
+async def test_get_accept_content_negociation(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', accepts=['text/html'])
+ async def get(req, resp):
+ resp.headers['Content-Type'] = 'text/html'
+ resp.body = 'accepted'
+
+ resp = await client.get('/test', headers={'Accept': 'text/html'})
+ assert resp.status == HTTPStatus.OK
+ assert resp.body == b'accepted'
+ assert resp.headers['Content-Type'] == 'text/html'
+
+
+async def test_get_accept_content_negociation_if_many(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', accepts=['text/html', 'application/json'])
+ async def get(req, resp):
+ if req.headers['Accept'] == 'text/html':
+ resp.headers['Content-Type'] = 'text/html'
+ resp.body = 'accepted
'
+ elif req.headers['Accept'] == 'application/json':
+ resp.json = {'status': 'accepted'}
+
+ resp = await client.get('/test', headers={'Accept': 'text/html'})
+ assert resp.status == HTTPStatus.OK
+ assert resp.body == b'accepted
'
+ assert resp.headers['Content-Type'] == 'text/html'
+ resp = await client.get('/test', headers={'Accept': 'application/json'})
+ assert resp.status == HTTPStatus.OK
+ assert json.loads(resp.body.decode()) == {'status': 'accepted'}
+ assert resp.headers['Content-Type'] == 'application/json; charset=utf-8'
+
+
+async def test_get_reject_content_negociation(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', accepts=['text/html'])
+ async def get(req, resp):
+ resp.body = 'rejected'
+
+ resp = await client.get('/test', headers={'Accept': 'text/css'})
+ assert resp.status == HTTPStatus.NOT_ACCEPTABLE
+
+
+async def test_get_reject_content_negociation_if_no_accept_header(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', accepts=['*/*'])
+ async def get(req, resp):
+ resp.body = 'rejected'
+
+ resp = await client.get('/test')
+ assert resp.status == HTTPStatus.NOT_ACCEPTABLE
+
+
+async def test_get_accept_star_content_negociation(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', accepts=['text/css'])
+ async def get(req, resp):
+ resp.body = 'accepted'
+
+ resp = await client.get('/test', headers={'Accept': 'text/*'})
+ assert resp.status == HTTPStatus.OK
+
+
+async def test_post_accept_content_negociation(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', methods=['POST'], accepts=['application/json'])
+ async def get(req, resp):
+ resp.json = {'status': 'accepted'}
+
+ client.content_type = 'application/x-www-form-urlencoded'
+ resp = await client.post('/test', body={'key': 'value'},
+ headers={'Accept': 'application/json'})
+ assert resp.status == HTTPStatus.OK
+ assert resp.headers['Content-Type'] == 'application/json; charset=utf-8'
+ assert json.loads(resp.body.decode()) == {'status': 'accepted'}
+
+
+async def test_post_reject_content_negociation(client, app):
+
+ extensions.content_negociation(app)
+
+ @app.route('/test', methods=['POST'], accepts=['text/html'])
+ async def get(req, resp):
+ resp.json = {'status': 'accepted'}
+
+ client.content_type = 'application/x-www-form-urlencoded'
+ resp = await client.post('/test', body={'key': 'value'},
+ headers={'Accept': 'application/json'})
+ assert resp.status == HTTPStatus.NOT_ACCEPTABLE
diff --git a/tests/test_hooks.py b/tests/test_hooks.py
index db62f684..a602f203 100644
--- a/tests/test_hooks.py
+++ b/tests/test_hooks.py
@@ -13,7 +13,7 @@ async def test_request_hook_can_alter_response(client, app):
@app.listen('request')
async def listener(request, response):
response.status = 400
- response.body = 'another response'
+ response.body = b'another response'
return True # Shortcut the response process.
@app.route('/test')
@@ -22,7 +22,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.BAD_REQUEST
- assert resp.body == 'another response'
+ assert resp.body == b'another response'
async def test_response_hook_can_alter_response(client, app):
@@ -39,7 +39,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.BAD_REQUEST
- assert resp.body == 'another response'
+ assert resp.body == b'another response'
async def test_error_with_json_format(client, app):
@@ -55,7 +55,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
- error = json.loads(resp.body)
+ error = json.loads(resp.body.decode())
assert error == {"status": 500, "message": "JSON error"}
diff --git a/tests/test_request.py b/tests/test_request.py
index 3a2c1b70..0cb106d0 100644
--- a/tests/test_request.py
+++ b/tests/test_request.py
@@ -2,19 +2,11 @@
import pytest
from roll import Protocol, HttpError
+from roll.testing import Transport
pytestmark = pytest.mark.asyncio
-class Transport:
-
- def write(self, data):
- ...
-
- def close(self):
- ...
-
-
@pytest.fixture
def protocol(app):
protocol = Protocol(app)
diff --git a/tests/test_views.py b/tests/test_views.py
index 81396066..f489cecc 100644
--- a/tests/test_views.py
+++ b/tests/test_views.py
@@ -13,7 +13,7 @@ async def get(req, resp):
resp = await client.get('/test')
assert resp.status == HTTPStatus.OK
- assert resp.body == 'test response'
+ assert resp.body == b'test response'
async def test_simple_non_200_response(client, app):
@@ -53,9 +53,9 @@ async def test_post_json(client, app):
async def get(req, resp):
resp.body = req.body
- resp = await client.post('/test', body={"key": "value"})
+ resp = await client.post('/test', body={'key': 'value'})
assert resp.status == HTTPStatus.OK
- assert resp.body == '{"key": "value"}'
+ assert resp.body == b'{"key": "value"}'
async def test_post_urlencoded(client, app):
@@ -65,9 +65,9 @@ async def get(req, resp):
resp.body = req.body
client.content_type = 'application/x-www-form-urlencoded'
- resp = await client.post('/test', body={"key": "value"})
+ resp = await client.post('/test', body={'key': 'value'})
assert resp.status == HTTPStatus.OK
- assert resp.body == 'key=value'
+ assert resp.body == b'key=value'
async def test_can_define_twice_a_route_with_different_payloads(client, app):