diff --git a/CHANGELOG.md b/CHANGELOG.md index b0e1ea3a7..87410d878 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added - GET `/collections` collection search structured filter extension with support for both cql2-json and cql2-text formats. [#475](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/475) +- GET `/collections` collection search query extension. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476) - GET `/collections` collections search datetime filtering support. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476) ### Changed diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 220219e9f..a6862cf25 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -232,6 +232,7 @@ async def all_collections( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, q: Optional[Union[str, List[str]]] = None, + query: Optional[str] = None, **kwargs, ) -> stac_types.Collections: """Read all collections from the database. @@ -241,6 +242,7 @@ async def all_collections( fields (Optional[List[str]]): Fields to include or exclude from the results. sortby (Optional[str]): Sorting options for the results. filter_expr (Optional[str]): Structured filter expression in CQL2 JSON or CQL2-text format. + query (Optional[str]): Legacy query parameter (deprecated). filter_lang (Optional[str]): Must be 'cql2-json' or 'cql2-text' if specified, other values will result in an error. q (Optional[Union[str, List[str]]]): Free text search terms. **kwargs: Keyword arguments from the request. @@ -255,7 +257,7 @@ async def all_collections( # Process fields parameter for filtering collection properties includes, excludes = set(), set() - if fields and self.extension_is_enabled("FieldsExtension"): + if fields: for field in fields: if field[0] == "-": excludes.add(field[1:]) @@ -282,11 +284,21 @@ async def all_collections( if q is not None: q_list = [q] if isinstance(q, str) else q + # Parse the query parameter if provided + parsed_query = None + if query is not None: + try: + parsed_query = orjson.loads(query) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid query parameter: {e}" + ) + # Parse the filter parameter if provided parsed_filter = None if filter_expr is not None: try: - # Check if filter_lang is specified and not one of the supported formats + # Only raise an error for explicitly unsupported filter languages if filter_lang is not None and filter_lang not in [ "cql2-json", "cql2-text", @@ -294,7 +306,7 @@ async def all_collections( # Raise an error for unsupported filter languages raise HTTPException( status_code=400, - detail=f"Input should be 'cql2-json' or 'cql2-text' for collections. Got '{filter_lang}'.", + detail=f"Only 'cql2-json' and 'cql2-text' filter languages are supported for collections. Got '{filter_lang}'.", ) # Handle different filter formats @@ -341,11 +353,12 @@ async def all_collections( sort=sort, q=q_list, filter=parsed_filter, + query=parsed_query, datetime=parsed_datetime, ) # Apply field filtering if fields parameter was provided - if fields and self.extension_is_enabled("FieldsExtension"): + if fields: filtered_collections = [ filter_fields(collection, includes, excludes) for collection in collections @@ -685,11 +698,7 @@ async def post_search( datetime_search=datetime_search, ) - fields = ( - getattr(search_request, "fields", None) - if self.extension_is_enabled("FieldsExtension") - else None - ) + fields = getattr(search_request, "fields", None) include: Set[str] = fields.include if fields and fields.include else set() exclude: Set[str] = fields.exclude if fields and fields.exclude else set() diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index e3292cbf8..8cc32088c 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -121,7 +121,7 @@ if ENABLE_COLLECTIONS_SEARCH: # Create collection search extensions collection_search_extensions = [ - # QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), + QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]), FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]), CollectionSearchFilterExtension( diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 08669a280..a7893dc8b 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -177,6 +177,7 @@ async def get_all_collections( sort: Optional[List[Dict[str, Any]]] = None, q: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, + query: Optional[Dict[str, Dict[str, Any]]] = None, datetime: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """Retrieve a list of collections from Elasticsearch, supporting pagination. @@ -187,6 +188,7 @@ async def get_all_collections( request (Request): The FastAPI request object. sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request. q (Optional[List[str]]): Free text search terms. + query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters. filter (Optional[Dict[str, Any]]): Structured query in CQL2 format. datetime (Optional[str]): Temporal filter. @@ -272,6 +274,35 @@ async def get_all_collections( es_query = filter_module.to_es(await self.get_queryables_mapping(), filter) query_parts.append(es_query) + # Apply query extension if provided + if query: + try: + # First create a search object to apply filters + search = Search(index=COLLECTIONS_INDEX) + + # Process each field and operator in the query + for field_name, expr in query.items(): + for op, value in expr.items(): + # For collections, we don't need to prefix with 'properties__' + field = field_name + # Apply the filter using apply_stacql_filter + search = self.apply_stacql_filter( + search=search, op=op, field=field, value=value + ) + + # Convert the search object to a query dict and add it to query_parts + search_dict = search.to_dict() + if "query" in search_dict: + query_parts.append(search_dict["query"]) + + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Error converting query to Elasticsearch: {e}") + # If there's an error, add a query that matches nothing + query_parts.append({"bool": {"must_not": {"match_all": {}}}}) + raise + + # Combine all query parts with AND logic if there are multiple datetime_filter = None if datetime: datetime_filter = self._apply_collection_datetime_filter(datetime) @@ -605,18 +636,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): Args: search (Search): The search object to apply the filter to. - op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal), - 'lt' (less than), or 'lte' (less than or equal). + op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than), + 'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal). field (str): The field to perform the comparison on. value (float): The value to compare the field against. Returns: search (Search): The search object with the specified filter applied. """ - if op != "eq": + if op == "eq": + search = search.filter("term", **{field: value}) + elif op == "ne" or op == "neq": + # For not equal, use a bool query with must_not + search = search.exclude("term", **{field: value}) + elif op in ["gt", "gte", "lt", "lte"]: + # For range operators key_filter = {field: {op: value}} search = search.filter(Q("range", **key_filter)) - else: + elif op == "in": + # For in operator (value should be a list) + if isinstance(value, list): + search = search.filter("terms", **{field: value}) + else: + search = search.filter("term", **{field: value}) + elif op == "contains": + # For contains operator (for arrays) search = search.filter("term", **{field: value}) return search diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index b842d929c..56f717a34 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -121,7 +121,7 @@ if ENABLE_COLLECTIONS_SEARCH: # Create collection search extensions collection_search_extensions = [ - # QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), + QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]), FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]), CollectionSearchFilterExtension( @@ -170,6 +170,7 @@ post_request_model=post_request_model, landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"), ), + "collections_get_request_model": collections_get_request_model, "search_get_request_model": create_get_request_model(search_extensions), "search_post_request_model": post_request_model, "items_get_request_model": items_get_request_model, diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 54deb36bb..694d6cfae 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -161,9 +161,10 @@ async def get_all_collections( sort: Optional[List[Dict[str, Any]]] = None, q: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, + query: Optional[Dict[str, Dict[str, Any]]] = None, datetime: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """Retrieve a list of collections from Opensearch, supporting pagination. + """Retrieve a list of collections from OpenSearch, supporting pagination. Args: token (Optional[str]): The pagination token. @@ -171,7 +172,8 @@ async def get_all_collections( request (Request): The FastAPI request object. sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request. q (Optional[List[str]]): Free text search terms. - filter (Optional[Dict[str, Any]]): Structured query in CQL2 format. + filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format. + query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters. datetime (Optional[str]): Temporal filter. Returns: @@ -195,7 +197,7 @@ async def get_all_collections( raise HTTPException( status_code=400, detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. " - + "Text fields are not sortable by default in Opensearch. " + + "Text fields are not sortable by default in OpenSearch. " + "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ", ) formatted_sort.append({field: {"order": direction}}) @@ -252,10 +254,37 @@ async def get_all_collections( # Convert string filter to dict if needed if isinstance(filter, str): filter = orjson.loads(filter) - # Convert the filter to an Opensearch query using the filter module + # Convert the filter to an OpenSearch query using the filter module es_query = filter_module.to_es(await self.get_queryables_mapping(), filter) query_parts.append(es_query) + # Apply query extension if provided + if query: + try: + # First create a search object to apply filters + search = Search(index=COLLECTIONS_INDEX) + + # Process each field and operator in the query + for field_name, expr in query.items(): + for op, value in expr.items(): + # For collections, we don't need to prefix with 'properties__' + field = field_name + # Apply the filter using apply_stacql_filter + search = self.apply_stacql_filter( + search=search, op=op, field=field, value=value + ) + + # Convert the search object to a query dict and add it to query_parts + search_dict = search.to_dict() + if "query" in search_dict: + query_parts.append(search_dict["query"]) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error(f"Error converting query to OpenSearch: {e}") + # If there's an error, add a query that matches nothing + query_parts.append({"bool": {"must_not": {"match_all": {}}}}) + raise + datetime_filter = None if datetime: datetime_filter = self._apply_collection_datetime_filter(datetime) @@ -607,18 +636,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): Args: search (Search): The search object to apply the filter to. - op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal), - 'lt' (less than), or 'lte' (less than or equal). + op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than), + 'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal). field (str): The field to perform the comparison on. value (float): The value to compare the field against. Returns: search (Search): The search object with the specified filter applied. """ - if op != "eq": - key_filter = {field: {f"{op}": value}} + if op == "eq": + search = search.filter("term", **{field: value}) + elif op == "ne" or op == "neq": + # For not equal, use a bool query with must_not + search = search.exclude("term", **{field: value}) + elif op in ["gt", "gte", "lt", "lte"]: + # For range operators + key_filter = {field: {op: value}} search = search.filter(Q("range", **key_filter)) - else: + elif op == "in": + # For in operator (value should be a list) + if isinstance(value, list): + search = search.filter("terms", **{field: value}) + else: + search = search.filter("term", **{field: value}) + elif op == "contains": + # For contains operator (for arrays) search = search.filter("term", **{field: value}) return search diff --git a/stac_fastapi/tests/api/test_api_search_collections.py b/stac_fastapi/tests/api/test_api_search_collections.py index fb739f903..668ba0603 100644 --- a/stac_fastapi/tests/api/test_api_search_collections.py +++ b/stac_fastapi/tests/api/test_api_search_collections.py @@ -7,10 +7,10 @@ @pytest.mark.asyncio -async def test_collections_sort_id_asc(app_client, txn_client, load_test_data): +async def test_collections_sort_id_asc(app_client, txn_client, ctx): """Verify GET /collections honors ascending sort on id.""" # Create multiple collections with different ids - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Create collections with ids in a specific order to test sorting # Use unique prefixes to avoid conflicts between tests @@ -23,6 +23,8 @@ async def test_collections_sort_id_asc(app_client, txn_client, load_test_data): test_collection["title"] = f"Test Collection {i}" await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test ascending sort by id resp = await app_client.get( "/collections", @@ -44,10 +46,10 @@ async def test_collections_sort_id_asc(app_client, txn_client, load_test_data): @pytest.mark.asyncio -async def test_collections_sort_id_desc(app_client, txn_client, load_test_data): +async def test_collections_sort_id_desc(app_client, txn_client, ctx): """Verify GET /collections honors descending sort on id.""" # Create multiple collections with different ids - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Create collections with ids in a specific order to test sorting # Use unique prefixes to avoid conflicts between tests @@ -60,6 +62,8 @@ async def test_collections_sort_id_desc(app_client, txn_client, load_test_data): test_collection["title"] = f"Test Collection {i}" await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test descending sort by id resp = await app_client.get( "/collections", @@ -81,10 +85,10 @@ async def test_collections_sort_id_desc(app_client, txn_client, load_test_data): @pytest.mark.asyncio -async def test_collections_fields(app_client, txn_client, load_test_data): +async def test_collections_fields(app_client, txn_client, ctx): """Verify GET /collections honors the fields parameter.""" # Create multiple collections with different ids - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Create collections with ids in a specific order to test fields # Use unique prefixes to avoid conflicts between tests @@ -98,6 +102,8 @@ async def test_collections_fields(app_client, txn_client, load_test_data): test_collection["description"] = f"Description for collection {i}" await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test include fields parameter resp = await app_client.get( "/collections", @@ -156,10 +162,10 @@ async def test_collections_fields(app_client, txn_client, load_test_data): @pytest.mark.asyncio -async def test_collections_free_text_search_get(app_client, txn_client, load_test_data): +async def test_collections_free_text_search_get(app_client, txn_client, ctx): """Verify GET /collections honors the q parameter for free text search.""" # Create multiple collections with different content - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Use unique prefixes to avoid conflicts between tests test_prefix = f"q-get-{uuid.uuid4().hex[:8]}" @@ -193,6 +199,8 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes test_collection["summaries"] = coll["summaries"] await create_collection(txn_client, test_collection) + await refresh_indices(txn_client) + # Test free text search for "sentinel" resp = await app_client.get( "/collections", @@ -229,10 +237,10 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes @pytest.mark.asyncio -async def test_collections_filter_search(app_client, txn_client, load_test_data): +async def test_collections_filter_search(app_client, txn_client, ctx): """Verify GET /collections honors the filter parameter for structured search.""" # Create multiple collections with different content - base_collection = load_test_data("test_collection.json") + base_collection = ctx.collection # Use unique prefixes to avoid conflicts between tests test_prefix = f"filter-{uuid.uuid4().hex[:8]}" @@ -316,6 +324,121 @@ async def test_collections_filter_search(app_client, txn_client, load_test_data) @pytest.mark.asyncio +async def test_collections_query_extension(app_client, txn_client, ctx): + """Verify GET /collections honors the query extension.""" + # Create multiple collections with different content + base_collection = ctx.collection + # Use unique prefixes to avoid conflicts between tests + test_prefix = f"query-ext-{uuid.uuid4().hex[:8]}" + + # Create collections with different content to test query extension + test_collections = [ + { + "id": f"{test_prefix}-sentinel", + "title": "Sentinel-2 Collection", + "description": "Collection of Sentinel-2 data", + "summaries": {"platform": ["sentinel-2a", "sentinel-2b"]}, + }, + { + "id": f"{test_prefix}-landsat", + "title": "Landsat Collection", + "description": "Collection of Landsat data", + "summaries": {"platform": ["landsat-8", "landsat-9"]}, + }, + { + "id": f"{test_prefix}-modis", + "title": "MODIS Collection", + "description": "Collection of MODIS data", + "summaries": {"platform": ["terra", "aqua"]}, + }, + ] + + for i, coll in enumerate(test_collections): + test_collection = base_collection.copy() + test_collection["id"] = coll["id"] + test_collection["title"] = coll["title"] + test_collection["description"] = coll["description"] + test_collection["summaries"] = coll["summaries"] + await create_collection(txn_client, test_collection) + + await refresh_indices(txn_client) + + # Use the exact ID that was created + sentinel_id = f"{test_prefix}-sentinel" + + query = {"id": {"eq": sentinel_id}} + + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + found_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + + # Should only find the sentinel collection + assert len(found_collections) == 1 + assert found_collections[0]["id"] == f"{test_prefix}-sentinel" + + # Test query extension with equal operator on ID + query = {"id": {"eq": f"{test_prefix}-sentinel"}} + + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + assert resp.status_code == 200 + resp_json = resp.json() + + # Filter collections to only include the ones we created for this test + found_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + found_ids = [c["id"] for c in found_collections] + + # Should find landsat and modis collections but not sentinel + assert len(found_collections) == 1 + assert f"{test_prefix}-sentinel" in found_ids + assert f"{test_prefix}-landsat" not in found_ids + assert f"{test_prefix}-modis" not in found_ids + + # Test query extension with not-equal operator on ID + query = {"id": {"neq": f"{test_prefix}-sentinel"}} + + print(f"\nTesting neq query: {query}") + print(f"JSON query: {json.dumps(query)}") + + resp = await app_client.get( + "/collections", + params=[("query", json.dumps(query))], + ) + print(f"Response status: {resp.status_code}") + assert resp.status_code == 200 + resp_json = resp.json() + print(f"Response JSON keys: {resp_json.keys()}") + print(f"Number of collections in response: {len(resp_json.get('collections', []))}") + + # Print all collection IDs in the response + all_ids = [c["id"] for c in resp_json.get("collections", [])] + print(f"All collection IDs in response: {all_ids}") + + # Filter collections to only include the ones we created for this test + found_collections = [ + c for c in resp_json["collections"] if c["id"].startswith(test_prefix) + ] + found_ids = [c["id"] for c in found_collections] + + # Should find landsat and modis collections but not sentinel + assert len(found_collections) == 2 + assert f"{test_prefix}-sentinel" not in found_ids + assert f"{test_prefix}-landsat" in found_ids + assert f"{test_prefix}-modis" in found_ids + + async def test_collections_datetime_filter(app_client, load_test_data, txn_client): """Test filtering collections by datetime.""" # Create a test collection with a specific temporal extent diff --git a/stac_fastapi/tests/api/test_collections_search_env.py b/stac_fastapi/tests/api/test_collections_search_env.py deleted file mode 100644 index 5358faf98..000000000 --- a/stac_fastapi/tests/api/test_collections_search_env.py +++ /dev/null @@ -1,167 +0,0 @@ -"""Test the ENABLE_COLLECTIONS_SEARCH environment variable.""" - -import os -import uuid -from unittest import mock - -import pytest - -from ..conftest import create_collection, refresh_indices - - -@pytest.mark.asyncio -@mock.patch.dict(os.environ, {"ENABLE_COLLECTIONS_SEARCH": "false"}) -async def test_collections_search_disabled(app_client, txn_client, load_test_data): - """Test that collection search extensions are disabled when ENABLE_COLLECTIONS_SEARCH=false.""" - # Create multiple collections with different ids to test sorting - base_collection = load_test_data("test_collection.json") - - # Use unique prefixes to avoid conflicts between tests - test_prefix = f"disabled-{uuid.uuid4().hex[:8]}" - collection_ids = [f"{test_prefix}-c", f"{test_prefix}-a", f"{test_prefix}-b"] - - for i, coll_id in enumerate(collection_ids): - test_collection = base_collection.copy() - test_collection["id"] = coll_id - test_collection["title"] = f"Test Collection {i}" - await create_collection(txn_client, test_collection) - - # Refresh indices to ensure collections are searchable - await refresh_indices(txn_client) - - # When collection search is disabled, sortby parameter should be ignored - resp = await app_client.get( - "/collections", - params=[("sortby", "+id")], - ) - assert resp.status_code == 200 - - # Verify that results are NOT sorted by id (should be in insertion order or default order) - resp_json = resp.json() - collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Extract the ids in the order they were returned - returned_ids = [c["id"] for c in collections] - - # If sorting was working, they would be in alphabetical order: a, b, c - # But since sorting is disabled, they should be in a different order - # We can't guarantee the exact order, but we can check they're not in alphabetical order - sorted_ids = sorted(returned_ids) - assert ( - returned_ids != sorted_ids or len(collections) < 2 - ), "Collections appear to be sorted despite ENABLE_COLLECTIONS_SEARCH=false" - - # Fields parameter should also be ignored - resp = await app_client.get( - "/collections", - params=[("fields", "id")], # Request only id field - ) - assert resp.status_code == 200 - - # Verify that all fields are still returned, not just id - resp_json = resp.json() - for collection in resp_json["collections"]: - if collection["id"].startswith(test_prefix): - # If fields filtering was working, only id would be present - # Since it's disabled, other fields like title should still be present - assert ( - "title" in collection - ), "Fields filtering appears to be working despite ENABLE_COLLECTIONS_SEARCH=false" - - -@pytest.mark.asyncio -@mock.patch.dict(os.environ, {"ENABLE_COLLECTIONS_SEARCH": "true"}) -async def test_collections_search_enabled(app_client, txn_client, load_test_data): - """Test that collection search extensions work when ENABLE_COLLECTIONS_SEARCH=true.""" - # Create multiple collections with different ids to test sorting - base_collection = load_test_data("test_collection.json") - - # Use unique prefixes to avoid conflicts between tests - test_prefix = f"enabled-{uuid.uuid4().hex[:8]}" - collection_ids = [f"{test_prefix}-c", f"{test_prefix}-a", f"{test_prefix}-b"] - - for i, coll_id in enumerate(collection_ids): - test_collection = base_collection.copy() - test_collection["id"] = coll_id - test_collection["title"] = f"Test Collection {i}" - await create_collection(txn_client, test_collection) - - # Refresh indices to ensure collections are searchable - await refresh_indices(txn_client) - - # Test that sortby parameter works - sort by id ascending - resp = await app_client.get( - "/collections", - params=[("sortby", "+id")], - ) - assert resp.status_code == 200 - - # Verify that results are sorted by id in ascending order - resp_json = resp.json() - collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Extract the ids in the order they were returned - returned_ids = [c["id"] for c in collections] - - # Verify they're in ascending order - assert returned_ids == sorted( - returned_ids - ), "Collections are not sorted by id ascending" - - # Test that sortby parameter works - sort by id descending - resp = await app_client.get( - "/collections", - params=[("sortby", "-id")], - ) - assert resp.status_code == 200 - - # Verify that results are sorted by id in descending order - resp_json = resp.json() - collections = [ - c for c in resp_json["collections"] if c["id"].startswith(test_prefix) - ] - - # Extract the ids in the order they were returned - returned_ids = [c["id"] for c in collections] - - # Verify they're in descending order - assert returned_ids == sorted( - returned_ids, reverse=True - ), "Collections are not sorted by id descending" - - # Test that fields parameter works - request only id field - resp = await app_client.get( - "/collections", - params=[("fields", "id")], - ) - assert resp.status_code == 200 - resp_json = resp.json() - - # When fields=id is specified, collections should only have id field - for collection in resp_json["collections"]: - if collection["id"].startswith(test_prefix): - assert "id" in collection, "id field is missing" - assert ( - "title" not in collection - ), "title field should be excluded when fields=id" - - # Test that fields parameter works - request multiple fields - resp = await app_client.get( - "/collections", - params=[("fields", "id,title")], - ) - assert resp.status_code == 200 - resp_json = resp.json() - - # When fields=id,title is specified, collections should have both fields but not others - for collection in resp_json["collections"]: - if collection["id"].startswith(test_prefix): - assert "id" in collection, "id field is missing" - assert "title" in collection, "title field is missing" - assert ( - "description" not in collection - ), "description field should be excluded when fields=id,title"