diff --git a/docs/changelog.md b/docs/changelog.md index f8751521..b16781d0 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -10,13 +10,22 @@ A changelog: ## dev version -* Changed `Roll.hook` signature to also accept kwargs ([#5](https://github.com/pyrates/roll/pull/5)) -* `json` shorcut sets `utf-8` charset in `Content-Type` header ([#13](https://github.com/pyrates/roll/pull/13)) -* Added `static` extension to serve static files for development ([#16](https://github.com/pyrates/roll/pull/16)) -* `cors` accepts `headers` parameter to control `Access-Control-Allow-Headers` ([#12](https://github.com/pyrates/roll/pull/12)) +* Changed `Roll.hook` signature to also accept kwargs + ([#5](https://github.com/pyrates/roll/pull/5)) +* `json` shorcut sets `utf-8` charset in `Content-Type` header + ([#13](https://github.com/pyrates/roll/pull/13)) +* Added `static` extension to serve static files for development + ([#16](https://github.com/pyrates/roll/pull/16)) +* `cors` accepts `headers` parameter to control `Access-Control-Allow-Headers` + ([#12](https://github.com/pyrates/roll/pull/12)) +* Added `content_negociation` extension to reject unacceptable client requests + based on the `Accept` header + ([#21](https://github.com/pyrates/roll/pull/21)) * **Breaking changes**: - * `options` extension is no more applied by default ([#16](https://github.com/pyrates/roll/pull/16)) - * deprecated `req` pytest fixture is now removed ([#9](https://github.com/pyrates/roll/pull/9)) + * `options` extension is no more applied by default + ([#16](https://github.com/pyrates/roll/pull/16)) + * deprecated `req` pytest fixture is now removed + ([#9](https://github.com/pyrates/roll/pull/9)) ## 0.5.0 — 2017-09-21 diff --git a/docs/how-to-guides.md b/docs/how-to-guides.md index 35994c3e..9fd5613b 100644 --- a/docs/how-to-guides.md +++ b/docs/how-to-guides.md @@ -46,6 +46,27 @@ always a bonus. The `response` object is modified in place. Make sure to check these out!* +## How to deal with content negociation + +The [`content_negociation` extension](reference.md#content_negociation) +is made for this purpose, you can use it that way: + +```python +extensions.content_negociation(app) + +@app.route('/test', accepts=['text/html', 'application/json']) +async def get(req, resp): + if req.headers['Accept'] == 'text/html': + resp.headers['Content-Type'] = 'text/html' + resp.body = '

accepted

