diff --git a/README.md b/README.md index 728b721..96a1f38 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,8 @@ Your application code should include that value as a hidden form field in any PO Note that `request.scope["csrftoken"]` is a function that returns a string. Calling that function also lets the middleware know that the cookie should be set by that page, if the user does not already have that cookie. +If the cookie needs to be set, the middleware will add a `Vary: Cookie` header to the response to ensure it is not incorrectly cached by any CDNs or intermediary proxies. + The middleware will return a 403 forbidden error for any POST requests that do not include the matching `csrftoken` - either in the POST data or in a `x-csrftoken` HTTP header (useful for JavaScript `fetch()` calls). The `signing_secret` is used to sign the tokens, to protect against subdomain vulnerabilities. diff --git a/asgi_csrf.py b/asgi_csrf.py index 6729ccb..b754d6a 100644 --- a/asgi_csrf.py +++ b/asgi_csrf.py @@ -58,18 +58,31 @@ async def wrapped_send(event): if event["type"] == "http.response.start": if should_set_cookie: original_headers = event.get("headers") or [] - set_cookie_headers = [ + new_headers = [] + # Loop through original headers in case we need to modify "vary" + found_vary = False + for key, value in original_headers: + if key == b"vary": + found_vary = True + vary_bits = [v.strip() for v in value.split(b",")] + if b"Cookie" not in vary_bits: + vary_bits.append(b"Cookie") + value = b", ".join(vary_bits) + new_headers.append((key, value)) + if not found_vary: + new_headers.append((b"vary", b"Cookie")) + new_headers.append( ( b"set-cookie", "{}={}; Path=/".format(cookie_name, csrftoken).encode( "utf-8" ), ) - ] + ) event = { "type": "http.response.start", "status": event["status"], - "headers": original_headers + set_cookie_headers, + "headers": new_headers, } await send(event) diff --git a/test_asgi_csrf.py b/test_asgi_csrf.py index 73fe4ae..514fae6 100644 --- a/test_asgi_csrf.py +++ b/test_asgi_csrf.py @@ -16,7 +16,10 @@ async def hello_world(request): if request.method == "POST": data = await request.form() return JSONResponse(dict(await request.form())) - return JSONResponse({"hello": "world"}) + headers = {} + if "_vary" in request.query_params: + headers["Vary"] = request.query_params["_vary"] + return JSONResponse({"hello": "world"}, headers=headers) async def hello_world_static(request): @@ -55,6 +58,17 @@ async def test_asgi_csrf_sets_cookie(app_csrf): assert b'{"hello":"world"}' == response.content assert "csrftoken" in response.cookies assert response.headers["set-cookie"].endswith("; Path=/") + assert "Cookie" == response.headers["vary"] + + +@pytest.mark.asyncio +async def test_asgi_csrf_modifies_existing_vary_header(app_csrf): + async with httpx.AsyncClient(app=app_csrf) as client: + response = await client.get("http://localhost/?_vary=User-Agent") + assert b'{"hello":"world"}' == response.content + assert "csrftoken" in response.cookies + assert response.headers["set-cookie"].endswith("; Path=/") + assert "User-Agent, Cookie" == response.headers["vary"] @pytest.mark.asyncio