From 5d17f81054d9b753c117b342528ab41cc8b7f9f7 Mon Sep 17 00:00:00 2001 From: Bariq Nurlis Date: Sun, 8 Jan 2023 16:38:32 +0800 Subject: [PATCH] Implementation of `maybe_single` (#118) * add initial implementation on maybe_single Signed-off-by: Bariq * add sync maybe_single and fix error implementation Signed-off-by: Bariq * use relative import Signed-off-by: Bariq * implement new design for sync method Signed-off-by: Bariq * remove error from APIResponse Signed-off-by: Bariq * shift changes to async part Signed-off-by: Bariq * change class design to factory pattern Signed-off-by: Bariq * black and isort Signed-off-by: Bariq * fix: CI errors Signed-off-by: Bariq * fix tests and add additional test Signed-off-by: Bariq * fix new test Signed-off-by: Bariq * revamp class design Signed-off-by: Bariq * fix CI test Signed-off-by: Bariq * fix CI test 2 Signed-off-by: Bariq * fix unasync error and add typing Signed-off-by: Bariq * make tests for new methods Signed-off-by: Bariq * generate code and test for sync Signed-off-by: Bariq * fix docstring Signed-off-by: Bariq * fix docstring and remove unwanted changes Signed-off-by: Bariq * fix tests on CI Signed-off-by: Bariq * remove single ok tests Signed-off-by: Bariq Signed-off-by: Bariq --- postgrest/_async/request_builder.py | 108 ++++++++++++++++++++++++++- postgrest/_sync/client.py | 3 +- postgrest/_sync/request_builder.py | 108 ++++++++++++++++++++++++++- postgrest/base_request_builder.py | 43 ++++++++--- tests/_async/test_client.py | 43 ++++++++++- tests/_async/test_request_builder.py | 44 ++++++++++- tests/_sync/test_client.py | 43 ++++++++++- tests/_sync/test_request_builder.py | 44 ++++++++++- 8 files changed, 419 insertions(+), 17 deletions(-) diff --git a/postgrest/_async/request_builder.py b/postgrest/_async/request_builder.py index 67541f6..024e96f 100644 --- a/postgrest/_async/request_builder.py +++ b/postgrest/_async/request_builder.py @@ -10,6 +10,7 @@ BaseFilterRequestBuilder, BaseSelectRequestBuilder, CountMethod, + SingleAPIResponse, pre_delete, pre_insert, pre_select, @@ -57,13 +58,90 @@ async def execute(self) -> APIResponse: params=self.params, headers=self.headers, ) + try: + if ( + 200 <= r.status_code <= 299 + ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) + return APIResponse.from_http_request_response(r) + else: + raise APIError(r.json()) + except ValidationError as e: + raise APIError(r.json()) from e + + +class AsyncSingleRequestBuilder: + def __init__( + self, + session: AsyncClient, + path: str, + http_method: str, + headers: Headers, + params: QueryParams, + json: dict, + ) -> None: + self.session = session + self.path = path + self.http_method = http_method + self.headers = headers + self.params = params + self.json = json + + async def execute(self) -> SingleAPIResponse: + """Execute the query. + + .. tip:: + This is the last method called, after the query is built. + + Returns: + :class:`SingleAPIResponse` + Raises: + :class:`APIError` If the API raised an error. + """ + r = await self.session.request( + self.http_method, + self.path, + json=self.json, + params=self.params, + headers=self.headers, + ) try: - return APIResponse.from_http_request_response(r) + if ( + 200 <= r.status_code <= 299 + ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) + return SingleAPIResponse.from_http_request_response(r) + else: + raise APIError(r.json()) except ValidationError as e: raise APIError(r.json()) from e +class AsyncMaybeSingleRequestBuilder(AsyncSingleRequestBuilder): + async def execute(self) -> SingleAPIResponse: + r = None + try: + r = await super().execute() + except APIError as e: + if e.details and "Results contain 0 rows" in e.details: + return SingleAPIResponse.from_dict( + { + "data": None, + "error": None, + "count": 0, # NOTE: needs to take value from res.count + } + ) + if not r: + raise APIError( + { + "message": "Missing response", + "code": "204", + "hint": "Please check traceback of the code", + "details": "Postgrest couldn't retrieve response, please check traceback of the code. Please create an issue in `supabase-community/postgrest-py` if needed.", + } + ) + return r + + # ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): # type: ignore def __init__( @@ -97,6 +175,34 @@ def __init__( self, session, path, http_method, headers, params, json ) + def single(self) -> AsyncSingleRequestBuilder: + """Specify that the query will only return a single row in response. + + .. caution:: + The API will raise an error if the query returned more than one row. + """ + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return AsyncSingleRequestBuilder( + headers=self.headers, + http_method=self.http_method, + json=self.json, + params=self.params, + path=self.path, + session=self.session, # type: ignore + ) + + def maybe_single(self) -> AsyncMaybeSingleRequestBuilder: + """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return AsyncMaybeSingleRequestBuilder( + headers=self.headers, + http_method=self.http_method, + json=self.json, + params=self.params, + path=self.path, + session=self.session, # type: ignore + ) + class AsyncRequestBuilder: def __init__(self, session: AsyncClient, path: str) -> None: diff --git a/postgrest/_sync/client.py b/postgrest/_sync/client.py index 20906bb..3ab67cd 100644 --- a/postgrest/_sync/client.py +++ b/postgrest/_sync/client.py @@ -68,7 +68,7 @@ def from_(self, table: str) -> SyncRequestBuilder: return SyncRequestBuilder(self.session, f"/{table}") def table(self, table: str) -> SyncRequestBuilder: - """Alias to :meth:`from_`.""" + """Alias to self.from_().""" return self.from_(table) @deprecated("0.2.0", "1.0.0", __version__, "Use self.from_() instead") @@ -86,6 +86,7 @@ def rpc(self, func: str, params: dict) -> SyncFilterRequestBuilder: :class:`SyncFilterRequestBuilder` Example: :: + await client.rpc("foobar", {"arg": "value"}).execute() .. versionchanged:: 0.11.0 diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index b578765..93d4a3f 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -10,6 +10,7 @@ BaseFilterRequestBuilder, BaseSelectRequestBuilder, CountMethod, + SingleAPIResponse, pre_delete, pre_insert, pre_select, @@ -57,13 +58,90 @@ def execute(self) -> APIResponse: params=self.params, headers=self.headers, ) + try: + if ( + 200 <= r.status_code <= 299 + ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) + return APIResponse.from_http_request_response(r) + else: + raise APIError(r.json()) + except ValidationError as e: + raise APIError(r.json()) from e + + +class SyncSingleRequestBuilder: + def __init__( + self, + session: SyncClient, + path: str, + http_method: str, + headers: Headers, + params: QueryParams, + json: dict, + ) -> None: + self.session = session + self.path = path + self.http_method = http_method + self.headers = headers + self.params = params + self.json = json + + def execute(self) -> SingleAPIResponse: + """Execute the query. + + .. tip:: + This is the last method called, after the query is built. + + Returns: + :class:`SingleAPIResponse` + Raises: + :class:`APIError` If the API raised an error. + """ + r = self.session.request( + self.http_method, + self.path, + json=self.json, + params=self.params, + headers=self.headers, + ) try: - return APIResponse.from_http_request_response(r) + if ( + 200 <= r.status_code <= 299 + ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) + return SingleAPIResponse.from_http_request_response(r) + else: + raise APIError(r.json()) except ValidationError as e: raise APIError(r.json()) from e +class SyncMaybeSingleRequestBuilder(SyncSingleRequestBuilder): + def execute(self) -> SingleAPIResponse: + r = None + try: + r = super().execute() + except APIError as e: + if e.details and "Results contain 0 rows" in e.details: + return SingleAPIResponse.from_dict( + { + "data": None, + "error": None, + "count": 0, # NOTE: needs to take value from res.count + } + ) + if not r: + raise APIError( + { + "message": "Missing response", + "code": "204", + "hint": "Please check traceback of the code", + "details": "Postgrest couldn't retrieve response, please check traceback of the code. Please create an issue in `supabase-community/postgrest-py` if needed.", + } + ) + return r + + # ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 class SyncFilterRequestBuilder(BaseFilterRequestBuilder, SyncQueryRequestBuilder): # type: ignore def __init__( @@ -97,6 +175,34 @@ def __init__( self, session, path, http_method, headers, params, json ) + def single(self) -> SyncSingleRequestBuilder: + """Specify that the query will only return a single row in response. + + .. caution:: + The API will raise an error if the query returned more than one row. + """ + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return SyncSingleRequestBuilder( + headers=self.headers, + http_method=self.http_method, + json=self.json, + params=self.params, + path=self.path, + session=self.session, # type: ignore + ) + + def maybe_single(self) -> SyncMaybeSingleRequestBuilder: + """Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error.""" + self.headers["Accept"] = "application/vnd.pgrst.object+json" + return SyncMaybeSingleRequestBuilder( + headers=self.headers, + http_method=self.http_method, + json=self.json, + params=self.params, + path=self.path, + session=self.session, # type: ignore + ) + class SyncRequestBuilder: def __init__(self, session: SyncClient, path: str) -> None: diff --git a/postgrest/base_request_builder.py b/postgrest/base_request_builder.py index 277cb9e..5fa5fae 100644 --- a/postgrest/base_request_builder.py +++ b/postgrest/base_request_builder.py @@ -6,6 +6,7 @@ Any, Dict, Iterable, + List, NamedTuple, Optional, Tuple, @@ -105,7 +106,7 @@ def pre_delete( class APIResponse(BaseModel): - data: Any + data: List[Dict[str, Any]] """The data returned by the query.""" count: Optional[int] = None """The number of rows returned.""" @@ -155,6 +156,37 @@ def from_http_request_response( count = cls._get_count_from_http_request_response(request_response) return cls(data=data, count=count) + @classmethod + def from_dict(cls: Type[APIResponse], dict: Dict[str, Any]) -> APIResponse: + keys = dict.keys() + assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys + return cls( + data=dict.get("data"), count=dict.get("count"), error=dict.get("error") + ) + + +class SingleAPIResponse(APIResponse): + data: Dict[str, Any] # type: ignore + """The data returned by the query.""" + + @classmethod + def from_http_request_response( + cls: Type[SingleAPIResponse], request_response: RequestResponse + ) -> SingleAPIResponse: + data = request_response.json() + count = cls._get_count_from_http_request_response(request_response) + return cls(data=data, count=count) + + @classmethod + def from_dict( + cls: Type[SingleAPIResponse], dict: Dict[str, Any] + ) -> SingleAPIResponse: + keys = dict.keys() + assert len(keys) == 3 and "data" in keys and "count" in keys and "error" in keys + return cls( + data=dict.get("data"), count=dict.get("count"), error=dict.get("error") + ) + _FilterT = TypeVar("_FilterT", bound="BaseFilterRequestBuilder") @@ -410,12 +442,3 @@ def range(self: _FilterT, start: int, end: int) -> _FilterT: self.headers["Range-Unit"] = "items" self.headers["Range"] = f"{start}-{end - 1}" return self - - def single(self: _FilterT) -> _FilterT: - """Specify that the query will only return a single row in response. - - .. caution:: - The API will raise an error if the query returned more than one row. - """ - self.headers["Accept"] = "application/vnd.pgrst.object+json" - return self diff --git a/tests/_async/test_client.py b/tests/_async/test_client.py index 4e06ee8..33b59a2 100644 --- a/tests/_async/test_client.py +++ b/tests/_async/test_client.py @@ -1,7 +1,10 @@ +from unittest.mock import patch + import pytest from httpx import BasicAuth, Headers from postgrest import AsyncPostgrestClient +from postgrest.exceptions import APIError @pytest.fixture @@ -72,5 +75,43 @@ def test_schema(postgrest_client: AsyncPostgrestClient): @pytest.mark.asyncio async def test_params_purged_after_execute(postgrest_client: AsyncPostgrestClient): assert len(postgrest_client.session.params) == 0 - await postgrest_client.from_("test").select("a", "b").eq("c", "d").execute() + with pytest.raises(APIError): + await postgrest_client.from_("test").select("a", "b").eq("c", "d").execute() assert len(postgrest_client.session.params) == 0 + + +@pytest.mark.asyncio +async def test_response_status_code_outside_ok(postgrest_client: AsyncPostgrestClient): + with pytest.raises(APIError) as exc_info: + await postgrest_client.from_("test").select("a", "b").eq( + "c", "d" + ).execute() # gives status_code = 400 + exc_response = exc_info.value.json() + assert not exc_response.get("success") + assert isinstance(exc_response.get("errors"), list) + assert ( + isinstance(exc_response["errors"][0], dict) + and "code" in exc_response["errors"][0] + ) + assert exc_response["errors"][0].get("code") == 400 + + +@pytest.mark.asyncio +async def test_response_maybe_single(postgrest_client: AsyncPostgrestClient): + with patch( + "postgrest._async.request_builder.AsyncSingleRequestBuilder.execute", + side_effect=APIError( + {"message": "mock error", "code": "400", "hint": "mock", "details": "mock"} + ), + ): + client = ( + postgrest_client.from_("test").select("a", "b").eq("c", "d").maybe_single() + ) + assert "Accept" in client.headers + assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + with pytest.raises(APIError) as exc_info: + await client.execute() + assert isinstance(exc_info, pytest.ExceptionInfo) + exc_response = exc_info.value.json() + assert isinstance(exc_response.get("message"), str) + assert "code" in exc_response and int(exc_response["code"]) == 204 diff --git a/tests/_async/test_request_builder.py b/tests/_async/test_request_builder.py index 0c6434f..5aa67ab 100644 --- a/tests/_async/test_request_builder.py +++ b/tests/_async/test_request_builder.py @@ -4,7 +4,7 @@ from httpx import Request, Response from postgrest import AsyncRequestBuilder -from postgrest.base_request_builder import APIResponse +from postgrest.base_request_builder import APIResponse, SingleAPIResponse from postgrest.types import CountMethod from postgrest.utils import AsyncClient @@ -145,6 +145,18 @@ def api_response() -> List[Dict[str, Any]]: ] +@pytest.fixture +def single_api_response() -> Dict[str, Any]: + return { + "id": 1, + "name": "Bonaire, Sint Eustatius and Saba", + "iso2": "BQ", + "iso3": "BES", + "local_name": None, + "continent": None, + } + + @pytest.fixture def content_range_header_with_count() -> str: return "0-1/2" @@ -219,6 +231,24 @@ def request_response_with_data( ) +@pytest.fixture +def request_response_with_single_data( + prefer_header_with_count: str, + content_range_header_with_count: str, + single_api_response: Dict[str, Any], +) -> Response: + return Response( + status_code=200, + headers={"content-range": content_range_header_with_count}, + json=single_api_response, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_with_count}, + ), + ) + + class TestApiResponse: def test_response_raises_when_api_error( self, api_response_with_error: Dict[str, Any] @@ -300,3 +330,15 @@ def test_from_http_request_response_constructor( result = APIResponse.from_http_request_response(request_response_with_data) assert result.data == api_response assert result.count == 2 + + def test_single_from_http_request_response_constructor( + self, + request_response_with_single_data: Response, + single_api_response: Dict[str, Any], + ): + result = SingleAPIResponse.from_http_request_response( + request_response_with_single_data + ) + assert isinstance(result.data, dict) + assert result.data == single_api_response + assert result.count == 2 diff --git a/tests/_sync/test_client.py b/tests/_sync/test_client.py index bfd2c3e..daa0c81 100644 --- a/tests/_sync/test_client.py +++ b/tests/_sync/test_client.py @@ -1,7 +1,10 @@ +from unittest.mock import patch + import pytest from httpx import BasicAuth, Headers from postgrest import SyncPostgrestClient +from postgrest.exceptions import APIError @pytest.fixture @@ -72,5 +75,43 @@ def test_schema(postgrest_client: SyncPostgrestClient): @pytest.mark.asyncio def test_params_purged_after_execute(postgrest_client: SyncPostgrestClient): assert len(postgrest_client.session.params) == 0 - postgrest_client.from_("test").select("a", "b").eq("c", "d").execute() + with pytest.raises(APIError): + postgrest_client.from_("test").select("a", "b").eq("c", "d").execute() assert len(postgrest_client.session.params) == 0 + + +@pytest.mark.asyncio +def test_response_status_code_outside_ok(postgrest_client: SyncPostgrestClient): + with pytest.raises(APIError) as exc_info: + postgrest_client.from_("test").select("a", "b").eq( + "c", "d" + ).execute() # gives status_code = 400 + exc_response = exc_info.value.json() + assert not exc_response.get("success") + assert isinstance(exc_response.get("errors"), list) + assert ( + isinstance(exc_response["errors"][0], dict) + and "code" in exc_response["errors"][0] + ) + assert exc_response["errors"][0].get("code") == 400 + + +@pytest.mark.asyncio +def test_response_maybe_single(postgrest_client: SyncPostgrestClient): + with patch( + "postgrest._sync.request_builder.SyncSingleRequestBuilder.execute", + side_effect=APIError( + {"message": "mock error", "code": "400", "hint": "mock", "details": "mock"} + ), + ): + client = ( + postgrest_client.from_("test").select("a", "b").eq("c", "d").maybe_single() + ) + assert "Accept" in client.headers + assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + with pytest.raises(APIError) as exc_info: + client.execute() + assert isinstance(exc_info, pytest.ExceptionInfo) + exc_response = exc_info.value.json() + assert isinstance(exc_response.get("message"), str) + assert "code" in exc_response and int(exc_response["code"]) == 204 diff --git a/tests/_sync/test_request_builder.py b/tests/_sync/test_request_builder.py index 0114588..96299ea 100644 --- a/tests/_sync/test_request_builder.py +++ b/tests/_sync/test_request_builder.py @@ -4,7 +4,7 @@ from httpx import Request, Response from postgrest import SyncRequestBuilder -from postgrest.base_request_builder import APIResponse +from postgrest.base_request_builder import APIResponse, SingleAPIResponse from postgrest.types import CountMethod from postgrest.utils import SyncClient @@ -145,6 +145,18 @@ def api_response() -> List[Dict[str, Any]]: ] +@pytest.fixture +def single_api_response() -> Dict[str, Any]: + return { + "id": 1, + "name": "Bonaire, Sint Eustatius and Saba", + "iso2": "BQ", + "iso3": "BES", + "local_name": None, + "continent": None, + } + + @pytest.fixture def content_range_header_with_count() -> str: return "0-1/2" @@ -219,6 +231,24 @@ def request_response_with_data( ) +@pytest.fixture +def request_response_with_single_data( + prefer_header_with_count: str, + content_range_header_with_count: str, + single_api_response: Dict[str, Any], +) -> Response: + return Response( + status_code=200, + headers={"content-range": content_range_header_with_count}, + json=single_api_response, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_with_count}, + ), + ) + + class TestApiResponse: def test_response_raises_when_api_error( self, api_response_with_error: Dict[str, Any] @@ -300,3 +330,15 @@ def test_from_http_request_response_constructor( result = APIResponse.from_http_request_response(request_response_with_data) assert result.data == api_response assert result.count == 2 + + def test_single_from_http_request_response_constructor( + self, + request_response_with_single_data: Response, + single_api_response: Dict[str, Any], + ): + result = SingleAPIResponse.from_http_request_response( + request_response_with_single_data + ) + assert isinstance(result.data, dict) + assert result.data == single_api_response + assert result.count == 2