' + elif req.headers['Accept'] == 'application/json': + resp.json = {'status': 'accepted'} +``` + +Requests with `Accept` header not matching `text/html` or +`application/json` will be honored with a `406 Not Acceptable` response. + + ## How to return an HTTP error There are many reasons to return an HTTP error, with Roll you have to diff --git a/docs/index.md b/docs/index.md index ddc1619a..a3f64c7e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -56,6 +56,7 @@ A how-to guide: * [How to install Roll](how-to-guides.md#how-to-install-roll) * [How to create an extension](how-to-guides.md#how-to-create-an-extension) +* [How to deal with content negociation](how-to-guides.md#how-to-deal-with-content-negociation) * [How to return an HTTP error](how-to-guides.md#how-to-return-an-http-error) * [How to return JSON content](how-to-guides.md#how-to-return-json-content) * [How to subclass Roll itself](how-to-guides.md#how-to-subclass-roll-itself) diff --git a/docs/reference.md b/docs/reference.md index ec823a0f..6288b469 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -26,8 +26,12 @@ The `status` can be either a `http.HTTPStatus` instance or an integer. ### Request -A container for the result of the parsing on each request made by -`httptools.HttpRequestParser`. +A container for the result of the parsing on each request. +The default parsing is made by `httptools.HttpRequestParser`. + +You can use the empty `kwargs` dict to attach whatever you want, +especially useful for extensions. + #### Properties @@ -38,6 +42,7 @@ A container for the result of the parsing on each request made by - **method** (`str`): HTTP verb - **body** (`bytes`): raw body as received by Roll - **headers** (`dict`): HTTP headers +- **route** (`Route`): a [Route instance](#Route) storing results from URL matching - **kwargs** (`dict`): store here any extra data needed in the Request lifetime @@ -105,6 +110,15 @@ parser. Default routes use [autoroutes](https://github.com/pyrates/autoroutes), please refers to that documentation for available patterns. +### Route + +A namedtuple to collect matched route data with attributes: + +* **payload** (`dict`): the data received by the `@app.route` decorator, + contains all handlers plus optionnal custom data. +* **vars** (`dict`): URL placeholders resolved for the current route. + + ## Extensions Please read @@ -152,10 +166,24 @@ Combine it with the `cors` extension to handle the preflight request. - **app**: Roll app to register the extension against +### content_negociation + +Deal with content negociation declared during routes definition. +Will return a `406 Not Acceptable` response in case of mismatch between +the `Accept` header from the client and the `accepts` parameter set in +routes. Useful to reject requests which are not expecting the available +response. + #### Parameters - **app**: Roll app to register the extension against + +#### Requirements + +- mimetype-match>=1.0.4 + + ### traceback Print the traceback on the server side if any. Handy for debugging. @@ -164,6 +192,7 @@ Print the traceback on the server side if any. Handy for debugging. - **app**: Roll app to register the extension against + ### igniter Display a BIG message when running the server. @@ -173,6 +202,7 @@ Quite useless, hence so essential! - **app**: Roll app to register the extension against + ### static Serve static files. Should not be used in production. diff --git a/requirements-dev.txt b/requirements-dev.txt index be628831..a882ad69 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +mimetype-match==1.0.4 mkdocs==0.16.3 pytest==3.1.2 pytest-asyncio==0.6.0 diff --git a/roll/__init__.py b/roll/__init__.py index 415141aa..28eeed11 100644 --- a/roll/__init__.py +++ b/roll/__init__.py @@ -9,6 +9,7 @@ a test failing): https://github.com/pyrates/roll/issues/new """ import asyncio +from collections import namedtuple from http import HTTPStatus from typing import TypeVar from urllib.parse import parse_qs, unquote @@ -99,10 +100,10 @@ def float(self, key: str, default=...): class Request: """A container for the result of the parsing on each request. - The parsing is made by `httptools.HttpRequestParser`. + The default parsing is made by `httptools.HttpRequestParser`. """ - __slots__ = ('url', 'path', 'query_string', 'query', 'method', 'kwargs', - 'body', 'headers') + __slots__ = ('url', 'path', 'query_string', 'query', 'method', 'body', + 'headers', 'route', 'kwargs') def __init__(self): self.kwargs = {} @@ -199,7 +200,7 @@ def write(self, *args): payload = b'HTTP/1.1 %a %b\r\n' % ( self.response.status.value, self.response.status.phrase.encode()) if not isinstance(self.response.body, bytes): - self.response.body = self.response.body.encode() + self.response.body = str(self.response.body).encode() if 'Content-Length' not in self.response.headers: length = len(self.response.body) self.response.headers['Content-Length'] = length @@ -221,6 +222,9 @@ def match(self, url: str): return payload, params +Route = namedtuple('Route', ['payload', 'vars']) + + class Roll: """Deal with routes dispatching and events listening. @@ -241,9 +245,13 @@ async def shutdown(self): async def __call__(self, request: Request, response: Response): try: + request.route = Route(*self.routes.match(request.path)) if not await self.hook('request', request, response): - params, handler = self.dispatch(request) - await handler(request, response, **params) + # Uppercased in order to only consider HTTP verbs. + if request.method.upper() not in request.route.payload: + raise HttpError(HTTPStatus.METHOD_NOT_ALLOWED) + handler = request.route.payload[request.method] + await handler(request, response, **request.route.vars) except Exception as error: await self.on_error(request, response, error) try: @@ -268,23 +276,18 @@ async def on_error(self, request: Request, response: Response, error): def factory(self): return self.Protocol(self) - def route(self, path: str, methods: list=None): + def route(self, path: str, methods: list=None, **extras: dict): if methods is None: methods = ['GET'] def wrapper(func): - self.routes.add(path, **{m: func for m in methods}) + payload = {method: func for method in methods} + payload.update(extras) + self.routes.add(path, **payload) return func return wrapper - def dispatch(self, request: Request): - handlers, params = self.routes.match(request.path) - if request.method not in handlers: - raise HttpError(HTTPStatus.METHOD_NOT_ALLOWED) - request.kwargs.update(params) - return params, handlers[request.method] - def listen(self, name: str): def wrapper(func): self.hooks.setdefault(name, []) diff --git a/roll/extensions.py b/roll/extensions.py index e65369ad..67c4bef9 100644 --- a/roll/extensions.py +++ b/roll/extensions.py @@ -1,6 +1,7 @@ import asyncio import logging import mimetypes +import sys from http import HTTPStatus from pathlib import Path from traceback import print_exc @@ -48,6 +49,22 @@ async def handle_options(request, response): return request.method == 'OPTIONS' +def content_negociation(app): + + try: + from mimetype_match import get_best_match + except ImportError: + sys.exit('Please install mimetype-match>=1.0.4 to be able to use the ' + 'content_negociation extension.') + + @app.listen('request') + async def reject_unacceptable_requests(request, response): + accept = request.headers.get('Accept') + accepts = request.route.payload['accepts'] + if accept is None or get_best_match(accept, accepts) is None: + raise HttpError(HTTPStatus.NOT_ACCEPTABLE) + + def traceback(app): @app.listen('error') @@ -108,8 +125,8 @@ async def serve(request, response, path): content_type, encoding = mimetypes.guess_type(str(abspath)) with abspath.open('rb') as source: response.body = source.read() - response.headers['Content-Type'] = (content_type - or 'application/octet-stream') + response.headers['Content-Type'] = (content_type or + 'application/octet-stream') @app.listen('startup') async def register_route(): diff --git a/roll/testing.py b/roll/testing.py index 25a356ba..218e1b9d 100644 --- a/roll/testing.py +++ b/roll/testing.py @@ -4,6 +4,15 @@ import pytest +class Transport: + + def write(self, data): + ... + + def close(self): + ... + + class Client: # Default content type for request body encoding, change it to your own @@ -36,12 +45,15 @@ async def request(self, path, method='GET', body=b'', headers=None, headers['Content-Type'] = content_type body, headers = self.encode_body(body, headers) protocol = self.app.factory() + protocol.connection_made(Transport()) protocol.on_message_begin() protocol.on_url(path.encode()) protocol.request.body = body protocol.request.method = method protocol.request.headers = headers - return await self.app(protocol.request, protocol.response) + await self.app(protocol.request, protocol.response) + protocol.write() + return protocol.response async def get(self, path, **kwargs): return await self.request(path, method='GET', **kwargs) diff --git a/tests/test_errors.py b/tests/test_errors.py index 8507be9c..f14dd279 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -14,7 +14,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR - assert resp.body == 'Oops.' + assert resp.body == b'Oops.' async def test_httpstatus_error(client, app): @@ -25,7 +25,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.BAD_REQUEST - assert resp.body == 'Really bad.' + assert resp.body == b'Really bad.' async def test_error_only_with_status(client, app): @@ -36,7 +36,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR - assert resp.body == 'Internal Server Error' + assert resp.body == b'Internal Server Error' async def test_error_only_with_httpstatus(client, app): @@ -47,7 +47,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR - assert resp.body == 'Internal Server Error' + assert resp.body == b'Internal Server Error' async def test_error_subclasses_with_super(client, app): @@ -63,7 +63,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR - assert resp.body == '

