Skip to content

Commit

Permalink
Merge 9d81490 into c1aad05
Browse files Browse the repository at this point in the history
  • Loading branch information
peldszus committed Mar 8, 2018
2 parents c1aad05 + 9d81490 commit 6a152e9
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 18 deletions.
10 changes: 9 additions & 1 deletion supercell/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from supercell.mediatypes import ContentType


def provides(content_type, vendor=None, version=None, default=False):
def provides(content_type, vendor=None, version=None, default=False,
partial=False):
"""Class decorator for mapping HTTP GET responses to content types and
their representation.
Expand All @@ -51,6 +52,9 @@ class MyHandler(s.RequestHandler):
:param float version: The vendor version
:param bool default: If **True** and no **Accept** header is present, this
content type is provided
:param bool partial: If **True**, the provider can return partial
representations, i.e. the underlying model validates
even though required fields are missing.
"""

def wrapper(cls):
Expand All @@ -60,14 +64,18 @@ def wrapper(cls):

if not hasattr(cls, '_PROD_CONTENT_TYPES'):
cls._PROD_CONTENT_TYPES = defaultdict(list)
if not hasattr(cls, '_PROD_CONFIGURATION'):
cls._PROD_CONFIGURATION = defaultdict(dict)

ctype = ContentType(content_type,
vendor,
version)
cls._PROD_CONTENT_TYPES[content_type].append(ctype)
cls._PROD_CONFIGURATION[content_type]['partial'] = partial
if default:
assert 'default' not in cls._PROD_CONTENT_TYPES, 'TODO: nice msg'
cls._PROD_CONTENT_TYPES['default'] = ctype
cls._PROD_CONFIGURATION['default']['partial'] = partial

return cls

