diff --git a/orchestrator/cli/search/index_llm.py b/orchestrator/cli/search/index_llm.py index 98d7563ca..afd1065a3 100644 --- a/orchestrator/cli/search/index_llm.py +++ b/orchestrator/cli/search/index_llm.py @@ -14,6 +14,7 @@ def subscriptions_command( subscription_id: str | None = typer.Option(None, help="UUID (default = all)"), dry_run: bool = typer.Option(False, help="No DB writes"), force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), + show_progress: bool = typer.Option(False, help="Show per-entity progress"), ) -> None: """Index subscription_search_index.""" run_indexing_for_entity( @@ -21,6 +22,7 @@ def subscriptions_command( entity_id=subscription_id, dry_run=dry_run, force_index=force_index, + show_progress=show_progress, ) @@ -29,6 +31,7 @@ def products_command( product_id: str | None = typer.Option(None, help="UUID (default = all)"), dry_run: bool = typer.Option(False, help="No DB writes"), force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), + show_progress: bool = typer.Option(False, help="Show per-entity progress"), ) -> None: """Index product_search_index.""" run_indexing_for_entity( @@ -36,6 +39,7 @@ def products_command( entity_id=product_id, dry_run=dry_run, force_index=force_index, + show_progress=show_progress, ) @@ -44,6 +48,7 @@ def processes_command( process_id: str | None = typer.Option(None, help="UUID (default = all)"), dry_run: bool = typer.Option(False, help="No DB writes"), force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), + show_progress: bool = typer.Option(False, help="Show per-entity progress"), ) -> None: """Index process_search_index.""" run_indexing_for_entity( @@ -51,6 +56,7 @@ def processes_command( entity_id=process_id, dry_run=dry_run, force_index=force_index, + show_progress=show_progress, ) @@ -59,6 +65,7 @@ def workflows_command( workflow_id: str | None = typer.Option(None, help="UUID (default = all)"), dry_run: bool = typer.Option(False, help="No DB writes"), force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), + show_progress: bool = typer.Option(False, help="Show per-entity progress"), ) -> None: """Index workflow_search_index.""" run_indexing_for_entity( @@ -66,6 +73,7 @@ def workflows_command( entity_id=workflow_id, dry_run=dry_run, force_index=force_index, + show_progress=show_progress, ) diff --git a/orchestrator/search/indexing/indexer.py b/orchestrator/search/indexing/indexer.py index d2906f767..d519677ce 100644 --- a/orchestrator/search/indexing/indexer.py +++ b/orchestrator/search/indexing/indexer.py @@ -45,6 +45,23 @@ def _maybe_begin(session: Session | None) -> Iterator[None]: yield +@contextmanager +def _maybe_progress(show_progress: bool, total_count: int | None, label: str) -> Iterator[Any]: + """Context manager that optionally creates a progress bar.""" + if show_progress: + import typer + + with typer.progressbar( + length=total_count, + label=label, + show_eta=True, + show_percent=bool(total_count), + ) as progress: + yield progress + else: + yield None + + class Indexer: """Index entities into `AiSearchIndex` using streaming reads and batched writes. @@ -89,11 +106,21 @@ class Indexer: 8) Repeat until the stream is exhausted. """ - def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk_size: int = 1000) -> None: + def __init__( + self, + config: EntityConfig, + dry_run: bool, + force_index: bool, + chunk_size: int = 1000, + show_progress: bool = False, + total_count: int | None = None, + ) -> None: self.config = config self.dry_run = dry_run self.force_index = force_index self.chunk_size = chunk_size + self.show_progress = show_progress + self.total_count = total_count self.embedding_model = llm_settings.EMBEDDING_MODEL self.logger = logger.bind(entity_kind=config.entity_kind.value) self._entity_titles: dict[str, str] = {} @@ -116,13 +143,22 @@ def flush() -> None: with write_scope as database: session: Session | None = getattr(database, "session", None) - for entity in entities: - chunk.append(entity) - if len(chunk) >= self.chunk_size: - flush() - if chunk: - flush() + with _maybe_progress( + self.show_progress, self.total_count, f"Indexing {self.config.entity_kind.value}" + ) as progress: + for entity in entities: + chunk.append(entity) + + if len(chunk) >= self.chunk_size: + flush() + if progress: + progress.update(self.chunk_size) + + if chunk: + flush() + if progress: + progress.update(len(chunk)) final_log_message = ( f"processed {total_records_processed} records and skipped {total_identical_records} identical records." diff --git a/orchestrator/search/indexing/registry.py b/orchestrator/search/indexing/registry.py index acf10f676..c2404bb6b 100644 --- a/orchestrator/search/indexing/registry.py +++ b/orchestrator/search/indexing/registry.py @@ -66,6 +66,21 @@ def get_title_from_fields(self, fields: list[ExtractedField]) -> str: return "UNKNOWN" +@dataclass(frozen=True) +class ProcessConfig(EntityConfig[ProcessTable]): + """Processes need to eager load workflow for workflow_name field.""" + + def get_all_query(self, entity_id: str | None = None) -> Query | Select: + from sqlalchemy.orm import selectinload + + # Only load workflow, not subscriptions (keeps it lightweight) + query = self.table.query.options(selectinload(ProcessTable.workflow)) + if entity_id: + pk_column = getattr(self.table, self.pk_name) + query = query.filter(pk_column == UUID(entity_id)) + return query + + @dataclass(frozen=True) class WorkflowConfig(EntityConfig[WorkflowTable]): """Workflows have a custom select() function that filters out deleted workflows.""" @@ -95,7 +110,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select: root_name="product", title_paths=["product.description", "product.name"], ), - EntityType.PROCESS: EntityConfig( + EntityType.PROCESS: ProcessConfig( entity_kind=EntityType.PROCESS, table=ProcessTable, traverser=ProcessTraverser, diff --git a/orchestrator/search/indexing/tasks.py b/orchestrator/search/indexing/tasks.py index f9eef1e5b..eb140f8a5 100644 --- a/orchestrator/search/indexing/tasks.py +++ b/orchestrator/search/indexing/tasks.py @@ -11,7 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import structlog +from sqlalchemy import func, select from sqlalchemy.orm import Query from orchestrator.db import db @@ -23,12 +26,20 @@ logger = structlog.get_logger(__name__) +def _get_entity_count(stmt: Any) -> int | None: + """Get total count of entities from a select statement.""" + + count_stmt = select(func.count()).select_from(stmt.subquery()) + return db.session.execute(count_stmt).scalar() + + def run_indexing_for_entity( entity_kind: EntityType, entity_id: str | None = None, dry_run: bool = False, force_index: bool = False, chunk_size: int = 1000, + show_progress: bool = False, ) -> None: """Stream and index entities for the given kind. @@ -46,6 +57,7 @@ def run_indexing_for_entity( existing hashes. chunk_size (int): Number of rows fetched per round-trip and passed to the indexer per batch. + show_progress (bool): When True, logs progress for each processed entity. Returns: None @@ -60,10 +72,19 @@ def run_indexing_for_entity( else: stmt = q + total_count = _get_entity_count(stmt) if show_progress else None + stmt = stmt.execution_options(stream_results=True, yield_per=chunk_size) entities = db.session.execute(stmt).scalars() - indexer = Indexer(config=config, dry_run=dry_run, force_index=force_index, chunk_size=chunk_size) + indexer = Indexer( + config=config, + dry_run=dry_run, + force_index=force_index, + chunk_size=chunk_size, + show_progress=show_progress, + total_count=total_count, + ) with cache_subscription_models(): indexer.run(entities) diff --git a/orchestrator/search/indexing/traverse.py b/orchestrator/search/indexing/traverse.py index cf440da83..a0dd0e5b8 100644 --- a/orchestrator/search/indexing/traverse.py +++ b/orchestrator/search/indexing/traverse.py @@ -29,7 +29,7 @@ from orchestrator.domain.lifecycle import ( lookup_specialized_type, ) -from orchestrator.schemas.process import ProcessSchema +from orchestrator.schemas.process import ProcessBaseSchema from orchestrator.schemas.workflow import WorkflowSchema from orchestrator.search.core.exceptions import ModelLoadError, ProductNotInRegistryError from orchestrator.search.core.types import LTREE_SEPARATOR, ExtractedField, FieldType @@ -307,17 +307,39 @@ def _load_model(cls, product: ProductTable) -> SubscriptionModel | None: class ProcessTraverser(BaseTraverser): - """Traverser for process entities using ProcessSchema model. + """Traverser for process entities using ProcessBaseSchema. - Note: Currently extracts only top-level process fields. Could be extended to include: - - Related subscriptions (entity.subscriptions) - - Related workflow information beyond workflow_name + Only indexes top-level process fields (no subscriptions or steps) + to keep the index size manageable. """ + EXCLUDED_FIELDS = {"traceback", "failed_reason"} + + @classmethod + def _load_model(cls, entity: ProcessTable) -> ProcessBaseSchema | None: + return cls._load_model_with_schema(entity, ProcessBaseSchema, "process_id") + @classmethod - def _load_model(cls, process: ProcessTable) -> ProcessSchema: - """Load process model using ProcessSchema.""" - return cls._load_model_with_schema(process, ProcessSchema, "process_id") + def get_fields(cls, entity: ProcessTable, pk_name: str, root_name: str) -> list[ExtractedField]: # type: ignore[override] + """Extract fields from process, excluding fields in EXCLUDED_FIELDS.""" + try: + model = cls._load_model(entity) + if model is None: + return [] + + return sorted( + ( + field + for field in cls.traverse(model, root_name) + if field.path.split(LTREE_SEPARATOR)[-1] not in cls.EXCLUDED_FIELDS + ), + key=lambda f: f.path, + ) + + except (ProductNotInRegistryError, ModelLoadError) as e: + entity_id = getattr(entity, pk_name, "unknown") + logger.error(f"Failed to extract fields from {entity.__class__.__name__}", id=str(entity_id), error=str(e)) + return [] class WorkflowTraverser(BaseTraverser): diff --git a/orchestrator/search/query/results.py b/orchestrator/search/query/results.py index 0d743e53a..d1165cfc6 100644 --- a/orchestrator/search/query/results.py +++ b/orchestrator/search/query/results.py @@ -139,6 +139,63 @@ def format_aggregation_response( ) +def truncate_text_with_highlights( + text: str, highlight_indices: list[tuple[int, int]] | None = None, max_length: int = 500, context_chars: int = 100 +) -> tuple[str, list[tuple[int, int]] | None]: + """Truncate text to max_length while preserving context around the first highlight. + + Args: + text: The text to truncate + highlight_indices: List of (start, end) tuples indicating highlight positions, or None + max_length: Maximum length of the returned text + context_chars: Number of characters to show before and after the first highlight + + Returns: + Tuple of (truncated_text, adjusted_highlight_indices) + """ + # If text is short enough, return as-is + if len(text) <= max_length: + return text, highlight_indices + + # If no highlights, truncate from beginning + if not highlight_indices: + truncated_text = text[:max_length] + suffix = "..." if len(text) > max_length else "" + return truncated_text + suffix, None + + # Use first highlight to determine what to show + first_highlight_start = highlight_indices[0][0] + + # Calculate start position: try to center around first highlight + start = max(0, first_highlight_start - context_chars) + end = min(len(text), start + max_length) + + # Adjust start if we hit the end boundary + if end == len(text) and (end - start) < max_length: + start = max(0, end - max_length) + + truncated_text = text[start:end] + + # Add ellipsis to indicate truncation + truncated_from_start = start > 0 + truncated_from_end = end < len(text) + + if truncated_from_start: + truncated_text = "..." + truncated_text + if truncated_from_end: + truncated_text = truncated_text + "..." + + # Adjust highlight indices to be relative to truncated text + offset = start - (3 if truncated_from_start else 0) # Account for leading "..." + adjusted_indices = [] + for hl_start, hl_end in highlight_indices: + # Only include highlights that are within the truncated range + if hl_start >= start and hl_end <= end: + adjusted_indices.append((hl_start - offset, hl_end - offset)) + + return truncated_text, adjusted_indices if adjusted_indices else None + + def generate_highlight_indices(text: str, term: str) -> list[tuple[int, int]]: """Finds all occurrences of individual words from the term, including both word boundary and substring matches.""" import re @@ -201,8 +258,9 @@ def format_search_response( if not isinstance(path, str): path = str(path) - highlight_indices = generate_highlight_indices(text, user_query) or None - matching_field = MatchingField(text=text, path=path, highlight_indices=highlight_indices) + highlight_indices = generate_highlight_indices(text, user_query) + truncated_text, adjusted_indices = truncate_text_with_highlights(text, highlight_indices) + matching_field = MatchingField(text=truncated_text, path=path, highlight_indices=adjusted_indices) elif not user_query and query.filters and metadata.search_type == "structured": # Structured search (filter-only) diff --git a/orchestrator/search/retrieval/retrievers/__init__.py b/orchestrator/search/retrieval/retrievers/__init__.py index 8bedd22be..01aa78bb6 100644 --- a/orchestrator/search/retrieval/retrievers/__init__.py +++ b/orchestrator/search/retrieval/retrievers/__init__.py @@ -14,6 +14,7 @@ from .base import Retriever from .fuzzy import FuzzyRetriever from .hybrid import RrfHybridRetriever +from .process import ProcessHybridRetriever from .semantic import SemanticRetriever from .structured import StructuredRetriever @@ -21,6 +22,7 @@ "Retriever", "FuzzyRetriever", "RrfHybridRetriever", + "ProcessHybridRetriever", "SemanticRetriever", "StructuredRetriever", ] diff --git a/orchestrator/search/retrieval/retrievers/base.py b/orchestrator/search/retrieval/retrievers/base.py index 58e433503..be6df5691 100644 --- a/orchestrator/search/retrieval/retrievers/base.py +++ b/orchestrator/search/retrieval/retrievers/base.py @@ -17,7 +17,7 @@ import structlog from sqlalchemy import BindParameter, Numeric, Select, literal -from orchestrator.search.core.types import FieldType, SearchMetadata +from orchestrator.search.core.types import EntityType, FieldType, SearchMetadata from orchestrator.search.query.queries import ExportQuery, SelectQuery from ..pagination import PageCursor @@ -63,12 +63,15 @@ def route( Returns: A concrete retriever instance based on available search criteria """ + from .fuzzy import FuzzyRetriever from .hybrid import RrfHybridRetriever + from .process import ProcessHybridRetriever from .semantic import SemanticRetriever from .structured import StructuredRetriever fuzzy_term = query.fuzzy_term + is_process = query.entity_type == EntityType.PROCESS # If vector_query exists but embedding generation failed, fall back to fuzzy search with full query text if query_embedding is None and query.vector_query is not None and query.query_text is not None: @@ -76,10 +79,14 @@ def route( # Select retriever based on available search criteria if query_embedding is not None and fuzzy_term is not None: + if is_process: + return ProcessHybridRetriever(query_embedding, fuzzy_term, cursor) return RrfHybridRetriever(query_embedding, fuzzy_term, cursor) if query_embedding is not None: return SemanticRetriever(query_embedding, cursor) if fuzzy_term is not None: + if is_process: + return ProcessHybridRetriever(None, fuzzy_term, cursor) return FuzzyRetriever(fuzzy_term, cursor) return StructuredRetriever(cursor) diff --git a/orchestrator/search/retrieval/retrievers/process.py b/orchestrator/search/retrieval/retrievers/process.py new file mode 100644 index 000000000..699ce5b04 --- /dev/null +++ b/orchestrator/search/retrieval/retrievers/process.py @@ -0,0 +1,225 @@ +# Copyright 2019-2025 SURF, GÉANT. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from sqlalchemy import BindParameter, Select, String, and_, case, cast, func, literal, select +from sqlalchemy.sql.expression import ColumnElement, Label +from sqlalchemy_utils import LtreeType + +from orchestrator.db.models import AiSearchIndex, ProcessStepTable +from orchestrator.search.core.types import SearchMetadata + +from .hybrid import RrfHybridRetriever, compute_rrf_hybrid_score_sql + + +class ProcessHybridRetriever(RrfHybridRetriever): + """Process-specific hybrid retriever with process.last_step JSONB search. + + Extends RrfHybridRetriever to include fuzzy search over the process.last_step + (JSONB) column. For process searches: + - Indexed fields (from AiSearchIndex): semantic + fuzzy search + - Last step JSONB field: fuzzy search only (no embeddings for dynamic data) + + The retriever: + 1. Gets field candidates from AiSearchIndex + 2. Uses process.last_step JSONB column directly for fuzzy matching + 3. Combines both sources (indexed + JSONB) in unified ranking + """ + + q_vec: list[float] | None # type: ignore[assignment] # Override parent's type to allow None for fuzzy-only search + + def __init__(self, q_vec: list[float] | None, *args: Any, **kwargs: Any) -> None: + # ProcessHybridRetriever accepts None for q_vec (fuzzy-only search) + # We pass empty list to parent to satisfy type requirements, but override behavior in _get_semantic_distance_expr + super().__init__(q_vec or [], *args, **kwargs) + self.q_vec = q_vec + + def _get_semantic_distance_expr(self) -> Label[Any]: + """Get semantic distance expression, handling optional q_vec.""" + if self.q_vec is None: + return literal(1.0).label("semantic_distance") + + from sqlalchemy import bindparam + + q_param: BindParameter[list[float]] = bindparam("q_vec", type_=AiSearchIndex.embedding.type) + sem_expr = case( + (AiSearchIndex.embedding.is_(None), None), + else_=AiSearchIndex.embedding.op("<->")(q_param), + ) + return func.coalesce(sem_expr, literal(1.0)).label("semantic_distance") + + def _build_indexed_candidates( + self, cand: Any, sem_val: Label[Any], best_similarity: ColumnElement[Any], filter_condition: ColumnElement[Any] + ) -> Select: + """Build candidates from indexed fields in AiSearchIndex.""" + return ( + select( + AiSearchIndex.entity_id, + AiSearchIndex.entity_title, + AiSearchIndex.path, + AiSearchIndex.value, + sem_val, + best_similarity.label("fuzzy_score"), + ) + .select_from(AiSearchIndex) + .join(cand, cand.c.entity_id == AiSearchIndex.entity_id) + .where( + and_( + AiSearchIndex.value_type.in_(self.SEARCHABLE_FIELD_TYPES), + filter_condition, + ) + ) + .limit(self.field_candidates_limit) + ) + + def _build_jsonb_candidates(self, cand: Any) -> Select: + """Build candidates from last process_step.state JSONB column.""" + # Get the last step per process using LATERAL subquery + last_step_subq = ( + select(ProcessStepTable.process_id, ProcessStepTable.state) + .where(ProcessStepTable.process_id == cand.c.entity_id) + .order_by(ProcessStepTable.completed_at.desc()) + .limit(1) + .lateral("last_step") + ) + + # Cast JSONB to text for substring search + state_text = cast(last_step_subq.c.state, String) + jsonb_fuzzy_score = func.word_similarity(self.fuzzy_term, state_text) + jsonb_filter = state_text.ilike(f"%{self.fuzzy_term}%") + + return ( + select( + cand.c.entity_id, + cand.c.entity_title, + cast(literal("process.last_step.state"), LtreeType).label("path"), + state_text.label("value"), + literal(1.0).label("semantic_distance"), + jsonb_fuzzy_score.label("fuzzy_score"), + ) + .select_from(cand) + .join(last_step_subq, literal(True)) + .where(and_(last_step_subq.c.state.isnot(None), jsonb_filter)) + .limit(self.field_candidates_limit) + ) + + def apply(self, candidate_query: Select) -> Select: + """Apply process-specific hybrid search with process.last_step JSONB. + + Args: + candidate_query: Base query returning process entity_id candidates + + Returns: + Select statement with RRF scoring including last step JSONB fields + """ + cand = candidate_query.subquery() + + best_similarity = func.word_similarity(self.fuzzy_term, AiSearchIndex.value) + sem_val = self._get_semantic_distance_expr() + filter_condition = literal(self.fuzzy_term).op("<%")(AiSearchIndex.value) + + indexed_candidates = self._build_indexed_candidates(cand, sem_val, best_similarity, filter_condition) + jsonb_candidates = self._build_jsonb_candidates(cand) + + field_candidates = indexed_candidates.union_all(jsonb_candidates).cte("field_candidates") + + entity_scores = ( + select( + field_candidates.c.entity_id, + field_candidates.c.entity_title, + func.avg(field_candidates.c.semantic_distance).label("avg_semantic_distance"), + func.avg(field_candidates.c.fuzzy_score).label("avg_fuzzy_score"), + ).group_by(field_candidates.c.entity_id, field_candidates.c.entity_title) + ).cte("entity_scores") + + entity_highlights = ( + select( + field_candidates.c.entity_id, + func.first_value(field_candidates.c.value) + .over( + partition_by=field_candidates.c.entity_id, + order_by=[field_candidates.c.fuzzy_score.desc(), field_candidates.c.path.asc()], + ) + .label(self.HIGHLIGHT_TEXT_LABEL), + func.first_value(field_candidates.c.path) + .over( + partition_by=field_candidates.c.entity_id, + order_by=[field_candidates.c.fuzzy_score.desc(), field_candidates.c.path.asc()], + ) + .label(self.HIGHLIGHT_PATH_LABEL), + ).distinct(field_candidates.c.entity_id) + ).cte("entity_highlights") + + ranked = ( + select( + entity_scores.c.entity_id, + entity_scores.c.entity_title, + entity_scores.c.avg_semantic_distance, + entity_scores.c.avg_fuzzy_score, + entity_highlights.c.highlight_text, + entity_highlights.c.highlight_path, + func.dense_rank() + .over( + order_by=[entity_scores.c.avg_semantic_distance.asc().nulls_last(), entity_scores.c.entity_id.asc()] + ) + .label("sem_rank"), + func.dense_rank() + .over(order_by=[entity_scores.c.avg_fuzzy_score.desc().nulls_last(), entity_scores.c.entity_id.asc()]) + .label("fuzzy_rank"), + ).select_from( + entity_scores.join(entity_highlights, entity_scores.c.entity_id == entity_highlights.c.entity_id) + ) + ).cte("ranked_results") + + score_components = compute_rrf_hybrid_score_sql( + sem_rank_col=ranked.c.sem_rank, + fuzzy_rank_col=ranked.c.fuzzy_rank, + avg_fuzzy_score_col=ranked.c.avg_fuzzy_score, + k=self.k, + perfect_threshold=0.9, + score_numeric_type=self.SCORE_NUMERIC_TYPE, + ) + + perfect = score_components["perfect"] + normalized_score = score_components["normalized_score"] + + score = cast( + func.round(cast(normalized_score, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION), + self.SCORE_NUMERIC_TYPE, + ).label(self.SCORE_LABEL) + + stmt = select( + ranked.c.entity_id, + ranked.c.entity_title, + score, + ranked.c.highlight_text, + ranked.c.highlight_path, + perfect.label("perfect_match"), + ).select_from(ranked) + + stmt = self._apply_fused_pagination(stmt, score, ranked.c.entity_id) + + stmt = stmt.order_by( + score.desc().nulls_last(), + ranked.c.entity_id.asc(), + ) + + if self.q_vec is not None: + stmt = stmt.params(q_vec=self.q_vec) + + return stmt + + @property + def metadata(self) -> SearchMetadata: + return SearchMetadata.hybrid() diff --git a/test/unit_tests/search/retrieval/test_utils.py b/test/unit_tests/search/retrieval/test_utils.py index d52994472..bc6f1e3cb 100644 --- a/test/unit_tests/search/retrieval/test_utils.py +++ b/test/unit_tests/search/retrieval/test_utils.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from orchestrator.search.query.results import generate_highlight_indices +from orchestrator.search.query.results import generate_highlight_indices, truncate_text_with_highlights class TestGenerateHighlightIndices: @@ -61,3 +61,87 @@ def test_word_and_substring_matches_included(self): term = "cat" result = generate_highlight_indices(text, term) assert result == [(4, 7), (19, 22)] + + +class TestTruncateTextWithHighlights: + def test_text_shorter_than_max_length(self): + """Text shorter than max_length should not be truncated.""" + text = "Short text" + highlights = [(0, 5)] + result_text, result_highlights = truncate_text_with_highlights(text, highlights, max_length=100) + assert result_text == "Short text" + assert result_highlights == [(0, 5)] + + def test_no_highlights_truncates_from_start(self): + """Text with no highlights should truncate from beginning and add ellipsis.""" + text = "a" * 600 + result_text, result_highlights = truncate_text_with_highlights(text, None, max_length=500) + assert result_text == ("a" * 500) + "..." + assert result_highlights is None + + def test_truncate_with_highlight_at_start(self): + """Highlight at the start should not add leading ellipsis.""" + text = "match" + ("x" * 600) + highlights = [(0, 5)] + result_text, result_highlights = truncate_text_with_highlights(text, highlights, max_length=500) + assert result_text == ("match" + ("x" * 495)) + "..." + assert result_highlights == [(0, 5)] + + def test_truncate_with_highlight_in_middle(self): + """Highlight in middle should add ellipsis on both sides.""" + text = ("x" * 300) + "match" + ("y" * 300) + highlights = [(300, 305)] + result_text, result_highlights = truncate_text_with_highlights( + text, highlights, max_length=200, context_chars=50 + ) + # Should center around position 300 with 50 chars context + # start = max(0, 300 - 50) = 250 + # end = min(605, 250 + 200) = 450 + expected_text = "..." + ("x" * 50) + "match" + ("y" * 145) + "..." + assert result_text == expected_text + # Highlight at position 300 becomes position 53 (accounting for leading "...") + assert result_highlights == [(53, 58)] + + def test_truncate_with_highlight_at_end(self): + """Highlight near the end should not add trailing ellipsis.""" + text = ("x" * 600) + "match" + highlights = [(600, 605)] + result_text, result_highlights = truncate_text_with_highlights( + text, highlights, max_length=200, context_chars=50 + ) + # Should show context before match + # First attempt: start = max(0, 600 - 50) = 550, end = min(605, 550 + 200) = 605 + # Since end == len(text), adjust: start = max(0, 605 - 200) = 405 + expected_text = "..." + ("x" * 195) + "match" + assert result_text == expected_text + assert result_highlights == [(198, 203)] + + def test_multiple_highlights_uses_first(self): + """Multiple highlights should center around the first one, excluding highlights outside range.""" + text = ("x" * 100) + "first" + ("y" * 100) + "second" + ("z" * 100) + highlights = [(100, 105), (205, 211)] + result_text, result_highlights = truncate_text_with_highlights( + text, highlights, max_length=150, context_chars=50 + ) + # Should center around first highlight at position 100 + # start = max(0, 100 - 50) = 50 + # end = min(312, 50 + 150) = 200 + expected_text = "..." + ("x" * 50) + "first" + ("y" * 95) + "..." + assert result_text == expected_text + # First highlight at 100 becomes 53 (with leading "...") + # Second highlight at 205 is outside truncated range, so not included + assert result_highlights == [(53, 58)] + + def test_no_highlights_remain_after_truncation(self): + """If all highlights are outside truncated range, should return None.""" + text = "x" * 1000 + highlights = [(900, 905)] + result_text, result_highlights = truncate_text_with_highlights( + text, highlights, max_length=100, context_chars=10 + ) + # Should center around highlight at 900 + # start = max(0, 900 - 10) = 890 + # end = min(1000, 890 + 100) = 990 + expected_text = "..." + ("x" * 100) + "..." + assert result_text == expected_text + assert result_highlights == [(13, 18)] # Highlight at 900 becomes 13 (890 start + 3 for "...") diff --git a/test/unit_tests/search/test_traverser_exceptions.py b/test/unit_tests/search/test_traverser_exceptions.py index f14a386bc..69d1512ba 100644 --- a/test/unit_tests/search/test_traverser_exceptions.py +++ b/test/unit_tests/search/test_traverser_exceptions.py @@ -56,12 +56,12 @@ def test_subscription_traverser_model_load_error(self): SubscriptionTraverser._load_model(mock_subscription) def test_process_traverser_model_load_error(self): - """Test ProcessTraverser raises ModelLoadError when ProcessSchema validation fails.""" + """Test ProcessTraverser raises ModelLoadError when ProcessBaseSchema validation fails.""" # Create an invalid process that will fail validation mock_process = MagicMock(spec=ProcessTable) mock_process.process_id = "invalid-uuid" - with pytest.raises(ModelLoadError, match="Failed to load ProcessSchema for process_id"): + with pytest.raises(ModelLoadError, match="Failed to load ProcessBaseSchema for process_id"): ProcessTraverser._load_model(mock_process) def test_get_fields_handles_product_not_in_registry(self, caplog):