From 2f2faa2301cb3557a392b1c91c494a17491ba25c Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 4 May 2023 11:32:47 -0700 Subject: [PATCH] Add tooling to make customization of DB Connection possible (#538) * Add functionality to customize db connection retrieval * Add test for customizing the connection_getter * Cleanup * isort fixes * flake8 fix * Update typing --------- Co-authored-by: David Bitner --- .../pgstac/stac_fastapi/pgstac/core.py | 57 +++++++-------- stac_fastapi/pgstac/stac_fastapi/pgstac/db.py | 66 ++++++++++------- .../stac_fastapi/pgstac/transactions.py | 70 +++++++++---------- .../pgstac/tests/clients/test_postgres.py | 42 ++++++++++- 4 files changed, 142 insertions(+), 93 deletions(-) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index a8c73d9f8..568cf0995 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -8,14 +8,13 @@ import orjson from asyncpg.exceptions import InvalidDatetimeFormatError from buildpg import render -from fastapi import HTTPException +from fastapi import HTTPException, Request from pydantic import ValidationError from pygeofilter.backends.cql2_json import to_cql2 from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate from stac_pydantic.links import Relations from stac_pydantic.shared import MimeTypes -from starlette.requests import Request from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.models.links import ( @@ -38,13 +37,11 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - async def all_collections(self, **kwargs) -> Collections: + async def all_collections(self, request: Request, **kwargs) -> Collections: """Read all collections from the database.""" - request: Request = kwargs["request"] base_url = get_base_url(request) - pool = request.app.state.readpool - async with pool.acquire() as conn: + async with request.app.state.get_connection(request, "r") as conn: collections = await conn.fetchval( """ SELECT * FROM all_collections(); @@ -80,7 +77,9 @@ async def all_collections(self, **kwargs) -> Collections: collection_list = Collections(collections=linked_collections or [], links=links) return collection_list - async def get_collection(self, collection_id: str, **kwargs) -> Collection: + async def get_collection( + self, collection_id: str, request: Request, **kwargs + ) -> Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -93,9 +92,7 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection: """ collection: Optional[Dict[str, Any]] - request: Request = kwargs["request"] - pool = request.app.state.readpool - async with pool.acquire() as conn: + async with request.app.state.get_connection(request, "r") as conn: q, p = render( """ SELECT * FROM get_collection(:id::text); @@ -125,8 +122,7 @@ async def _get_base_item( """ item: Optional[Dict[str, Any]] - pool = request.app.state.readpool - async with pool.acquire() as conn: + async with request.app.state.get_connection(request, "r") as conn: q, p = render( """ SELECT * FROM collection_base_item(:collection_id::text); @@ -143,7 +139,7 @@ async def _get_base_item( async def _search_base( self, search_request: PgstacSearch, - **kwargs: Any, + request: Request, ) -> ItemCollection: """Cross catalog search (POST). @@ -157,21 +153,19 @@ async def _search_base( """ items: Dict[str, Any] - request: Request = kwargs["request"] settings: Settings = request.app.state.settings - pool = request.app.state.readpool search_request.conf = search_request.conf or {} search_request.conf["nohydrate"] = settings.use_api_hydrate - req = search_request.json(exclude_none=True, by_alias=True) + search_request_json = search_request.json(exclude_none=True, by_alias=True) try: - async with pool.acquire() as conn: + async with request.app.state.get_connection(request, "r") as conn: q, p = render( """ SELECT * FROM search(:req::text::jsonb); """, - req=req, + req=search_request_json, ) items = await conn.fetchval(q, *p) except InvalidDatetimeFormatError: @@ -253,6 +247,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]: async def item_collection( self, collection_id: str, + request: Request, bbox: Optional[List[NumType]] = None, datetime: Optional[Union[str, datetime]] = None, limit: Optional[int] = None, @@ -272,7 +267,7 @@ async def item_collection( An ItemCollection. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, **kwargs) + await self.get_collection(collection_id, request) base_args = { "collections": [collection_id], @@ -287,17 +282,19 @@ async def item_collection( if v is not None and v != []: clean[k] = v - req = self.post_request_model( + search_request = self.post_request_model( **clean, ) - item_collection = await self._search_base(req, **kwargs) + item_collection = await self._search_base(search_request, request) links = await ItemCollectionLinks( - collection_id=collection_id, request=kwargs["request"] + collection_id=collection_id, request=request ).get_links(extra_links=item_collection["links"]) item_collection["links"] = links return item_collection - async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: + async def get_item( + self, item_id: str, collection_id: str, request: Request, **kwargs + ) -> Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -310,12 +307,12 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: Item. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, **kwargs) + await self.get_collection(collection_id, request) - req = self.post_request_model( + search_request = self.post_request_model( ids=[item_id], collections=[collection_id], limit=1 ) - item_collection = await self._search_base(req, **kwargs) + item_collection = await self._search_base(search_request, request) if not item_collection["features"]: raise NotFoundError( f"Item {item_id} in Collection {collection_id} does not exist." @@ -324,7 +321,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: return Item(**item_collection["features"][0]) async def post_search( - self, search_request: PgstacSearch, **kwargs + self, search_request: PgstacSearch, request: Request, **kwargs ) -> ItemCollection: """Cross catalog search (POST). @@ -336,11 +333,12 @@ async def post_search( Returns: ItemCollection containing items which match the search criteria. """ - item_collection = await self._search_base(search_request, **kwargs) + item_collection = await self._search_base(search_request, request) return ItemCollection(**item_collection) async def get_search( self, + request: Request, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[List[NumType]] = None, @@ -362,7 +360,6 @@ async def get_search( Returns: ItemCollection containing items which match the search criteria. """ - request = kwargs["request"] query_params = str(request.query_params) # Kludgy fix because using factory does not allow alias for filter-lang @@ -432,4 +429,4 @@ async def get_search( raise HTTPException( status_code=400, detail=f"Invalid parameters provided {e}" ) - return await self.post_search(search_request, request=kwargs["request"]) + return await self.post_search(search_request, request=request) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py index 57c39c0ba..afa15abe7 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py @@ -1,14 +1,14 @@ """Database connection handling.""" import json -from contextlib import contextmanager -from typing import Dict, Generator, Union +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncIterator, Callable, Dict, Generator, Literal, Union import attr import orjson -from asyncpg import exceptions, pool +from asyncpg import Connection, exceptions from buildpg import V, asyncpg, render -from fastapi import FastAPI +from fastapi import FastAPI, Request from stac_fastapi.types.errors import ( ConflictError, @@ -34,8 +34,11 @@ async def con_init(conn): ) -async def connect_to_db(app: FastAPI) -> None: - """Connect to Database.""" +ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]] + + +async def connect_to_db(app: FastAPI, get_conn: ConnectionGetter = None) -> None: + """Create connection pools & connection retriever on application.""" settings = app.state.settings if app.state.settings.testing: readpool = writepool = settings.testing_connection_string @@ -45,6 +48,7 @@ async def connect_to_db(app: FastAPI) -> None: db = DB() app.state.readpool = await db.create_pool(readpool, settings) app.state.writepool = await db.create_pool(writepool, settings) + app.state.get_connection = get_conn if get_conn else get_connection async def close_db_connection(app: FastAPI) -> None: @@ -53,7 +57,21 @@ async def close_db_connection(app: FastAPI) -> None: await app.state.writepool.close() -async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]): +@asynccontextmanager +async def get_connection( + request: Request, + readwrite: Literal["r", "w"] = "r", +) -> AsyncIterator[Connection]: + """Retrieve connection from database conection pool.""" + pool = ( + request.app.state.writepool if readwrite == "w" else request.app.state.readpool + ) + with translate_pgstac_errors(): + async with pool.acquire() as conn: + yield conn + + +async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict]): """Wrap PLPGSQL Functions. Keyword arguments: @@ -64,25 +82,23 @@ async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]): """ with translate_pgstac_errors(): if isinstance(arg, str): - async with pool.acquire() as conn: - q, p = render( - """ - SELECT * FROM :func(:item::text); - """, - func=V(func), - item=arg, - ) - return await conn.fetchval(q, *p) + q, p = render( + """ + SELECT * FROM :func(:item::text); + """, + func=V(func), + item=arg, + ) + return await conn.fetchval(q, *p) else: - async with pool.acquire() as conn: - q, p = render( - """ - SELECT * FROM :func(:item::text::jsonb); - """, - func=V(func), - item=json.dumps(arg), - ) - return await conn.fetchval(q, *p) + q, p = render( + """ + SELECT * FROM :func(:item::text::jsonb); + """, + func=V(func), + item=json.dumps(arg), + ) + return await conn.fetchval(q, *p) @contextmanager diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py index 68479aa34..91cb1fee2 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py @@ -5,14 +5,14 @@ import attr from buildpg import render -from fastapi import HTTPException +from fastapi import HTTPException, Request from starlette.responses import JSONResponse, Response from stac_fastapi.extensions.third_party.bulk_transactions import ( AsyncBaseBulkTransactionsClient, Items, ) -from stac_fastapi.pgstac.db import dbfunc, translate_pgstac_errors +from stac_fastapi.pgstac.db import dbfunc from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks from stac_fastapi.types import stac as stac_types from stac_fastapi.types.core import AsyncBaseTransactionsClient @@ -26,7 +26,7 @@ class TransactionsClient(AsyncBaseTransactionsClient): """Transactions extension specific CRUD operations.""" async def create_item( - self, collection_id: str, item: stac_types.Item, **kwargs + self, collection_id: str, item: stac_types.Item, request: Request, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Create item.""" body_collection_id = item.get("collection") @@ -36,9 +36,8 @@ async def create_item( detail=f"Collection ID from path parameter ({collection_id}) does not match Collection ID from Item ({body_collection_id})", ) item["collection"] = collection_id - request = kwargs["request"] - pool = request.app.state.writepool - await dbfunc(pool, "create_item", item) + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "create_item", item) item["links"] = await ItemLinks( collection_id=collection_id, item_id=item["id"], @@ -47,7 +46,12 @@ async def create_item( return stac_types.Item(**item) async def update_item( - self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs + self, + request: Request, + collection_id: str, + item_id: str, + item: stac_types.Item, + **kwargs, ) -> Optional[Union[stac_types.Item, Response]]: """Update item.""" body_collection_id = item.get("collection") @@ -63,9 +67,8 @@ async def update_item( status_code=400, detail=f"Item ID from path parameter ({item_id}) does not match Item ID from Item ({body_item_id})", ) - request = kwargs["request"] - pool = request.app.state.writepool - await dbfunc(pool, "update_item", item) + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "update_item", item) item["links"] = await ItemLinks( collection_id=collection_id, item_id=item["id"], @@ -74,12 +77,11 @@ async def update_item( return stac_types.Item(**item) async def create_collection( - self, collection: stac_types.Collection, **kwargs + self, collection: stac_types.Collection, request: Request, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Create collection.""" - request = kwargs["request"] - pool = request.app.state.writepool - await dbfunc(pool, "create_collection", collection) + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "create_collection", collection) collection["links"] = await CollectionLinks( collection_id=collection["id"], request=request ).get_links(extra_links=collection.get("links")) @@ -87,40 +89,35 @@ async def create_collection( return stac_types.Collection(**collection) async def update_collection( - self, collection: stac_types.Collection, **kwargs + self, collection: stac_types.Collection, request: Request, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Update collection.""" - request = kwargs["request"] - pool = request.app.state.writepool - await dbfunc(pool, "update_collection", collection) + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "update_collection", collection) collection["links"] = await CollectionLinks( collection_id=collection["id"], request=request ).get_links(extra_links=collection.get("links")) return stac_types.Collection(**collection) async def delete_item( - self, item_id: str, collection_id: str, **kwargs + self, item_id: str, collection_id: str, request: Request, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Delete item.""" - request = kwargs["request"] - pool = request.app.state.writepool - async with pool.acquire() as conn: - q, p = render( - "SELECT * FROM delete_item(:item::text, :collection::text);", - item=item_id, - collection=collection_id, - ) - with translate_pgstac_errors(): - await conn.fetchval(q, *p) + q, p = render( + "SELECT * FROM delete_item(:item::text, :collection::text);", + item=item_id, + collection=collection_id, + ) + async with request.app.state.get_connection(request, "w") as conn: + await conn.fetchval(q, *p) return JSONResponse({"deleted item": item_id}) async def delete_collection( - self, collection_id: str, **kwargs + self, collection_id: str, request: Request, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Delete collection.""" - request = kwargs["request"] - pool = request.app.state.writepool - await dbfunc(pool, "delete_collection", collection_id) + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "delete_collection", collection_id) return JSONResponse({"deleted collection": collection_id}) @@ -128,12 +125,11 @@ async def delete_collection( class BulkTransactionsClient(AsyncBaseBulkTransactionsClient): """Postgres bulk transactions.""" - async def bulk_item_insert(self, items: Items, **kwargs) -> str: + async def bulk_item_insert(self, items: Items, request: Request, **kwargs) -> str: """Bulk item insertion using pgstac.""" - request = kwargs["request"] - pool = request.app.state.writepool items = list(items.items.values()) - await dbfunc(pool, "create_items", items) + async with request.app.state.get_connection(request, "w") as conn: + await dbfunc(conn, "create_items", items) return_msg = f"Successfully added {len(items)} items." return return_msg diff --git a/stac_fastapi/pgstac/tests/clients/test_postgres.py b/stac_fastapi/pgstac/tests/clients/test_postgres.py index 345dc4f9a..2d8fd6b52 100644 --- a/stac_fastapi/pgstac/tests/clients/test_postgres.py +++ b/stac_fastapi/pgstac/tests/clients/test_postgres.py @@ -1,10 +1,17 @@ +import logging import uuid +from contextlib import asynccontextmanager from copy import deepcopy -from typing import Callable +from typing import Callable, Literal +import pytest +from fastapi import Request from stac_pydantic import Collection, Item +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db, get_connection + # from tests.conftest import MockStarletteRequest +logger = logging.getLogger(__name__) async def test_create_collection(app_client, load_test_data: Callable): @@ -170,3 +177,36 @@ async def test_create_bulk_items( # for item in fc.features: # assert item.collection == coll.id + + +@asynccontextmanager +async def custom_get_connection( + request: Request, + readwrite: Literal["r", "w"], +): + """An example of customizing the connection getter""" + async with get_connection(request, readwrite) as conn: + await conn.execute("SELECT set_config('api.test', 'added-config', false)") + yield conn + + +class TestDbConnect: + @pytest.fixture + async def app(self, api_client): + """ + app fixture override to setup app with a customized db connection getter + """ + logger.debug("Customizing app setup") + await connect_to_db(api_client.app, custom_get_connection) + yield api_client.app + await close_db_connection(api_client.app) + + async def test_db_setup(self, api_client, app_client): + @api_client.app.get(f"{api_client.router.prefix}/db-test") + async def example_view(request: Request): + async with request.app.state.get_connection(request, "r") as conn: + return await conn.fetchval("SELECT current_setting('api.test', true)") + + response = await app_client.get("/db-test") + assert response.status_code == 200 + assert response.json() == "added-config"