Skip to content

Commit

Permalink
Add tooling to make customization of DB Connection possible (#538)
Browse files Browse the repository at this point in the history
* Add functionality to customize db connection retrieval

* Add test for customizing the connection_getter

* Cleanup

* isort fixes

* flake8 fix

* Update typing

---------

Co-authored-by: David Bitner <bitner@dbspatial.com>
  • Loading branch information
alukach and bitner committed May 4, 2023
1 parent 42b5588 commit 2f2faa2
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 93 deletions.
57 changes: 27 additions & 30 deletions stac_fastapi/pgstac/stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import orjson
from asyncpg.exceptions import InvalidDatetimeFormatError
from buildpg import render
from fastapi import HTTPException
from fastapi import HTTPException, Request
from pydantic import ValidationError
from pygeofilter.backends.cql2_json import to_cql2
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
from pypgstac.hydration import hydrate
from stac_pydantic.links import Relations
from stac_pydantic.shared import MimeTypes
from starlette.requests import Request

from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.models.links import (
Expand All @@ -38,13 +37,11 @@
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

async def all_collections(self, **kwargs) -> Collections:
async def all_collections(self, request: Request, **kwargs) -> Collections:
"""Read all collections from the database."""
request: Request = kwargs["request"]
base_url = get_base_url(request)
pool = request.app.state.readpool

async with pool.acquire() as conn:
async with request.app.state.get_connection(request, "r") as conn:
collections = await conn.fetchval(
"""
SELECT * FROM all_collections();
Expand Down Expand Up @@ -80,7 +77,9 @@ async def all_collections(self, **kwargs) -> Collections:
collection_list = Collections(collections=linked_collections or [], links=links)
return collection_list

async def get_collection(self, collection_id: str, **kwargs) -> Collection:
async def get_collection(
self, collection_id: str, request: Request, **kwargs
) -> Collection:
"""Get collection by id.
Called with `GET /collections/{collection_id}`.
Expand All @@ -93,9 +92,7 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
"""
collection: Optional[Dict[str, Any]]

request: Request = kwargs["request"]
pool = request.app.state.readpool
async with pool.acquire() as conn:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM get_collection(:id::text);
Expand Down Expand Up @@ -125,8 +122,7 @@ async def _get_base_item(
"""
item: Optional[Dict[str, Any]]

pool = request.app.state.readpool
async with pool.acquire() as conn:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM collection_base_item(:collection_id::text);
Expand All @@ -143,7 +139,7 @@ async def _get_base_item(
async def _search_base(
self,
search_request: PgstacSearch,
**kwargs: Any,
request: Request,
) -> ItemCollection:
"""Cross catalog search (POST).
Expand All @@ -157,21 +153,19 @@ async def _search_base(
"""
items: Dict[str, Any]

request: Request = kwargs["request"]
settings: Settings = request.app.state.settings
pool = request.app.state.readpool

search_request.conf = search_request.conf or {}
search_request.conf["nohydrate"] = settings.use_api_hydrate
req = search_request.json(exclude_none=True, by_alias=True)
search_request_json = search_request.json(exclude_none=True, by_alias=True)

try:
async with pool.acquire() as conn:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM search(:req::text::jsonb);
""",
req=req,
req=search_request_json,
)
items = await conn.fetchval(q, *p)
except InvalidDatetimeFormatError:
Expand Down Expand Up @@ -253,6 +247,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
async def item_collection(
self,
collection_id: str,
request: Request,
bbox: Optional[List[NumType]] = None,
datetime: Optional[Union[str, datetime]] = None,
limit: Optional[int] = None,
Expand All @@ -272,7 +267,7 @@ async def item_collection(
An ItemCollection.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, **kwargs)
await self.get_collection(collection_id, request)

base_args = {
"collections": [collection_id],
Expand All @@ -287,17 +282,19 @@ async def item_collection(
if v is not None and v != []:
clean[k] = v

req = self.post_request_model(
search_request = self.post_request_model(
**clean,
)
item_collection = await self._search_base(req, **kwargs)
item_collection = await self._search_base(search_request, request)
links = await ItemCollectionLinks(
collection_id=collection_id, request=kwargs["request"]
collection_id=collection_id, request=request
).get_links(extra_links=item_collection["links"])
item_collection["links"] = links
return item_collection

async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
async def get_item(
self, item_id: str, collection_id: str, request: Request, **kwargs
) -> Item:
"""Get item by id.
Called with `GET /collections/{collection_id}/items/{item_id}`.
Expand All @@ -310,12 +307,12 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
Item.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, **kwargs)
await self.get_collection(collection_id, request)

req = self.post_request_model(
search_request = self.post_request_model(
ids=[item_id], collections=[collection_id], limit=1
)
item_collection = await self._search_base(req, **kwargs)
item_collection = await self._search_base(search_request, request)
if not item_collection["features"]:
raise NotFoundError(
f"Item {item_id} in Collection {collection_id} does not exist."
Expand All @@ -324,7 +321,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
return Item(**item_collection["features"][0])

async def post_search(
self, search_request: PgstacSearch, **kwargs
self, search_request: PgstacSearch, request: Request, **kwargs
) -> ItemCollection:
"""Cross catalog search (POST).
Expand All @@ -336,11 +333,12 @@ async def post_search(
Returns:
ItemCollection containing items which match the search criteria.
"""
item_collection = await self._search_base(search_request, **kwargs)
item_collection = await self._search_base(search_request, request)
return ItemCollection(**item_collection)

async def get_search(
self,
request: Request,
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[List[NumType]] = None,
Expand All @@ -362,7 +360,6 @@ async def get_search(
Returns:
ItemCollection containing items which match the search criteria.
"""
request = kwargs["request"]
query_params = str(request.query_params)

# Kludgy fix because using factory does not allow alias for filter-lang
Expand Down Expand Up @@ -432,4 +429,4 @@ async def get_search(
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
)
return await self.post_search(search_request, request=kwargs["request"])
return await self.post_search(search_request, request=request)
66 changes: 41 additions & 25 deletions stac_fastapi/pgstac/stac_fastapi/pgstac/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Database connection handling."""

import json
from contextlib import contextmanager
from typing import Dict, Generator, Union
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncIterator, Callable, Dict, Generator, Literal, Union

import attr
import orjson
from asyncpg import exceptions, pool
from asyncpg import Connection, exceptions
from buildpg import V, asyncpg, render
from fastapi import FastAPI
from fastapi import FastAPI, Request

from stac_fastapi.types.errors import (
ConflictError,
Expand All @@ -34,8 +34,11 @@ async def con_init(conn):
)


async def connect_to_db(app: FastAPI) -> None:
"""Connect to Database."""
ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]]


async def connect_to_db(app: FastAPI, get_conn: ConnectionGetter = None) -> None:
"""Create connection pools & connection retriever on application."""
settings = app.state.settings
if app.state.settings.testing:
readpool = writepool = settings.testing_connection_string
Expand All @@ -45,6 +48,7 @@ async def connect_to_db(app: FastAPI) -> None:
db = DB()
app.state.readpool = await db.create_pool(readpool, settings)
app.state.writepool = await db.create_pool(writepool, settings)
app.state.get_connection = get_conn if get_conn else get_connection


async def close_db_connection(app: FastAPI) -> None:
Expand All @@ -53,7 +57,21 @@ async def close_db_connection(app: FastAPI) -> None:
await app.state.writepool.close()


async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]):
@asynccontextmanager
async def get_connection(
request: Request,
readwrite: Literal["r", "w"] = "r",
) -> AsyncIterator[Connection]:
"""Retrieve connection from database conection pool."""
pool = (
request.app.state.writepool if readwrite == "w" else request.app.state.readpool
)
with translate_pgstac_errors():
async with pool.acquire() as conn:
yield conn


async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict]):
"""Wrap PLPGSQL Functions.
Keyword arguments:
Expand All @@ -64,25 +82,23 @@ async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]):
"""
with translate_pgstac_errors():
if isinstance(arg, str):
async with pool.acquire() as conn:
q, p = render(
"""
SELECT * FROM :func(:item::text);
""",
func=V(func),
item=arg,
)
return await conn.fetchval(q, *p)
q, p = render(
"""
SELECT * FROM :func(:item::text);
""",
func=V(func),
item=arg,
)
return await conn.fetchval(q, *p)
else:
async with pool.acquire() as conn:
q, p = render(
"""
SELECT * FROM :func(:item::text::jsonb);
""",
func=V(func),
item=json.dumps(arg),
)
return await conn.fetchval(q, *p)
q, p = render(
"""
SELECT * FROM :func(:item::text::jsonb);
""",
func=V(func),
item=json.dumps(arg),
)
return await conn.fetchval(q, *p)


@contextmanager
Expand Down
Loading

0 comments on commit 2f2faa2

Please sign in to comment.