Expand Down
28 changes: 21 additions & 7 deletions supercell/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ def map_provider(accept_header, handler, allow_default=False):
:param accept_header: HTTP Accept header value
:type accept_header: str
:param handler: supercell request handler
:param allow_default: allow usage of default provider if no accept header is set, default is False
:param allow_default: allow usage of default provider if no accept
header is set, default is False
:type allow_default: bool
:raises: :exc:`NoProviderFound`
:return: A tuple of the matching provider implementation class and
the provide()-kwargs
:rtype: (supercell.api.provider.ProviderBase, dict)
"""
if not hasattr(handler, '_PROD_CONTENT_TYPES'):
raise NoProviderFound()
Expand All @@ -110,27 +115,31 @@ def map_provider(accept_header, handler, allow_default=False):
ProviderMeta.KNOWN_CONTENT_TYPES[ctype]
if t[0] == c]

configuration = handler._PROD_CONFIGURATION[ctype]
if len(known_types) == 1:
return known_types[0][1]
return (known_types[0][1], configuration)

if allow_default and 'default' in handler._PROD_CONTENT_TYPES:
content_type = handler._PROD_CONTENT_TYPES['default']
configuration = handler._PROD_CONFIGURATION['default']
ctype = content_type.content_type
default_type = [t for t in
ProviderMeta.KNOWN_CONTENT_TYPES[ctype]
if t[0] == content_type]

if len(default_type) == 1:
return default_type[0][1]
return (default_type[0][1], configuration)

raise NoProviderFound()

def provide(self, model, handler):
def provide(self, model, handler, **kwargs):
"""This method should return the correct representation as a simple
string (i.e. byte buffer) that will be used as return value.
:param model: the model to convert to a certain content type
:type model: supercell.schematics.Model
:param handler: the handler to write the return
:type handler: supercell.requesthandler.RequestHandler
"""
raise NotImplementedError

Expand All @@ -153,13 +162,18 @@ class JsonProvider(ProviderBase):

CONTENT_TYPE = ContentType(MediaType.ApplicationJson)

def provide(self, model, handler):
def provide(self, model, handler, **kwargs):
"""Simply return the json via `json.dumps`.
Keyword arguments:
:param partial: if **True** the model will be validate as a partial.
:type partial: bool
.. seealso:: :py:mod:`supercell.api.provider.ProviderBase.provide`
"""
try:
model.validate()
partial = kwargs.get("partial", False)
model.validate(partial=partial)
handler.write(model.to_primitive())
except ModelValidationError as e:
e.messages = {"result_model": e.messages}
Expand All @@ -186,7 +200,7 @@ class TornadoTemplateProvider(ProviderBase):

CONTENT_TYPE = ContentType(MediaType.TextHtml)

def provide(self, model, handler):
def provide(self, model, handler, **kwargs):
"""Render a template with the given model into HTML.
By default we will use the tornado built in template language."""
Expand Down
8 changes: 4 additions & 4 deletions supercell/requesthandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def environment(self):

@property
def config(self):
"""Convinience method for accessing the environment."""
"""Convinience method for accessing the configuration."""
return self.application.config

@property
Expand Down Expand Up @@ -259,21 +259,21 @@ def _provide_result(self, verb, headers, result):

else:
try:
provider_class = ProviderBase.map_provider(
provider_class, provider_config = ProviderBase.map_provider(
headers.get('Accept', ''), self, allow_default=True)
except NoProviderFound:
raise HTTPError(406)

provider = provider_class()
if isinstance(result, Model):
provider.provide(result, self)
provider.provide(result, self, **provider_config)

if not self._finished:
self.finish()

def write_error(self, status_code, **kwargs):
try:
provider_class = ProviderBase.map_provider(
provider_class, _ = ProviderBase.map_provider(
self.request.headers.get('Accept', ''), self,
allow_default=True)
provider_class().error(status_code, self._reason, self)
Expand Down
16 changes: 16 additions & 0 deletions test/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,19 @@ def update_stuff(self):
MediaType.ApplicationJson)
self.assertIsNone(content_type.vendor)
self.assertIsNone(content_type.version)

def test_provides_decorator_with_partial(self):

@provides(MediaType.ApplicationJson, partial=True)
class MyHandler(RequestHandler):

def update_stuff(self):
pass

self.assertTrue(hasattr(MyHandler, '_PROD_CONFIGURATION'))
self.assertEqual(len(MyHandler._PROD_CONFIGURATION), 1)
self.assertTrue(MediaType.ApplicationJson in
MyHandler._PROD_CONFIGURATION)
configuration = MyHandler._PROD_CONFIGURATION[
MediaType.ApplicationJson]
self.assertIs(configuration["partial"], True)
23 changes: 17 additions & 6 deletions test/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def test_default_json_provider(self):
class MyHandler(RequestHandler):
pass

provider = ProviderBase.map_provider(MediaType.ApplicationJson,
handler=MyHandler)
provider, _ = ProviderBase.map_provider(MediaType.ApplicationJson,
handler=MyHandler)
self.assertIs(provider, JsonProvider)

with self.assertRaises(NoProviderFound):
Expand All @@ -63,8 +63,8 @@ def test_specific_json_provider(self):
class MyHandler(RequestHandler):
pass

provider = ProviderBase.map_provider('application/vnd.supercell+json',
handler=MyHandler)
provider, _ = ProviderBase.map_provider('application/vnd.supercell+json',
handler=MyHandler)
self.assertIs(provider, MoreDetailedJsonProvider)

def test_json_provider_with_version(self):
Expand All @@ -77,11 +77,22 @@ def __init__(self, *args, **kwargs):
# of this class
pass

provider = ProviderBase.map_provider(
provider, _ = ProviderBase.map_provider(
'application/vnd.supercell-v1.0+json', handler=MyHandler)
self.assertIs(provider, JsonProviderWithVendorAndVersion)

handler = MyHandler()
provider = ProviderBase.map_provider(
provider, _ = ProviderBase.map_provider(
'application/vnd.supercell-v1.0+json', handler=handler)
self.assertIs(provider, JsonProviderWithVendorAndVersion)

def test_json_provider_with_configuration(self):

@provides(MediaType.ApplicationJson, partial=True)
class MyHandler(RequestHandler):
pass

provider, configuration = ProviderBase.map_provider(
MediaType.ApplicationJson, handler=MyHandler)
self.assertIs(provider, JsonProvider)
self.assertIs(configuration["partial"], True)
87 changes: 87 additions & 0 deletions test/test_requesthandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
from __future__ import (absolute_import, division, print_function,
with_statement)

import pytest

import json
import os.path as op

import schematics
from schematics.models import Model
from schematics.types import StringType
from schematics.types import IntType
from schematics.types.compound import ModelType
from schematics.types.compound import ListType

from tornado.ioloop import IOLoop
from tornado.testing import AsyncHTTPTestCase
Expand Down Expand Up @@ -266,3 +271,85 @@ def get_new_ioloop(self):
def test_simple_html(self):
response = self.fetch('/test_html/')
self.assertEqual(500, response.code)


class StricterMessage(Model):
doc_id = StringType(required=True)
message = StringType(required=True)
number = IntType()

class Options:
serialize_when_none = False


class StricterMessageCollection(Model):
messages = ListType(ModelType(StricterMessage))


class TestHandlerProvidingPartialModels(AsyncHTTPTestCase):

def get_app(self):

@provides(s.MediaType.ApplicationJson)
class MyHandlerWithoutPartial(RequestHandler):

@s.async
def get(self, *args, **kwargs):
raise s.Return(StricterMessage({"doc_id": 'test123'}))

@provides(s.MediaType.ApplicationJson, partial=True)
class MyHandlerWithPartial(RequestHandler):

@s.async
def get(self, *args, **kwargs):
raise s.Return(StricterMessage({"doc_id": 'test123'}))

@provides(s.MediaType.ApplicationJson, partial=True)
class MyHandlerWithPartialComplex(RequestHandler):

@s.async
def get(self, *args, **kwargs):
raise s.Return(StricterMessageCollection(
{"messages": [{"doc_id": 'test123'}]}))

env = Environment()
env.add_handler('/test_no_partial', MyHandlerWithoutPartial)
env.add_handler('/test_partial', MyHandlerWithPartial)
env.add_handler('/test_partial_complex', MyHandlerWithPartialComplex)
return env.get_application()

def get_new_ioloop(self):
return IOLoop.instance()

def test_provide_partial_model_with_partial_false(self):
response = self.fetch(
'/test_no_partial',
headers={'Accept': s.MediaType.ApplicationJson})
self.assertEqual(response.code, 500)
self.assertEqual(
'{"error": true, "message": {"result_model": ' +
'{"message": ["This field is required."]}}}',
json.dumps(json.loads(response.body.decode('utf8')),
sort_keys=True))

def test_provide_partial_model_with_partial_true(self):
response = self.fetch(
'/test_partial',
headers={'Accept': s.MediaType.ApplicationJson})
self.assertEqual(response.code, 200)
self.assertEqual('{"doc_id": "test123"}',
json.dumps(json.loads(response.body.decode('utf8')),
sort_keys=True))

@pytest.mark.skipif(
schematics.__version__ < "2.0.1",
reason="Partial validation of complex models is broken in schematics" +
" version < 2.0.1.")
def test_provide_partial_model_with_complex_partial_true(self):
response = self.fetch(
'/test_partial_complex',
headers={'Accept': s.MediaType.ApplicationJson})
self.assertEqual(response.code, 200)
self.assertEqual('{"messages": [{"doc_id": "test123"}]}',
json.dumps(json.loads(response.body.decode('utf8')),
sort_keys=True))

0 comments on commit 6a152e9

Please sign in to comment.