Skip to content

Commit

Permalink
added tests for CSRF referer haeder
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Nov 12, 2019
1 parent 875bc84 commit d1595e6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 21 deletions.
39 changes: 32 additions & 7 deletions piccolo_api/csrf/middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import uuid
import typing as t

from starlette.datastructures import URL
from starlette.middleware.base import (
BaseHTTPMiddleware,
RequestResponseEndpoint,
Request,
)
from starlette.types import ASGIApp
from starlette.exceptions import HTTPException


Expand All @@ -23,9 +26,8 @@ class CSRFMiddleware(BaseHTTPMiddleware):
https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie
https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#use-of-custom-request-headers
For a good explanation on how SPAs mitigate CSRF:
https://angular.io/guide/security#xsrf
This is currently only intended for use using AJAX - since the CSRF token
needs to be added to the request header.
"""

cookie_name = "csrftoken"
Expand All @@ -35,9 +37,26 @@ class CSRFMiddleware(BaseHTTPMiddleware):
def get_new_token() -> str:
return str(uuid.uuid4())

def check_referer(self, request: Request):
# Prefer the origin header if available.
pass
def __init__(
self, app: ASGIApp, allowed_hosts: t.Iterable[str] = [], **kwargs
):
if not isinstance(allowed_hosts, list):
raise ValueError("allowed_hosts must be a list")

self.allowed_hosts = allowed_hosts
super().__init__(app, **kwargs)

def is_valid_referer(self, request: Request) -> bool:
header: str = (
request.headers.get("origin")
or request.headers.get("referer")
or ""
)

url = URL(header)
hostname = url.hostname
is_valid = hostname in self.allowed_hosts if hostname else False
return is_valid

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand All @@ -57,7 +76,13 @@ async def dispatch(
if cookie_token != header_token:
raise HTTPException(403, "CSRF tokens don't match")

# Provides defence in depth:
if request.base_url.is_secure:
self.check_referer(request)
# According to this paper, the referer header is present in
# the vast majority of HTTPS requests, but not HTTP requests,
# so only check it for HTTPS.
# https://seclab.stanford.edu/websec/csrf/csrf.pdf
if not self.is_valid_referer(request):
raise HTTPException(403, "Referer or origin is incorrect")

return await call_next(request)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Jinja2==2.10.1
piccolo>=0.5.2
pydantic==1.0
python-multipart==0.0.5
starlette>=0.12.3
starlette>=0.12.13
PyJWT==1.7.1
56 changes: 43 additions & 13 deletions tests/test_csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ async def app(scope, receive, send):


WRAPPED_APP = ExceptionMiddleware(CSRFMiddleware(app))
HOST_RESTRICTED_APP = ExceptionMiddleware(
CSRFMiddleware(app, allowed_hosts=["foo.com"])
)


class TestCSRFMiddleware(TestCase):
Expand Down Expand Up @@ -98,22 +101,49 @@ def test_token_mismatch_rejected(self):

def test_referer_accepted(self):
"""
Make sure a post containing a CSRF cookie and matching token are
accepted.
Make sure that a correct referer or origin header is allowed.
"""
client = TestClient(WRAPPED_APP)
response = client.post(
"https://foo.com",
cookies={CSRFMiddleware.cookie_name: self.csrf_token},
headers={
CSRFMiddleware.header_name: self.csrf_token,
"referer": "https://foo.com",
},
)
self.assertTrue(response.status_code == 200)
cookies = {CSRFMiddleware.cookie_name: self.csrf_token}
base_headers = {CSRFMiddleware.header_name: self.csrf_token}

client = TestClient(HOST_RESTRICTED_APP)
valid_domain = "https://foo.com"

kwargs = [
{"referer": valid_domain},
{"referer": f"{valid_domain}/bar/"},
{"origin": valid_domain},
{"origin": valid_domain, "referer": valid_domain},
]

for _kwargs in kwargs:
response = client.post(
valid_domain,
cookies=cookies,
headers=dict(base_headers, **_kwargs),
)
self.assertTrue(response.status_code == 200)

def test_referer_rejected(self):
pass
"""
Make sure that an incorrect or missing referer / origin header isn't
allowed.
"""
cookies = {CSRFMiddleware.cookie_name: self.csrf_token}
base_headers = {CSRFMiddleware.header_name: self.csrf_token}

client = TestClient(HOST_RESTRICTED_APP)
invalid_domain = "https://bar.com"

kwargs = [{"referer": invalid_domain}, {"origin": invalid_domain}, {}]

for _kwargs in kwargs:
response = client.post(
"https://foo.com",
cookies=cookies,
headers=dict(base_headers, **_kwargs),
)
self.assertTrue(response.status_code == 403)


if __name__ == "__main__":
Expand Down

0 comments on commit d1595e6

Please sign in to comment.