Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added

- GET `/collections` collection search structured filter extension with support for both cql2-json and cql2-text formats. [#475](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/475)
- GET `/collections` collection search query extension. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476)
- GET `/collections` collections search datetime filtering support. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476)

### Changed
Expand Down
27 changes: 18 additions & 9 deletions stac_fastapi/core/stac_fastapi/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ async def all_collections(
filter_expr: Optional[str] = None,
filter_lang: Optional[str] = None,
q: Optional[Union[str, List[str]]] = None,
query: Optional[str] = None,
**kwargs,
) -> stac_types.Collections:
"""Read all collections from the database.
Expand All @@ -241,6 +242,7 @@ async def all_collections(
fields (Optional[List[str]]): Fields to include or exclude from the results.
sortby (Optional[str]): Sorting options for the results.
filter_expr (Optional[str]): Structured filter expression in CQL2 JSON or CQL2-text format.
query (Optional[str]): Legacy query parameter (deprecated).
filter_lang (Optional[str]): Must be 'cql2-json' or 'cql2-text' if specified, other values will result in an error.
q (Optional[Union[str, List[str]]]): Free text search terms.
**kwargs: Keyword arguments from the request.
Expand All @@ -255,7 +257,7 @@ async def all_collections(

# Process fields parameter for filtering collection properties
includes, excludes = set(), set()
if fields and self.extension_is_enabled("FieldsExtension"):
if fields:
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
Expand All @@ -282,19 +284,29 @@ async def all_collections(
if q is not None:
q_list = [q] if isinstance(q, str) else q

# Parse the query parameter if provided
parsed_query = None
if query is not None:
try:
parsed_query = orjson.loads(query)
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Invalid query parameter: {e}"
)

# Parse the filter parameter if provided
parsed_filter = None
if filter_expr is not None:
try:
# Check if filter_lang is specified and not one of the supported formats
# Only raise an error for explicitly unsupported filter languages
if filter_lang is not None and filter_lang not in [
"cql2-json",
"cql2-text",
]:
# Raise an error for unsupported filter languages
raise HTTPException(
status_code=400,
detail=f"Input should be 'cql2-json' or 'cql2-text' for collections. Got '{filter_lang}'.",
detail=f"Only 'cql2-json' and 'cql2-text' filter languages are supported for collections. Got '{filter_lang}'.",
)

# Handle different filter formats
Expand Down Expand Up @@ -341,11 +353,12 @@ async def all_collections(
sort=sort,
q=q_list,
filter=parsed_filter,
query=parsed_query,
datetime=parsed_datetime,
)

# Apply field filtering if fields parameter was provided
if fields and self.extension_is_enabled("FieldsExtension"):
if fields:
filtered_collections = [
filter_fields(collection, includes, excludes)
for collection in collections
Expand Down Expand Up @@ -685,11 +698,7 @@ async def post_search(
datetime_search=datetime_search,
)

fields = (
getattr(search_request, "fields", None)
if self.extension_is_enabled("FieldsExtension")
else None
)
fields = getattr(search_request, "fields", None)
include: Set[str] = fields.include if fields and fields.include else set()
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
if ENABLE_COLLECTIONS_SEARCH:
# Create collection search extensions
collection_search_extensions = [
# QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
CollectionSearchFilterExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ async def get_all_collections(
sort: Optional[List[Dict[str, Any]]] = None,
q: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
query: Optional[Dict[str, Dict[str, Any]]] = None,
datetime: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
"""Retrieve a list of collections from Elasticsearch, supporting pagination.
Expand All @@ -187,6 +188,7 @@ async def get_all_collections(
request (Request): The FastAPI request object.
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
q (Optional[List[str]]): Free text search terms.
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
datetime (Optional[str]): Temporal filter.

Expand Down Expand Up @@ -272,6 +274,35 @@ async def get_all_collections(
es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
query_parts.append(es_query)

# Apply query extension if provided
if query:
try:
# First create a search object to apply filters
search = Search(index=COLLECTIONS_INDEX)

# Process each field and operator in the query
for field_name, expr in query.items():
for op, value in expr.items():
# For collections, we don't need to prefix with 'properties__'
field = field_name
# Apply the filter using apply_stacql_filter
search = self.apply_stacql_filter(
search=search, op=op, field=field, value=value
)

# Convert the search object to a query dict and add it to query_parts
search_dict = search.to_dict()
if "query" in search_dict:
query_parts.append(search_dict["query"])

except Exception as e:
logger = logging.getLogger(__name__)
logger.error(f"Error converting query to Elasticsearch: {e}")
# If there's an error, add a query that matches nothing
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
raise

# Combine all query parts with AND logic if there are multiple
datetime_filter = None
if datetime:
datetime_filter = self._apply_collection_datetime_filter(datetime)
Expand Down Expand Up @@ -605,18 +636,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):

Args:
search (Search): The search object to apply the filter to.
op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal),
'lt' (less than), or 'lte' (less than or equal).
op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
field (str): The field to perform the comparison on.
value (float): The value to compare the field against.

