diff --git a/tartiflette_asgi/endpoints.py b/tartiflette_asgi/endpoints.py index 93f46b4..7966b4a 100644 --- a/tartiflette_asgi/endpoints.py +++ b/tartiflette_asgi/endpoints.py @@ -1,4 +1,6 @@ import typing +import json + from starlette.background import BackgroundTasks from starlette.datastructures import QueryParams @@ -37,7 +39,10 @@ async def post(self, request: Request) -> Response: content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: - data = await request.json() + try: + data = await request.json() + except json.JSONDecodeError: + return JSONResponse({"error": "Invalid JSON."}, 400) elif "application/graphql" in content_type: body = await request.body() data = {"query": body.decode()} diff --git a/tests/test_graphql_api.py b/tests/test_graphql_api.py index a85b413..5cb46a7 100644 --- a/tests/test_graphql_api.py +++ b/tests/test_graphql_api.py @@ -27,6 +27,14 @@ def test_post_json(client: TestClient): assert response.json() == {"data": {"hello": "Hello stranger"}} +def test_post_invalid_json(client: TestClient): + response = client.post( + "/", data="{test", headers={"content-type": "application/json"} + ) + assert response.status_code == 400 + assert response.json() == {"error": "Invalid JSON."} + + def test_post_graphql(client: TestClient): response = client.post( "/", data="{ hello }", headers={"content-type": "application/graphql"}