Oops.

' + assert resp.body == b'

Oops.

' async def test_error_subclasses_without_super(client, app): @@ -79,4 +79,4 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR - assert resp.body == '

Oops.

' + assert resp.body == b'

Oops.

' diff --git a/tests/test_extensions.py b/tests/test_extensions.py index d1f3c2e6..204a9a3f 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -18,7 +18,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.OK - assert resp.body == 'test response' + assert resp.body == b'test response' assert resp.headers['Access-Control-Allow-Origin'] == '*' @@ -85,7 +85,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.headers['Content-Type'] == 'application/json; charset=utf-8' - assert json.loads(resp.body) == {'key': 'value'} + assert json.loads(resp.body.decode()) == {'key': 'value'} assert resp.status == HTTPStatus.OK @@ -98,7 +98,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.headers['Content-Type'] == 'application/json; charset=utf-8' - assert json.loads(resp.body) == {'key': 'value'} + assert json.loads(resp.body.decode()) == {'key': 'value'} assert resp.status == HTTPStatus.BAD_REQUEST @@ -167,3 +167,106 @@ async def test_can_change_static_prefix(client, app): resp = await client.get('/foo/index.html') assert resp.status == HTTPStatus.OK assert b'Test' in resp.body + + +async def test_get_accept_content_negociation(client, app): + + extensions.content_negociation(app) + + @app.route('/test', accepts=['text/html']) + async def get(req, resp): + resp.headers['Content-Type'] = 'text/html' + resp.body = 'accepted' + + resp = await client.get('/test', headers={'Accept': 'text/html'}) + assert resp.status == HTTPStatus.OK + assert resp.body == b'accepted' + assert resp.headers['Content-Type'] == 'text/html' + + +async def test_get_accept_content_negociation_if_many(client, app): + + extensions.content_negociation(app) + + @app.route('/test', accepts=['text/html', 'application/json']) + async def get(req, resp): + if req.headers['Accept'] == 'text/html': + resp.headers['Content-Type'] = 'text/html' + resp.body = '

accepted

' + elif req.headers['Accept'] == 'application/json': + resp.json = {'status': 'accepted'} + + resp = await client.get('/test', headers={'Accept': 'text/html'}) + assert resp.status == HTTPStatus.OK + assert resp.body == b'

accepted