Returns:
search (Search): The search object with the specified filter applied.
"""
if op != "eq":
if op == "eq":
search = search.filter("term", **{field: value})
elif op == "ne" or op == "neq":
# For not equal, use a bool query with must_not
search = search.exclude("term", **{field: value})
elif op in ["gt", "gte", "lt", "lte"]:
# For range operators
key_filter = {field: {op: value}}
search = search.filter(Q("range", **key_filter))
else:
elif op == "in":
# For in operator (value should be a list)
if isinstance(value, list):
search = search.filter("terms", **{field: value})
else:
search = search.filter("term", **{field: value})
elif op == "contains":
# For contains operator (for arrays)
search = search.filter("term", **{field: value})

return search
Expand Down
3 changes: 2 additions & 1 deletion stac_fastapi/opensearch/stac_fastapi/opensearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
if ENABLE_COLLECTIONS_SEARCH:
# Create collection search extensions
collection_search_extensions = [
# QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
CollectionSearchFilterExtension(
Expand Down Expand Up @@ -170,6 +170,7 @@
post_request_model=post_request_model,
landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"),
),
"collections_get_request_model": collections_get_request_model,
"search_get_request_model": create_get_request_model(search_extensions),
"search_post_request_model": post_request_model,
"items_get_request_model": items_get_request_model,
Expand Down
60 changes: 51 additions & 9 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,19 @@ async def get_all_collections(
sort: Optional[List[Dict[str, Any]]] = None,
q: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
query: Optional[Dict[str, Dict[str, Any]]] = None,
datetime: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
"""Retrieve a list of collections from Opensearch, supporting pagination.
"""Retrieve a list of collections from OpenSearch, supporting pagination.

Args:
token (Optional[str]): The pagination token.
limit (int): The number of results to return.
request (Request): The FastAPI request object.
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
q (Optional[List[str]]): Free text search terms.
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format.
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
datetime (Optional[str]): Temporal filter.

Returns:
Expand All @@ -195,7 +197,7 @@ async def get_all_collections(
raise HTTPException(
status_code=400,
detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. "
+ "Text fields are not sortable by default in Opensearch. "
+ "Text fields are not sortable by default in OpenSearch. "
+ "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ",
)
formatted_sort.append({field: {"order": direction}})
Expand Down Expand Up @@ -252,10 +254,37 @@ async def get_all_collections(
# Convert string filter to dict if needed
if isinstance(filter, str):
filter = orjson.loads(filter)
# Convert the filter to an Opensearch query using the filter module
# Convert the filter to an OpenSearch query using the filter module
es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
query_parts.append(es_query)

# Apply query extension if provided
if query:
try:
# First create a search object to apply filters
search = Search(index=COLLECTIONS_INDEX)

# Process each field and operator in the query
for field_name, expr in query.items():
for op, value in expr.items():
# For collections, we don't need to prefix with 'properties__'
field = field_name
# Apply the filter using apply_stacql_filter
search = self.apply_stacql_filter(
search=search, op=op, field=field, value=value
)

# Convert the search object to a query dict and add it to query_parts
search_dict = search.to_dict()
if "query" in search_dict:
query_parts.append(search_dict["query"])
except Exception as e:
logger = logging.getLogger(__name__)
logger.error(f"Error converting query to OpenSearch: {e}")
# If there's an error, add a query that matches nothing
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
raise

datetime_filter = None
if datetime:
datetime_filter = self._apply_collection_datetime_filter(datetime)
Expand Down Expand Up @@ -607,18 +636,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):

Args:
search (Search): The search object to apply the filter to.
op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal),
'lt' (less than), or 'lte' (less than or equal).
op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
field (str): The field to perform the comparison on.
value (float): The value to compare the field against.

Returns:
search (Search): The search object with the specified filter applied.
"""
if op != "eq":
key_filter = {field: {f"{op}": value}}
if op == "eq":
search = search.filter("term", **{field: value})
elif op == "ne" or op == "neq":
# For not equal, use a bool query with must_not
search = search.exclude("term", **{field: value})
elif op in ["gt", "gte", "lt", "lte"]:
# For range operators
key_filter = {field: {op: value}}
search = search.filter(Q("range", **key_filter))
else:
elif op == "in":
# For in operator (value should be a list)
if isinstance(value, list):
search = search.filter("terms", **{field: value})
else:
search = search.filter("term", **{field: value})
elif op == "contains":
# For contains operator (for arrays)
search = search.filter("term", **{field: value})

return search
Expand Down
Loading