Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions orchestrator/cli/search/index_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def subscriptions_command(
subscription_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
show_progress: bool = typer.Option(False, help="Show per-entity progress"),
) -> None:
"""Index subscription_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.SUBSCRIPTION,
entity_id=subscription_id,
dry_run=dry_run,
force_index=force_index,
show_progress=show_progress,
)


Expand All @@ -29,13 +31,15 @@ def products_command(
product_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
show_progress: bool = typer.Option(False, help="Show per-entity progress"),
) -> None:
"""Index product_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.PRODUCT,
entity_id=product_id,
dry_run=dry_run,
force_index=force_index,
show_progress=show_progress,
)


Expand All @@ -44,13 +48,15 @@ def processes_command(
process_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
show_progress: bool = typer.Option(False, help="Show per-entity progress"),
) -> None:
"""Index process_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.PROCESS,
entity_id=process_id,
dry_run=dry_run,
force_index=force_index,
show_progress=show_progress,
)


Expand All @@ -59,13 +65,15 @@ def workflows_command(
workflow_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
show_progress: bool = typer.Option(False, help="Show per-entity progress"),
) -> None:
"""Index workflow_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.WORKFLOW,
entity_id=workflow_id,
dry_run=dry_run,
force_index=force_index,
show_progress=show_progress,
)


Expand Down
50 changes: 43 additions & 7 deletions orchestrator/search/indexing/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ def _maybe_begin(session: Session | None) -> Iterator[None]:
yield


@contextmanager
def _maybe_progress(show_progress: bool, total_count: int | None, label: str) -> Iterator[Any]:
"""Context manager that optionally creates a progress bar."""
if show_progress:
import typer

with typer.progressbar(
length=total_count,
label=label,
show_eta=True,
show_percent=bool(total_count),
) as progress:
yield progress
else:
yield None


class Indexer:
"""Index entities into `AiSearchIndex` using streaming reads and batched writes.

Expand Down Expand Up @@ -89,11 +106,21 @@ class Indexer:
8) Repeat until the stream is exhausted.
"""

def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk_size: int = 1000) -> None:
def __init__(
self,
config: EntityConfig,
dry_run: bool,
force_index: bool,
chunk_size: int = 1000,
show_progress: bool = False,
total_count: int | None = None,
) -> None:
self.config = config
self.dry_run = dry_run
self.force_index = force_index
self.chunk_size = chunk_size
self.show_progress = show_progress
self.total_count = total_count
self.embedding_model = llm_settings.EMBEDDING_MODEL
self.logger = logger.bind(entity_kind=config.entity_kind.value)
self._entity_titles: dict[str, str] = {}
Expand All @@ -116,13 +143,22 @@ def flush() -> None:

with write_scope as database:
session: Session | None = getattr(database, "session", None)
for entity in entities:
chunk.append(entity)
if len(chunk) >= self.chunk_size:
flush()

if chunk:
flush()
with _maybe_progress(
self.show_progress, self.total_count, f"Indexing {self.config.entity_kind.value}"
) as progress:
for entity in entities:
chunk.append(entity)

if len(chunk) >= self.chunk_size:
flush()
if progress:
progress.update(self.chunk_size)

if chunk:
flush()
if progress:
progress.update(len(chunk))