' + assert resp.headers['Content-Type'] == 'text/html' + resp = await client.get('/test', headers={'Accept': 'application/json'}) + assert resp.status == HTTPStatus.OK + assert json.loads(resp.body.decode()) == {'status': 'accepted'} + assert resp.headers['Content-Type'] == 'application/json; charset=utf-8' + + +async def test_get_reject_content_negociation(client, app): + + extensions.content_negociation(app) + + @app.route('/test', accepts=['text/html']) + async def get(req, resp): + resp.body = 'rejected' + + resp = await client.get('/test', headers={'Accept': 'text/css'}) + assert resp.status == HTTPStatus.NOT_ACCEPTABLE + + +async def test_get_reject_content_negociation_if_no_accept_header(client, app): + + extensions.content_negociation(app) + + @app.route('/test', accepts=['*/*']) + async def get(req, resp): + resp.body = 'rejected' + + resp = await client.get('/test') + assert resp.status == HTTPStatus.NOT_ACCEPTABLE + + +async def test_get_accept_star_content_negociation(client, app): + + extensions.content_negociation(app) + + @app.route('/test', accepts=['text/css']) + async def get(req, resp): + resp.body = 'accepted' + + resp = await client.get('/test', headers={'Accept': 'text/*'}) + assert resp.status == HTTPStatus.OK + + +async def test_post_accept_content_negociation(client, app): + + extensions.content_negociation(app) + + @app.route('/test', methods=['POST'], accepts=['application/json']) + async def get(req, resp): + resp.json = {'status': 'accepted'} + + client.content_type = 'application/x-www-form-urlencoded' + resp = await client.post('/test', body={'key': 'value'}, + headers={'Accept': 'application/json'}) + assert resp.status == HTTPStatus.OK + assert resp.headers['Content-Type'] == 'application/json; charset=utf-8' + assert json.loads(resp.body.decode()) == {'status': 'accepted'} + + +async def test_post_reject_content_negociation(client, app): + + extensions.content_negociation(app) + + @app.route('/test', methods=['POST'], accepts=['text/html']) + async def get(req, resp): + resp.json = {'status': 'accepted'} + + client.content_type = 'application/x-www-form-urlencoded' + resp = await client.post('/test', body={'key': 'value'}, + headers={'Accept': 'application/json'}) + assert resp.status == HTTPStatus.NOT_ACCEPTABLE diff --git a/tests/test_hooks.py b/tests/test_hooks.py index db62f684..a602f203 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -13,7 +13,7 @@ async def test_request_hook_can_alter_response(client, app): @app.listen('request') async def listener(request, response): response.status = 400 - response.body = 'another response' + response.body = b'another response' return True # Shortcut the response process. @app.route('/test') @@ -22,7 +22,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.BAD_REQUEST - assert resp.body == 'another response' + assert resp.body == b'another response' async def test_response_hook_can_alter_response(client, app): @@ -39,7 +39,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.BAD_REQUEST - assert resp.body == 'another response' + assert resp.body == b'another response' async def test_error_with_json_format(client, app): @@ -55,7 +55,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR - error = json.loads(resp.body) + error = json.loads(resp.body.decode()) assert error == {"status": 500, "message": "JSON error"} diff --git a/tests/test_request.py b/tests/test_request.py index 3a2c1b70..0cb106d0 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -2,19 +2,11 @@ import pytest from roll import Protocol, HttpError +from roll.testing import Transport pytestmark = pytest.mark.asyncio -class Transport: - - def write(self, data): - ... - - def close(self): - ... - - @pytest.fixture def protocol(app): protocol = Protocol(app) diff --git a/tests/test_views.py b/tests/test_views.py index 81396066..f489cecc 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -13,7 +13,7 @@ async def get(req, resp): resp = await client.get('/test') assert resp.status == HTTPStatus.OK - assert resp.body == 'test response' + assert resp.body == b'test response' async def test_simple_non_200_response(client, app): @@ -53,9 +53,9 @@ async def test_post_json(client, app): async def get(req, resp): resp.body = req.body - resp = await client.post('/test', body={"key": "value"}) + resp = await client.post('/test', body={'key': 'value'}) assert resp.status == HTTPStatus.OK - assert resp.body == '{"key": "value"}' + assert resp.body == b'{"key": "value"}' async def test_post_urlencoded(client, app): @@ -65,9 +65,9 @@ async def get(req, resp): resp.body = req.body client.content_type = 'application/x-www-form-urlencoded' - resp = await client.post('/test', body={"key": "value"}) + resp = await client.post('/test', body={'key': 'value'}) assert resp.status == HTTPStatus.OK - assert resp.body == 'key=value' + assert resp.body == b'key=value' async def test_can_define_twice_a_route_with_different_payloads(client, app):