diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6f1c12..10c035e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.1.0 hooks: - id: trailing-whitespace - id: check-added-large-files @@ -35,18 +35,19 @@ repos: ] - repo: https://github.com/ambv/black - rev: 21.11b1 + rev: 21.12b0 hooks: - id: black + args: [--line-length, "90"] - repo: https://github.com/asottile/pyupgrade - rev: v2.29.1 + rev: v2.31.0 hooks: - id: pyupgrade args: ["--py37-plus", "--keep-runtime-typing"] - repo: https://github.com/commitizen-tools/commitizen - rev: v2.20.0 + rev: v2.20.3 hooks: - id: commitizen stages: [commit-msg] diff --git a/poetry.lock b/poetry.lock index 496ee83..32c3b75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -606,6 +606,21 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +[[package]] +name = "pydantic" +version = "1.9.0" +description = "Data validation and settings management using python 3.6 type hinting" +category = "main" +optional = false +python-versions = ">=3.6.1" + +[package.dependencies] +typing-extensions = ">=3.7.4.3" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] + [[package]] name = "pyflakes" version = "2.4.0" @@ -1546,6 +1561,43 @@ pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +pydantic = [ + {file = "pydantic-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cb23bcc093697cdea2708baae4f9ba0e972960a835af22560f6ae4e7e47d33f5"}, + {file = "pydantic-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1d5278bd9f0eee04a44c712982343103bba63507480bfd2fc2790fa70cd64cf4"}, + {file = "pydantic-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab624700dc145aa809e6f3ec93fb8e7d0f99d9023b713f6a953637429b437d37"}, + {file = "pydantic-1.9.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8d7da6f1c1049eefb718d43d99ad73100c958a5367d30b9321b092771e96c25"}, + {file = "pydantic-1.9.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3c3b035103bd4e2e4a28da9da7ef2fa47b00ee4a9cf4f1a735214c1bcd05e0f6"}, + {file = "pydantic-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3011b975c973819883842c5ab925a4e4298dffccf7782c55ec3580ed17dc464c"}, + {file = "pydantic-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:086254884d10d3ba16da0588604ffdc5aab3f7f09557b998373e885c690dd398"}, + {file = "pydantic-1.9.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0fe476769acaa7fcddd17cadd172b156b53546ec3614a4d880e5d29ea5fbce65"}, + {file = "pydantic-1.9.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8e9dcf1ac499679aceedac7e7ca6d8641f0193c591a2d090282aaf8e9445a46"}, + {file = "pydantic-1.9.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e4c28f30e767fd07f2ddc6f74f41f034d1dd6bc526cd59e63a82fe8bb9ef4c"}, + {file = "pydantic-1.9.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c86229333cabaaa8c51cf971496f10318c4734cf7b641f08af0a6fbf17ca3054"}, + {file = "pydantic-1.9.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:c0727bda6e38144d464daec31dff936a82917f431d9c39c39c60a26567eae3ed"}, + {file = "pydantic-1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:dee5ef83a76ac31ab0c78c10bd7d5437bfdb6358c95b91f1ba7ff7b76f9996a1"}, + {file = "pydantic-1.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d9c9bdb3af48e242838f9f6e6127de9be7063aad17b32215ccc36a09c5cf1070"}, + {file = "pydantic-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ee7e3209db1e468341ef41fe263eb655f67f5c5a76c924044314e139a1103a2"}, + {file = "pydantic-1.9.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b6037175234850ffd094ca77bf60fb54b08b5b22bc85865331dd3bda7a02fa1"}, + {file = "pydantic-1.9.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b2571db88c636d862b35090ccf92bf24004393f85c8870a37f42d9f23d13e032"}, + {file = "pydantic-1.9.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8b5ac0f1c83d31b324e57a273da59197c83d1bb18171e512908fe5dc7278a1d6"}, + {file = "pydantic-1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:bbbc94d0c94dd80b3340fc4f04fd4d701f4b038ebad72c39693c794fd3bc2d9d"}, + {file = "pydantic-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e0896200b6a40197405af18828da49f067c2fa1f821491bc8f5bde241ef3f7d7"}, + {file = "pydantic-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bdfdadb5994b44bd5579cfa7c9b0e1b0e540c952d56f627eb227851cda9db77"}, + {file = "pydantic-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:574936363cd4b9eed8acdd6b80d0143162f2eb654d96cb3a8ee91d3e64bf4cf9"}, + {file = "pydantic-1.9.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c556695b699f648c58373b542534308922c46a1cda06ea47bc9ca45ef5b39ae6"}, + {file = "pydantic-1.9.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f947352c3434e8b937e3aa8f96f47bdfe6d92779e44bb3f41e4c213ba6a32145"}, + {file = "pydantic-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5e48ef4a8b8c066c4a31409d91d7ca372a774d0212da2787c0d32f8045b1e034"}, + {file = "pydantic-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:96f240bce182ca7fe045c76bcebfa0b0534a1bf402ed05914a6f1dadff91877f"}, + {file = "pydantic-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:815ddebb2792efd4bba5488bc8fde09c29e8ca3227d27cf1c6990fc830fd292b"}, + {file = "pydantic-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c5b77947b9e85a54848343928b597b4f74fc364b70926b3c4441ff52620640c"}, + {file = "pydantic-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c68c3bc88dbda2a6805e9a142ce84782d3930f8fdd9655430d8576315ad97ce"}, + {file = "pydantic-1.9.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a79330f8571faf71bf93667d3ee054609816f10a259a109a0738dac983b23c3"}, + {file = "pydantic-1.9.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f5a64b64ddf4c99fe201ac2724daada8595ada0d102ab96d019c1555c2d6441d"}, + {file = "pydantic-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a733965f1a2b4090a5238d40d983dcd78f3ecea221c7af1497b845a9709c1721"}, + {file = "pydantic-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cc6a4cb8a118ffec2ca5fcb47afbacb4f16d0ab8b7350ddea5e8ef7bcc53a16"}, + {file = "pydantic-1.9.0-py3-none-any.whl", hash = "sha256:085ca1de245782e9b46cefcf99deecc67d418737a1fd3f6a4f511344b613a5b3"}, + {file = "pydantic-1.9.0.tar.gz", hash = "sha256:742645059757a56ecd886faf4ed2441b9c0cd406079c2b4bee51bcc3fbcd510a"}, +] pyflakes = [ {file = "pyflakes-2.4.0-py2.py3-none-any.whl", hash = "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e"}, {file = "pyflakes-2.4.0.tar.gz", hash = "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c"}, diff --git a/postgrest_py/_async/client.py b/postgrest_py/_async/client.py index 6cc61d3..99cf4ad 100644 --- a/postgrest_py/_async/client.py +++ b/postgrest_py/_async/client.py @@ -47,7 +47,7 @@ def create_session( timeout=timeout, ) - async def __aenter__(self) -> "AsyncPostgrestClient": + async def __aenter__(self) -> AsyncPostgrestClient: return self async def __aexit__(self, exc_type, exc, tb) -> None: diff --git a/postgrest_py/_async/request_builder.py b/postgrest_py/_async/request_builder.py index 1bc88e5..9470260 100644 --- a/postgrest_py/_async/request_builder.py +++ b/postgrest_py/_async/request_builder.py @@ -1,8 +1,11 @@ from __future__ import annotations -from typing import Any, Optional, Tuple +from typing import Optional + +from pydantic import ValidationError from ..base_request_builder import ( + APIResponse, BaseFilterRequestBuilder, BaseSelectRequestBuilder, CountMethod, @@ -11,8 +14,8 @@ pre_select, pre_update, pre_upsert, - process_response, ) +from ..exceptions import APIError from ..types import ReturnMethod from ..utils import AsyncClient @@ -30,16 +33,20 @@ def __init__( self.http_method = http_method self.json = json - async def execute(self) -> Tuple[Any, Optional[int]]: + async def execute(self) -> APIResponse: r = await self.session.request( self.http_method, self.path, json=self.json, ) - return process_response(self.session, r) + try: + return APIResponse.from_http_request_response(r) + except ValidationError as e: + raise APIError(r.json()) from e -class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): +# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 +class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): # type: ignore def __init__( self, session: AsyncClient, @@ -51,7 +58,8 @@ def __init__( AsyncQueryRequestBuilder.__init__(self, session, path, http_method, json) -class AsyncSelectRequestBuilder(BaseSelectRequestBuilder, AsyncQueryRequestBuilder): +# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 +class AsyncSelectRequestBuilder(BaseSelectRequestBuilder, AsyncQueryRequestBuilder): # type: ignore def __init__( self, session: AsyncClient, @@ -73,7 +81,7 @@ def select( *columns: str, count: Optional[CountMethod] = None, ) -> AsyncSelectRequestBuilder: - method, json = pre_select(self.session, self.path, *columns, count=count) + method, json = pre_select(self.session, *columns, count=count) return AsyncSelectRequestBuilder(self.session, self.path, method, json) def insert( @@ -86,7 +94,6 @@ def insert( ) -> AsyncQueryRequestBuilder: method, json = pre_insert( self.session, - self.path, json, count=count, returning=returning, @@ -104,7 +111,6 @@ def upsert( ) -> AsyncQueryRequestBuilder: method, json = pre_upsert( self.session, - self.path, json, count=count, returning=returning, @@ -121,7 +127,6 @@ def update( ) -> AsyncFilterRequestBuilder: method, json = pre_update( self.session, - self.path, json, count=count, returning=returning, @@ -136,7 +141,6 @@ def delete( ) -> AsyncFilterRequestBuilder: method, json = pre_delete( self.session, - self.path, count=count, returning=returning, ) diff --git a/postgrest_py/_sync/request_builder.py b/postgrest_py/_sync/request_builder.py index 8343ed3..1055366 100644 --- a/postgrest_py/_sync/request_builder.py +++ b/postgrest_py/_sync/request_builder.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any, Optional, Tuple +from typing import Optional from ..base_request_builder import ( + APIResponse, BaseFilterRequestBuilder, BaseSelectRequestBuilder, CountMethod, @@ -11,8 +12,8 @@ pre_select, pre_update, pre_upsert, - process_response, ) +from ..exceptions import APIError from ..types import ReturnMethod from ..utils import SyncClient @@ -30,16 +31,20 @@ def __init__( self.http_method = http_method self.json = json - def execute(self) -> Tuple[Any, Optional[int]]: + def execute(self) -> APIResponse: r = self.session.request( self.http_method, self.path, json=self.json, ) - return process_response(self.session, r) + try: + return APIResponse.from_http_request_response(r) + except ValueError as e: + raise APIError(r.json()) from e -class SyncFilterRequestBuilder(BaseFilterRequestBuilder, SyncQueryRequestBuilder): +# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 +class SyncFilterRequestBuilder(BaseFilterRequestBuilder, SyncQueryRequestBuilder): # type: ignore def __init__( self, session: SyncClient, @@ -51,7 +56,8 @@ def __init__( SyncQueryRequestBuilder.__init__(self, session, path, http_method, json) -class SyncSelectRequestBuilder(BaseSelectRequestBuilder, SyncQueryRequestBuilder): +# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319 +class SyncSelectRequestBuilder(BaseSelectRequestBuilder, SyncQueryRequestBuilder): # type: ignore def __init__( self, session: SyncClient, @@ -73,7 +79,7 @@ def select( *columns: str, count: Optional[CountMethod] = None, ) -> SyncSelectRequestBuilder: - method, json = pre_select(self.session, self.path, *columns, count=count) + method, json = pre_select(self.session, *columns, count=count) return SyncSelectRequestBuilder(self.session, self.path, method, json) def insert( @@ -86,7 +92,6 @@ def insert( ) -> SyncQueryRequestBuilder: method, json = pre_insert( self.session, - self.path, json, count=count, returning=returning, @@ -104,7 +109,6 @@ def upsert( ) -> SyncQueryRequestBuilder: method, json = pre_upsert( self.session, - self.path, json, count=count, returning=returning, @@ -121,7 +125,6 @@ def update( ) -> SyncFilterRequestBuilder: method, json = pre_update( self.session, - self.path, json, count=count, returning=returning, @@ -136,7 +139,6 @@ def delete( ) -> SyncFilterRequestBuilder: method, json = pre_delete( self.session, - self.path, count=count, returning=returning, ) diff --git a/postgrest_py/base_request_builder.py b/postgrest_py/base_request_builder.py index 47c65cb..5853c0c 100644 --- a/postgrest_py/base_request_builder.py +++ b/postgrest_py/base_request_builder.py @@ -1,9 +1,10 @@ from __future__ import annotations from re import search -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Type, Union -from httpx import Response +from httpx import Response as RequestResponse +from pydantic import BaseModel, validator from .types import CountMethod, Filters, RequestMethod, ReturnMethod from .utils import AsyncClient, SyncClient, sanitize_param, sanitize_pattern_param @@ -11,7 +12,6 @@ def pre_select( session: Union[AsyncClient, SyncClient], - path: str, *columns: str, count: Optional[CountMethod] = None, ) -> Tuple[RequestMethod, dict]: @@ -27,7 +27,6 @@ def pre_select( def pre_insert( session: Union[AsyncClient, SyncClient], - path: str, json: dict, *, count: Optional[CountMethod], @@ -45,14 +44,13 @@ def pre_insert( def pre_upsert( session: Union[AsyncClient, SyncClient], - path: str, json: dict, *, count: Optional[CountMethod], returning: ReturnMethod, ignore_duplicates: bool, ) -> Tuple[RequestMethod, dict]: - prefer_headers = ["return=representation"] + prefer_headers = [f"return={returning}"] if count: prefer_headers.append(f"count={count}") resolution = "ignore" if ignore_duplicates else "merge" @@ -63,7 +61,6 @@ def pre_upsert( def pre_update( session: Union[AsyncClient, SyncClient], - path: str, json: dict, *, count: Optional[CountMethod], @@ -78,7 +75,6 @@ def pre_update( def pre_delete( session: Union[AsyncClient, SyncClient], - path: str, *, count: Optional[CountMethod], returning: ReturnMethod, @@ -90,21 +86,54 @@ def pre_delete( return RequestMethod.DELETE, {} -def process_response( - session: Union[AsyncClient, SyncClient], - r: Response, -) -> Tuple[Any, Optional[int]]: - count = None - prefer_header = session.headers.get("prefer") - if prefer_header: +class APIResponse(BaseModel): + data: Any + count: Optional[int] = None + + @validator("data") + @classmethod + def raise_when_api_error(cls: Type[APIResponse], value: Any) -> Any: + if isinstance(value, dict) and value.get("message"): + raise ValueError("You are passing an API error to the data field.") + return value + + @staticmethod + def _get_count_from_content_range_header( + content_range_header: str, + ) -> Optional[int]: + content_range = content_range_header.split("/") + if len(content_range) < 2: + return None + return int(content_range[1]) + + @staticmethod + def _is_count_in_prefer_header(prefer_header: str) -> bool: pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})" - count_header_match = search(pattern, prefer_header) - content_range_header = r.headers.get("content-range") - if count_header_match and content_range_header: - content_range = content_range_header.split("/") - if len(content_range) >= 2: - count = int(content_range[1]) - return r.json(), count + return bool(search(pattern, prefer_header)) + + @classmethod + def _get_count_from_http_request_response( + cls: Type[APIResponse], + request_response: RequestResponse, + ) -> Optional[int]: + prefer_header: Optional[str] = request_response.request.headers.get("prefer") + if not prefer_header: + return None + is_count_in_prefer_header = cls._is_count_in_prefer_header(prefer_header) + content_range_header: Optional[str] = request_response.headers.get( + "content-range" + ) + if not (is_count_in_prefer_header and content_range_header): + return None + return cls._get_count_from_content_range_header(content_range_header) + + @classmethod + def from_http_request_response( + cls: Type[APIResponse], request_response: RequestResponse + ) -> APIResponse: + data = request_response.json() + count = cls._get_count_from_http_request_response(request_response) + return cls(data=data, count=count) class BaseFilterRequestBuilder: diff --git a/postgrest_py/exceptions.py b/postgrest_py/exceptions.py new file mode 100644 index 0000000..db774f8 --- /dev/null +++ b/postgrest_py/exceptions.py @@ -0,0 +1,32 @@ +from typing import Dict + + +class APIError(Exception): + """ + Base exception for all API errors. + """ + + _raw_error: Dict[str, str] + message: str + code: str + hint: str + details: str + + def __init__(self, error: Dict[str, str]) -> None: + self._raw_error = error + self.message = error["message"] + self.code = error["code"] + self.hint = error["hint"] + self.details = error["details"] + Exception.__init__(self, str(self)) + + def __repr__(self) -> str: + error_text = f"Error {self.code}:" if self.code else "" + message_text = f"\nMessage: {self.message}" if self.message else "" + hint_text = f"\nHint: {self.hint}" if self.hint else "" + details_text = f"\nDetails: {self.details}" if self.details else "" + complete_error_text = f"{error_text}{message_text}{hint_text}{details_text}" + return complete_error_text or "Empty error" + + def json(self) -> Dict[str, str]: + return self._raw_error diff --git a/pyproject.toml b/pyproject.toml index 79a22f9..8a0e49f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ python = "^3.7" httpx = ">=0.20,<0.23" deprecation = "^2.1.0" +pydantic = "^1.9.0" [tool.poetry.dev-dependencies] pytest = "^6.2.5" diff --git a/tests/_async/test_request_builder.py b/tests/_async/test_request_builder.py index 85339b0..78c4f8a 100644 --- a/tests/_async/test_request_builder.py +++ b/tests/_async/test_request_builder.py @@ -1,6 +1,10 @@ +from typing import Any, Dict, List + import pytest +from httpx import Request, Response from postgrest_py import AsyncRequestBuilder +from postgrest_py.base_request_builder import APIResponse from postgrest_py.types import CountMethod from postgrest_py.utils import AsyncClient @@ -114,3 +118,191 @@ def test_delete_with_count(self, request_builder: AsyncRequestBuilder): ] assert builder.http_method == "DELETE" assert builder.json == {} + + +@pytest.fixture +def api_response_with_error() -> Dict[str, Any]: + return { + "message": "Route GET:/countries?select=%2A not found", + "error": "Not Found", + "statusCode": 404, + } + + +@pytest.fixture +def api_response() -> List[Dict[str, Any]]: + return [ + { + "id": 1, + "name": "Bonaire, Sint Eustatius and Saba", + "iso2": "BQ", + "iso3": "BES", + "local_name": None, + "continent": None, + }, + { + "id": 2, + "name": "Curaçao", + "iso2": "CW", + "iso3": "CUW", + "local_name": None, + "continent": None, + }, + ] + + +@pytest.fixture +def content_range_header_with_count() -> str: + return "0-1/2" + + +@pytest.fixture +def content_range_header_without_count() -> str: + return "0-1" + + +@pytest.fixture +def prefer_header_with_count() -> str: + return "count=exact" + + +@pytest.fixture +def prefer_header_without_count() -> str: + return "random prefer header" + + +@pytest.fixture +def request_response_without_prefer_header() -> Response: + return Response( + status_code=200, request=Request(method="GET", url="http://example.com") + ) + + +@pytest.fixture +def request_response_with_prefer_header_without_count( + prefer_header_without_count: str, +) -> Response: + return Response( + status_code=200, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_without_count}, + ), + ) + + +@pytest.fixture +def request_response_with_prefer_header_with_count_and_content_range( + prefer_header_with_count: str, content_range_header_with_count: str +) -> Response: + return Response( + status_code=200, + headers={"content-range": content_range_header_with_count}, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_with_count}, + ), + ) + + +@pytest.fixture +def request_response_with_data( + prefer_header_with_count: str, + content_range_header_with_count: str, + api_response: List[Dict[str, Any]], +) -> Response: + return Response( + status_code=200, + headers={"content-range": content_range_header_with_count}, + json=api_response, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_with_count}, + ), + ) + + +class TestApiResponse: + def test_response_raises_when_api_error( + self, api_response_with_error: Dict[str, Any] + ): + with pytest.raises(ValueError): + APIResponse(data=api_response_with_error) + + def test_parses_valid_response_only_data(self, api_response: List[Dict[str, Any]]): + result = APIResponse(data=api_response) + assert result.data == api_response + + def test_parses_valid_response_data_and_count( + self, api_response: List[Dict[str, Any]] + ): + count = len(api_response) + result = APIResponse(data=api_response, count=count) + assert result.data == api_response + assert result.count == count + + def test_get_count_from_content_range_header_with_count( + self, content_range_header_with_count: str + ): + assert ( + APIResponse._get_count_from_content_range_header( + content_range_header_with_count + ) + == 2 + ) + + def test_get_count_from_content_range_header_without_count( + self, content_range_header_without_count: str + ): + assert ( + APIResponse._get_count_from_content_range_header( + content_range_header_without_count + ) + is None + ) + + def test_is_count_in_prefer_header_true(self, prefer_header_with_count: str): + assert APIResponse._is_count_in_prefer_header(prefer_header_with_count) + + def test_is_count_in_prefer_header_false(self, prefer_header_without_count: str): + assert not APIResponse._is_count_in_prefer_header(prefer_header_without_count) + + def test_get_count_from_http_request_response_without_prefer_header( + self, request_response_without_prefer_header: Response + ): + assert ( + APIResponse._get_count_from_http_request_response( + request_response_without_prefer_header + ) + is None + ) + + def test_get_count_from_http_request_response_with_prefer_header_without_count( + self, request_response_with_prefer_header_without_count: Response + ): + assert ( + APIResponse._get_count_from_http_request_response( + request_response_with_prefer_header_without_count + ) + is None + ) + + def test_get_count_from_http_request_response_with_count_and_content_range( + self, request_response_with_prefer_header_with_count_and_content_range: Response + ): + assert ( + APIResponse._get_count_from_http_request_response( + request_response_with_prefer_header_with_count_and_content_range + ) + == 2 + ) + + def test_from_http_request_response_constructor( + self, request_response_with_data: Response, api_response: List[Dict[str, Any]] + ): + result = APIResponse.from_http_request_response(request_response_with_data) + assert result.data == api_response + assert result.count == 2 diff --git a/tests/_sync/test_request_builder.py b/tests/_sync/test_request_builder.py index 4fb3850..c57fed4 100644 --- a/tests/_sync/test_request_builder.py +++ b/tests/_sync/test_request_builder.py @@ -1,6 +1,10 @@ +from typing import Any, Dict, List + import pytest +from httpx import Request, Response from postgrest_py import SyncRequestBuilder +from postgrest_py.base_request_builder import APIResponse from postgrest_py.types import CountMethod from postgrest_py.utils import SyncClient @@ -114,3 +118,191 @@ def test_delete_with_count(self, request_builder: SyncRequestBuilder): ] assert builder.http_method == "DELETE" assert builder.json == {} + + +@pytest.fixture +def api_response_with_error() -> Dict[str, Any]: + return { + "message": "Route GET:/countries?select=%2A not found", + "error": "Not Found", + "statusCode": 404, + } + + +@pytest.fixture +def api_response() -> List[Dict[str, Any]]: + return [ + { + "id": 1, + "name": "Bonaire, Sint Eustatius and Saba", + "iso2": "BQ", + "iso3": "BES", + "local_name": None, + "continent": None, + }, + { + "id": 2, + "name": "Curaçao", + "iso2": "CW", + "iso3": "CUW", + "local_name": None, + "continent": None, + }, + ] + + +@pytest.fixture +def content_range_header_with_count() -> str: + return "0-1/2" + + +@pytest.fixture +def content_range_header_without_count() -> str: + return "0-1" + + +@pytest.fixture +def prefer_header_with_count() -> str: + return "count=exact" + + +@pytest.fixture +def prefer_header_without_count() -> str: + return "random prefer header" + + +@pytest.fixture +def request_response_without_prefer_header() -> Response: + return Response( + status_code=200, request=Request(method="GET", url="http://example.com") + ) + + +@pytest.fixture +def request_response_with_prefer_header_without_count( + prefer_header_without_count: str, +) -> Response: + return Response( + status_code=200, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_without_count}, + ), + ) + + +@pytest.fixture +def request_response_with_prefer_header_with_count_and_content_range( + prefer_header_with_count: str, content_range_header_with_count: str +) -> Response: + return Response( + status_code=200, + headers={"content-range": content_range_header_with_count}, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_with_count}, + ), + ) + + +@pytest.fixture +def request_response_with_data( + prefer_header_with_count: str, + content_range_header_with_count: str, + api_response: List[Dict[str, Any]], +) -> Response: + return Response( + status_code=200, + headers={"content-range": content_range_header_with_count}, + json=api_response, + request=Request( + method="GET", + url="http://example.com", + headers={"prefer": prefer_header_with_count}, + ), + ) + + +class TestApiResponse: + def test_response_raises_when_api_error( + self, api_response_with_error: Dict[str, Any] + ): + with pytest.raises(ValueError): + APIResponse(data=api_response_with_error) + + def test_parses_valid_response_only_data(self, api_response: List[Dict[str, Any]]): + result = APIResponse(data=api_response) + assert result.data == api_response + + def test_parses_valid_response_data_and_count( + self, api_response: List[Dict[str, Any]] + ): + count = len(api_response) + result = APIResponse(data=api_response, count=count) + assert result.data == api_response + assert result.count == count + + def test_get_count_from_content_range_header_with_count( + self, content_range_header_with_count: str + ): + assert ( + APIResponse._get_count_from_content_range_header( + content_range_header_with_count + ) + == 2 + ) + + def test_get_count_from_content_range_header_without_count( + self, content_range_header_without_count: str + ): + assert ( + APIResponse._get_count_from_content_range_header( + content_range_header_without_count + ) + is None + ) + + def test_is_count_in_prefer_header_true(self, prefer_header_with_count: str): + assert APIResponse._is_count_in_prefer_header(prefer_header_with_count) + + def test_is_count_in_prefer_header_false(self, prefer_header_without_count: str): + assert not APIResponse._is_count_in_prefer_header(prefer_header_without_count) + + def test_get_count_from_http_request_response_without_prefer_header( + self, request_response_without_prefer_header: Response + ): + assert ( + APIResponse._get_count_from_http_request_response( + request_response_without_prefer_header + ) + is None + ) + + def test_get_count_from_http_request_response_with_prefer_header_without_count( + self, request_response_with_prefer_header_without_count: Response + ): + assert ( + APIResponse._get_count_from_http_request_response( + request_response_with_prefer_header_without_count + ) + is None + ) + + def test_get_count_from_http_request_response_with_count_and_content_range( + self, request_response_with_prefer_header_with_count_and_content_range: Response + ): + assert ( + APIResponse._get_count_from_http_request_response( + request_response_with_prefer_header_with_count_and_content_range + ) + == 2 + ) + + def test_from_http_request_response_constructor( + self, request_response_with_data: Response, api_response: List[Dict[str, Any]] + ): + result = APIResponse.from_http_request_response(request_response_with_data) + assert result.data == api_response + assert result.count == 2