From 038db368399b04c80c3cf155da43fd7ee617d208 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Fri, 19 Sep 2025 16:18:46 +0800 Subject: [PATCH 1/9] query ext scratch --- stac_fastapi/core/stac_fastapi/core/core.py | 20 ++- .../elasticsearch/database_logic.py | 49 +++++- .../tests/api/test_api_search_collections.py | 148 ++++++++++++++++-- 3 files changed, 203 insertions(+), 14 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 7c6fdf2f2..e26914d7f 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -231,6 +231,7 @@ async def all_collections( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, q: Optional[Union[str, List[str]]] = None, + query: Optional[str] = None, **kwargs, ) -> stac_types.Collections: """Read all collections from the database. @@ -239,6 +240,7 @@ async def all_collections( fields (Optional[List[str]]): Fields to include or exclude from the results. sortby (Optional[str]): Sorting options for the results. filter_expr (Optional[str]): Structured filter expression in CQL2 JSON or CQL2-text format. + query (Optional[str]): Legacy query parameter (deprecated). filter_lang (Optional[str]): Must be 'cql2-json' or 'cql2-text' if specified, other values will result in an error. q (Optional[Union[str, List[str]]]): Free text search terms. **kwargs: Keyword arguments from the request. @@ -280,11 +282,24 @@ async def all_collections( if q is not None: q_list = [q] if isinstance(q, str) else q + # Parse the query parameter if provided + parsed_query = None + if query is not None: + try: + import orjson + + parsed_query = orjson.loads(query) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid query parameter: {e}" + ) + # Parse the filter parameter if provided parsed_filter = None if filter_expr is not None: try: - # Check if filter_lang is specified and not one of the supported formats + # Only raise an error for explicitly unsupported filter languages + # Allow None, cql2-json, and cql2-text (we'll treat it as JSON) if filter_lang is not None and filter_lang not in [ "cql2-json", "cql2-text", @@ -292,7 +307,7 @@ async def all_collections( # Raise an error for unsupported filter languages raise HTTPException( status_code=400, - detail=f"Input should be 'cql2-json' or 'cql2-text' for collections. Got '{filter_lang}'.", + detail=f"Only 'cql2-json' and 'cql2-text' filter languages are supported for collections. Got '{filter_lang}'.", ) # Handle different filter formats @@ -335,6 +350,7 @@ async def all_collections( sort=sort, q=q_list, filter=parsed_filter, + query=parsed_query, ) # Apply field filtering if fields parameter was provided diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index b3907c8e9..cef2b2491 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -177,6 +177,7 @@ async def get_all_collections( sort: Optional[List[Dict[str, Any]]] = None, q: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, + query: Optional[Dict[str, Dict[str, Any]]] = None, ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """Retrieve a list of collections from Elasticsearch, supporting pagination. @@ -186,7 +187,8 @@ async def get_all_collections( request (Request): The FastAPI request object. sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request. q (Optional[List[str]]): Free text search terms. - filter (Optional[Dict[str, Any]]): Structured query in CQL2 format. + filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format. + query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters. Returns: A tuple of (collections, next pagination token if any). @@ -270,7 +272,50 @@ async def get_all_collections( es_query = filter_module.to_es(await self.get_queryables_mapping(), filter) query_parts.append(es_query) - # Combine all query parts with AND logic + # Apply query extension if provided + if query: + try: + # Process each field and operator in the query + for field_name, expr in query.items(): + for op, value in expr.items(): + # Handle different operators + if op == "eq": + # Equality operator + # Use different query types based on field name + if field_name in ["title", "description"]: + # For text fields, use match_phrase for exact phrase matching + query_part = {"match_phrase": {field_name: value}} + else: + # For other fields, use term query for exact matching + query_part = {"term": {field_name: value}} + query_parts.append(query_part) + elif op == "neq": + # Not equal operator + query_part = { + "bool": {"must_not": [{"term": {field_name: value}}]} + } + print(f"Adding neq query part: {query_part}") + query_parts.append(query_part) + elif op in ["lt", "lte", "gt", "gte"]: + # Range operators + query_parts.append({"range": {field_name: {op: value}}}) + elif op == "in": + # In operator (value should be a list) + if isinstance(value, list): + query_parts.append({"terms": {field_name: value}}) + else: + query_parts.append({"term": {field_name: value}}) + elif op == "contains": + # Contains operator for arrays + query_parts.append({"term": {field_name: value}}) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Error converting query to Elasticsearch: {e}") + # If there's an error, add a query that matches nothing + query_parts.append({"bool": {"must_not": {"match_all": {}}}}) + raise + + # Combine all query parts with AND logic if there are multiple if query_parts: body["query"] = ( query_parts[0] diff --git a/stac_fastapi/tests/api/test_api_search_collections.py b/stac_fastapi/tests/api/test_api_search_collections.py index 85a393fc4..57921ee4d 100644 --- a/stac_fastapi/tests/api/test_api_search_collections.py +++ b/stac_fastapi/tests/api/test_api_search_collections.py @@ -7,10 +7,10 @@ @pytest.mark.asyncio -async def test_collections_sort_id_asc(app_client, txn_client, load_test_data): +async def test_collections_sort_id_asc(app_client, txn_client, ctx): """Verify GET /collections honors ascending sort on id.""" # Create multiple collections with different ids - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Create collections with ids in a specific order to test sorting # Use unique prefixes to avoid conflicts between tests @@ -23,6 +23,8 @@ async def test_collections_sort_id_asc(app_client, txn_client, load_test_data): test_collection["title"] = f"Test Collection {i}" await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test ascending sort by id resp = await app_client.get( "/collections", @@ -44,10 +46,10 @@ async def test_collections_sort_id_asc(app_client, txn_client, load_test_data): @pytest.mark.asyncio -async def test_collections_sort_id_desc(app_client, txn_client, load_test_data): +async def test_collections_sort_id_desc(app_client, txn_client, ctx): """Verify GET /collections honors descending sort on id.""" # Create multiple collections with different ids - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Create collections with ids in a specific order to test sorting # Use unique prefixes to avoid conflicts between tests @@ -60,6 +62,8 @@ async def test_collections_sort_id_desc(app_client, txn_client, load_test_data): test_collection["title"] = f"Test Collection {i}" await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test descending sort by id resp = await app_client.get( "/collections", @@ -81,10 +85,10 @@ async def test_collections_sort_id_desc(app_client, txn_client, load_test_data): @pytest.mark.asyncio -async def test_collections_fields(app_client, txn_client, load_test_data): +async def test_collections_fields(app_client, txn_client, ctx): """Verify GET /collections honors the fields parameter.""" # Create multiple collections with different ids - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Create collections with ids in a specific order to test fields # Use unique prefixes to avoid conflicts between tests @@ -98,6 +102,8 @@ async def test_collections_fields(app_client, txn_client, load_test_data): test_collection["description"] = f"Description for collection {i}" await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test include fields parameter resp = await app_client.get( "/collections", @@ -156,10 +162,10 @@ async def test_collections_fields(app_client, txn_client, load_test_data): @pytest.mark.asyncio -async def test_collections_free_text_search_get(app_client, txn_client, load_test_data): +async def test_collections_free_text_search_get(app_client, txn_client, ctx): """Verify GET /collections honors the q parameter for free text search.""" # Create multiple collections with different content - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Use unique prefixes to avoid conflicts between tests test_prefix = f"q-get-{uuid.uuid4().hex[:8]}" @@ -193,6 +199,8 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes test_collection["summaries"] = coll["summaries"] await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test free text search for "sentinel" resp = await app_client.get( "/collections", @@ -229,10 +237,10 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes @pytest.mark.asyncio -async def test_collections_filter_search(app_client, txn_client, load_test_data): +async def test_collections_filter_search(app_client, txn_client, ctx): """Verify GET /collections honors the filter parameter for structured search.""" # Create multiple collections with different content - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Use unique prefixes to avoid conflicts between tests test_prefix = f"filter-{uuid.uuid4().hex[:8]}" @@ -313,3 +321,123 @@ async def test_collections_filter_search(app_client, txn_client, load_test_data) assert ( len(found_collections) >= 1 ), f"Expected at least 1 collection with ID {test_collection_id} using LIKE filter" + + +@pytest.mark.asyncio +async def test_collections_query_extension(app_client, txn_client, ctx): + """Verify GET /collections honors the query extension.""" + # Create multiple collections with different content + base_collection = ctx.collection + # Use unique prefixes to avoid conflicts between tests + test_prefix = f"query-ext-{uuid.uuid4().hex[:8]}" + + # Create collections with different content to test query extension + test_collections = [ + { + "id": f"{test_prefix}-sentinel", + "title": "Sentinel-2 Collection", + "description": "Collection of Sentinel-2 data", + "summaries": {"platform": ["sentinel-2a", "sentinel-2b"]}, + }, + { + "id": f"{test_prefix}-landsat", + "title": "Landsat Collection", + "description": "Collection of Landsat data", + "summaries": {"platform": ["landsat-8", "landsat-9"]}, + }, + { + "id": f"{test_prefix}-modis", + "title": "MODIS Collection", + "description": "Collection of MODIS data", + "summaries": {"platform": ["terra", "aqua"]}, + }, + ] + + for i, coll in enumerate(test_collections): + test_collection = base_collection.copy() + test_collection["id"] = coll["id"] + test_collection["title"] = coll["title"] + test_collection["description"] = coll["description"] + test_collection["summaries"] = coll["summaries"] + await create_collection(txn_client, test_collection) + + await refresh_indices(txn_client) + + # Test query extension for exact ID match + import json + + # Use the exact ID that was created + sentinel_id = f"{test_prefix}-sentinel" + print(f"Searching for ID: {sentinel_id}") + + query = {"id": {"eq": sentinel_id}} + + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + found_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Should only find the sentinel collection + assert len(found_collections) == 1 + assert found_collections[0]["id"] == f"{test_prefix}-sentinel" + + # Test query extension with equal operator on ID + query = {"id": {"eq": f"{test_prefix}-sentinel"}} + + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + found_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + found_ids = [c["id"] for c in found_collections] + + # Should find landsat and modis collections but not sentinel + assert len(found_collections) == 1 + assert f"{test_prefix}-sentinel" in found_ids + assert f"{test_prefix}-landsat" not in found_ids + assert f"{test_prefix}-modis" not in found_ids + + # Test query extension with not-equal operator on ID + query = {"id": {"neq": f"{test_prefix}-sentinel"}} + + print(f"\nTesting neq query: {query}") + print(f"JSON query: {json.dumps(query)}") + + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + print(f"Response status: {resp.status_code}") + assert resp.status_code == 200 + resp_json = resp.json() + print(f"Response JSON keys: {resp_json.keys()}") + print(f"Number of collections in response: {len(resp_json.get('collections', []))}") + + # Print all collection IDs in the response + all_ids = [c["id"] for c in resp_json.get("collections", [])] + print(f"All collection IDs in response: {all_ids}") + + # Filter collections to only include the ones we created for this test + found_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + found_ids = [c["id"] for c in found_collections] + + # Should find landsat and modis collections but not sentinel + assert len(found_collections) == 2 + assert f"{test_prefix}-sentinel" not in found_ids + assert f"{test_prefix}-landsat" in found_ids + assert f"{test_prefix}-modis" in found_ids From 9583f5abd5acb2d0942e479190b7be63f0667c11 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Fri, 19 Sep 2025 16:23:32 +0800 Subject: [PATCH 2/9] test desc ids --- .../tests/api/test_api_search_collections.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/stac_fastapi/tests/api/test_api_search_collections.py b/stac_fastapi/tests/api/test_api_search_collections.py index 57921ee4d..adff9e0d1 100644 --- a/stac_fastapi/tests/api/test_api_search_collections.py +++ b/stac_fastapi/tests/api/test_api_search_collections.py @@ -44,6 +44,25 @@ async def test_collections_sort_id_asc(app_client, txn_client, ctx): for i, expected_id in enumerate(sorted_ids): assert test_collections[i]["id"] == expected_id + # Test descending sort by id + resp = await app_client.get( + "/collections", + params=[("sortby", "-id")], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + test_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Collections should be sorted in reverse alphabetical order by id + sorted_ids = sorted(collection_ids, reverse=True) + assert len(test_collections) == len(collection_ids) + for i, expected_id in enumerate(sorted_ids): + assert test_collections[i]["id"] == expected_id + @pytest.mark.asyncio async def test_collections_sort_id_desc(app_client, txn_client, ctx): From fad578ef108637ebf151f082e618b4c317e94462 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Fri, 19 Sep 2025 17:45:03 +0800 Subject: [PATCH 3/9] opensearch update --- CHANGELOG.md | 1 + .../stac_fastapi/elasticsearch/app.py | 2 +- .../elasticsearch/database_logic.py | 1 - .../opensearch/stac_fastapi/opensearch/app.py | 3 +- .../stac_fastapi/opensearch/database_logic.py | 54 +++++++++++++++++-- 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99e3ddd32..a7dc0293a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added - GET `/collections` collection search structured filter extension with support for both cql2-json and cql2-text formats. [#475](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/475) +- GET `/collections` collection search query extension. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476) ### Changed diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index e3292cbf8..8cc32088c 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -121,7 +121,7 @@ if ENABLE_COLLECTIONS_SEARCH: # Create collection search extensions collection_search_extensions = [ - # QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), + QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]), FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]), CollectionSearchFilterExtension( diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index cef2b2491..8f4824088 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -294,7 +294,6 @@ async def get_all_collections( query_part = { "bool": {"must_not": [{"term": {field_name: value}}]} } - print(f"Adding neq query part: {query_part}") query_parts.append(query_part) elif op in ["lt", "lte", "gt", "gte"]: # Range operators diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index b842d929c..56f717a34 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -121,7 +121,7 @@ if ENABLE_COLLECTIONS_SEARCH: # Create collection search extensions collection_search_extensions = [ - # QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), + QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]), FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]), CollectionSearchFilterExtension( @@ -170,6 +170,7 @@ post_request_model=post_request_model, landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"), ), + "collections_get_request_model": collections_get_request_model, "search_get_request_model": create_get_request_model(search_extensions), "search_post_request_model": post_request_model, "items_get_request_model": items_get_request_model, diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index e94dee254..899d11dec 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -161,8 +161,9 @@ async def get_all_collections( sort: Optional[List[Dict[str, Any]]] = None, q: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, + query: Optional[Dict[str, Dict[str, Any]]] = None, ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """Retrieve a list of collections from Opensearch, supporting pagination. + """Retrieve a list of collections from Elasticsearch, supporting pagination. Args: token (Optional[str]): The pagination token. @@ -170,7 +171,8 @@ async def get_all_collections( request (Request): The FastAPI request object. sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request. q (Optional[List[str]]): Free text search terms. - filter (Optional[Dict[str, Any]]): Structured query in CQL2 format. + filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format. + query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters. Returns: A tuple of (collections, next pagination token if any). @@ -193,7 +195,7 @@ async def get_all_collections( raise HTTPException( status_code=400, detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. " - + "Text fields are not sortable by default in Opensearch. " + + "Text fields are not sortable by default in Elasticsearch. " + "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ", ) formatted_sort.append({field: {"order": direction}}) @@ -250,11 +252,53 @@ async def get_all_collections( # Convert string filter to dict if needed if isinstance(filter, str): filter = orjson.loads(filter) - # Convert the filter to an Opensearch query using the filter module + # Convert the filter to an Elasticsearch query using the filter module es_query = filter_module.to_es(await self.get_queryables_mapping(), filter) query_parts.append(es_query) - # Combine all query parts with AND logic + # Apply query extension if provided + if query: + try: + # Process each field and operator in the query + for field_name, expr in query.items(): + for op, value in expr.items(): + # Handle different operators + if op == "eq": + # Equality operator + # Use different query types based on field name + if field_name in ["title", "description"]: + # For text fields, use match_phrase for exact phrase matching + query_part = {"match_phrase": {field_name: value}} + else: + # For other fields, use term query for exact matching + query_part = {"term": {field_name: value}} + query_parts.append(query_part) + elif op == "neq": + # Not equal operator + query_part = { + "bool": {"must_not": [{"term": {field_name: value}}]} + } + query_parts.append(query_part) + elif op in ["lt", "lte", "gt", "gte"]: + # Range operators + query_parts.append({"range": {field_name: {op: value}}}) + elif op == "in": + # In operator (value should be a list) + if isinstance(value, list): + query_parts.append({"terms": {field_name: value}}) + else: + query_parts.append({"term": {field_name: value}}) + elif op == "contains": + # Contains operator for arrays + query_parts.append({"term": {field_name: value}}) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Error converting query to Elasticsearch: {e}") + # If there's an error, add a query that matches nothing + query_parts.append({"bool": {"must_not": {"match_all": {}}}}) + raise + + # Combine all query parts with AND logic if there are multiple if query_parts: body["query"] = ( query_parts[0] From a70356b98d41619b62733d35ab8b0ebf933d7845 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Sat, 27 Sep 2025 20:59:27 +0800 Subject: [PATCH 4/9] clean up --- stac_fastapi/core/stac_fastapi/core/core.py | 3 --- .../elasticsearch/stac_fastapi/elasticsearch/database_logic.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index e26914d7f..3da8a8e40 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -286,8 +286,6 @@ async def all_collections( parsed_query = None if query is not None: try: - import orjson - parsed_query = orjson.loads(query) except Exception as e: raise HTTPException( @@ -299,7 +297,6 @@ async def all_collections( if filter_expr is not None: try: # Only raise an error for explicitly unsupported filter languages - # Allow None, cql2-json, and cql2-text (we'll treat it as JSON) if filter_lang is not None and filter_lang not in [ "cql2-json", "cql2-text", diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 8f4824088..e52c9de24 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -307,6 +307,7 @@ async def get_all_collections( elif op == "contains": # Contains operator for arrays query_parts.append({"term": {field_name: value}}) + except Exception as e: logger = logging.getLogger(__name__) logger.error(f"Error converting query to Elasticsearch: {e}") From 4fbbd149ae058d178e302005e21975a176334920 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Sat, 27 Sep 2025 23:55:31 +0800 Subject: [PATCH 5/9] clean up tests --- .../tests/api/test_api_search_collections.py | 23 --- .../tests/api/test_collections_search_env.py | 167 ------------------ 2 files changed, 190 deletions(-) delete mode 100644 stac_fastapi/tests/api/test_collections_search_env.py diff --git a/stac_fastapi/tests/api/test_api_search_collections.py b/stac_fastapi/tests/api/test_api_search_collections.py index adff9e0d1..390546ecf 100644 --- a/stac_fastapi/tests/api/test_api_search_collections.py +++ b/stac_fastapi/tests/api/test_api_search_collections.py @@ -44,25 +44,6 @@ async def test_collections_sort_id_asc(app_client, txn_client, ctx): for i, expected_id in enumerate(sorted_ids): assert test_collections[i]["id"] == expected_id - # Test descending sort by id - resp = await app_client.get( - "/collections", - params=[("sortby", "-id")], - ) - assert resp.status_code == 200 - resp_json = resp.json() - - # Filter collections to only include the ones we created for this test - test_collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Collections should be sorted in reverse alphabetical order by id - sorted_ids = sorted(collection_ids, reverse=True) - assert len(test_collections) == len(collection_ids) - for i, expected_id in enumerate(sorted_ids): - assert test_collections[i]["id"] == expected_id - @pytest.mark.asyncio async def test_collections_sort_id_desc(app_client, txn_client, ctx): @@ -382,12 +363,8 @@ async def test_collections_query_extension(app_client, txn_client, ctx): await refresh_indices(txn_client) - # Test query extension for exact ID match - import json - # Use the exact ID that was created sentinel_id = f"{test_prefix}-sentinel" - print(f"Searching for ID: {sentinel_id}") query = {"id": {"eq": sentinel_id}} diff --git a/stac_fastapi/tests/api/test_collections_search_env.py b/stac_fastapi/tests/api/test_collections_search_env.py deleted file mode 100644 index 5358faf98..000000000 --- a/stac_fastapi/tests/api/test_collections_search_env.py +++ /dev/null @@ -1,167 +0,0 @@ -"""Test the ENABLE_COLLECTIONS_SEARCH environment variable.""" - -import os -import uuid -from unittest import mock - -import pytest - -from ..conftest import create_collection, refresh_indices - - -@pytest.mark.asyncio -@mock.patch.dict(os.environ, {"ENABLE_COLLECTIONS_SEARCH": "false"}) -async def test_collections_search_disabled(app_client, txn_client, load_test_data): - """Test that collection search extensions are disabled when ENABLE_COLLECTIONS_SEARCH=false.""" - # Create multiple collections with different ids to test sorting - base_collection = load_test_data("test_collection.json") - - # Use unique prefixes to avoid conflicts between tests - test_prefix = f"disabled-{uuid.uuid4().hex[:8]}" - collection_ids = [f"{test_prefix}-c", f"{test_prefix}-a", f"{test_prefix}-b"] - - for i, coll_id in enumerate(collection_ids): - test_collection = base_collection.copy() - test_collection["id"] = coll_id - test_collection["title"] = f"Test Collection {i}" - await create_collection(txn_client, test_collection) - - # Refresh indices to ensure collections are searchable - await refresh_indices(txn_client) - - # When collection search is disabled, sortby parameter should be ignored - resp = await app_client.get( - "/collections", - params=[("sortby", "+id")], - ) - assert resp.status_code == 200 - - # Verify that results are NOT sorted by id (should be in insertion order or default order) - resp_json = resp.json() - collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Extract the ids in the order they were returned - returned_ids = [c["id"] for c in collections] - - # If sorting was working, they would be in alphabetical order: a, b, c - # But since sorting is disabled, they should be in a different order - # We can't guarantee the exact order, but we can check they're not in alphabetical order - sorted_ids = sorted(returned_ids) - assert ( - returned_ids != sorted_ids or len(collections) < 2 - ), "Collections appear to be sorted despite ENABLE_COLLECTIONS_SEARCH=false" - - # Fields parameter should also be ignored - resp = await app_client.get( - "/collections", - params=[("fields", "id")], # Request only id field - ) - assert resp.status_code == 200 - - # Verify that all fields are still returned, not just id - resp_json = resp.json() - for collection in resp_json["collections"]: - if collection["id"].startswith(test_prefix): - # If fields filtering was working, only id would be present - # Since it's disabled, other fields like title should still be present - assert ( - "title" in collection - ), "Fields filtering appears to be working despite ENABLE_COLLECTIONS_SEARCH=false" - - -@pytest.mark.asyncio -@mock.patch.dict(os.environ, {"ENABLE_COLLECTIONS_SEARCH": "true"}) -async def test_collections_search_enabled(app_client, txn_client, load_test_data): - """Test that collection search extensions work when ENABLE_COLLECTIONS_SEARCH=true.""" - # Create multiple collections with different ids to test sorting - base_collection = load_test_data("test_collection.json") - - # Use unique prefixes to avoid conflicts between tests - test_prefix = f"enabled-{uuid.uuid4().hex[:8]}" - collection_ids = [f"{test_prefix}-c", f"{test_prefix}-a", f"{test_prefix}-b"] - - for i, coll_id in enumerate(collection_ids): - test_collection = base_collection.copy() - test_collection["id"] = coll_id - test_collection["title"] = f"Test Collection {i}" - await create_collection(txn_client, test_collection) - - # Refresh indices to ensure collections are searchable - await refresh_indices(txn_client) - - # Test that sortby parameter works - sort by id ascending - resp = await app_client.get( - "/collections", - params=[("sortby", "+id")], - ) - assert resp.status_code == 200 - - # Verify that results are sorted by id in ascending order - resp_json = resp.json() - collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Extract the ids in the order they were returned - returned_ids = [c["id"] for c in collections] - - # Verify they're in ascending order - assert returned_ids == sorted( - returned_ids - ), "Collections are not sorted by id ascending" - - # Test that sortby parameter works - sort by id descending - resp = await app_client.get( - "/collections", - params=[("sortby", "-id")], - ) - assert resp.status_code == 200 - - # Verify that results are sorted by id in descending order - resp_json = resp.json() - collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Extract the ids in the order they were returned - returned_ids = [c["id"] for c in collections] - - # Verify they're in descending order - assert returned_ids == sorted( - returned_ids, reverse=True - ), "Collections are not sorted by id descending" - - # Test that fields parameter works - request only id field - resp = await app_client.get( - "/collections", - params=[("fields", "id")], - ) - assert resp.status_code == 200 - resp_json = resp.json() - - # When fields=id is specified, collections should only have id field - for collection in resp_json["collections"]: - if collection["id"].startswith(test_prefix): - assert "id" in collection, "id field is missing" - assert ( - "title" not in collection - ), "title field should be excluded when fields=id" - - # Test that fields parameter works - request multiple fields - resp = await app_client.get( - "/collections", - params=[("fields", "id,title")], - ) - assert resp.status_code == 200 - resp_json = resp.json() - - # When fields=id,title is specified, collections should have both fields but not others - for collection in resp_json["collections"]: - if collection["id"].startswith(test_prefix): - assert "id" in collection, "id field is missing" - assert "title" in collection, "title field is missing" - assert ( - "description" not in collection - ), "description field should be excluded when fields=id,title" From 24142a6fc41302c96d14240710e3cebf16758ac5 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Sun, 28 Sep 2025 00:24:13 +0800 Subject: [PATCH 6/9] remove fields extension check --- stac_fastapi/core/stac_fastapi/core/core.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 3da8a8e40..376ababb8 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -255,7 +255,7 @@ async def all_collections( # Process fields parameter for filtering collection properties includes, excludes = set(), set() - if fields and self.extension_is_enabled("FieldsExtension"): + if fields: for field in fields: if field[0] == "-": excludes.add(field[1:]) @@ -351,7 +351,7 @@ async def all_collections( ) # Apply field filtering if fields parameter was provided - if fields and self.extension_is_enabled("FieldsExtension"): + if fields: filtered_collections = [ filter_fields(collection, includes, excludes) for collection in collections @@ -691,11 +691,7 @@ async def post_search( datetime_search=datetime_search, ) - fields = ( - getattr(search_request, "fields", None) - if self.extension_is_enabled("FieldsExtension") - else None - ) + fields = getattr(search_request, "fields", None) include: Set[str] = fields.include if fields and fields.include else set() exclude: Set[str] = fields.exclude if fields and fields.exclude else set() From c8902d6b5369ccaea9a37c566cc210fdc1120f45 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Sun, 28 Sep 2025 00:24:41 +0800 Subject: [PATCH 7/9] update stacql filter --- .../elasticsearch/database_logic.py | 64 +++++++++--------- .../stac_fastapi/opensearch/database_logic.py | 66 +++++++++---------- 2 files changed, 63 insertions(+), 67 deletions(-) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index e52c9de24..0e6536dcf 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -275,38 +275,23 @@ async def get_all_collections( # Apply query extension if provided if query: try: + # First create a search object to apply filters + search = Search(index=COLLECTIONS_INDEX) + # Process each field and operator in the query for field_name, expr in query.items(): for op, value in expr.items(): - # Handle different operators - if op == "eq": - # Equality operator - # Use different query types based on field name - if field_name in ["title", "description"]: - # For text fields, use match_phrase for exact phrase matching - query_part = {"match_phrase": {field_name: value}} - else: - # For other fields, use term query for exact matching - query_part = {"term": {field_name: value}} - query_parts.append(query_part) - elif op == "neq": - # Not equal operator - query_part = { - "bool": {"must_not": [{"term": {field_name: value}}]} - } - query_parts.append(query_part) - elif op in ["lt", "lte", "gt", "gte"]: - # Range operators - query_parts.append({"range": {field_name: {op: value}}}) - elif op == "in": - # In operator (value should be a list) - if isinstance(value, list): - query_parts.append({"terms": {field_name: value}}) - else: - query_parts.append({"term": {field_name: value}}) - elif op == "contains": - # Contains operator for arrays - query_parts.append({"term": {field_name: value}}) + # For collections, we don't need to prefix with 'properties__' + field = field_name + # Apply the filter using apply_stacql_filter + search = self.apply_stacql_filter( + search=search, op=op, field=field, value=value + ) + + # Convert the search object to a query dict and add it to query_parts + search_dict = search.to_dict() + if "query" in search_dict: + query_parts.append(search_dict["query"]) except Exception as e: logger = logging.getLogger(__name__) @@ -607,18 +592,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): Args: search (Search): The search object to apply the filter to. - op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal), - 'lt' (less than), or 'lte' (less than or equal). + op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than), + 'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal). field (str): The field to perform the comparison on. value (float): The value to compare the field against. Returns: search (Search): The search object with the specified filter applied. """ - if op != "eq": + if op == "eq": + search = search.filter("term", **{field: value}) + elif op == "ne" or op == "neq": + # For not equal, use a bool query with must_not + search = search.exclude("term", **{field: value}) + elif op in ["gt", "gte", "lt", "lte"]: + # For range operators key_filter = {field: {op: value}} search = search.filter(Q("range", **key_filter)) - else: + elif op == "in": + # For in operator (value should be a list) + if isinstance(value, list): + search = search.filter("terms", **{field: value}) + else: + search = search.filter("term", **{field: value}) + elif op == "contains": + # For contains operator (for arrays) search = search.filter("term", **{field: value}) return search diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 899d11dec..6170e122d 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -259,38 +259,23 @@ async def get_all_collections( # Apply query extension if provided if query: try: + # First create a search object to apply filters + search = Search(index=COLLECTIONS_INDEX) + # Process each field and operator in the query for field_name, expr in query.items(): for op, value in expr.items(): - # Handle different operators - if op == "eq": - # Equality operator - # Use different query types based on field name - if field_name in ["title", "description"]: - # For text fields, use match_phrase for exact phrase matching - query_part = {"match_phrase": {field_name: value}} - else: - # For other fields, use term query for exact matching - query_part = {"term": {field_name: value}} - query_parts.append(query_part) - elif op == "neq": - # Not equal operator - query_part = { - "bool": {"must_not": [{"term": {field_name: value}}]} - } - query_parts.append(query_part) - elif op in ["lt", "lte", "gt", "gte"]: - # Range operators - query_parts.append({"range": {field_name: {op: value}}}) - elif op == "in": - # In operator (value should be a list) - if isinstance(value, list): - query_parts.append({"terms": {field_name: value}}) - else: - query_parts.append({"term": {field_name: value}}) - elif op == "contains": - # Contains operator for arrays - query_parts.append({"term": {field_name: value}}) + # For collections, we don't need to prefix with 'properties__' + field = field_name + # Apply the filter using apply_stacql_filter + search = self.apply_stacql_filter( + search=search, op=op, field=field, value=value + ) + + # Convert the search object to a query dict and add it to query_parts + search_dict = search.to_dict() + if "query" in search_dict: + query_parts.append(search_dict["query"]) except Exception as e: logger = logging.getLogger(__name__) logger.error(f"Error converting query to Elasticsearch: {e}") @@ -608,18 +593,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): Args: search (Search): The search object to apply the filter to. - op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal), - 'lt' (less than), or 'lte' (less than or equal). + op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than), + 'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal). field (str): The field to perform the comparison on. value (float): The value to compare the field against. Returns: search (Search): The search object with the specified filter applied. """ - if op != "eq": - key_filter = {field: {f"{op}": value}} + if op == "eq": + search = search.filter("term", **{field: value}) + elif op == "ne" or op == "neq": + # For not equal, use a bool query with must_not + search = search.exclude("term", **{field: value}) + elif op in ["gt", "gte", "lt", "lte"]: + # For range operators + key_filter = {field: {op: value}} search = search.filter(Q("range", **key_filter)) - else: + elif op == "in": + # For in operator (value should be a list) + if isinstance(value, list): + search = search.filter("terms", **{field: value}) + else: + search = search.filter("term", **{field: value}) + elif op == "contains": + # For contains operator (for arrays) search = search.filter("term", **{field: value}) return search From ebfc05314a63689e6f39de451c07beb1e17759db Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Sun, 28 Sep 2025 00:27:20 +0800 Subject: [PATCH 8/9] es to os --- .../opensearch/stac_fastapi/opensearch/database_logic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 6170e122d..e86d03a30 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -163,7 +163,7 @@ async def get_all_collections( filter: Optional[Dict[str, Any]] = None, query: Optional[Dict[str, Dict[str, Any]]] = None, ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """Retrieve a list of collections from Elasticsearch, supporting pagination. + """Retrieve a list of collections from OpenSearch, supporting pagination. Args: token (Optional[str]): The pagination token. @@ -195,7 +195,7 @@ async def get_all_collections( raise HTTPException( status_code=400, detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. " - + "Text fields are not sortable by default in Elasticsearch. " + + "Text fields are not sortable by default in OpenSearch. " + "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ", ) formatted_sort.append({field: {"order": direction}}) @@ -252,7 +252,7 @@ async def get_all_collections( # Convert string filter to dict if needed if isinstance(filter, str): filter = orjson.loads(filter) - # Convert the filter to an Elasticsearch query using the filter module + # Convert the filter to an OpenSearch query using the filter module es_query = filter_module.to_es(await self.get_queryables_mapping(), filter) query_parts.append(es_query) @@ -278,7 +278,7 @@ async def get_all_collections( query_parts.append(search_dict["query"]) except Exception as e: logger = logging.getLogger(__name__) - logger.error(f"Error converting query to Elasticsearch: {e}") + logger.error(f"Error converting query to OpenSearch: {e}") # If there's an error, add a query that matches nothing query_parts.append({"bool": {"must_not": {"match_all": {}}}}) raise From 5891dda825a9c7573d42e6fa835ee184de1b2ba0 Mon Sep 17 00:00:00 2001 From: jonhealy1 Date: Sun, 28 Sep 2025 00:41:53 +0800 Subject: [PATCH 9/9] lint --- stac_fastapi/tests/api/test_api_search_collections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_fastapi/tests/api/test_api_search_collections.py b/stac_fastapi/tests/api/test_api_search_collections.py index d8b7a1a60..668ba0603 100644 --- a/stac_fastapi/tests/api/test_api_search_collections.py +++ b/stac_fastapi/tests/api/test_api_search_collections.py @@ -437,7 +437,7 @@ async def test_collections_query_extension(app_client, txn_client, ctx): assert f"{test_prefix}-sentinel" not in found_ids assert f"{test_prefix}-landsat" in found_ids assert f"{test_prefix}-modis" in found_ids - + async def test_collections_datetime_filter(app_client, load_test_data, txn_client): """Test filtering collections by datetime."""