Skip to content

Commit

Permalink
feat: generic query builders (#309)
Browse files Browse the repository at this point in the history
* feat: make all query builders generic

* feat: return generic request builders from client methods

* chore: use typing.List instead of builtin

* chore: use typing.List

* fix: correct type of APIResponse.data

* feat: make RPCFilterRequestBuilder

This makes sure the return types of rpc() and other
query methods are correct.
See https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
for an explanation.

* chore: use typing.List

* feat: make get_origin_and_cast

This fixes the type-checker error raised while accessing
RequestBuilder[T].__origin__

* fix: use typing.List
  • Loading branch information
anand2312 committed Sep 20, 2023
1 parent 3329234 commit ba9ad8d
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 150 deletions.
22 changes: 13 additions & 9 deletions postgrest/_async/client.py
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Dict, Union, cast
from typing import Any, Dict, Union, cast

from deprecation import deprecated
from httpx import Headers, QueryParams, Timeout
Expand All @@ -12,7 +12,9 @@
DEFAULT_POSTGREST_CLIENT_TIMEOUT,
)
from ..utils import AsyncClient
from .request_builder import AsyncFilterRequestBuilder, AsyncRequestBuilder
from .request_builder import AsyncRequestBuilder, AsyncRPCFilterRequestBuilder

_TableT = Dict[str, Any]


class AsyncPostgrestClient(BasePostgrestClient):
Expand Down Expand Up @@ -57,17 +59,17 @@ async def aclose(self) -> None:
"""Close the underlying HTTP connections."""
await self.session.aclose()

def from_(self, table: str) -> AsyncRequestBuilder:
def from_(self, table: str) -> AsyncRequestBuilder[_TableT]:
"""Perform a table operation.
Args:
table: The name of the table
Returns:
:class:`AsyncRequestBuilder`
"""
return AsyncRequestBuilder(self.session, f"/{table}")
return AsyncRequestBuilder[_TableT](self.session, f"/{table}")

def table(self, table: str) -> AsyncRequestBuilder:
def table(self, table: str) -> AsyncRequestBuilder[_TableT]:
"""Alias to :meth:`from_`."""
return self.from_(table)

Expand All @@ -76,24 +78,26 @@ def from_table(self, table: str) -> AsyncRequestBuilder:
"""Alias to :meth:`from_`."""
return self.from_(table)

async def rpc(self, func: str, params: dict) -> AsyncFilterRequestBuilder:
async def rpc(self, func: str, params: dict) -> AsyncRPCFilterRequestBuilder[Any]:
"""Perform a stored procedure call.
Args:
func: The name of the remote procedure to run.
params: The parameters to be passed to the remote procedure.
Returns:
:class:`AsyncFilterRequestBuilder`
:class:`AsyncRPCFilterRequestBuilder`
Example:
.. code-block:: python
await client.rpc("foobar", {"arg": "value"}).execute()
.. versionchanged:: 0.11.0
.. versionchanged:: 0.10.9
This method now returns a :class:`AsyncRPCFilterRequestBuilder`.
.. versionchanged:: 0.10.2
This method now returns a :class:`AsyncFilterRequestBuilder` which allows you to
filter on the RPC's resultset.
"""
# the params here are params to be sent to the RPC and not the queryparams!
return AsyncFilterRequestBuilder(
return AsyncRPCFilterRequestBuilder[Any](
self.session, f"/rpc/{func}", "POST", Headers(), QueryParams(), json=params
)
97 changes: 62 additions & 35 deletions postgrest/_async/request_builder.py
@@ -1,7 +1,7 @@
from __future__ import annotations

from json import JSONDecodeError
from typing import Optional, Union
from typing import Any, Generic, Optional, TypeVar, Union

from httpx import Headers, QueryParams
from pydantic import ValidationError
Expand All @@ -20,10 +20,12 @@
)
from ..exceptions import APIError, generate_default_error_message
from ..types import ReturnMethod
from ..utils import AsyncClient
from ..utils import AsyncClient, get_origin_and_cast

_ReturnT = TypeVar("_ReturnT")

class AsyncQueryRequestBuilder:

