diff --git a/clearest/core.py b/clearest/core.py index e2929fc..eb8bb0b 100644 --- a/clearest/core.py +++ b/clearest/core.py @@ -21,12 +21,14 @@ HttpNotFound, NotRootError, HttpUnsupportedMediaType, HttpBadRequest, HttpNotImplemented from clearest.http import HTTP_GET, HTTP_POST, CONTENT_TYPE, MIME_TEXT_PLAIN, HTTP_OK, MIME_WWW_FORM_URLENCODED, \ MIME_FORM_DATA, CONTENT_DISPOSITION, MIME_JSON, MIME_XML -from clearest.wsgi import REQUEST_METHOD, PATH_INFO, QUERY_STRING, WSGI_INPUT, WSGI_CONTENT_TYPE, WSGI_CONTENT_LENGTH +from clearest.wsgi import REQUEST_METHOD, PATH_INFO, QUERY_STRING, WSGI_INPUT, WSGI_CONTENT_TYPE, WSGI_CONTENT_LENGTH, \ + HTTP_ACCEPT KEY_PATTERN = re.compile("\{(.*)\}") STATUS_FMT = "{0} {1}" CALLABLE = 0 DEFAULT = 1 +ACCEPT_MIMES = "accept_mimes" _content_types = {} @@ -105,14 +107,16 @@ def is_matching(signature, args, path, query): return True -def parse_args(args, path, query): +def parse_args(args, path, query, specials): def one_or_many(fn_, dict_, key): result = [fn_(value) for value in dict_[key]] return result[0] if len(result) == 1 else result kwargs = {} for arg, parse_fn in six.iteritems(args): - if parse_fn is None: + if arg in specials: + kwargs[arg] = specials[arg]() + elif parse_fn is None: kwargs[arg] = one_or_many(lambda x: x, query, arg) elif isinstance(parse_fn, tuple): kwargs[arg] = parse_fn[DEFAULT] if arg not in query else one_or_many(parse_fn[CALLABLE], query, arg) @@ -124,7 +128,7 @@ def one_or_many(fn_, dict_, key): fn = closures[0].cell_contents else: fn = eval(".".join(_code.co_names), six.get_function_globals(parse_fn)) - kwargs[arg] = fn(**parse_args(get_function_args(parse_fn), path, query)) + kwargs[arg] = fn(**parse_args(get_function_args(parse_fn), path, query, specials)) else: kwargs[arg] = one_or_many(parse_fn, query, arg) return kwargs @@ -161,14 +165,26 @@ def parse_json(input_file, n, extras): encoding = "utf-8" if "encoding" not in extras else extras["encoding"] return {k: [v] for k, v in six.iteritems(json.loads(input_file.read(n), encoding=encoding))} + def parse_accept(): + result = [] + parts = (x.split(";") for x in environ[HTTP_ACCEPT].split(",")) + for part in parts: + if len(part) == 1: + result.append((1.0, part[0])) + else: + mime, q = part + result.append((float(q[2:]), mime)) + return tuple(value for (weight, value) in sorted(result, key=lambda x: x[0], reverse=True)) + content_types = {MIME_WWW_FORM_URLENCODED: parse_www_form, MIME_FORM_DATA: parse_form_data, MIME_JSON: parse_json} + specials = {ACCEPT_MIMES: parse_accept} try: if environ[REQUEST_METHOD] in all_registered(): path = tuple(environ[PATH_INFO][1:].split("/")) query = parse_qs(environ[QUERY_STRING]) if QUERY_STRING in environ else {} - if WSGI_CONTENT_TYPE in environ: + if WSGI_CONTENT_TYPE in environ and environ[WSGI_CONTENT_TYPE] != MIME_TEXT_PLAIN: content_type, extras_ = parse_content_type(environ[WSGI_CONTENT_TYPE]) if content_type not in content_types: raise HttpUnsupportedMediaType() @@ -182,7 +198,7 @@ def parse_json(input_file, n, extras): updated_query.update({key.name: [value] for key, value in zip(signature, path) if isinstance(key, Key)}) - parsed_args = parse_args(args, path, updated_query) + parsed_args = parse_args(args, path, updated_query, specials) except Exception as e: logging.exception(e) raise HttpBadRequest() diff --git a/clearest/wsgi.py b/clearest/wsgi.py index e827aa2..2448556 100644 --- a/clearest/wsgi.py +++ b/clearest/wsgi.py @@ -8,3 +8,4 @@ SERVER_PORT = "SERVER_PORT" SERVER_PROTOCOL = "SERVER_PROTOCOL" WSGI_INPUT = "wsgi.input" +HTTP_ACCEPT = "HTTP_ACCEPT" diff --git a/tests/test_application.py b/tests/test_application.py index de7c772..3525093 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -399,3 +399,33 @@ def asd(user_id=lambda user, password: g_login): self.get("/asd?user=guest&password=secret") self.assertEqual(HTTP_OK, self.status) self.assertCalledWith(asd, 1) + + def test_application_http_accept_1(self): + @GET("/asd") + @called_with + def asd(accept_mimes): + return {} + + self.get("/asd", accept="text/html") + self.assertEqual(HTTP_OK, self.status) + self.assertCalledWith(asd, ("text/html",)) + + def test_application_http_accept_2(self): + @GET("/asd") + @called_with + def asd(accept_mimes): + return {} + + self.get("/asd", accept="text/html,application/xml;q=0.9") + self.assertEqual(HTTP_OK, self.status) + self.assertCalledWith(asd, ("text/html", "application/xml")) + + def test_application_http_accept_3(self): + @GET("/asd") + @called_with + def asd(accept_mimes): + return {} + + self.get("/asd", accept="text/html;q=0.5,application/xml") + self.assertEqual(HTTP_OK, self.status) + self.assertCalledWith(asd, ("application/xml", "text/html")) diff --git a/tests/wsgi.py b/tests/wsgi.py index bf4d4dd..03291c4 100644 --- a/tests/wsgi.py +++ b/tests/wsgi.py @@ -18,7 +18,7 @@ def _start_response(self, status, headers): self.status = HttpStatus(code=int(code), msg=msg) self.headers = dict(headers) - def request(self, app, method, query, input_, content_type=None, content_len=0): + def request(self, app, method, query, input_, content_type=None, content_len=0, accept=None): assert method in HTTP_METHODS env = {REQUEST_METHOD: method, SERVER_PROTOCOL: HTTP_1_1, @@ -35,14 +35,16 @@ def request(self, app, method, query, input_, content_type=None, content_len=0): env[WSGI_CONTENT_TYPE] = content_type if content_len > 0: env[WSGI_CONTENT_LENGTH] = content_len + if accept: + env[HTTP_ACCEPT] = accept result = None for data in app(env, self._start_response): # TODO: PY2 vs PY3 string/bytes result = data break return result - def get(self, query, app=application): - return self.request(app, HTTP_GET, query, None) + def get(self, query, app=application, **kwargs): + return self.request(app, HTTP_GET, query, None, **kwargs) def post(self, query, app=application, input_=None, content_type=None, content_len=0): return self.request(app, HTTP_POST, query, input_, content_type, content_len)