Skip to content

Commit 8821c6d

Browse files
authored
GET /collections query extension (#477)
**Related Issue(s):** - #462 **Description:** **PR Checklist:** - [x] Code is formatted and linted (run `pre-commit run --all-files`) - [x] Tests pass (run `make test`) - [x] Documentation has been updated to reflect changes, if applicable - [x] Changes are added to the changelog
1 parent 859e456 commit 8821c6d

File tree

8 files changed

+254
-201
lines changed

8 files changed

+254
-201
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1010
### Added
1111

1212
- GET `/collections` collection search structured filter extension with support for both cql2-json and cql2-text formats. [#475](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/475)
13+
- GET `/collections` collection search query extension. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476)
1314
- GET `/collections` collections search datetime filtering support. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476)
1415

1516
### Changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ async def all_collections(
232232
filter_expr: Optional[str] = None,
233233
filter_lang: Optional[str] = None,
234234
q: Optional[Union[str, List[str]]] = None,
235+
query: Optional[str] = None,
235236
**kwargs,
236237
) -> stac_types.Collections:
237238
"""Read all collections from the database.
@@ -241,6 +242,7 @@ async def all_collections(
241242
fields (Optional[List[str]]): Fields to include or exclude from the results.
242243
sortby (Optional[str]): Sorting options for the results.
243244
filter_expr (Optional[str]): Structured filter expression in CQL2 JSON or CQL2-text format.
245+
query (Optional[str]): Legacy query parameter (deprecated).
244246
filter_lang (Optional[str]): Must be 'cql2-json' or 'cql2-text' if specified, other values will result in an error.
245247
q (Optional[Union[str, List[str]]]): Free text search terms.
246248
**kwargs: Keyword arguments from the request.
@@ -255,7 +257,7 @@ async def all_collections(
255257

256258
# Process fields parameter for filtering collection properties
257259
includes, excludes = set(), set()
258-
if fields and self.extension_is_enabled("FieldsExtension"):
260+
if fields:
259261
for field in fields:
260262
if field[0] == "-":
261263
excludes.add(field[1:])
@@ -282,19 +284,29 @@ async def all_collections(
282284
if q is not None:
283285
q_list = [q] if isinstance(q, str) else q
284286

287+
# Parse the query parameter if provided
288+
parsed_query = None
289+
if query is not None:
290+
try:
291+
parsed_query = orjson.loads(query)
292+
except Exception as e:
293+
raise HTTPException(
294+
status_code=400, detail=f"Invalid query parameter: {e}"
295+
)
296+
285297
# Parse the filter parameter if provided
286298
parsed_filter = None
287299
if filter_expr is not None:
288300
try:
289-
# Check if filter_lang is specified and not one of the supported formats
301+
# Only raise an error for explicitly unsupported filter languages
290302
if filter_lang is not None and filter_lang not in [
291303
"cql2-json",
292304
"cql2-text",
293305
]:
294306
# Raise an error for unsupported filter languages
295307
raise HTTPException(
296308
status_code=400,
297-
detail=f"Input should be 'cql2-json' or 'cql2-text' for collections. Got '{filter_lang}'.",
309+
detail=f"Only 'cql2-json' and 'cql2-text' filter languages are supported for collections. Got '{filter_lang}'.",
298310
)
299311

300312
# Handle different filter formats
@@ -341,11 +353,12 @@ async def all_collections(
341353
sort=sort,
342354
q=q_list,
343355
filter=parsed_filter,
356+
query=parsed_query,
344357
datetime=parsed_datetime,
345358
)
346359

347360
# Apply field filtering if fields parameter was provided
348-
if fields and self.extension_is_enabled("FieldsExtension"):
361+
if fields:
349362
filtered_collections = [
350363
filter_fields(collection, includes, excludes)
351364
for collection in collections
@@ -685,11 +698,7 @@ async def post_search(
685698
datetime_search=datetime_search,
686699
)
687700

688-
fields = (
689-
getattr(search_request, "fields", None)
690-
if self.extension_is_enabled("FieldsExtension")
691-
else None
692-
)
701+
fields = getattr(search_request, "fields", None)
693702
include: Set[str] = fields.include if fields and fields.include else set()
694703
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()
695704

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
if ENABLE_COLLECTIONS_SEARCH:
122122
# Create collection search extensions
123123
collection_search_extensions = [
124-
# QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
124+
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
125125
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
126126
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
127127
CollectionSearchFilterExtension(

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ async def get_all_collections(
177177
sort: Optional[List[Dict[str, Any]]] = None,
178178
q: Optional[List[str]] = None,
179179
filter: Optional[Dict[str, Any]] = None,
180+
query: Optional[Dict[str, Dict[str, Any]]] = None,
180181
datetime: Optional[str] = None,
181182
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
182183
"""Retrieve a list of collections from Elasticsearch, supporting pagination.
@@ -187,6 +188,7 @@ async def get_all_collections(
187188
request (Request): The FastAPI request object.
188189
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
189190
q (Optional[List[str]]): Free text search terms.
191+
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
190192
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
191193
datetime (Optional[str]): Temporal filter.
192194
@@ -272,6 +274,35 @@ async def get_all_collections(
272274
es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
273275
query_parts.append(es_query)
274276

277+
# Apply query extension if provided
278+
if query:
279+
try:
280+
# First create a search object to apply filters
281+
search = Search(index=COLLECTIONS_INDEX)
282+
283+
# Process each field and operator in the query
284+
for field_name, expr in query.items():
285+
for op, value in expr.items():
286+
# For collections, we don't need to prefix with 'properties__'
287+
field = field_name
288+
# Apply the filter using apply_stacql_filter
289+
search = self.apply_stacql_filter(
290+
search=search, op=op, field=field, value=value
291+
)
292+
293+
# Convert the search object to a query dict and add it to query_parts
294+
search_dict = search.to_dict()
295+
if "query" in search_dict:
296+
query_parts.append(search_dict["query"])
297+
298+
except Exception as e:
299+
logger = logging.getLogger(__name__)
300+
logger.error(f"Error converting query to Elasticsearch: {e}")
301+
# If there's an error, add a query that matches nothing
302+
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
303+
raise
304+
305+
# Combine all query parts with AND logic if there are multiple
275306
datetime_filter = None
276307
if datetime:
277308
datetime_filter = self._apply_collection_datetime_filter(datetime)
@@ -605,18 +636,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
605636
606637
Args:
607638
search (Search): The search object to apply the filter to.
608-
op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal),
609-
'lt' (less than), or 'lte' (less than or equal).
639+
op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
640+
'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
610641
field (str): The field to perform the comparison on.
611642
value (float): The value to compare the field against.
612643
613644
Returns:
614645
search (Search): The search object with the specified filter applied.
615646
"""
616-
if op != "eq":
647+
if op == "eq":
648+
search = search.filter("term", **{field: value})
649+
elif op == "ne" or op == "neq":
650+
# For not equal, use a bool query with must_not
651+
search = search.exclude("term", **{field: value})
652+
elif op in ["gt", "gte", "lt", "lte"]:
653+
# For range operators
617654
key_filter = {field: {op: value}}
618655
search = search.filter(Q("range", **key_filter))
619-
else:
656+
elif op == "in":
657+
# For in operator (value should be a list)
658+
if isinstance(value, list):
659+
search = search.filter("terms", **{field: value})
660+
else:
661+
search = search.filter("term", **{field: value})
662+
elif op == "contains":
663+
# For contains operator (for arrays)
620664
search = search.filter("term", **{field: value})
621665

622666
return search

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
if ENABLE_COLLECTIONS_SEARCH:
122122
# Create collection search extensions
123123
collection_search_extensions = [
124-
# QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
124+
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
125125
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
126126
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
127127
CollectionSearchFilterExtension(
@@ -170,6 +170,7 @@
170170
post_request_model=post_request_model,
171171
landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"),
172172
),
173+
"collections_get_request_model": collections_get_request_model,
173174
"search_get_request_model": create_get_request_model(search_extensions),
174175
"search_post_request_model": post_request_model,
175176
"items_get_request_model": items_get_request_model,

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,19 @@ async def get_all_collections(
161161
sort: Optional[List[Dict[str, Any]]] = None,
162162
q: Optional[List[str]] = None,
163163
filter: Optional[Dict[str, Any]] = None,
164+
query: Optional[Dict[str, Dict[str, Any]]] = None,
164165
datetime: Optional[str] = None,
165166
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
166-
"""Retrieve a list of collections from Opensearch, supporting pagination.
167+
"""Retrieve a list of collections from OpenSearch, supporting pagination.
167168
168169
Args:
169170
token (Optional[str]): The pagination token.
170171
limit (int): The number of results to return.
171172
request (Request): The FastAPI request object.
172173
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
173174
q (Optional[List[str]]): Free text search terms.
174-
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
175+
filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format.
176+
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
175177
datetime (Optional[str]): Temporal filter.
176178
177179
Returns:
@@ -195,7 +197,7 @@ async def get_all_collections(
195197
raise HTTPException(
196198
status_code=400,
197199
detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. "
198-
+ "Text fields are not sortable by default in Opensearch. "
200+
+ "Text fields are not sortable by default in OpenSearch. "
199201
+ "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ",
200202
)
201203
formatted_sort.append({field: {"order": direction}})
@@ -252,10 +254,37 @@ async def get_all_collections(
252254
# Convert string filter to dict if needed
253255
if isinstance(filter, str):
254256
filter = orjson.loads(filter)
255-
# Convert the filter to an Opensearch query using the filter module
257+
# Convert the filter to an OpenSearch query using the filter module
256258
es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
257259
query_parts.append(es_query)
258260

261+
# Apply query extension if provided
262+
if query:
263+
try:
264+
# First create a search object to apply filters
265+
search = Search(index=COLLECTIONS_INDEX)
266+
267+
# Process each field and operator in the query
268+
for field_name, expr in query.items():
269+
for op, value in expr.items():
270+
# For collections, we don't need to prefix with 'properties__'
271+
field = field_name
272+
# Apply the filter using apply_stacql_filter
273+
search = self.apply_stacql_filter(
274+
search=search, op=op, field=field, value=value
275+
)
276+
277+
# Convert the search object to a query dict and add it to query_parts
278+
search_dict = search.to_dict()
279+
if "query" in search_dict:
280+
query_parts.append(search_dict["query"])
281+
except Exception as e:
282+
logger = logging.getLogger(__name__)
283+
logger.error(f"Error converting query to OpenSearch: {e}")
284+
# If there's an error, add a query that matches nothing
285+
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
286+
raise
287+
259288
datetime_filter = None
260289
if datetime:
261290
datetime_filter = self._apply_collection_datetime_filter(datetime)
@@ -607,18 +636,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
607636
608637
Args:
609638
search (Search): The search object to apply the filter to.
610-
op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal),
611-
'lt' (less than), or 'lte' (less than or equal).
639+
op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
640+
'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
612641
field (str): The field to perform the comparison on.
613642
value (float): The value to compare the field against.
614643
615644
Returns:
616645
search (Search): The search object with the specified filter applied.
617646
"""
618-
if op != "eq":
619-
key_filter = {field: {f"{op}": value}}
647+
if op == "eq":
648+
search = search.filter("term", **{field: value})
649+
elif op == "ne" or op == "neq":
650+
# For not equal, use a bool query with must_not
651+
search = search.exclude("term", **{field: value})
652+
elif op in ["gt", "gte", "lt", "lte"]:
653+
# For range operators
654+
key_filter = {field: {op: value}}
620655
search = search.filter(Q("range", **key_filter))
621-
else:
656+
elif op == "in":
657+
# For in operator (value should be a list)
658+
if isinstance(value, list):
659+
search = search.filter("terms", **{field: value})
660+
else:
661+
search = search.filter("term", **{field: value})
662+
elif op == "contains":
663+
# For contains operator (for arrays)
622664
search = search.filter("term", **{field: value})
623665

624666
return search

0 commit comments

Comments
 (0)