final_log_message = (
f"processed {total_records_processed} records and skipped {total_identical_records} identical records."
Expand Down
17 changes: 16 additions & 1 deletion orchestrator/search/indexing/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ def get_title_from_fields(self, fields: list[ExtractedField]) -> str:
return "UNKNOWN"


@dataclass(frozen=True)
class ProcessConfig(EntityConfig[ProcessTable]):
"""Processes need to eager load workflow for workflow_name field."""

def get_all_query(self, entity_id: str | None = None) -> Query | Select:
from sqlalchemy.orm import selectinload

# Only load workflow, not subscriptions (keeps it lightweight)
query = self.table.query.options(selectinload(ProcessTable.workflow))
if entity_id:
pk_column = getattr(self.table, self.pk_name)
query = query.filter(pk_column == UUID(entity_id))
return query


@dataclass(frozen=True)
class WorkflowConfig(EntityConfig[WorkflowTable]):
"""Workflows have a custom select() function that filters out deleted workflows."""
Expand Down Expand Up @@ -95,7 +110,7 @@ def get_all_query(self, entity_id: str | None = None) -> Select:
root_name="product",
title_paths=["product.description", "product.name"],
),
EntityType.PROCESS: EntityConfig(
EntityType.PROCESS: ProcessConfig(
entity_kind=EntityType.PROCESS,
table=ProcessTable,
traverser=ProcessTraverser,
Expand Down
23 changes: 22 additions & 1 deletion orchestrator/search/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import structlog
from sqlalchemy import func, select
from sqlalchemy.orm import Query

from orchestrator.db import db
Expand All @@ -23,12 +26,20 @@
logger = structlog.get_logger(__name__)


def _get_entity_count(stmt: Any) -> int | None:
"""Get total count of entities from a select statement."""

count_stmt = select(func.count()).select_from(stmt.subquery())
return db.session.execute(count_stmt).scalar()


def run_indexing_for_entity(
entity_kind: EntityType,
entity_id: str | None = None,
dry_run: bool = False,
force_index: bool = False,
chunk_size: int = 1000,
show_progress: bool = False,
) -> None:
"""Stream and index entities for the given kind.

Expand All @@ -46,6 +57,7 @@ def run_indexing_for_entity(
existing hashes.
chunk_size (int): Number of rows fetched per round-trip and passed to
the indexer per batch.
show_progress (bool): When True, logs progress for each processed entity.

Returns:
None
Expand All @@ -60,10 +72,19 @@ def run_indexing_for_entity(
else:
stmt = q

total_count = _get_entity_count(stmt) if show_progress else None

stmt = stmt.execution_options(stream_results=True, yield_per=chunk_size)
entities = db.session.execute(stmt).scalars()

indexer = Indexer(config=config, dry_run=dry_run, force_index=force_index, chunk_size=chunk_size)
indexer = Indexer(
config=config,
dry_run=dry_run,
force_index=force_index,
chunk_size=chunk_size,
show_progress=show_progress,
total_count=total_count,
)

with cache_subscription_models():
indexer.run(entities)
38 changes: 30 additions & 8 deletions orchestrator/search/indexing/traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from orchestrator.domain.lifecycle import (
lookup_specialized_type,
)
from orchestrator.schemas.process import ProcessSchema
from orchestrator.schemas.process import ProcessBaseSchema
from orchestrator.schemas.workflow import WorkflowSchema
from orchestrator.search.core.exceptions import ModelLoadError, ProductNotInRegistryError
from orchestrator.search.core.types import LTREE_SEPARATOR, ExtractedField, FieldType
Expand Down Expand Up @@ -307,17 +307,39 @@ def _load_model(cls, product: ProductTable) -> SubscriptionModel | None:


class ProcessTraverser(BaseTraverser):
"""Traverser for process entities using ProcessSchema model.
"""Traverser for process entities using ProcessBaseSchema.

Note: Currently extracts only top-level process fields. Could be extended to include:
- Related subscriptions (entity.subscriptions)
- Related workflow information beyond workflow_name
Only indexes top-level process fields (no subscriptions or steps)
to keep the index size manageable.
"""

EXCLUDED_FIELDS = {"traceback", "failed_reason"}

@classmethod
def _load_model(cls, entity: ProcessTable) -> ProcessBaseSchema | None:
return cls._load_model_with_schema(entity, ProcessBaseSchema, "process_id")

@classmethod
def _load_model(cls, process: ProcessTable) -> ProcessSchema:
"""Load process model using ProcessSchema."""
return cls._load_model_with_schema(process, ProcessSchema, "process_id")
def get_fields(cls, entity: ProcessTable, pk_name: str, root_name: str) -> list[ExtractedField]: # type: ignore[override]
"""Extract fields from process, excluding fields in EXCLUDED_FIELDS."""
try:
model = cls._load_model(entity)
if model is None:
return []

return sorted(
(
field
for field in cls.traverse(model, root_name)
if field.path.split(LTREE_SEPARATOR)[-1] not in cls.EXCLUDED_FIELDS
),
key=lambda f: f.path,
)

except (ProductNotInRegistryError, ModelLoadError) as e:
entity_id = getattr(entity, pk_name, "unknown")
logger.error(f"Failed to extract fields from {entity.__class__.__name__}", id=str(entity_id), error=str(e))
return []


class WorkflowTraverser(BaseTraverser):
Expand Down
62 changes: 60 additions & 2 deletions orchestrator/search/query/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,63 @@ def format_aggregation_response(
)


def truncate_text_with_highlights(
text: str, highlight_indices: list[tuple[int, int]] | None = None, max_length: int = 500, context_chars: int = 100
) -> tuple[str, list[tuple[int, int]] | None]:
"""Truncate text to max_length while preserving context around the first highlight.

Args:
text: The text to truncate
highlight_indices: List of (start, end) tuples indicating highlight positions, or None
max_length: Maximum length of the returned text
context_chars: Number of characters to show before and after the first highlight

Returns:
Tuple of (truncated_text, adjusted_highlight_indices)
"""
# If text is short enough, return as-is
if len(text) <= max_length:
return text, highlight_indices

# If no highlights, truncate from beginning
if not highlight_indices:
truncated_text = text[:max_length]
suffix = "..." if len(text) > max_length else ""
return truncated_text + suffix, None

# Use first highlight to determine what to show
first_highlight_start = highlight_indices[0][0]

# Calculate start position: try to center around first highlight
start = max(0, first_highlight_start - context_chars)
end = min(len(text), start + max_length)

# Adjust start if we hit the end boundary
if end == len(text) and (end - start) < max_length:
start = max(0, end - max_length)

truncated_text = text[start:end]

# Add ellipsis to indicate truncation
truncated_from_start = start > 0
truncated_from_end = end < len(text)

if truncated_from_start:
truncated_text = "..." + truncated_text
if truncated_from_end:
truncated_text = truncated_text + "..."

# Adjust highlight indices to be relative to truncated text
offset = start - (3 if truncated_from_start else 0) # Account for leading "..."
adjusted_indices = []
for hl_start, hl_end in highlight_indices:
# Only include highlights that are within the truncated range
if hl_start >= start and hl_end <= end:
adjusted_indices.append((hl_start - offset, hl_end - offset))

return truncated_text, adjusted_indices if adjusted_indices else None


def generate_highlight_indices(text: str, term: str) -> list[tuple[int, int]]:
"""Finds all occurrences of individual words from the term, including both word boundary and substring matches."""
import re
Expand Down Expand Up @@ -201,8 +258,9 @@ def format_search_response(
if not isinstance(path, str):
path = str(path)

highlight_indices = generate_highlight_indices(text, user_query) or None
matching_field = MatchingField(text=text, path=path, highlight_indices=highlight_indices)
highlight_indices = generate_highlight_indices(text, user_query)
truncated_text, adjusted_indices = truncate_text_with_highlights(text, highlight_indices)
matching_field = MatchingField(text=truncated_text, path=path, highlight_indices=adjusted_indices)

elif not user_query and query.filters and metadata.search_type == "structured":
# Structured search (filter-only)
Expand Down
2 changes: 2 additions & 0 deletions orchestrator/search/retrieval/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from .base import Retriever
from .fuzzy import FuzzyRetriever
from .hybrid import RrfHybridRetriever
from .process import ProcessHybridRetriever
from .semantic import SemanticRetriever
from .structured import StructuredRetriever

__all__ = [
"Retriever",
"FuzzyRetriever",
"RrfHybridRetriever",
"ProcessHybridRetriever",
"SemanticRetriever",
"StructuredRetriever",
]
Loading
Loading