diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 641be715d..3da9d9cdc 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -81,6 +81,7 @@ jobs: run: python -m pytest stac_fastapi/api/tests/benchmarks.py --benchmark-only --benchmark-columns 'min, max, mean, median' --benchmark-json output.json - name: Store and benchmark result + if: github.repository == 'stac-utils/stac-fastapi' uses: benchmark-action/github-action-benchmark@v1 with: name: STAC FastAPI Benchmarks diff --git a/.github/workflows/deploy_mkdocs.yml b/.github/workflows/deploy_mkdocs.yml index 4715015e9..a3469aad8 100644 --- a/.github/workflows/deploy_mkdocs.yml +++ b/.github/workflows/deploy_mkdocs.yml @@ -20,10 +20,10 @@ jobs: - name: Checkout main uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.11 - name: Install dependencies run: | diff --git a/.gitignore b/.gitignore index 908694a3a..3b2a1fea8 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,7 @@ docs/api/* # Virtualenv venv +.venv/ # IDE .vscode \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 193edc5c7..68c3b8567 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,7 @@ repos: - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.0.267" + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.2.2" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black + - id: ruff-format diff --git a/CHANGES.md b/CHANGES.md index 2cde04885..b47b03aa3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,15 @@ ## [Unreleased] +## Changes + +* Update to pydantic v2 and stac_pydantic v3 ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Removed internal Search and Operator Types in favor of stac_pydantic Types ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Fix response model validation ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Add Response Model to OpenAPI, even if model validation is turned off ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Use status code 201 for Item/Collection creation ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Replace Black with Ruff Format ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) + ## [2.5.5.post1] - 2024-04-25 ### Fixed @@ -48,6 +57,7 @@ * Add `/queryables` link to the landing page ([#587](https://github.com/stac-utils/stac-fastapi/pull/587)) - `id`, `title`, `description` and `api_version` fields can be customized via env variables * Add `DeprecationWarning` for the `ContextExtension` +* Add support for Python 3.12 ### Changed diff --git a/Dockerfile b/Dockerfile index 501de7f36..9b6817182 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.8-slim as base +FROM python:3.11-slim as base # Any python libraries that require system libraries to be installed will likely # need the following packages in order to build diff --git a/Dockerfile.docs b/Dockerfile.docs index caa0f7e9f..6c7f00843 100644 --- a/Dockerfile.docs +++ b/Dockerfile.docs @@ -1,4 +1,4 @@ -FROM python:3.8-slim +FROM python:3.11-slim # build-essential is required to build a wheel for ciso8601 RUN apt update && apt install -y build-essential diff --git a/README.md b/README.md index 9a8ec78ed..02c155993 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,18 @@ Backends are hosted in their own repositories: `stac-fastapi` was initially developed by [arturo-ai](https://github.com/arturo-ai). + +## Response Model Validation + +A common question when using this package is how request and response types are validated? + +This package uses [`stac-pydantic`](https://github.com/stac-utils/stac-pydantic) to validate and document STAC objects. However, by default, validation of response types is turned off and the API will simply forward responses without validating them against the Pydantic model first. This decision was made with the assumption that responses usually come from a (typed) database and can be considered safe. Extra validation would only increase latency, in particular for large payloads. + +To turn on response validation, set `ENABLE_RESPONSE_MODELS` to `True`. Either as an environment variable or directly in the `ApiSettings`. + +With the introduction of Pydantic 2, the extra [time it takes to validate models became negatable](https://github.com/stac-utils/stac-fastapi/pull/625#issuecomment-2045824578). While `ENABLE_RESPONSE_MODELS` still defaults to `False` there should be no penalty for users to turn on this feature but users discretion is advised. + + ## Installation ```bash diff --git a/pyproject.toml b/pyproject.toml index 162a81b1e..ad2edbb00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,8 @@ [tool.ruff] +target-version = "py38" # minimum supported version line-length = 90 + +[tool.ruff.lint] select = [ "C9", "D1", @@ -9,13 +12,13 @@ select = [ "W", ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "**/tests/**/*.py" = ["D1"] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["stac_fastapi"] known-third-party = ["stac_pydantic", "fastapi"] section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] -[tool.black] -target-version = ["py38", "py39", "py310", "py311"] +[tool.ruff.format] +quote-style = "double" diff --git a/stac_fastapi/api/setup.py b/stac_fastapi/api/setup.py index a5bfd897e..af596adf0 100644 --- a/stac_fastapi/api/setup.py +++ b/stac_fastapi/api/setup.py @@ -6,9 +6,6 @@ desc = f.read() install_requires = [ - "attrs", - "pydantic[dotenv]<2", - "stac_pydantic==2.0.*", "brotli_asgi", "stac-fastapi.types", ] diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 7ad0c96f5..194b22a00 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -1,5 +1,6 @@ """Fastapi app creation.""" + from typing import Any, Dict, List, Optional, Tuple, Type, Union import attr @@ -7,10 +8,10 @@ from fastapi import APIRouter, FastAPI from fastapi.openapi.utils import get_openapi from fastapi.params import Depends -from stac_pydantic import Collection, Item, ItemCollection -from stac_pydantic.api import ConformanceClasses, LandingPage +from stac_pydantic import api from stac_pydantic.api.collections import Collections -from stac_pydantic.version import STAC_VERSION +from stac_pydantic.api.version import STAC_API_VERSION +from stac_pydantic.shared import MimeTypes from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers @@ -94,7 +95,7 @@ class StacApi: lambda self: self.settings.stac_fastapi_version, takes_self=True ) ) - stac_version: str = attr.ib(default=STAC_VERSION) + stac_version: str = attr.ib(default=STAC_API_VERSION) description: str = attr.ib( default=attr.Factory( lambda self: self.settings.stac_fastapi_description, takes_self=True @@ -138,9 +139,17 @@ def register_landing_page(self): self.router.add_api_route( name="Landing Page", path="/", - response_model=LandingPage - if self.settings.enable_response_models - else None, + response_model=( + api.LandingPage if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.LandingPage, + }, + }, response_class=self.response_class, response_model_exclude_unset=False, response_model_exclude_none=True, @@ -157,9 +166,17 @@ def register_conformance_classes(self): self.router.add_api_route( name="Conformance Classes", path="/conformance", - response_model=ConformanceClasses - if self.settings.enable_response_models - else None, + response_model=( + api.ConformanceClasses if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.ConformanceClasses, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -176,7 +193,15 @@ def register_get_item(self): self.router.add_api_route( name="Get Item", path="/collections/{collection_id}/items/{item_id}", - response_model=Item if self.settings.enable_response_models else None, + response_model=api.Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.Item, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -194,9 +219,19 @@ def register_post_search(self): self.router.add_api_route( name="Search", path="/search", - response_model=(ItemCollection if not fields_ext else None) - if self.settings.enable_response_models - else None, + response_model=( + (api.ItemCollection if not fields_ext else None) + if self.settings.enable_response_models + else None + ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -216,9 +251,19 @@ def register_get_search(self): self.router.add_api_route( name="Search", path="/search", - response_model=(ItemCollection if not fields_ext else None) - if self.settings.enable_response_models - else None, + response_model=( + (api.ItemCollection if not fields_ext else None) + if self.settings.enable_response_models + else None + ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -237,9 +282,17 @@ def register_get_collections(self): self.router.add_api_route( name="Get Collections", path="/collections", - response_model=Collections - if self.settings.enable_response_models - else None, + response_model=( + Collections if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collections, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -256,7 +309,17 @@ def register_get_collection(self): self.router.add_api_route( name="Get Collection", path="/collections/{collection_id}", - response_model=Collection if self.settings.enable_response_models else None, + response_model=api.Collection + if self.settings.enable_response_models + else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.Collection, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -283,9 +346,17 @@ def register_get_item_collection(self): self.router.add_api_route( name="Get ItemCollection", path="/collections/{collection_id}/items", - response_model=ItemCollection - if self.settings.enable_response_models - else None, + response_model=( + api.ItemCollection if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index e6e4d882a..3918421ff 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -1,4 +1,5 @@ """Application settings.""" + import enum diff --git a/stac_fastapi/api/stac_fastapi/api/errors.py b/stac_fastapi/api/stac_fastapi/api/errors.py index 3f052bd31..6d90ba63a 100644 --- a/stac_fastapi/api/stac_fastapi/api/errors.py +++ b/stac_fastapi/api/stac_fastapi/api/errors.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, Type, TypedDict from fastapi import FastAPI -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, ResponseValidationError from starlette import status from starlette.requests import Request from starlette.responses import JSONResponse @@ -27,6 +27,7 @@ DatabaseError: status.HTTP_424_FAILED_DEPENDENCY, Exception: status.HTTP_500_INTERNAL_SERVER_ERROR, InvalidQueryParameter: status.HTTP_400_BAD_REQUEST, + ResponseValidationError: status.HTTP_500_INTERNAL_SERVER_ERROR, } diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index 3ed67d6c9..2ba3ef570 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,4 +1,5 @@ """Api middleware.""" + import re import typing from http.client import HTTP_PORT, HTTPS_PORT diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 0721413a9..2716fe7fb 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,12 +1,11 @@ """Api request/response models.""" import importlib.util -from typing import Optional, Type, Union +from typing import List, Optional, Type, Union import attr -from fastapi import Body, Path +from fastapi import Path from pydantic import BaseModel, create_model -from pydantic.fields import UndefinedType from stac_pydantic.shared import BBox from stac_fastapi.types.extension import ApiExtension @@ -23,8 +22,8 @@ def create_request_model( model_name="SearchGetRequest", base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest, - extensions: Optional[ApiExtension] = None, - mixins: Optional[Union[BaseModel, APIRequest]] = None, + extensions: Optional[List[ApiExtension]] = None, + mixins: Optional[Union[List[BaseModel], List[APIRequest]]] = None, request_type: Optional[str] = "GET", ) -> Union[Type[BaseModel], APIRequest]: """Create a pydantic model for validating request bodies.""" @@ -47,40 +46,19 @@ def create_request_model( # Handle POST requests elif all([issubclass(m, BaseModel) for m in models]): for model in models: - for k, v in model.__fields__.items(): - field_info = v.field_info - body = Body( - None - if isinstance(field_info.default, UndefinedType) - else field_info.default, - default_factory=field_info.default_factory, - alias=field_info.alias, - alias_priority=field_info.alias_priority, - title=field_info.title, - description=field_info.description, - const=field_info.const, - gt=field_info.gt, - ge=field_info.ge, - lt=field_info.lt, - le=field_info.le, - multiple_of=field_info.multiple_of, - min_items=field_info.min_items, - max_items=field_info.max_items, - min_length=field_info.min_length, - max_length=field_info.max_length, - regex=field_info.regex, - extra=field_info.extra, - ) - fields[k] = (v.outer_type_, body) + for k, field_info in model.model_fields.items(): + fields[k] = (field_info.annotation, field_info) return create_model(model_name, **fields, __base__=base_model) raise TypeError("Mixed Request Model types. Check extension request types.") def create_get_request_model( - extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest -): + extensions: Optional[List[ApiExtension]], + base_model: BaseSearchGetRequest = BaseSearchGetRequest, +) -> APIRequest: """Wrap create_request_model to create the GET request model.""" + return create_request_model( "SearchGetRequest", base_model=base_model, @@ -90,8 +68,9 @@ def create_get_request_model( def create_post_request_model( - extensions, base_model: BaseSearchPostRequest = BaseSearchPostRequest -): + extensions: Optional[List[ApiExtension]], + base_model: BaseSearchPostRequest = BaseSearchPostRequest, +) -> Type[BaseModel]: """Wrap create_request_model to create the POST request model.""" return create_request_model( "SearchPostRequest", diff --git a/stac_fastapi/api/stac_fastapi/api/openapi.py b/stac_fastapi/api/stac_fastapi/api/openapi.py index a38a70bae..ab90ce425 100644 --- a/stac_fastapi/api/stac_fastapi/api/openapi.py +++ b/stac_fastapi/api/stac_fastapi/api/openapi.py @@ -1,4 +1,5 @@ """openapi.""" + import warnings from fastapi import FastAPI @@ -43,9 +44,7 @@ async def patched_openapi_endpoint(req: Request) -> Response: # Get the response from the old endpoint function response: JSONResponse = await old_endpoint(req) # Update the content type header in place - response.headers[ - "content-type" - ] = "application/vnd.oai.openapi+json;version=3.0" + response.headers["content-type"] = "application/vnd.oai.openapi+json;version=3.0" # Return the updated response return response diff --git a/stac_fastapi/api/tests/benchmarks.py b/stac_fastapi/api/tests/benchmarks.py index ad73d2424..95e1c532a 100644 --- a/stac_fastapi/api/tests/benchmarks.py +++ b/stac_fastapi/api/tests/benchmarks.py @@ -160,9 +160,7 @@ def f(): benchmark.group = "Collection With Model validation" if validate else "Collection" benchmark.name = "Collection With Model validation" if validate else "Collection" - benchmark.fullname = ( - "Collection With Model validation" if validate else "Collection" - ) + benchmark.fullname = "Collection With Model validation" if validate else "Collection" response = benchmark(f) assert response.status_code == 200 diff --git a/stac_fastapi/api/tests/conftest.py b/stac_fastapi/api/tests/conftest.py index ed8c66d4d..1b89f07cd 100644 --- a/stac_fastapi/api/tests/conftest.py +++ b/stac_fastapi/api/tests/conftest.py @@ -31,12 +31,12 @@ def _collection(): @pytest.fixture def collection(_collection: Collection): - return _collection.json() + return _collection.model_dump_json() @pytest.fixture def collection_dict(_collection: Collection): - return _collection.dict() + return _collection.model_dump(mode="json") @pytest.fixture @@ -54,12 +54,12 @@ def _item(): @pytest.fixture def item(_item: Item): - return _item.json() + return _item.model_dump_json() @pytest.fixture def item_dict(_item: Item): - return _item.dict() + return _item.model_dump(mode="json") @pytest.fixture @@ -142,9 +142,7 @@ async def get_search( type="FeatureCollection", features=[stac.Item(**item_dict)] ) - async def get_item( - self, item_id: str, collection_id: str, **kwargs - ) -> stac.Item: + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: return stac.Item(**item_dict) async def all_collections(self, **kwargs) -> stac.Collections: diff --git a/stac_fastapi/api/tests/test_api.py b/stac_fastapi/api/tests/test_api.py index 91b50371e..f2d51f1db 100644 --- a/stac_fastapi/api/tests/test_api.py +++ b/stac_fastapi/api/tests/test_api.py @@ -41,11 +41,11 @@ def _assert_dependency_applied(api, routes): method=route["method"].lower(), url=path, auth=("bob", "dobbs"), - content='{"dummy": "payload"}', + content=route["payload"], headers={"content-type": "application/json"}, ) assert ( - response.status_code == 200 + 200 <= response.status_code < 300 ), "Authenticated requests should be accepted" assert response.json() == "dummy response" @@ -58,27 +58,59 @@ def test_openapi_content_type(self): == "application/vnd.oai.openapi+json;version=3.0" ) - def test_build_api_with_route_dependencies(self): + def test_build_api_with_route_dependencies(self, collection, item): routes = [ - {"path": "/collections", "method": "POST"}, - {"path": "/collections/{collectionId}", "method": "PUT"}, - {"path": "/collections/{collectionId}", "method": "DELETE"}, - {"path": "/collections/{collectionId}/items", "method": "POST"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "PUT"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + {"path": "/collections", "method": "POST", "payload": collection}, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + {"path": "/collections/{collectionId}", "method": "DELETE", "payload": ""}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, ] dependencies = [Depends(must_be_bob)] api = self._build_api(route_dependencies=[(routes, dependencies)]) self._assert_dependency_applied(api, routes) - def test_add_route_dependencies_after_building_api(self): + def test_add_route_dependencies_after_building_api(self, collection, item): routes = [ - {"path": "/collections", "method": "POST"}, - {"path": "/collections/{collectionId}", "method": "PUT"}, - {"path": "/collections/{collectionId}", "method": "DELETE"}, - {"path": "/collections/{collectionId}/items", "method": "POST"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "PUT"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + {"path": "/collections", "method": "POST", "payload": collection}, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + {"path": "/collections/{collectionId}", "method": "DELETE", "payload": ""}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, ] api = self._build_api() api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py new file mode 100644 index 000000000..9b4e0e828 --- /dev/null +++ b/stac_fastapi/api/tests/test_app.py @@ -0,0 +1,188 @@ +from datetime import datetime +from typing import List, Optional, Union + +import pytest +from fastapi.testclient import TestClient +from pydantic import ValidationError +from stac_pydantic import api + +from stac_fastapi.api import app +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions.core.filter.filter import FilterExtension +from stac_fastapi.types import stac +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import NumType +from stac_fastapi.types.search import BaseSearchPostRequest + + +def test_client_response_type(TestCoreClient): + """Test all GET endpoints. Verify that responses are valid STAC items.""" + + test_app = app.StacApi( + settings=ApiSettings(), + client=TestCoreClient(), + ) + + with TestClient(test_app.app) as client: + landing = client.get("/") + collection = client.get("/collections/test") + collections = client.get("/collections") + item = client.get("/collections/test/items/test") + item_collection = client.get( + "/collections/test/items", + params={"limit": 10}, + ) + get_search = client.get( + "/search", + params={ + "collections": ["test"], + }, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + }, + ) + + assert landing.status_code == 200, landing.text + api.LandingPage(**landing.json()) + + assert collection.status_code == 200, collection.text + api.Collection(**collection.json()) + + assert collections.status_code == 200, collections.text + api.collections.Collections(**collections.json()) + + assert item.status_code == 200, item.text + api.Item(**item.json()) + + assert item_collection.status_code == 200, item_collection.text + api.ItemCollection(**item_collection.json()) + + assert get_search.status_code == 200, get_search.text + api.ItemCollection(**get_search.json()) + + assert post_search.status_code == 200, post_search.text + api.ItemCollection(**post_search.json()) + + +@pytest.mark.parametrize("validate", [True, False]) +def test_client_invalid_response_type(validate, TestCoreClient, item_dict): + """Check if the build in response validation switch works.""" + + class InValidResponseClient(TestCoreClient): + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + item_dict.pop("bbox") + item_dict.pop("geometry") + return stac.Item(**item_dict) + + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=InValidResponseClient(), + ) + + with TestClient(test_app.app) as client: + item = client.get("/collections/test/items/test") + + # Even if API validation passes, we should receive an invalid item + if item.status_code == 200: + with pytest.raises(ValidationError): + api.Item(**item.json()) + + # If internal validation is on, we should expect an internal error + if validate: + assert item.status_code == 500, item.text + else: + assert item.status_code == 200, item.text + + +def test_client_openapi(TestCoreClient): + """Test if response models are all documented with OpenAPI.""" + + test_app = app.StacApi( + settings=ApiSettings(), + client=TestCoreClient(), + ) + test_app.app.openapi() + components = ["LandingPage", "Collection", "Collections", "Item", "ItemCollection"] + for component in components: + assert component in test_app.app.openapi_schema["components"]["schemas"] + + +def test_filter_extension(TestCoreClient, item_dict): + """Test if Filter Parameters are passed correctly.""" + + class FilterClient(TestCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + search_request.collections = ["test"] + search_request.filter = {} + search_request.filter_crs = "EPSG:4326" + search_request.filter_lang = "cql2-text" + + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + filter: Optional[str] = None, + filter_crs: Optional[str] = None, + filter_lang: Optional[str] = None, + **kwargs, + ) -> stac.ItemCollection: + # Check if all filter parameters are passed correctly + + assert filter == "TEST" + + # FIXME: https://github.com/stac-utils/stac-fastapi/issues/638 + # hyphen alias for filter_crs and filter_lang are currently not working + # Query parameters `filter-crs` and `filter-lang` + # should be recognized by the API + # They are present in the `request.query_params` but not in the `kwargs` + + # assert filter_crs == "EPSG:4326" + # assert filter_lang == "cql2-text" + + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + post_request_model = create_post_request_model([FilterExtension()]) + + test_app = app.StacApi( + settings=ApiSettings(), + client=FilterClient(post_request_model=post_request_model), + search_get_request_model=create_get_request_model([FilterExtension()]), + search_post_request_model=post_request_model, + ) + + with TestClient(test_app.app) as client: + get_search = client.get( + "/search", + params={ + "filter": "TEST", + "filter-crs": "EPSG:4326", + "filter-lang": "cql2-text", + }, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + "filter": {}, + "filter-crs": "EPSG:4326", + "filter-lang": "cql2-text", + }, + ) + + assert get_search.status_code == 200, get_search.text + assert post_search.status_code == 200, post_search.text diff --git a/stac_fastapi/api/tests/test_models.py b/stac_fastapi/api/tests/test_models.py new file mode 100644 index 000000000..cbff0f53d --- /dev/null +++ b/stac_fastapi/api/tests/test_models.py @@ -0,0 +1,101 @@ +import json + +import pytest +from pydantic import ValidationError + +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions.core.filter.filter import FilterExtension +from stac_fastapi.extensions.core.sort.sort import SortExtension +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest + + +def test_create_get_request_model(): + extensions = [FilterExtension()] + request_model = create_get_request_model(extensions, BaseSearchGetRequest) + + model = request_model( + collections="test1,test2", + ids="test1,test2", + bbox="0,0,1,1", + intersects=json.dumps( + { + "type": "Polygon", + "coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], + } + ), + datetime="2020-01-01T00:00:00Z", + limit=10, + filter="test==test", + # FIXME: https://github.com/stac-utils/stac-fastapi/issues/638 + # hyphen aliases are not properly working + # **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"}, + ) + + assert model.collections == ["test1", "test2"] + # assert model.filter_crs == "epsg:4326" + + +@pytest.mark.parametrize( + "filter,passes", + [(None, True), ({"test": "test"}, True), ("test==test", False), ([], False)], +) +def test_create_post_request_model(filter, passes): + extensions = [FilterExtension()] + request_model = create_post_request_model(extensions, BaseSearchPostRequest) + + if not passes: + with pytest.raises(ValidationError): + model = request_model(filter=filter) + else: + model = request_model( + collections=["test1", "test2"], + ids=["test1", "test2"], + bbox=[0, 0, 1, 1], + datetime="2020-01-01T00:00:00Z", + limit=10, + filter=filter, + **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"}, + ) + + assert model.collections == ["test1", "test2"] + assert model.filter_crs == "epsg:4326" + assert model.filter == filter + + +@pytest.mark.parametrize( + "sortby,passes", + [ + (None, True), + ( + [ + {"field": "test", "direction": "asc"}, + {"field": "test2", "direction": "desc"}, + ], + True, + ), + ({"field": "test", "direction": "desc"}, False), + ("test", False), + ], +) +def test_create_post_request_model_nested_fields(sortby, passes): + extensions = [SortExtension()] + request_model = create_post_request_model(extensions, BaseSearchPostRequest) + + if not passes: + with pytest.raises(ValidationError): + model = request_model(sortby=sortby) + else: + model = request_model( + collections=["test1", "test2"], + ids=["test1", "test2"], + bbox=[0, 0, 1, 1], + datetime="2020-01-01T00:00:00Z", + limit=10, + sortby=sortby, + ) + + assert model.collections == ["test1", "test2"] + if model.sortby is None: + assert sortby is None + else: + assert model.model_dump(mode="json")["sortby"] == sortby diff --git a/stac_fastapi/extensions/setup.py b/stac_fastapi/extensions/setup.py index af564931b..39bc59b3f 100644 --- a/stac_fastapi/extensions/setup.py +++ b/stac_fastapi/extensions/setup.py @@ -1,14 +1,12 @@ """stac_fastapi: extensions module.""" + from setuptools import find_namespace_packages, setup with open("README.md") as f: desc = f.read() install_requires = [ - "attrs", - "pydantic[dotenv]<2", - "stac_pydantic==2.0.*", "stac-fastapi.types", "stac-fastapi.api", ] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py index 96317fe4a..74f15ed0a 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py @@ -1,4 +1,5 @@ """stac_api.extensions.core module.""" + from .context import ContextExtension from .fields import FieldsExtension from .filter import FilterExtension diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py index b9a246b63..087d01b7a 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py @@ -1,6 +1,5 @@ """Fields extension module.""" - from .fields import FieldsExtension __all__ = ["FieldsExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py index df4cd44de..25b6fe252 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py @@ -1,4 +1,5 @@ """Fields extension.""" + from typing import List, Optional, Set import attr diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py index 78256bfd2..256f3e06e 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py @@ -1,6 +1,5 @@ """Filter extension module.""" - from .filter import FilterExtension __all__ = ["FilterExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py index 3e85b406d..dcb162060 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py @@ -1,4 +1,5 @@ """Query extension.""" + from typing import List, Optional import attr diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py index c19f40dba..377067ff9 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -20,4 +20,4 @@ class SortExtensionGetRequest(APIRequest): class SortExtensionPostRequest(BaseModel): """Sortby parameter for POST requests.""" - sortby: Optional[List[PostSortModel]] + sortby: Optional[List[PostSortModel]] = None diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py index 5dd96cfa6..4b27d8d0e 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py @@ -1,4 +1,5 @@ """Sort extension.""" + from typing import List, Optional import attr diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index 86e1bfc52..818315e1a 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -4,12 +4,13 @@ import attr from fastapi import APIRouter, Body, FastAPI -from stac_pydantic import Collection, Item +from stac_pydantic import Collection, Item, ItemCollection +from stac_pydantic.shared import MimeTypes from starlette.responses import JSONResponse, Response from stac_fastapi.api.models import CollectionUri, ItemUri from stac_fastapi.api.routes import create_async_endpoint -from stac_fastapi.types import stac as stac_types +from stac_fastapi.types import stac from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import AsyncBaseTransactionsClient, BaseTransactionsClient from stac_fastapi.types.extension import ApiExtension @@ -19,23 +20,21 @@ class PostItem(CollectionUri): """Create Item.""" - item: Union[stac_types.Item, stac_types.ItemCollection] = attr.ib( - default=Body(None) - ) + item: Union[Item, ItemCollection] = attr.ib(default=Body(None)) @attr.s class PutItem(ItemUri): """Update Item.""" - item: stac_types.Item = attr.ib(default=Body(None)) + item: Item = attr.ib(default=Body(None)) @attr.s class PutCollection(CollectionUri): """Update Collection.""" - collection: stac_types.Collection = attr.ib(default=Body(None)) + collection: stac.Collection = attr.ib(default=Body(None)) @attr.s @@ -73,7 +72,16 @@ def register_create_item(self): self.router.add_api_route( name="Create Item", path="/collections/{collection_id}/items", + status_code=201, response_model=Item if self.settings.enable_response_models else None, + responses={ + 201: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": Item, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -88,6 +96,14 @@ def register_update_item(self): name="Update Item", path="/collections/{collection_id}/items/{item_id}", response_model=Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": Item, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -102,6 +118,14 @@ def register_delete_item(self): name="Delete Item", path="/collections/{collection_id}/items/{item_id}", response_model=Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": Item, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -114,14 +138,21 @@ def register_create_collection(self): self.router.add_api_route( name="Create Collection", path="/collections", + status_code=201, response_model=Collection if self.settings.enable_response_models else None, + responses={ + 201: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collection, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=create_async_endpoint( - self.client.create_collection, stac_types.Collection - ), + endpoint=create_async_endpoint(self.client.create_collection, Collection), ) def register_update_collection(self): @@ -130,13 +161,19 @@ def register_update_collection(self): name="Update Collection", path="/collections/{collection_id}", response_model=Collection if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collection, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["PUT"], - endpoint=create_async_endpoint( - self.client.update_collection, PutCollection - ), + endpoint=create_async_endpoint(self.client.update_collection, PutCollection), ) def register_delete_collection(self): @@ -145,13 +182,19 @@ def register_delete_collection(self): name="Delete Collection", path="/collections/{collection_id}", response_model=Collection if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collection, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["DELETE"], - endpoint=create_async_endpoint( - self.client.delete_collection, CollectionUri - ), + endpoint=create_async_endpoint(self.client.delete_collection, CollectionUri), ) def register(self, app: FastAPI) -> None: diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py index ab7349e60..d35c4c8f9 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py @@ -1,4 +1,5 @@ """stac_api.extensions.third_party module.""" + from .bulk_transactions import BulkTransactionExtension __all__ = ("BulkTransactionExtension",) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index 9fa96ff2b..d1faa5c0f 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -1,4 +1,5 @@ """Bulk transactions extension.""" + import abc from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -109,9 +110,7 @@ class BulkTransactionExtension(ApiExtension): } """ - client: Union[ - AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient - ] = attr.ib() + client: Union[AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient] = attr.ib() conformance_classes: List[str] = attr.ib(default=list()) schema_href: Optional[str] = attr.ib(default=None) diff --git a/stac_fastapi/extensions/tests/test_transaction.py b/stac_fastapi/extensions/tests/test_transaction.py index e6416eaea..d686d8f91 100644 --- a/stac_fastapi/extensions/tests/test_transaction.py +++ b/stac_fastapi/extensions/tests/test_transaction.py @@ -2,13 +2,15 @@ from typing import Iterator, Union import pytest +from stac_pydantic.item import Item +from stac_pydantic.item_collection import ItemCollection from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi from stac_fastapi.extensions.core import TransactionExtension from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import BaseCoreClient, BaseTransactionsClient -from stac_fastapi.types.stac import Collection, Item, ItemCollection +from stac_fastapi.types.stac import Collection class DummyCoreClient(BaseCoreClient): @@ -34,14 +36,14 @@ def item_collection(self, *args, **kwargs): class DummyTransactionsClient(BaseTransactionsClient): """Dummy client returning parts of the request, rather than proper STAC items.""" - def create_item(self, item: Union[Item, ItemCollection], **kwargs): - return {"type": item["type"]} + def create_item(self, item: Union[Item, ItemCollection], *args, **kwargs): + return {"created": True, "type": item.type} def update_item(self, collection_id: str, item_id: str, item: Item, **kwargs): return { "path_collection_id": collection_id, "path_item_id": item_id, - "type": item["type"], + "type": item.type, } def delete_item(self, item_id: str, collection_id: str, **kwargs): @@ -51,7 +53,7 @@ def delete_item(self, item_id: str, collection_id: str, **kwargs): } def create_collection(self, collection: Collection, **kwargs): - return {"type": collection["type"]} + return {"type": collection.type} def update_collection(self, collection_id: str, collection: Collection, **kwargs): return {"path_collection_id": collection_id, "type": collection["type"]} @@ -157,7 +159,7 @@ def item() -> Item: "id": "test_item", "geometry": {"type": "Point", "coordinates": [-105, 40]}, "bbox": [-105, 40, -105, 40], - "properties": {}, + "properties": {"datetime": "2020-06-13T13:00:00Z"}, "links": [], "assets": {}, "collection": "test_collection", @@ -171,10 +173,12 @@ def collection() -> Collection: "stac_version": "1.0.0", "stac_extensions": [], "id": "test_collection", + "description": "A test collection", "extent": { "spatial": {"bbox": [[-180, -90, 180, 90]]}, - "temporal": { - "interval": [["2000-01-01T00:00:00Z", "2024-01-01T00:00:00Z"]] - }, + "temporal": {"interval": [["2000-01-01T00:00:00Z", "2024-01-01T00:00:00Z"]]}, }, + "links": [], + "assets": {}, + "license": "proprietary", } diff --git a/stac_fastapi/types/setup.py b/stac_fastapi/types/setup.py index c3905ede5..0b9448e39 100644 --- a/stac_fastapi/types/setup.py +++ b/stac_fastapi/types/setup.py @@ -6,10 +6,10 @@ desc = f.read() install_requires = [ - "fastapi>=0.73.0", - "attrs", - "pydantic[dotenv]<2", - "stac_pydantic==2.0.*", + "fastapi>=0.100.0", + "attrs>=23.2.0", + "pydantic-settings>=2", + "stac_pydantic>=3", "pystac==1.*", "iso8601>=1.0.2,<2.2.0", ] diff --git a/stac_fastapi/types/stac_fastapi/types/config.py b/stac_fastapi/types/stac_fastapi/types/config.py index 4b88c56a4..d692043cc 100644 --- a/stac_fastapi/types/stac_fastapi/types/config.py +++ b/stac_fastapi/types/stac_fastapi/types/config.py @@ -1,7 +1,8 @@ """stac_fastapi.types.config module.""" + from typing import Optional, Set -from pydantic import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict class ApiSettings(BaseSettings): @@ -35,11 +36,7 @@ class ApiSettings(BaseSettings): openapi_url: str = "/api" docs_url: str = "/api.html" - class Config: - """Model config (https://pydantic-docs.helpmanual.io/usage/model_config/).""" - - extra = "allow" - env_file = ".env" + model_config = SettingsConfigDict(env_file=".env", extra="allow") class Settings: diff --git a/stac_fastapi/types/stac_fastapi/types/conformance.py b/stac_fastapi/types/stac_fastapi/types/conformance.py index 13836aaf5..840584c1b 100644 --- a/stac_fastapi/types/stac_fastapi/types/conformance.py +++ b/stac_fastapi/types/stac_fastapi/types/conformance.py @@ -1,4 +1,5 @@ """Conformance Classes.""" + from enum import Enum diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 05c3e1097..d0dc029f0 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -1,24 +1,26 @@ """Base clients.""" + import abc from typing import Any, Dict, List, Optional, Union from urllib.parse import urljoin import attr from fastapi import Request +from geojson_pydantic.geometries import Geometry +from stac_pydantic import Collection, Item, ItemCollection +from stac_pydantic.api.version import STAC_API_VERSION from stac_pydantic.links import Relations from stac_pydantic.shared import BBox, MimeTypes -from stac_pydantic.version import STAC_VERSION from starlette.responses import Response -from stac_fastapi.types import stac as stac_types +from stac_fastapi.types import stac from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.search import BaseSearchPostRequest -from stac_fastapi.types.stac import Conformance NumType = Union[float, int] StacType = Dict[str, Any] @@ -34,9 +36,9 @@ class BaseTransactionsClient(abc.ABC): def create_item( self, collection_id: str, - item: Union[stac_types.Item, stac_types.ItemCollection], + item: Union[Item, ItemCollection], **kwargs, - ) -> Optional[Union[stac_types.Item, Response, None]]: + ) -> Optional[Union[Item, Response, None]]: """Create a new item. Called with `POST /collections/{collection_id}/items`. @@ -52,8 +54,8 @@ def create_item( @abc.abstractmethod def update_item( - self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + self, collection_id: str, item_id: str, item: Item, **kwargs + ) -> Optional[Union[Item, Response]]: """Perform a complete update on an existing item. Called with `PUT /collections/{collection_id}/items`. It is expected @@ -73,7 +75,7 @@ def update_item( @abc.abstractmethod def delete_item( self, item_id: str, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + ) -> Optional[Union[Item, Response]]: """Delete an item from a collection. Called with `DELETE /collections/{collection_id}/items/{item_id}` @@ -89,8 +91,8 @@ def delete_item( @abc.abstractmethod def create_collection( - self, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection: Collection, **kwargs + ) -> Optional[Union[Collection, Response]]: """Create a new collection. Called with `POST /collections`. @@ -105,8 +107,8 @@ def create_collection( @abc.abstractmethod def update_collection( - self, collection_id: str, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection_id: str, collection: Collection, **kwargs + ) -> Optional[Union[Collection, Response]]: """Perform a complete update on an existing collection. Called with `PUT /collections/{collection_id}`. It is expected that this @@ -126,7 +128,7 @@ def update_collection( @abc.abstractmethod def delete_collection( self, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + ) -> Optional[Union[Collection, Response]]: """Delete a collection. Called with `DELETE /collections/{collection_id}` @@ -148,9 +150,9 @@ class AsyncBaseTransactionsClient(abc.ABC): async def create_item( self, collection_id: str, - item: Union[stac_types.Item, stac_types.ItemCollection], + item: Union[Item, ItemCollection], **kwargs, - ) -> Optional[Union[stac_types.Item, Response, None]]: + ) -> Optional[Union[Item, Response, None]]: """Create a new item. Called with `POST /collections/{collection_id}/items`. @@ -166,8 +168,8 @@ async def create_item( @abc.abstractmethod async def update_item( - self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + self, collection_id: str, item_id: str, item: Item, **kwargs + ) -> Optional[Union[Item, Response]]: """Perform a complete update on an existing item. Called with `PUT /collections/{collection_id}/items`. It is expected @@ -186,7 +188,7 @@ async def update_item( @abc.abstractmethod async def delete_item( self, item_id: str, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + ) -> Optional[Union[Item, Response]]: """Delete an item from a collection. Called with `DELETE /collections/{collection_id}/items/{item_id}` @@ -202,8 +204,8 @@ async def delete_item( @abc.abstractmethod async def create_collection( - self, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection: Collection, **kwargs + ) -> Optional[Union[Collection, Response]]: """Create a new collection. Called with `POST /collections`. @@ -218,8 +220,8 @@ async def create_collection( @abc.abstractmethod async def update_collection( - self, collection_id: str, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection_id: str, collection: Collection, **kwargs + ) -> Optional[Union[Collection, Response]]: """Perform a complete update on an existing collection. Called with `PUT /collections/{collection_id}`. It is expected that this item @@ -239,7 +241,7 @@ async def update_collection( @abc.abstractmethod async def delete_collection( self, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + ) -> Optional[Union[Collection, Response]]: """Delete a collection. Called with `DELETE /collections/{collection_id}` @@ -257,7 +259,7 @@ async def delete_collection( class LandingPageMixin(abc.ABC): """Create a STAC landing page (GET /).""" - stac_version: str = attr.ib(default=STAC_VERSION) + stac_version: str = attr.ib(default=STAC_API_VERSION) landing_page_id: str = attr.ib(default=api_settings.stac_fastapi_landing_id) title: str = attr.ib(default=api_settings.stac_fastapi_title) description: str = attr.ib(default=api_settings.stac_fastapi_description) @@ -267,8 +269,8 @@ def _landing_page( base_url: str, conformance_classes: List[str], extension_schemas: List[str], - ) -> stac_types.LandingPage: - landing_page = stac_types.LandingPage( + ) -> stac.LandingPage: + landing_page = stac.LandingPage( type="Catalog", id=self.landing_page_id, title=self.title, @@ -278,35 +280,35 @@ def _landing_page( links=[ { "rel": Relations.self.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "href": base_url, }, { "rel": Relations.root.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "href": base_url, }, { - "rel": "data", - "type": MimeTypes.json, + "rel": Relations.data.value, + "type": MimeTypes.json.value, "href": urljoin(base_url, "collections"), }, { "rel": Relations.conformance.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "title": "STAC/OGC conformance classes implemented by this server", "href": urljoin(base_url, "conformance"), }, { "rel": Relations.search.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "STAC search", "href": urljoin(base_url, "search"), "method": "GET", }, { "rel": Relations.search.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "STAC search", "href": urljoin(base_url, "search"), "method": "POST", @@ -314,6 +316,7 @@ def _landing_page( ], stac_extensions=extension_schemas, ) + return landing_page @@ -356,7 +359,7 @@ def list_conformance_classes(self): return base_conformance - def landing_page(self, **kwargs) -> stac_types.LandingPage: + def landing_page(self, **kwargs) -> stac.LandingPage: """Landing page. Called with `GET /`. @@ -366,6 +369,7 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: """ request: Request = kwargs["request"] base_url = get_base_url(request) + landing_page = self._landing_page( base_url=base_url, conformance_classes=self.conformance_classes(), @@ -388,6 +392,7 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add Collections links collections = self.all_collections(request=kwargs["request"]) + for collection in collections["collections"]: landing_page["links"].append( { @@ -401,8 +406,8 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add OpenAPI URL landing_page["links"].append( { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", + "rel": Relations.service_desc.value, + "type": MimeTypes.openapi.value, "title": "OpenAPI service description", "href": str(request.url_for("openapi")), } @@ -411,16 +416,16 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add human readable service-doc landing_page["links"].append( { - "rel": "service-doc", - "type": "text/html", + "rel": Relations.service_doc.value, + "type": MimeTypes.html.value, "title": "OpenAPI service documentation", "href": str(request.url_for("swagger_ui_html")), } ) - return landing_page + return stac.LandingPage(**landing_page) - def conformance(self, **kwargs) -> stac_types.Conformance: + def conformance(self, **kwargs) -> stac.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -428,12 +433,12 @@ def conformance(self, **kwargs) -> stac_types.Conformance: Returns: Conformance classes which the server conforms to. """ - return Conformance(conformsTo=self.conformance_classes()) + return stac.Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (POST). Called with `POST /search`. @@ -452,15 +457,11 @@ def get_search( collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[BBox] = None, + intersects: Optional[Geometry] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = 10, - query: Optional[str] = None, - token: Optional[str] = None, - fields: Optional[List[str]] = None, - sortby: Optional[str] = None, - intersects: Optional[str] = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (GET). Called with `GET /search`. @@ -471,7 +472,7 @@ def get_search( ... @abc.abstractmethod - def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Item: + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -486,7 +487,7 @@ def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Ite ... @abc.abstractmethod - def all_collections(self, **kwargs) -> stac_types.Collections: + def all_collections(self, **kwargs) -> stac.Collections: """Get all available collections. Called with `GET /collections`. @@ -497,7 +498,7 @@ def all_collections(self, **kwargs) -> stac_types.Collections: ... @abc.abstractmethod - def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -519,7 +520,7 @@ def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Get all items from a specific collection. Called with `GET /collections/{collection_id}/items` @@ -564,7 +565,7 @@ def extension_is_enabled(self, extension: str) -> bool: """Check if an api extension is enabled.""" return any([type(ext).__name__ == extension for ext in self.extensions]) - async def landing_page(self, **kwargs) -> stac_types.LandingPage: + async def landing_page(self, **kwargs) -> stac.LandingPage: """Landing page. Called with `GET /`. @@ -574,6 +575,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: """ request: Request = kwargs["request"] base_url = get_base_url(request) + landing_page = self._landing_page( base_url=base_url, conformance_classes=self.conformance_classes(), @@ -596,6 +598,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add Collections links collections = await self.all_collections(request=kwargs["request"]) + for collection in collections["collections"]: landing_page["links"].append( { @@ -609,8 +612,8 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add OpenAPI URL landing_page["links"].append( { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", + "rel": Relations.service_desc.value, + "type": MimeTypes.openapi.value, "title": "OpenAPI service description", "href": str(request.url_for("openapi")), } @@ -619,16 +622,16 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add human readable service-doc landing_page["links"].append( { - "rel": "service-doc", - "type": "text/html", + "rel": Relations.service_doc.value, + "type": MimeTypes.html.value, "title": "OpenAPI service documentation", "href": str(request.url_for("swagger_ui_html")), } ) - return landing_page + return stac.LandingPage(**landing_page) - async def conformance(self, **kwargs) -> stac_types.Conformance: + async def conformance(self, **kwargs) -> stac.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -636,12 +639,12 @@ async def conformance(self, **kwargs) -> stac_types.Conformance: Returns: Conformance classes which the server conforms to. """ - return Conformance(conformsTo=self.conformance_classes()) + return stac.Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod async def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (POST). Called with `POST /search`. @@ -660,15 +663,11 @@ async def get_search( collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[BBox] = None, + intersects: Optional[Geometry] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = 10, - query: Optional[str] = None, - token: Optional[str] = None, - fields: Optional[List[str]] = None, - sortby: Optional[str] = None, - intersects: Optional[str] = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (GET). Called with `GET /search`. @@ -679,9 +678,7 @@ async def get_search( ... @abc.abstractmethod - async def get_item( - self, item_id: str, collection_id: str, **kwargs - ) -> stac_types.Item: + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -696,7 +693,7 @@ async def get_item( ... @abc.abstractmethod - async def all_collections(self, **kwargs) -> stac_types.Collections: + async def all_collections(self, **kwargs) -> stac.Collections: """Get all available collections. Called with `GET /collections`. @@ -707,9 +704,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: ... @abc.abstractmethod - async def get_collection( - self, collection_id: str, **kwargs - ) -> stac_types.Collection: + async def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -731,7 +726,7 @@ async def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Get all items from a specific collection. Called with `GET /collections/{collection_id}/items` diff --git a/stac_fastapi/types/stac_fastapi/types/extension.py b/stac_fastapi/types/stac_fastapi/types/extension.py index 732a907bf..55a4a123c 100644 --- a/stac_fastapi/types/stac_fastapi/types/extension.py +++ b/stac_fastapi/types/stac_fastapi/types/extension.py @@ -1,4 +1,5 @@ """Base api extension.""" + import abc from typing import List, Optional diff --git a/stac_fastapi/types/stac_fastapi/types/requests.py b/stac_fastapi/types/stac_fastapi/types/requests.py index c9be8b6f6..4d94736a7 100644 --- a/stac_fastapi/types/stac_fastapi/types/requests.py +++ b/stac_fastapi/types/stac_fastapi/types/requests.py @@ -9,6 +9,4 @@ def get_base_url(request: Request) -> str: if not app.state.router_prefix: return str(request.base_url) else: - return "{}{}/".format( - str(request.base_url), app.state.router_prefix.lstrip("/") - ) + return "{}{}/".format(str(request.base_url), app.state.router_prefix.lstrip("/")) diff --git a/stac_fastapi/types/stac_fastapi/types/rfc3339.py b/stac_fastapi/types/stac_fastapi/types/rfc3339.py index 1277c998a..2f0a1f346 100644 --- a/stac_fastapi/types/stac_fastapi/types/rfc3339.py +++ b/stac_fastapi/types/stac_fastapi/types/rfc3339.py @@ -1,4 +1,5 @@ """rfc3339.""" + import re from datetime import datetime, timezone from typing import Optional, Tuple, Union diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index fb847349e..cf6647340 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -1,80 +1,26 @@ """stac_fastapi.types.search module. -# TODO: replace with stac-pydantic """ import abc -import operator -from datetime import datetime -from enum import auto -from types import DynamicClassAttribute -from typing import Any, Callable, Dict, Generator, List, Optional, Union +from typing import Dict, List, Optional, Union import attr -from geojson_pydantic.geometries import ( - GeometryCollection, - LineString, - MultiLineString, - MultiPoint, - MultiPolygon, - Point, - Polygon, - _GeometryBase, -) -from pydantic import BaseModel, ConstrainedInt, Field, validator -from pydantic.errors import NumberNotGtError -from pydantic.validators import int_validator +from pydantic import PositiveInt +from pydantic.functional_validators import AfterValidator +from stac_pydantic.api import Search from stac_pydantic.shared import BBox -from stac_pydantic.utils import AutoValueEnum +from typing_extensions import Annotated from stac_fastapi.types.rfc3339 import DateTimeType, str_to_interval -# Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 -NumType = Union[float, int] - - -class Limit(ConstrainedInt): - """An positive integer that maxes out at 10,000.""" - - ge: int = 1 - le: int = 10_000 - - @classmethod - def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: - """Yield the relevant validators.""" - yield int_validator - yield cls.validate - - @classmethod - def validate(cls, value: int) -> int: - """Validate the integer value.""" - if value < cls.ge: - raise NumberNotGtError(limit_value=cls.ge) - if value > cls.le: - return cls.le - return value - -class Operator(str, AutoValueEnum): - """Defines the set of operators supported by the API.""" - - eq = auto() - ne = auto() - lt = auto() - lte = auto() - gt = auto() - gte = auto() - - # TODO: These are defined in the spec but aren't currently implemented by the api - # startsWith = auto() - # endsWith = auto() - # contains = auto() - # in = auto() - - @DynamicClassAttribute - def operator(self) -> Callable[[Any, Any], bool]: - """Return python operator.""" - return getattr(operator, self._value_) +def crop(v: PositiveInt) -> PositiveInt: + """Crop value to 10,000.""" + limit = 10_000 + if v > limit: + v = limit + return v def str2list(x: str) -> Optional[List]: @@ -91,6 +37,11 @@ def str2bbox(x: str) -> Optional[BBox]: return t +# Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 +NumType = Union[float, int] +Limit = Annotated[PositiveInt, AfterValidator(crop)] + + @attr.s # type:ignore class APIRequest(abc.ABC): """Generic API Request base class.""" @@ -113,110 +64,7 @@ class BaseSearchGetRequest(APIRequest): limit: Optional[int] = attr.ib(default=10) -class BaseSearchPostRequest(BaseModel): - """Search model. - - Replace base model in STAC-pydantic as it includes additional fields, not in the core - model. - https://github.com/radiantearth/stac-api-spec/tree/master/item-search#query-parameter-table - - PR to fix this: - https://github.com/stac-utils/stac-pydantic/pull/100 - """ - - collections: Optional[List[str]] - ids: Optional[List[str]] - bbox: Optional[BBox] - intersects: Optional[ - Union[ - Point, - MultiPoint, - LineString, - MultiLineString, - Polygon, - MultiPolygon, - GeometryCollection, - ] - ] - datetime: Optional[DateTimeType] - limit: Optional[Limit] = Field(default=10) - - @property - def start_date(self) -> Optional[datetime]: - """Extract the start date from the datetime string.""" - return self.datetime[0] if self.datetime else None - - @property - def end_date(self) -> Optional[datetime]: - """Extract the end date from the datetime string.""" - return self.datetime[1] if self.datetime else None - - @validator("intersects") - def validate_spatial(cls, v, values): - """Check bbox and intersects are not both supplied.""" - if v and values["bbox"]: - raise ValueError("intersects and bbox parameters are mutually exclusive") - return v - - @validator("bbox", pre=True) - def validate_bbox(cls, v: Union[str, BBox]) -> BBox: - """Check order of supplied bbox coordinates.""" - if v: - if type(v) == str: - v = str2bbox(v) - # Validate order - if len(v) == 4: - xmin, ymin, xmax, ymax = v - else: - xmin, ymin, min_elev, xmax, ymax, max_elev = v - if max_elev < min_elev: - raise ValueError( - "Maximum elevation must greater than minimum elevation" - ) - - if xmax < xmin: - raise ValueError( - "Maximum longitude must be greater than minimum longitude" - ) - - if ymax < ymin: - raise ValueError( - "Maximum longitude must be greater than minimum longitude" - ) - - # Validate against WGS84 - if xmin < -180 or ymin < -90 or xmax > 180 or ymax > 90: - raise ValueError("Bounding box must be within (-180, -90, 180, 90)") - - return v - - @validator("datetime", pre=True) - def validate_datetime(cls, v: Union[str, DateTimeType]) -> DateTimeType: - """Parse datetime.""" - if type(v) == str: - v = str_to_interval(v) - return v - - @property - def spatial_filter(self) -> Optional[_GeometryBase]: - """Return a geojson-pydantic object representing the spatial filter for the search - request. - - Check for both because the ``bbox`` and ``intersects`` parameters are - mutually exclusive. - """ - if self.bbox: - return Polygon( - coordinates=[ - [ - [self.bbox[0], self.bbox[3]], - [self.bbox[2], self.bbox[3]], - [self.bbox[2], self.bbox[1]], - [self.bbox[0], self.bbox[1]], - [self.bbox[0], self.bbox[3]], - ] - ] - ) - if self.intersects: - return self.intersects - return +class BaseSearchPostRequest(Search): + """Base arguments for POST Request.""" + + limit: Optional[Limit] = 10 diff --git a/stac_fastapi/types/stac_fastapi/types/stac.py b/stac_fastapi/types/stac_fastapi/types/stac.py index 51bb6e652..b9c93fd80 100644 --- a/stac_fastapi/types/stac_fastapi/types/stac.py +++ b/stac_fastapi/types/stac_fastapi/types/stac.py @@ -1,4 +1,5 @@ """STAC types.""" + import sys from typing import Any, Dict, List, Literal, Optional, Union @@ -6,9 +7,9 @@ # Avoids a Pydantic error: # TypeError: You should use `typing_extensions.TypedDict` instead of -# `typing.TypedDict` with Python < 3.9.2. Without it, there is no way to +# `typing.TypedDict` with Python < 3.12.0. Without it, there is no way to # differentiate required and optional fields when subclassed. -if sys.version_info < (3, 9, 2): +if sys.version_info < (3, 12, 0): from typing_extensions import TypedDict else: from typing import TypedDict @@ -16,35 +17,28 @@ NumType = Union[float, int] -class LandingPage(TypedDict, total=False): - """STAC Landing Page.""" +class Catalog(TypedDict, total=False): + """STAC Catalog.""" type: str stac_version: str stac_extensions: Optional[List[str]] id: str - title: str + title: Optional[str] description: str - conformsTo: List[str] links: List[Dict[str, Any]] -class Conformance(TypedDict): - """STAC Conformance Classes.""" +class LandingPage(Catalog, total=False): + """STAC Landing Page.""" conformsTo: List[str] -class Catalog(TypedDict, total=False): - """STAC Catalog.""" +class Conformance(TypedDict): + """STAC Conformance Classes.""" - type: str - stac_version: str - stac_extensions: Optional[List[str]] - id: str - title: Optional[str] - description: str - links: List[Dict[str, Any]] + conformsTo: List[str] class Collection(Catalog, total=False): @@ -84,7 +78,6 @@ class ItemCollection(TypedDict, total=False): class Collections(TypedDict, total=False): """All collections endpoint. - https://github.com/radiantearth/stac-api-spec/tree/master/collections """