Skip to content

Commit

Permalink
Add header accessors (#2696)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Feb 27, 2023
1 parent dfc0704 commit cb49c2b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sanic/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class Header(CIMultiDict):
very similar to a regular dictionary.
"""

def __getattr__(self, key: str) -> str:
if key.startswith("_"):
return self.__getattribute__(key)
key = key.rstrip("_").replace("_", "-")
return ",".join(self.getall(key, default=[]))

def get_all(self, key: str):
"""
Convenience method mapped to ``getall()``.
Expand Down
49 changes: 48 additions & 1 deletion tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import pytest

from sanic import headers, text
from sanic import Sanic, headers, json, text
from sanic.exceptions import InvalidHeader, PayloadTooLarge
from sanic.http import Http
from sanic.request import Request


def make_request(headers) -> Request:
return Request(b"/", headers, "1.1", "GET", None, None)


@pytest.fixture
def raised_ceiling():
Http.HEADER_CEILING = 32_768
Expand Down Expand Up @@ -435,3 +439,46 @@ def test_accept_misc():
a = headers.parse_accept("")
assert a == []
assert not a.match("foo/bar")


@pytest.mark.parametrize(
"headers,expected",
(
({"foo": "bar"}, "bar"),
((("foo", "bar"), ("foo", "baz")), "bar,baz"),
({}, ""),
),
)
def test_field_simple_accessor(headers, expected):
request = make_request(headers)
assert request.headers.foo == request.headers.foo_ == expected


@pytest.mark.parametrize(
"headers,expected",
(
({"foo-bar": "bar"}, "bar"),
((("foo-bar", "bar"), ("foo-bar", "baz")), "bar,baz"),
),
)
def test_field_hyphenated_accessor(headers, expected):
request = make_request(headers)
assert request.headers.foo_bar == request.headers.foo_bar_ == expected


def test_bad_accessor():
request = make_request({})
msg = "'Header' object has no attribute '_foo'"
with pytest.raises(AttributeError, match=msg):
request.headers._foo


def test_multiple_fields_accessor(app: Sanic):
@app.get("")
async def handler(request: Request):
return json({"field": request.headers.example_field})

_, response = app.test_client.get(
"/", headers=(("Example-Field", "Foo, Bar"), ("Example-Field", "Baz"))
)
assert response.json["field"] == "Foo, Bar,Baz"

0 comments on commit cb49c2b

Please sign in to comment.