class AsyncQueryRequestBuilder(Generic[_ReturnT]):
def __init__(
self,
session: AsyncClient,
Expand All @@ -40,7 +42,7 @@ def __init__(
self.params = params
self.json = json

async def execute(self) -> APIResponse:
async def execute(self) -> APIResponse[_ReturnT]:
"""Execute the query.
.. tip::
Expand All @@ -63,7 +65,7 @@ async def execute(self) -> APIResponse:
if (
200 <= r.status_code <= 299
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
return APIResponse.from_http_request_response(r)
return APIResponse[_ReturnT].from_http_request_response(r)
else:
raise APIError(r.json())
except ValidationError as e:
Expand All @@ -72,7 +74,7 @@ async def execute(self) -> APIResponse:
raise APIError(generate_default_error_message(r))


class AsyncSingleRequestBuilder:
class AsyncSingleRequestBuilder(Generic[_ReturnT]):
def __init__(
self,
session: AsyncClient,
Expand All @@ -89,7 +91,7 @@ def __init__(
self.params = params
self.json = json

async def execute(self) -> SingleAPIResponse:
async def execute(self) -> SingleAPIResponse[_ReturnT]:
"""Execute the query.
.. tip::
Expand All @@ -112,7 +114,7 @@ async def execute(self) -> SingleAPIResponse:
if (
200 <= r.status_code <= 299
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
return SingleAPIResponse.from_http_request_response(r)
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
else:
raise APIError(r.json())
except ValidationError as e:
Expand All @@ -121,11 +123,11 @@ async def execute(self) -> SingleAPIResponse:
raise APIError(generate_default_error_message(r))


class AsyncMaybeSingleRequestBuilder(AsyncSingleRequestBuilder):
async def execute(self) -> Optional[SingleAPIResponse]:
class AsyncMaybeSingleRequestBuilder(AsyncSingleRequestBuilder[_ReturnT]):
async def execute(self) -> Optional[SingleAPIResponse[_ReturnT]]:
r = None
try:
r = await super().execute()
r = await AsyncSingleRequestBuilder[_ReturnT].execute(self)
except APIError as e:
if e.details and "The result contains 0 rows" in e.details:
return None
Expand All @@ -142,7 +144,7 @@ async def execute(self) -> Optional[SingleAPIResponse]:


# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319
class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): # type: ignore
class AsyncFilterRequestBuilder(BaseFilterRequestBuilder[_ReturnT], AsyncQueryRequestBuilder[_ReturnT]): # type: ignore
def __init__(
self,
session: AsyncClient,
Expand All @@ -152,14 +154,37 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
BaseFilterRequestBuilder.__init__(self, session, headers, params)
AsyncQueryRequestBuilder.__init__(
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)


# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
class AsyncRPCFilterRequestBuilder(
BaseFilterRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT]
):
def __init__(
self,
session: AsyncClient,
path: str,
http_method: str,
headers: Headers,
params: QueryParams,
json: dict,
) -> None:
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
get_origin_and_cast(AsyncSingleRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)


# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319
class AsyncSelectRequestBuilder(BaseSelectRequestBuilder, AsyncQueryRequestBuilder): # type: ignore
class AsyncSelectRequestBuilder(BaseSelectRequestBuilder[_ReturnT], AsyncQueryRequestBuilder[_ReturnT]): # type: ignore
def __init__(
self,
session: AsyncClient,
Expand All @@ -169,19 +194,21 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
BaseSelectRequestBuilder.__init__(self, session, headers, params)
AsyncQueryRequestBuilder.__init__(
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

def single(self) -> AsyncSingleRequestBuilder:
def single(self) -> AsyncSingleRequestBuilder[_ReturnT]:
"""Specify that the query will only return a single row in response.
.. caution::
The API will raise an error if the query returned more than one row.
"""
self.headers["Accept"] = "application/vnd.pgrst.object+json"
return AsyncSingleRequestBuilder(
return AsyncSingleRequestBuilder[_ReturnT](
headers=self.headers,
http_method=self.http_method,
json=self.json,
Expand All @@ -190,10 +217,10 @@ def single(self) -> AsyncSingleRequestBuilder:
session=self.session, # type: ignore
)

def maybe_single(self) -> AsyncMaybeSingleRequestBuilder:
def maybe_single(self) -> AsyncMaybeSingleRequestBuilder[_ReturnT]:
"""Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
self.headers["Accept"] = "application/vnd.pgrst.object+json"
return AsyncMaybeSingleRequestBuilder(
return AsyncMaybeSingleRequestBuilder[_ReturnT](
headers=self.headers,
http_method=self.http_method,
json=self.json,
Expand All @@ -203,8 +230,8 @@ def maybe_single(self) -> AsyncMaybeSingleRequestBuilder:
)

def text_search(
self, column: str, query: str, options: Dict[str, any] = {}
) -> AsyncFilterRequestBuilder:
self, column: str, query: str, options: dict[str, Any] = {}
) -> AsyncFilterRequestBuilder[_ReturnT]:
type_ = options.get("type")
type_part = ""
if type_ == "plain":
Expand All @@ -216,7 +243,7 @@ def text_search(
config_part = f"({options.get('config')})" if options.get("config") else ""
self.params = self.params.add(column, f"{type_part}fts{config_part}.{query}")

return AsyncQueryRequestBuilder(
return AsyncQueryRequestBuilder[_ReturnT](
headers=self.headers,
http_method=self.http_method,
json=self.json,
Expand All @@ -226,7 +253,7 @@ def text_search(
)


class AsyncRequestBuilder:
class AsyncRequestBuilder(Generic[_ReturnT]):
def __init__(self, session: AsyncClient, path: str) -> None:
self.session = session
self.path = path
Expand All @@ -235,7 +262,7 @@ def select(
self,
*columns: str,
count: Optional[CountMethod] = None,
) -> AsyncSelectRequestBuilder:
) -> AsyncSelectRequestBuilder[_ReturnT]:
"""Run a SELECT query.
Args:
Expand All @@ -245,7 +272,7 @@ def select(
:class:`AsyncSelectRequestBuilder`
"""
method, params, headers, json = pre_select(*columns, count=count)
return AsyncSelectRequestBuilder(
return AsyncSelectRequestBuilder[_ReturnT](
self.session, self.path, method, headers, params, json
)

Expand All @@ -256,7 +283,7 @@ def insert(
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
upsert: bool = False,
) -> AsyncQueryRequestBuilder:
) -> AsyncQueryRequestBuilder[_ReturnT]:
"""Run an INSERT query.
Args:
Expand All @@ -273,7 +300,7 @@ def insert(
returning=returning,
upsert=upsert,
)
return AsyncQueryRequestBuilder(
return AsyncQueryRequestBuilder[_ReturnT](
self.session, self.path, method, headers, params, json
)

Expand All @@ -285,7 +312,7 @@ def upsert(
returning: ReturnMethod = ReturnMethod.representation,
ignore_duplicates: bool = False,
on_conflict: str = "",
) -> AsyncQueryRequestBuilder:
) -> AsyncQueryRequestBuilder[_ReturnT]:
"""Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query.
Args:
Expand All @@ -304,7 +331,7 @@ def upsert(
ignore_duplicates=ignore_duplicates,
on_conflict=on_conflict,
)
return AsyncQueryRequestBuilder(
return AsyncQueryRequestBuilder[_ReturnT](
self.session, self.path, method, headers, params, json
)

Expand All @@ -314,7 +341,7 @@ def update(
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
) -> AsyncFilterRequestBuilder:
) -> AsyncFilterRequestBuilder[_ReturnT]:
"""Run an UPDATE query.
Args:
Expand All @@ -329,7 +356,7 @@ def update(
count=count,
returning=returning,
)
return AsyncFilterRequestBuilder(
return AsyncFilterRequestBuilder[_ReturnT](
self.session, self.path, method, headers, params, json
)

Expand All @@ -338,7 +365,7 @@ def delete(
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
) -> AsyncFilterRequestBuilder:
) -> AsyncFilterRequestBuilder[_ReturnT]:
"""Run a DELETE query.
Args:
Expand All @@ -351,7 +378,7 @@ def delete(
count=count,
returning=returning,
)
return AsyncFilterRequestBuilder(
return AsyncFilterRequestBuilder[_ReturnT](
self.session, self.path, method, headers, params, json
)

Expand Down

0 comments on commit ba9ad8d

Please sign in to comment.