From cf2cb47809cadbcd76fb6539e58bf39bcc551cdf Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Mon, 1 Jun 2026 09:47:18 +0200 Subject: [PATCH 1/3] =?UTF-8?q?feat(api,db):=20forecast=20champion=20selec?= =?UTF-8?q?tor=20slice=20B=20=E2=80=94=20async=20comparison=20&=20results?= =?UTF-8?q?=20(#360)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 10 + ..._model_selection_candidate_and_progress.py | 185 +++++++ app/core/config.py | 11 + app/core/tests/test_config.py | 9 + app/features/model_selection/models.py | 87 +++- app/features/model_selection/routes.py | 75 ++- app/features/model_selection/runner.py | 312 +++++++++++ app/features/model_selection/schemas.py | 53 +- app/features/model_selection/service.py | 492 +++++++++++++++++- .../tests/test_async_routes.py | 180 +++++++ .../model_selection/tests/test_models.py | 41 +- .../tests/test_routes_integration.py | 134 +++++ .../model_selection/tests/test_runner.py | 238 +++++++++ .../model_selection/tests/test_schemas.py | 82 +++ .../model_selection/tests/test_service.py | 147 +++++- .../results/cancel-run-dialog.test.tsx | 33 ++ .../results/cancel-run-dialog.tsx | 62 +++ .../results/comparison-charts.test.tsx | 36 ++ .../results/comparison-charts.tsx | 105 ++++ .../champion-selector/results/constants.ts | 17 + .../results/model-detail-drawer.test.tsx | 43 ++ .../results/model-detail-drawer.tsx | 79 +++ .../results/ranking-table.test.tsx | 50 ++ .../results/ranking-table.tsx | 90 ++++ .../results/run-progress-panel.test.tsx | 57 ++ .../results/run-progress-panel.tsx | 87 ++++ .../results/winner-card.test.tsx | 40 ++ .../champion-selector/results/winner-card.tsx | 100 ++++ .../src/hooks/use-model-selection.test.ts | 153 +++++- frontend/src/hooks/use-model-selection.ts | 73 ++- .../src/pages/visualize/champion.test.tsx | 4 + frontend/src/pages/visualize/champion.tsx | 128 ++++- frontend/src/types/api.ts | 38 ++ 33 files changed, 3206 insertions(+), 45 deletions(-) create mode 100644 alembic/versions/d3e4f5a6b7c8_add_model_selection_candidate_and_progress.py create mode 100644 app/features/model_selection/runner.py create mode 100644 app/features/model_selection/tests/test_async_routes.py create mode 100644 app/features/model_selection/tests/test_runner.py create mode 100644 frontend/src/components/champion-selector/results/cancel-run-dialog.test.tsx create mode 100644 frontend/src/components/champion-selector/results/cancel-run-dialog.tsx create mode 100644 frontend/src/components/champion-selector/results/comparison-charts.test.tsx create mode 100644 frontend/src/components/champion-selector/results/comparison-charts.tsx create mode 100644 frontend/src/components/champion-selector/results/constants.ts create mode 100644 frontend/src/components/champion-selector/results/model-detail-drawer.test.tsx create mode 100644 frontend/src/components/champion-selector/results/model-detail-drawer.tsx create mode 100644 frontend/src/components/champion-selector/results/ranking-table.test.tsx create mode 100644 frontend/src/components/champion-selector/results/ranking-table.tsx create mode 100644 frontend/src/components/champion-selector/results/run-progress-panel.test.tsx create mode 100644 frontend/src/components/champion-selector/results/run-progress-panel.tsx create mode 100644 frontend/src/components/champion-selector/results/winner-card.test.tsx create mode 100644 frontend/src/components/champion-selector/results/winner-card.tsx diff --git a/.env.example b/.env.example index 7d49f5b9..38ef75b4 100644 --- a/.env.example +++ b/.env.example @@ -126,5 +126,15 @@ BATCH_GLOBAL_MAX_PARALLEL=4 # mid-call, so a long fit can stall the drain. BATCH_CANCEL_DRAIN_TIMEOUT_SECONDS=30 +# Model selection (champion selector) async runner (Slice B) +# Hard upper bound on concurrent candidate backtests across all active selection +# runs on this host. Effective parallelism per run is min(this, candidates). +# Set to 1 for sequential execution. Requires uvicorn restart to apply. +MODEL_SELECTION_GLOBAL_MAX_PARALLEL=4 +# Max seconds DELETE /model-selection/{id} waits for in-flight candidates to +# drain before returning RFC 7807 504. sklearn / LightGBM fits are uncancellable +# mid-call, so a long fit can stall the drain. +MODEL_SELECTION_CANCEL_DRAIN_TIMEOUT_SECONDS=30 + # Frontend (Vite) VITE_API_BASE_URL=http://localhost:8123 diff --git a/alembic/versions/d3e4f5a6b7c8_add_model_selection_candidate_and_progress.py b/alembic/versions/d3e4f5a6b7c8_add_model_selection_candidate_and_progress.py new file mode 100644 index 00000000..c510c5ef --- /dev/null +++ b/alembic/versions/d3e4f5a6b7c8_add_model_selection_candidate_and_progress.py @@ -0,0 +1,185 @@ +"""add model_selection_candidate and async progress columns + +Revision ID: d3e4f5a6b7c8 +Revises: b667d321603c +Create Date: 2026-06-01 09:30:00.000000 + +Slice B of the Forecast Champion Selector (issue #360). Converts the selection +run into a DB-backed async LRO: + +- creates ``model_selection_candidate`` (one row per candidate, FK CASCADE to + ``model_selection_run.selection_id``) carrying per-candidate status, result + JSONB, error, and timing — the live-progress + audit surface; +- adds ``started_at`` + the four final count columns to ``model_selection_run``; +- widens the run status CheckConstraint to include ``'cancelled'`` (forward-only + drop + recreate of the named constraint). + +Mirrors ``c1d2e3f40512_create_batch_tables`` for JSONB / index / FK style. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d3e4f5a6b7c8" +down_revision: str | None = "b667d321603c" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +_OLD_RUN_STATUS = "status IN ('pending', 'running', 'completed', 'partial', 'failed')" +_NEW_RUN_STATUS = ( + "status IN ('pending', 'running', 'completed', 'partial', 'failed', 'cancelled')" +) + + +def upgrade() -> None: + """Apply migration.""" + # ------------------------------------------------------------------ + # 1. Widen the run status CheckConstraint to include 'cancelled'. + # ------------------------------------------------------------------ + op.drop_constraint( + "ck_model_selection_run_valid_status", + "model_selection_run", + type_="check", + ) + op.create_check_constraint( + "ck_model_selection_run_valid_status", + "model_selection_run", + _NEW_RUN_STATUS, + ) + + # ------------------------------------------------------------------ + # 2. Additive progress columns on the parent run. + # ------------------------------------------------------------------ + op.add_column( + "model_selection_run", + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + ) + op.add_column( + "model_selection_run", + sa.Column("total_candidates", sa.Integer(), nullable=False, server_default="0"), + ) + op.add_column( + "model_selection_run", + sa.Column( + "completed_candidates", sa.Integer(), nullable=False, server_default="0" + ), + ) + op.add_column( + "model_selection_run", + sa.Column("failed_candidates", sa.Integer(), nullable=False, server_default="0"), + ) + op.add_column( + "model_selection_run", + sa.Column( + "cancelled_candidates", sa.Integer(), nullable=False, server_default="0" + ), + ) + + # ------------------------------------------------------------------ + # 3. Per-candidate execution child table (FK CASCADE on selection_id). + # ------------------------------------------------------------------ + op.create_table( + "model_selection_candidate", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("candidate_id", sa.String(length=32), nullable=False), + sa.Column("selection_id", sa.String(length=32), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False), + sa.Column("model_type", sa.String(length=40), nullable=False), + sa.Column("params", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("result", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("error_message", sa.String(length=2000), nullable=True), + sa.Column("error_type", sa.String(length=100), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("duration_ms", sa.Integer(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'cancelled')", + name="ck_model_selection_candidate_valid_status", + ), + sa.ForeignKeyConstraint( + ["selection_id"], + ["model_selection_run.selection_id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_model_selection_candidate_candidate_id"), + "model_selection_candidate", + ["candidate_id"], + unique=True, + ) + op.create_index( + op.f("ix_model_selection_candidate_selection_id"), + "model_selection_candidate", + ["selection_id"], + unique=False, + ) + op.create_index( + op.f("ix_model_selection_candidate_status"), + "model_selection_candidate", + ["status"], + unique=False, + ) + op.create_index( + "ix_model_selection_candidate_selection_status", + "model_selection_candidate", + ["selection_id", "status"], + unique=False, + ) + + +def downgrade() -> None: + """Revert migration.""" + op.drop_index( + "ix_model_selection_candidate_selection_status", + table_name="model_selection_candidate", + ) + op.drop_index( + op.f("ix_model_selection_candidate_status"), + table_name="model_selection_candidate", + ) + op.drop_index( + op.f("ix_model_selection_candidate_selection_id"), + table_name="model_selection_candidate", + ) + op.drop_index( + op.f("ix_model_selection_candidate_candidate_id"), + table_name="model_selection_candidate", + ) + op.drop_table("model_selection_candidate") + + op.drop_column("model_selection_run", "cancelled_candidates") + op.drop_column("model_selection_run", "failed_candidates") + op.drop_column("model_selection_run", "completed_candidates") + op.drop_column("model_selection_run", "total_candidates") + op.drop_column("model_selection_run", "started_at") + + op.drop_constraint( + "ck_model_selection_run_valid_status", + "model_selection_run", + type_="check", + ) + op.create_check_constraint( + "ck_model_selection_run_valid_status", + "model_selection_run", + _OLD_RUN_STATUS, + ) diff --git a/app/core/config.py b/app/core/config.py index 09a30cfc..e2d76a85 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -134,6 +134,17 @@ class Settings(BaseSettings): # are uncancellable mid-call, so a long fit can stall the drain. batch_cancel_drain_timeout_seconds: int = 30 + # Model selection (champion selector) async runner (Slice B) — mirrors the + # batch runner. Hard upper bound on concurrent candidate backtests across + # all active selection runs on this host; sized for the same Postgres pool + # (pool_size=5, max_overflow=10). Setting this to 1 makes the runner + # sequential. Env override: MODEL_SELECTION_GLOBAL_MAX_PARALLEL=8 (restart). + model_selection_global_max_parallel: int = 4 + # Max seconds DELETE /model-selection/{id} waits for in-flight candidates to + # settle before returning RFC 7807 504. In-flight sklearn/LightGBM fits are + # uncancellable mid-call, so a long fit can stall the drain. + model_selection_cancel_drain_timeout_seconds: int = 30 + # RAG Embedding Configuration rag_embedding_provider: Literal["openai", "ollama"] = "openai" openai_api_key: str = "" diff --git a/app/core/tests/test_config.py b/app/core/tests/test_config.py index 0dc96733..496c29bb 100644 --- a/app/core/tests/test_config.py +++ b/app/core/tests/test_config.py @@ -23,6 +23,15 @@ def test_settings_has_defaults(monkeypatch): assert settings.api_port == 8123 +def test_model_selection_runner_defaults(monkeypatch): + """Slice B async-runner settings default to the batch-mirrored values.""" + monkeypatch.delenv("MODEL_SELECTION_GLOBAL_MAX_PARALLEL", raising=False) + monkeypatch.delenv("MODEL_SELECTION_CANCEL_DRAIN_TIMEOUT_SECONDS", raising=False) + settings = Settings(_env_file=None) + assert settings.model_selection_global_max_parallel == 4 + assert settings.model_selection_cancel_drain_timeout_seconds == 30 + + def test_settings_is_development_property(): """is_development should return True for development env.""" settings = Settings(app_env="development") diff --git a/app/features/model_selection/models.py b/app/features/model_selection/models.py index ce7c6e20..a39d5763 100644 --- a/app/features/model_selection/models.py +++ b/app/features/model_selection/models.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any -from sqlalchemy import CheckConstraint, Date, DateTime, Index, Integer, String +from sqlalchemy import CheckConstraint, Date, DateTime, ForeignKey, Index, Integer, String from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -27,10 +27,12 @@ class ModelSelectionStatus(str, Enum): """Lifecycle states of a selection run. Transitions: - - PENDING -> RUNNING -> {COMPLETED, PARTIAL, FAILED} - - PARTIAL fires when >=1 candidate succeeded AND >=1 candidate failed. + - PENDING -> RUNNING -> {COMPLETED, PARTIAL, FAILED, CANCELLED} + - PARTIAL fires when >=1 candidate succeeded AND >=1 candidate failed/cancelled. - FAILED fires when availability is unusable (fail-fast) OR every candidate's backtest errored (no valid winner). + - CANCELLED (Slice B) fires when a cancel drained before any candidate + reached a non-cancelled terminal state. """ PENDING = "pending" @@ -38,6 +40,29 @@ class ModelSelectionStatus(str, Enum): COMPLETED = "completed" PARTIAL = "partial" FAILED = "failed" + CANCELLED = "cancelled" + + +# Statuses a selection run cannot transition out of — the DELETE-route 409 set +# (Slice B). Mirrors ``batch.models.TERMINAL_BATCH_STATES``. +TERMINAL_SELECTION_STATES: frozenset[str] = frozenset( + { + ModelSelectionStatus.COMPLETED.value, + ModelSelectionStatus.PARTIAL.value, + ModelSelectionStatus.FAILED.value, + ModelSelectionStatus.CANCELLED.value, + } +) + + +class CandidateStatus(str, Enum): + """Per-candidate execution states inside an async selection run (Slice B).""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" class ModelSelectionRun(TimestampMixin, Base): @@ -74,13 +99,21 @@ class ModelSelectionRun(TimestampMixin, Base): forecast_result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) business_summary: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + # Slice B (async) — set when the run starts executing; the four count + # columns cache the FINAL per-status candidate tally written once at settle + # (live progress is derived from a GROUP BY over the child rows). + started_at: Mapped[_dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + total_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0") + completed_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0") + failed_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0") + cancelled_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0") completed_at: Mapped[_dt.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) __table_args__ = ( CheckConstraint( - "status IN ('pending', 'running', 'completed', 'partial', 'failed')", + "status IN ('pending', 'running', 'completed', 'partial', 'failed', 'cancelled')", name="ck_model_selection_run_valid_status", ), Index( @@ -91,3 +124,49 @@ class ModelSelectionRun(TimestampMixin, Base): ), Index("ix_model_selection_run_status_created", "status", "created_at"), ) + + +class ModelSelectionCandidate(TimestampMixin, Base): + """One candidate's async execution record inside a selection run (Slice B). + + Concurrent candidate tasks each write their OWN row in their OWN session — + no shared-row write race. ``result`` carries the full ``CandidateResult`` + JSONB (incl. folds) on success; failed/cancelled candidates keep their row + so they stay visible in the results UI. Mirrors ``batch.BatchJobItem``. + """ + + __tablename__ = "model_selection_candidate" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + candidate_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + selection_id: Mapped[str] = mapped_column( + String(32), + ForeignKey("model_selection_run.selection_id", ondelete="CASCADE"), + index=True, + ) + ordinal: Mapped[int] = mapped_column(Integer) # submit order — stable display + model_type: Mapped[str] = mapped_column(String(40)) + params: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + status: Mapped[str] = mapped_column( + String(20), default=CandidateStatus.PENDING.value, index=True + ) + result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + error_type: Mapped[str | None] = mapped_column(String(100), nullable=True) + started_at: Mapped[_dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[_dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True) + + __table_args__ = ( + CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'cancelled')", + name="ck_model_selection_candidate_valid_status", + ), + Index( + "ix_model_selection_candidate_selection_status", + "selection_id", + "status", + ), + ) diff --git a/app/features/model_selection/routes.py b/app/features/model_selection/routes.py index f4f833c7..7597464e 100644 --- a/app/features/model_selection/routes.py +++ b/app/features/model_selection/routes.py @@ -16,7 +16,7 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, Query, status +from fastapi import APIRouter, Depends, Query, Response, status from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession @@ -30,6 +30,7 @@ PairAvailabilityResponse, PredictWinnerResponse, RankingResult, + SubmitRunResponse, TrainWinnerResponse, ) from app.features.model_selection.service import ModelSelectionService @@ -79,6 +80,43 @@ async def get_model_catalog() -> ModelCatalogResponse: return service.get_model_catalog() +@router.post( + "/runs", + response_model=SubmitRunResponse, + status_code=status.HTTP_202_ACCEPTED, + summary="Submit an async candidate comparison (fire-and-forget LRO)", +) +async def submit_run( + request: ModelSelectionRunRequest, + response: Response, + db: AsyncSession = Depends(get_db), +) -> SubmitRunResponse: + """Submit an async selection run — returns 202 with monitor/cancel pointers. + + The candidate backtests run in a detached task; poll + ``GET /model-selection/{selection_id}`` for live progress, terminal ranking, + and the winner. + """ + logger.info( + "model_selection.runs_request_received", + store_id=request.store_id, + product_id=request.product_id, + n_candidates=len(request.candidate_models), + ) + service = ModelSelectionService() + try: + result = await service.submit_run(db, request) + response.headers["Location"] = result.monitor_url + response.headers["Retry-After"] = "2" + return result + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to submit selection run", details={"error": str(exc)} + ) from exc + + @router.post( "/run", response_model=ModelSelectionRunResponse, @@ -128,6 +166,41 @@ async def get_selection( ) from exc +@router.delete( + "/{selection_id}", + response_model=ModelSelectionRunResponse, + status_code=status.HTTP_200_OK, + summary="Cancel an in-flight selection run (cooperative drain)", + description=( + "Cooperatively cancel an async selection run (Slice B). Pending " + "candidates skip; running candidates observe ``asyncio.CancelledError`` " + "at the next safe yield — sklearn / LightGBM fits are uncancellable " + "mid-call, so an in-flight fit may finish first. Returns:\n\n" + "- ``200`` settled run on a clean drain\n" + "- ``404`` RFC 7807 if the run does not exist\n" + "- ``409`` RFC 7807 if the run is already terminal\n" + "- ``504`` RFC 7807 if the drain exceeds " + "``Settings.model_selection_cancel_drain_timeout_seconds``" + ), +) +async def cancel_run( + selection_id: str, + db: AsyncSession = Depends(get_db), +) -> ModelSelectionRunResponse: + """Cancel an in-flight selection run and return its settled record. + + ``NotFoundError`` (404) / ``ConflictError`` (409) / ``GatewayTimeoutError`` + (504) raised in-service bubble to the global RFC 7807 handler. + """ + service = ModelSelectionService() + try: + return await service.cancel_run(db, selection_id) + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to cancel selection run", details={"error": str(exc)} + ) from exc + + @router.get( "/{selection_id}/ranking", response_model=RankingResult, diff --git a/app/features/model_selection/runner.py b/app/features/model_selection/runner.py new file mode 100644 index 00000000..7320ea03 --- /dev/null +++ b/app/features/model_selection/runner.py @@ -0,0 +1,312 @@ +"""Bounded-concurrency candidate runner for the champion selector (Slice B). + +A slice-local mirror of ``app/features/batch/runner.py``: one +:class:`asyncio.Semaphore` inside an :class:`asyncio.TaskGroup` fans out one +task per ``model_selection_candidate``; each child opens its own +``AsyncSession`` and observes a cooperative :class:`asyncio.Event` so +``DELETE /model-selection/{selection_id}`` cancels what hasn't started and +gracefully drains what has. + +The asyncio mechanics (the three cancel mechanisms, the +``except* asyncio.CancelledError`` PEP-654 catch shape, the per-task cancel + +cooperative event) are documented in +``PRPs/ai_docs/asyncio-taskgroup-cancellation.md``. + +Cross-slice rule: this module imports from ``app.features.model_selection.models`` +(same slice) and ``app.core.*`` only — it does NOT import the batch runner +(vertical-slice rule). The per-child ``execute_candidate`` callable supplied by +``ModelSelectionService`` is the seam that keeps the heavy backtest work out of +this module. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from sqlalchemy import select, update + +from app.core.logging import get_logger +from app.features.model_selection.models import ( + CandidateStatus, + ModelSelectionCandidate, +) + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + +logger = get_logger(__name__) + + +@dataclass +class CancelHandle: + """Cancel signal + Task refs + completion event for an in-flight selection. + + Created by :func:`run_selection_candidates`, looked up by + :func:`cancel_selection`, removed from :data:`_ACTIVE_SELECTIONS` and + signalled by the runner's caller via :func:`mark_completed` *after* the + parent's settle has committed — so ``DELETE`` never observes the parent + mid-settle. + """ + + cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + completed_event: asyncio.Event = field(default_factory=asyncio.Event) + tasks: list[asyncio.Task[None]] = field(default_factory=list) + + +# Module-level registry — single-process scope (matches the single-host vision). +_ACTIVE_SELECTIONS: dict[str, CancelHandle] = {} + + +def register_selection(selection_id: str) -> CancelHandle: + """Eagerly create (or reuse) the cancel handle for a selection. + + Called by the service the moment ``POST /runs`` commits — BEFORE the + detached worker starts — so a ``DELETE`` arriving in the gap between the 202 + response and the worker's first ``run_selection_candidates`` call still + finds a handle (and is not misreported as "already settled"). The worker's + ``setdefault`` reuses this same handle. + """ + return _ACTIVE_SELECTIONS.setdefault(selection_id, CancelHandle()) + + +async def run_selection_candidates( + *, + selection_id: str, + candidate_ids: list[str], + max_parallel: int, + global_max_parallel: int, + session_maker: async_sessionmaker[AsyncSession], + execute_candidate: Callable[[str], Awaitable[None]], +) -> int: + """Execute one selection's candidates through a bounded TaskGroup. + + Args: + selection_id: ``model_selection_run.selection_id`` — registry key + log + correlator. + candidate_ids: ``model_selection_candidate.candidate_id`` values, in + submit order. + max_parallel: per-run cap (Slice B passes the global setting — there is + no per-run field). + global_max_parallel: host-wide cap from + :attr:`Settings.model_selection_global_max_parallel`. + session_maker: shared ``async_sessionmaker``; each child opens one + ``AsyncSession`` from it for the state-transition writes the runner + emits. The caller-supplied ``execute_candidate`` opens its OWN + session from the same maker. + execute_candidate: one-arg coroutine; runs one candidate's backtest + + persists its result/failure in its own session. + + Returns: + ``effective = min(max_parallel, global_max_parallel)``. + + Notes: + - Caller MUST call :func:`mark_completed` after the parent settle + commits (even on the exception path). + - Cancellation does NOT propagate out: ``except* asyncio.CancelledError`` + absorbs the ``ExceptionGroup`` so the caller can settle the parent. + """ + effective = min(max_parallel, global_max_parallel) + sem = asyncio.Semaphore(effective) + handle = _ACTIVE_SELECTIONS.setdefault(selection_id, CancelHandle()) + + logger.info( + "model_selection.runner_start", + selection_id=selection_id, + total_candidates=len(candidate_ids), + max_parallel=max_parallel, + effective_max_parallel=effective, + ) + + async def _child(candidate_id: str) -> None: + # One ``AsyncSession`` per child for the runner's own state writes. + async with session_maker() as session: + # FAST-CANCEL before the semaphore acquire — skips not-yet-started + # work cleanly (sync check; no await window). + if handle.cancel_event.is_set(): + await _mark_cancelled_skipped(session, candidate_id) + return + + acquired = False + try: + async with sem: + acquired = True + # Re-check after acquire — a sibling may have signalled + # cancel while we waited on the semaphore. + if handle.cancel_event.is_set(): + await _mark_cancelled_skipped(session, candidate_id) + return + try: + await execute_candidate(candidate_id) + except asyncio.CancelledError: + # Persist the cancelled terminal state before re-raising + # so the TaskGroup absorbs the cancel. + await _mark_cancelled_running(session, candidate_id) + raise + except Exception: + # Defensive: ``execute_candidate`` should persist its own + # failure; if it didn't, mark FAILED so settle aggregates + # correctly. Do NOT re-raise — that would tear down siblings. + logger.exception( + "model_selection.runner_unexpected_child_error", + selection_id=selection_id, + candidate_id=candidate_id, + ) + await _mark_failed_unexpected(session, candidate_id) + except asyncio.CancelledError: + if not acquired: + await _mark_cancelled_skipped(session, candidate_id) + raise + + try: + async with asyncio.TaskGroup() as tg: + for cid in candidate_ids: + task = tg.create_task(_child(cid), name=f"model_selection:{selection_id}:{cid}") + handle.tasks.append(task) + except* asyncio.CancelledError: + # Clean ``task.cancel()`` calls are absorbed here; the per-child blocks + # already wrote the terminal state. The caller settles the parent. + logger.info( + "model_selection.runner_cancelled_exception_group", + selection_id=selection_id, + ) + + logger.info( + "model_selection.runner_complete", + selection_id=selection_id, + cancel_requested=handle.cancel_event.is_set(), + ) + return effective + + +def cancel_selection(selection_id: str) -> bool: + """Signal cooperative cancel for an in-flight selection. + + Sets ``cancel_event`` (skips pending children) and ``task.cancel()`` on + every tracked child (interrupts running children at the next yield). + + Returns: + ``True`` if the selection was registered; ``False`` if no handle exists + (race: the selection settled before cancel). + """ + handle = _ACTIVE_SELECTIONS.get(selection_id) + if handle is None: + return False + handle.cancel_event.set() + cancelled_count = 0 + for task in handle.tasks: + if not task.done(): + task.cancel() + cancelled_count += 1 + logger.info( + "model_selection.cancel_requested", + selection_id=selection_id, + n_tasks_tracked=len(handle.tasks), + n_tasks_cancelled=cancelled_count, + ) + return True + + +async def await_drain(selection_id: str, timeout_seconds: float) -> bool: + """Block until the selection's parent settle commits, or timeout elapses. + + Returns: + ``True`` on clean drain (or if never registered); ``False`` on timeout. + """ + handle = _ACTIVE_SELECTIONS.get(selection_id) + if handle is None: + return True + try: + await asyncio.wait_for(handle.completed_event.wait(), timeout=timeout_seconds) + return True + except TimeoutError: + # asyncio.wait_for raises the built-in TimeoutError since Python 3.11. + logger.warning( + "model_selection.cancel_drain_timeout", + selection_id=selection_id, + timeout_seconds=timeout_seconds, + ) + return False + + +def mark_completed(selection_id: str) -> None: + """Signal that the selection's parent settle has committed. + + Must be called after ``_settle`` commits (including the failure path) so any + concurrent ``DELETE`` drain unblocks. Idempotent: a missing handle is a no-op. + """ + handle = _ACTIVE_SELECTIONS.pop(selection_id, None) + if handle is None: + return + handle.completed_event.set() + + +# --------------------------------------------------------------------- helpers +# Each helper accepts an already-open ``AsyncSession`` (one per child) and +# commits its single UPDATE. They never raise on a missing row (a deleted-parent +# race is survivable — log + move on). + + +async def _mark_cancelled_skipped(session: AsyncSession, candidate_id: str) -> None: + """Mark a not-yet-started candidate as cancelled (pending → cancelled).""" + now = datetime.now(UTC) + await session.execute( + update(ModelSelectionCandidate) + .where(ModelSelectionCandidate.candidate_id == candidate_id) + .values(status=CandidateStatus.CANCELLED.value, completed_at=now) + ) + await session.commit() + + +async def _mark_cancelled_running(session: AsyncSession, candidate_id: str) -> None: + """Mark a running candidate as cancelled (running → cancelled).""" + now = datetime.now(UTC) + row = ( + await session.execute( + select(ModelSelectionCandidate.started_at).where( + ModelSelectionCandidate.candidate_id == candidate_id + ) + ) + ).first() + started_at = row[0] if row is not None else None + duration_ms = int((now - started_at).total_seconds() * 1000) if started_at is not None else None + await session.execute( + update(ModelSelectionCandidate) + .where(ModelSelectionCandidate.candidate_id == candidate_id) + .values( + status=CandidateStatus.CANCELLED.value, + completed_at=now, + duration_ms=duration_ms, + ) + ) + await session.commit() + + +async def _mark_failed_unexpected(session: AsyncSession, candidate_id: str) -> None: + """Defensive: mark a candidate ``failed`` when ``execute_candidate`` raised.""" + now = datetime.now(UTC) + await session.execute( + update(ModelSelectionCandidate) + .where(ModelSelectionCandidate.candidate_id == candidate_id) + .values( + status=CandidateStatus.FAILED.value, + completed_at=now, + error_message="Runner caught unexpected exception (see structlog)", + error_type="UnexpectedRunnerError", + ) + ) + await session.commit() + + +__all__ = [ + "_ACTIVE_SELECTIONS", + "CancelHandle", + "await_drain", + "cancel_selection", + "mark_completed", + "run_selection_candidates", +] diff --git a/app/features/model_selection/schemas.py b/app/features/model_selection/schemas.py index d3bc45dd..050d3ead 100644 --- a/app/features/model_selection/schemas.py +++ b/app/features/model_selection/schemas.py @@ -46,7 +46,10 @@ ] RankingMetric = Literal["wape", "smape", "mae", "bias"] -SelectionStatusLiteral = Literal["pending", "running", "completed", "partial", "failed"] +SelectionStatusLiteral = Literal[ + "pending", "running", "completed", "partial", "failed", "cancelled" +] +CandidateStatusLiteral = Literal["pending", "running", "completed", "failed", "cancelled"] ConfidenceLevel = Literal["high", "medium", "low"] AvailabilityStatus = Literal["ready", "limited", "unusable"] @@ -264,8 +267,40 @@ class ForecastSummary(BaseModel): horizon: int +class CandidateProgress(BaseModel): + """One candidate's live execution state (Slice B async run). + + Output-only. Empty list on a legacy synchronous ``/run`` row (no children). + """ + + candidate_id: str + ordinal: int + model_type: str + status: CandidateStatusLiteral + error: str | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + duration_ms: int | None = None + + +class SelectionProgress(BaseModel): + """Per-status candidate counts for an async selection run (Slice B).""" + + total: int + pending: int + running: int + completed: int + failed: int + cancelled: int + + class ModelSelectionRunResponse(BaseModel): - """``POST /model-selection/run`` and ``GET /model-selection/{id}`` contract.""" + """``POST /model-selection/run`` and ``GET /model-selection/{id}`` contract. + + Slice B adds ``started_at`` / ``progress`` / ``candidate_progress`` as + ADDITIVE fields with safe defaults — a legacy synchronous ``/run`` row has + ``progress=None`` and ``candidate_progress=[]``. + """ selection_id: str store_id: int @@ -285,7 +320,21 @@ class ModelSelectionRunResponse(BaseModel): business_summary: dict[str, Any] | None error_message: str | None created_at: datetime + started_at: datetime | None = None completed_at: datetime | None + progress: SelectionProgress | None = None + candidate_progress: list[CandidateProgress] = Field(default_factory=list) + + +class SubmitRunResponse(ModelSelectionRunResponse): + """``POST /model-selection/runs`` 202 response — an additive superset. + + Carries the LRO status-monitor pointers (the frontend drives the UI from + these body fields, not the ``Location``/``Retry-After`` headers). + """ + + monitor_url: str + cancel_url: str class CandidateModelInfo(BaseModel): diff --git a/app/features/model_selection/service.py b/app/features/model_selection/service.py index b8536068..743e647e 100644 --- a/app/features/model_selection/service.py +++ b/app/features/model_selection/service.py @@ -15,24 +15,41 @@ from __future__ import annotations +import asyncio import uuid +from collections.abc import Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING from sqlalchemy import and_, func, or_, select -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from app.core.exceptions import BadRequestError, NotFoundError +from app.core.config import get_settings +from app.core.database import get_session_maker +from app.core.exceptions import ( + BadRequestError, + ConflictError, + GatewayTimeoutError, + NotFoundError, +) from app.core.logging import get_logger from app.features.backtesting.schemas import SplitConfig from app.features.data_platform.models import Product, Promotion, SalesDaily, Store +from app.features.model_selection import runner from app.features.model_selection.capabilities import build_model_catalog from app.features.model_selection.explanations import explain_winner -from app.features.model_selection.models import ModelSelectionRun, ModelSelectionStatus +from app.features.model_selection.models import ( + TERMINAL_SELECTION_STATES, + CandidateStatus, + ModelSelectionCandidate, + ModelSelectionRun, + ModelSelectionStatus, +) from app.features.model_selection.ranking import build_chart_data, rank_candidates from app.features.model_selection.schemas import ( AvailabilityStatus, CandidateModelConfig, + CandidateProgress, CandidateResult, ChartData, FoldChart, @@ -42,7 +59,9 @@ ModelSelectionRunResponse, PairAvailabilityResponse, RankingResult, + SelectionProgress, SelectionWindow, + SubmitRunResponse, TrainWinnerResponse, WinnerSummary, ) @@ -53,6 +72,11 @@ logger = get_logger(__name__) +# Strong refs to detached background workers — asyncio holds only a WEAK ref to +# a bare ``create_task`` result, so without this set a worker can be GC'd +# mid-run (https://docs.python.org/3.12/library/asyncio-task.html#asyncio.create_task). +_BACKGROUND_TASKS: set[asyncio.Task[None]] = set() + # Availability policy constants (module-level; not operator-configurable in v1). MIN_COVERAGE_RATIO = 0.8 DEFAULT_MIN_TRAIN_SIZE = 30 @@ -402,14 +426,471 @@ async def run_selection( ) return self._response(row, ranking) + # ------------------------------------------------------------------------- + # Async orchestration (Slice B) — fire-and-forget LRO + # ------------------------------------------------------------------------- + + async def submit_run( + self, db: AsyncSession, request: ModelSelectionRunRequest + ) -> SubmitRunResponse: + """Submit an async selection run: insert parent + children, detach worker. + + Returns 202-shaped ``SubmitRunResponse`` (status=running) IMMEDIATELY — + the candidate backtests run in a detached :func:`asyncio.create_task` + that uses its OWN sessions (never this request ``db``). + """ + availability = await self.get_availability( + db, + request.store_id, + request.product_id, + request.forecast_horizon, + request.split_config, + ) + + selection_id = uuid.uuid4().hex + now = datetime.now(UTC) + row = ModelSelectionRun( + selection_id=selection_id, + status=ModelSelectionStatus.RUNNING.value, + store_id=request.store_id, + product_id=request.product_id, + start_date=request.selection_window.start_date, + end_date=request.selection_window.end_date, + forecast_horizon=request.forecast_horizon, + ranking_metric=request.ranking_metric, + candidate_models=[c.model_dump() for c in request.candidate_models], + policy_snapshot=request.ranking_policy.model_dump(mode="json"), + availability_snapshot=availability.model_dump(mode="json"), + started_at=now, + total_candidates=len(request.candidate_models), + ) + db.add(row) + # Flush the parent INSERT before the children — there is no ORM + # ``relationship`` and the FK targets the non-PK ``selection_id``, so the + # unit-of-work would not otherwise order parent-before-child. + await db.flush() + + # Fail fast on unusable availability (LOCKED #2 parity with the sync path) + # — persist a failed parent (no children, no worker) and raise 400. + if availability.status == "unusable": + message = "Insufficient data for model selection (availability unusable)." + row.status = ModelSelectionStatus.FAILED.value + row.error_message = message + row.completed_at = now + await db.commit() + logger.warning( + "model_selection.run_failed", + selection_id=selection_id, + reason="unusable_availability", + ) + raise BadRequestError(message=message) + + candidates: list[ModelSelectionCandidate] = [] + for ordinal, candidate in enumerate(request.candidate_models): + cand = ModelSelectionCandidate( + candidate_id=uuid.uuid4().hex, + selection_id=selection_id, + ordinal=ordinal, + model_type=candidate.model_type, + params=candidate.params, + status=CandidateStatus.PENDING.value, + ) + db.add(cand) + candidates.append(cand) + await db.commit() + await db.refresh(row) # populate server-default created_at for the 202 body + + logger.info( + "model_selection.run_submitted", + selection_id=selection_id, + store_id=request.store_id, + product_id=request.product_id, + n_candidates=len(candidates), + ) + + # Eagerly register the cancel handle so a DELETE arriving before the + # detached worker starts still finds it (avoids a false "already settled" + # 409). The worker's setdefault reuses this same handle. + runner.register_selection(selection_id) + + # Detach the worker — hold a strong ref so it cannot be GC'd mid-run. + task = asyncio.create_task( + self._run_in_background(selection_id, request), + name=f"model_selection_worker:{selection_id}", + ) + _BACKGROUND_TASKS.add(task) + task.add_done_callback(_BACKGROUND_TASKS.discard) + + candidate_progress = [ + CandidateProgress( + candidate_id=c.candidate_id, + ordinal=c.ordinal, + model_type=c.model_type, + status="pending", + ) + for c in candidates + ] + progress = SelectionProgress( + total=len(candidates), + pending=len(candidates), + running=0, + completed=0, + failed=0, + cancelled=0, + ) + return SubmitRunResponse( + selection_id=selection_id, + store_id=request.store_id, + product_id=request.product_id, + status="running", + selection_window=request.selection_window, + forecast_horizon=request.forecast_horizon, + ranking_metric=request.ranking_metric, + availability=availability, + ranking=[], + winner=None, + recommendation_confidence=None, + confidence_reasons=[], + chart_data=None, + final_model=None, + forecast=None, + business_summary=None, + error_message=None, + created_at=row.created_at, + started_at=now, + completed_at=None, + progress=progress, + candidate_progress=candidate_progress, + monitor_url=f"/model-selection/{selection_id}", + cancel_url=f"/model-selection/{selection_id}", + ) + + async def _run_in_background( + self, selection_id: str, request: ModelSelectionRunRequest + ) -> None: + """Detached worker — runs candidate backtests, then settles the parent. + + Uses ONLY sessions from ``get_session_maker()`` (the request session is + long gone). Never raises out — settles the parent to its observed state. + """ + session_maker = get_session_maker() + settings = get_settings() + + async def _exec(candidate_id: str) -> None: + from pydantic import TypeAdapter # lazy + + from app.features.backtesting.schemas import BacktestConfig # lazy + from app.features.backtesting.service import BacktestingService # lazy + from app.features.forecasting.schemas import ModelConfig # lazy + + async with session_maker() as session: + cand = await session.scalar( + select(ModelSelectionCandidate).where( + ModelSelectionCandidate.candidate_id == candidate_id + ) + ) + if cand is None: # deleted-parent race — survivable + return + started = datetime.now(UTC) + cand.status = CandidateStatus.RUNNING.value + cand.started_at = started + await session.commit() + logger.info( + "model_selection.candidate_started", + selection_id=selection_id, + model_type=cand.model_type, + ) + try: + adapter: TypeAdapter[object] = TypeAdapter(ModelConfig) + cfg = adapter.validate_python({"model_type": cand.model_type, **cand.params}) + backtest = await BacktestingService().run_backtest( + session, + request.store_id, + request.product_id, + request.selection_window.start_date, + request.selection_window.end_date, + BacktestConfig( + split_config=request.split_config, + model_config_main=cfg, # type: ignore[arg-type] + include_baselines=False, + store_fold_details=True, + ), + ) + result = self._shape_candidate( + CandidateModelConfig.model_validate( + {"model_type": cand.model_type, "params": cand.params} + ), + backtest, + ) + cand.result = result.model_dump(mode="json") + cand.status = CandidateStatus.COMPLETED.value + logger.info( + "model_selection.candidate_completed", + selection_id=selection_id, + model_type=cand.model_type, + ) + except Exception as exc: # never hide a failed candidate + cand.status = CandidateStatus.FAILED.value + cand.error_message = str(exc)[:2000] + cand.error_type = type(exc).__name__ + logger.warning( + "model_selection.candidate_failed", + selection_id=selection_id, + model_type=cand.model_type, + error=str(exc), + ) + finished = datetime.now(UTC) + cand.completed_at = finished + cand.duration_ms = int((finished - started).total_seconds() * 1000) + await session.commit() + + try: + candidate_ids = await self._candidate_ids(session_maker, selection_id) + await runner.run_selection_candidates( + selection_id=selection_id, + candidate_ids=candidate_ids, + max_parallel=settings.model_selection_global_max_parallel, + global_max_parallel=settings.model_selection_global_max_parallel, + session_maker=session_maker, + execute_candidate=_exec, + ) + finally: + # Always settle + unblock any DELETE drain, even if loading the + # candidate ids or the runner itself raised unexpectedly. + await self._settle(selection_id, request, session_maker) + runner.mark_completed(selection_id) + + async def _candidate_ids( + self, session_maker: async_sessionmaker[AsyncSession], selection_id: str + ) -> list[str]: + """Load this run's candidate ids in submit (ordinal) order.""" + async with session_maker() as session: + rows = ( + await session.execute( + select(ModelSelectionCandidate.candidate_id) + .where(ModelSelectionCandidate.selection_id == selection_id) + .order_by(ModelSelectionCandidate.ordinal) + ) + ).all() + return [r[0] for r in rows] + + async def _settle( + self, + selection_id: str, + request: ModelSelectionRunRequest, + session_maker: async_sessionmaker[AsyncSession], + ) -> None: + """Aggregate terminal children → ranking/chart/business + final status. + + REUSES the pure ``rank_candidates`` / ``build_chart_data`` / + ``explain_winner`` so the terminal GET output is byte-compatible with + the synchronous ``/run`` path (LOCKED #7). + """ + async with session_maker() as session: + row = await session.scalar( + select(ModelSelectionRun).where(ModelSelectionRun.selection_id == selection_id) + ) + if row is None: # deleted-parent race + return + children = ( + ( + await session.execute( + select(ModelSelectionCandidate) + .where(ModelSelectionCandidate.selection_id == selection_id) + .order_by(ModelSelectionCandidate.ordinal) + ) + ) + .scalars() + .all() + ) + + results: list[CandidateResult] = [] + for child in children: + if child.status == CandidateStatus.COMPLETED.value and child.result: + results.append(CandidateResult.model_validate(child.result)) + elif child.status == CandidateStatus.CANCELLED.value: + results.append( + CandidateResult( + model_type=child.model_type, + params=child.params, + failed=True, + error="cancelled", + aggregated_metrics=None, + sample_size=0, + folds=[], + ) + ) + else: # failed (or any non-completed leftover) + results.append( + CandidateResult( + model_type=child.model_type, + params=child.params, + failed=True, + error=child.error_message or "candidate failed", + aggregated_metrics=None, + sample_size=0, + folds=[], + ) + ) + + availability = ( + PairAvailabilityResponse.model_validate(row.availability_snapshot) + if row.availability_snapshot + else None + ) + availability_status: AvailabilityStatus = ( + availability.status if availability is not None else "ready" + ) + ranking = rank_candidates( + results, request.ranking_policy, row.ranking_metric, availability_status + ) + row.candidate_results = [r.model_dump(mode="json") for r in results] + row.ranking_result = ranking.model_dump(mode="json") + if ranking.winner is not None: + row.winner_model_type = ranking.winner.model_type + row.winner_metrics = ranking.winner.metrics + row.chart_data = build_chart_data(results, ranking).model_dump(mode="json") + if availability is not None: + row.business_summary = explain_winner(ranking, availability) + + counts = self._status_counts(children) + row.completed_candidates = counts["completed"] + row.failed_candidates = counts["failed"] + row.cancelled_candidates = counts["cancelled"] + row.status = self._terminal_status(counts).value + row.completed_at = datetime.now(UTC) + await session.commit() + logger.info( + "model_selection.run_settled", + selection_id=selection_id, + status=row.status, + winner=row.winner_model_type, + ) + + async def cancel_run(self, db: AsyncSession, selection_id: str) -> ModelSelectionRunResponse: + """Cooperatively cancel + drain an in-flight selection run.""" + row = await self._load(db, selection_id) + if row.status in TERMINAL_SELECTION_STATES: + raise ConflictError( + message=f"Selection run already terminal: {row.status}", + details={"selection_id": selection_id, "status": row.status}, + ) + logger.info("model_selection.run_cancel_requested", selection_id=selection_id) + fired = runner.cancel_selection(selection_id) + if not fired: + # Race: the worker settled between our load and the cancel. + raise ConflictError( + message="Selection run settled before cancel could fire", + details={"selection_id": selection_id}, + ) + settings = get_settings() + drained = await runner.await_drain( + selection_id, + timeout_seconds=float(settings.model_selection_cancel_drain_timeout_seconds), + ) + if not drained: + raise GatewayTimeoutError( + message=( + f"Drain exceeded {settings.model_selection_cancel_drain_timeout_seconds}s; " + "in-flight sklearn / LightGBM fits are uncancellable mid-call — " + "retry once the fit completes." + ), + details={"selection_id": selection_id}, + ) + # Re-load through a fresh read so the settled state is visible. + await db.commit() + refreshed = await self._load(db, selection_id) + logger.info( + "model_selection.run_cancel_drained", + selection_id=selection_id, + status=refreshed.status, + ) + response = self._response(refreshed, self._load_ranking(refreshed)) + await self._attach_progress(db, selection_id, response) + return response + + @staticmethod + def _status_counts(children: Sequence[ModelSelectionCandidate]) -> dict[str, int]: + """Tally child statuses into the five count buckets.""" + counts = {"pending": 0, "running": 0, "completed": 0, "failed": 0, "cancelled": 0} + for child in children: + counts[child.status] = counts.get(child.status, 0) + 1 + return counts + + @staticmethod + def _terminal_status(counts: dict[str, int]) -> ModelSelectionStatus: + """Terminal-status rule at settle (mirror ``batch.service._settle``).""" + completed = counts.get("completed", 0) + failed = counts.get("failed", 0) + cancelled = counts.get("cancelled", 0) + if cancelled > 0 and completed == 0 and failed == 0: + return ModelSelectionStatus.CANCELLED + if completed > 0 and failed == 0 and cancelled == 0: + return ModelSelectionStatus.COMPLETED + if failed > 0 and completed == 0 and cancelled == 0: + return ModelSelectionStatus.FAILED + if completed > 0 or failed > 0: + return ModelSelectionStatus.PARTIAL + return ModelSelectionStatus.FAILED + + async def _attach_progress( + self, db: AsyncSession, selection_id: str, response: ModelSelectionRunResponse + ) -> None: + """Attach live ``progress`` + ``candidate_progress`` to a response. + + A legacy synchronous ``/run`` row has no children → ``progress`` stays + ``None`` and ``candidate_progress`` stays ``[]``. + """ + children = ( + ( + await db.execute( + select(ModelSelectionCandidate) + .where(ModelSelectionCandidate.selection_id == selection_id) + .order_by(ModelSelectionCandidate.ordinal) + ) + ) + .scalars() + .all() + ) + if not children: + return + counts = self._status_counts(children) + response.progress = SelectionProgress( + total=len(children), + pending=counts["pending"], + running=counts["running"], + completed=counts["completed"], + failed=counts["failed"], + cancelled=counts["cancelled"], + ) + response.candidate_progress = [ + CandidateProgress( + candidate_id=child.candidate_id, + ordinal=child.ordinal, + model_type=child.model_type, + status=child.status, # type: ignore[arg-type] + error=child.error_message, + started_at=child.started_at, + completed_at=child.completed_at, + duration_ms=child.duration_ms, + ) + for child in children + ] + # ------------------------------------------------------------------------- # Read / re-run helpers # ------------------------------------------------------------------------- async def get_selection(self, db: AsyncSession, selection_id: str) -> ModelSelectionRunResponse: - """Return a persisted selection run by id (404 when missing).""" + """Return a persisted selection run by id (404 when missing). + + Attaches live async progress (Slice B) when the run has child rows; a + legacy synchronous ``/run`` row has none and reads as before. + """ row = await self._load(db, selection_id) - return self._response(row, self._load_ranking(row)) + response = self._response(row, self._load_ranking(row)) + await self._attach_progress(db, selection_id, response) + return response async def get_ranking(self, db: AsyncSession, selection_id: str) -> RankingResult: """Return just the ranking block for a selection run.""" @@ -578,5 +1059,6 @@ def _response( business_summary=row.business_summary, error_message=row.error_message, created_at=row.created_at, + started_at=row.started_at, completed_at=row.completed_at, ) diff --git a/app/features/model_selection/tests/test_async_routes.py b/app/features/model_selection/tests/test_async_routes.py new file mode 100644 index 00000000..6d0f3532 --- /dev/null +++ b/app/features/model_selection/tests/test_async_routes.py @@ -0,0 +1,180 @@ +"""Unit route tests for the Slice B async endpoints (service mocked). + +Mirrors ``test_routes.py``: ``get_db`` overridden with a mock session, the +service patched at the class level. Asserts the 202 shape + headers and the +DELETE 404/409 mapping over the HTTP boundary. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import UTC, datetime +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.core.database import get_db +from app.core.exceptions import ConflictError, NotFoundError +from app.features.model_selection.schemas import ( + CandidateProgress, + ModelSelectionRunResponse, + SelectionProgress, + SelectionWindow, + SubmitRunResponse, +) +from app.features.model_selection.service import ModelSelectionService +from app.main import app + + +@asynccontextmanager +async def _client() -> AsyncGenerator[AsyncClient, None]: + async def override_get_db() -> AsyncGenerator[AsyncMock, None]: + yield AsyncMock() + + app.dependency_overrides[get_db] = override_get_db + try: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + finally: + app.dependency_overrides.pop(get_db, None) + + +def _assert_problem_detail(body: dict[str, Any], expected_status: int) -> None: + for key in ("type", "title", "status", "detail"): + assert key in body, f"missing RFC 7807 field: {key}" + assert body["status"] == expected_status + + +def _valid_run_body(**overrides: Any) -> dict[str, Any]: + body: dict[str, Any] = { + "store_id": 5, + "product_id": 8, + "selection_window": {"start_date": "2026-01-01", "end_date": "2026-05-31"}, + "forecast_horizon": 14, + "split_config": { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 14, + }, + "candidate_models": [ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ], + } + body.update(overrides) + return body + + +def _running_submit_response(selection_id: str = "sel_async") -> SubmitRunResponse: + return SubmitRunResponse( + selection_id=selection_id, + store_id=5, + product_id=8, + status="running", + selection_window=SelectionWindow(start_date="2026-01-01", end_date="2026-05-31"), # type: ignore[arg-type] + forecast_horizon=14, + ranking_metric="wape", + availability=None, + ranking=[], + winner=None, + recommendation_confidence=None, + confidence_reasons=[], + chart_data=None, + final_model=None, + forecast=None, + business_summary=None, + error_message=None, + created_at=datetime.now(UTC), + started_at=datetime.now(UTC), + completed_at=None, + progress=SelectionProgress( + total=2, pending=2, running=0, completed=0, failed=0, cancelled=0 + ), + candidate_progress=[ + CandidateProgress(candidate_id="c0", ordinal=0, model_type="naive", status="pending"), + CandidateProgress( + candidate_id="c1", ordinal=1, model_type="seasonal_naive", status="pending" + ), + ], + monitor_url=f"/model-selection/{selection_id}", + cancel_url=f"/model-selection/{selection_id}", + ) + + +async def test_submit_runs_returns_202_with_headers_and_running_body( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + ModelSelectionService, + "submit_run", + AsyncMock(return_value=_running_submit_response()), + ) + async with _client() as ac: + response = await ac.post("/model-selection/runs", json=_valid_run_body()) + assert response.status_code == 202 + body = response.json() + assert body["status"] == "running" + assert body["monitor_url"] == "/model-selection/sel_async" + assert body["cancel_url"] == "/model-selection/sel_async" + assert body["progress"]["pending"] == 2 + assert len(body["candidate_progress"]) == 2 + # LRO status-monitor headers. + assert response.headers.get("location") == "/model-selection/sel_async" + assert response.headers.get("retry-after") == "2" + + +async def test_submit_runs_validation_error_returns_problem_json() -> None: + """A horizon mismatch is rejected by the request validator (422).""" + bad = _valid_run_body(forecast_horizon=14) + bad["split_config"] = { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 7, + } + async with _client() as ac: + response = await ac.post("/model-selection/runs", json=bad) + assert response.status_code == 422 + _assert_problem_detail(response.json(), 422) + + +async def test_delete_run_404_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "cancel_run", + AsyncMock(side_effect=NotFoundError(message="Selection run missing not found")), + ) + async with _client() as ac: + response = await ac.delete("/model-selection/missing") + assert response.status_code == 404 + _assert_problem_detail(response.json(), 404) + + +async def test_delete_run_409_when_terminal(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "cancel_run", + AsyncMock(side_effect=ConflictError(message="Selection run already terminal: completed")), + ) + async with _client() as ac: + response = await ac.delete("/model-selection/sel_done") + assert response.status_code == 409 + _assert_problem_detail(response.json(), 409) + + +async def test_delete_run_returns_settled_200(monkeypatch: pytest.MonkeyPatch) -> None: + settled = _running_submit_response("sel_cancel") + settled_resp = ModelSelectionRunResponse.model_validate( + {**settled.model_dump(), "status": "cancelled"} + ) + monkeypatch.setattr(ModelSelectionService, "cancel_run", AsyncMock(return_value=settled_resp)) + async with _client() as ac: + response = await ac.delete("/model-selection/sel_cancel") + assert response.status_code == 200 + assert response.json()["status"] == "cancelled" diff --git a/app/features/model_selection/tests/test_models.py b/app/features/model_selection/tests/test_models.py index 4f69d9e9..7264aec9 100644 --- a/app/features/model_selection/tests/test_models.py +++ b/app/features/model_selection/tests/test_models.py @@ -9,7 +9,13 @@ from datetime import date -from app.features.model_selection.models import ModelSelectionRun, ModelSelectionStatus +from app.features.model_selection.models import ( + TERMINAL_SELECTION_STATES, + CandidateStatus, + ModelSelectionCandidate, + ModelSelectionRun, + ModelSelectionStatus, +) def test_status_enum_values() -> None: @@ -19,9 +25,26 @@ def test_status_enum_values() -> None: "completed", "partial", "failed", + "cancelled", } +def test_candidate_status_enum_values() -> None: + assert {s.value for s in CandidateStatus} == { + "pending", + "running", + "completed", + "failed", + "cancelled", + } + + +def test_terminal_selection_states() -> None: + assert TERMINAL_SELECTION_STATES == {"completed", "partial", "failed", "cancelled"} + assert "running" not in TERMINAL_SELECTION_STATES + assert "pending" not in TERMINAL_SELECTION_STATES + + def test_model_selection_run_construction_defaults() -> None: row = ModelSelectionRun( selection_id="abc123", @@ -39,3 +62,19 @@ def test_model_selection_run_construction_defaults() -> None: assert row.status == "running" assert row.winner_model_type is None assert row.final_model_path is None + + +def test_model_selection_candidate_construction() -> None: + cand = ModelSelectionCandidate( + candidate_id="cand1", + selection_id="abc123", + ordinal=0, + model_type="naive", + params={}, + status=CandidateStatus.PENDING.value, + ) + assert cand.candidate_id == "cand1" + assert cand.selection_id == "abc123" + assert cand.status == "pending" + assert cand.result is None + assert cand.error_message is None diff --git a/app/features/model_selection/tests/test_routes_integration.py b/app/features/model_selection/tests/test_routes_integration.py index a6440f71..b74b98c2 100644 --- a/app/features/model_selection/tests/test_routes_integration.py +++ b/app/features/model_selection/tests/test_routes_integration.py @@ -6,6 +6,7 @@ from __future__ import annotations +import asyncio from typing import Any import pytest @@ -15,6 +16,23 @@ pytestmark = pytest.mark.integration +_TERMINAL = {"completed", "partial", "failed", "cancelled"} + + +async def _poll_until_terminal( + client: AsyncClient, selection_id: str, *, attempts: int = 60, delay: float = 0.5 +) -> dict[str, Any]: + """Poll GET /{id} until the run reaches a terminal status (or attempts run out).""" + body: dict[str, Any] = {} + for _ in range(attempts): + response = await client.get(f"/model-selection/{selection_id}") + assert response.status_code == 200 + body = response.json() + if body["status"] in _TERMINAL: + return body + await asyncio.sleep(delay) + raise AssertionError(f"run {selection_id} did not settle: last status {body.get('status')}") + def _run_body( pair: dict[str, Any], extra_candidates: list[dict[str, Any]] | None = None @@ -136,3 +154,119 @@ async def test_get_missing_selection_returns_404(client: AsyncClient) -> None: response = await client.get("/model-selection/does-not-exist") assert response.status_code == 404 assert response.json()["status"] == 404 + + +# --------------------------------------------------------------------- Slice B + + +async def test_async_runs_submits_202_and_polls_to_terminal_with_winner( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + """POST /runs returns 202 running immediately; polling settles with a winner.""" + submit = await client.post("/model-selection/runs", json=_run_body(ready_pair)) + assert submit.status_code == 202 + body = submit.json() + assert body["status"] == "running" + selection_id = body["selection_id"] + assert body["monitor_url"] == f"/model-selection/{selection_id}" + assert body["cancel_url"] == f"/model-selection/{selection_id}" + assert body["progress"]["total"] == 3 + assert submit.headers.get("location") == f"/model-selection/{selection_id}" + assert submit.headers.get("retry-after") == "2" + + terminal = await _poll_until_terminal(client, selection_id) + assert terminal["status"] in {"completed", "partial"} + assert terminal["winner"] is not None + assert terminal["chart_data"] is not None + assert terminal["ranking"] + assert terminal["progress"]["total"] == 3 + # Terminal GET output is byte-compatible with the sync /run shape. + assert terminal["recommendation_confidence"] in {"high", "medium", "low"} + + +async def test_async_runs_failed_candidate_stays_visible( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + """An invalid candidate surfaces as a failed/excluded entry, not a 500.""" + body = _run_body( + ready_pair, + extra_candidates=[{"model_type": "moving_average", "params": {"window_size": 0}}], + ) + submit = await client.post("/model-selection/runs", json=body) + assert submit.status_code == 202 + selection_id = submit.json()["selection_id"] + + terminal = await _poll_until_terminal(client, selection_id) + assert terminal["status"] == "partial" + excluded = [e for e in terminal["ranking"] if not e["included"]] + assert excluded + assert terminal["winner"] is not None + # The failed candidate is visible in candidate_progress too. + failed = [c for c in terminal["candidate_progress"] if c["status"] == "failed"] + assert failed + + +async def test_cancel_leaves_no_candidate_running( + client: AsyncClient, ready_pair: dict[str, Any], db_session: AsyncSession +) -> None: + """DELETE cooperatively cancels + drains — no candidate left 'running'.""" + submit = await client.post("/model-selection/runs", json=_run_body(ready_pair)) + assert submit.status_code == 202 + selection_id = submit.json()["selection_id"] + + # Cancel almost immediately. Fast baseline fits are uncancellable mid-call + # and may settle the whole run before the DELETE arrives — an HONEST race: + # 200 = the cancel fired and drained; + # 409 = the run had already settled (so nothing was left to cancel). + # Either way the LOAD-BEARING invariant below must hold. + cancel = await client.delete(f"/model-selection/{selection_id}") + assert cancel.status_code in {200, 409} + + # Ensure the run is terminal before asserting the invariant (covers the 200 + # path where the worker just settled, and the 409 already-settled path). + await _poll_until_terminal(client, selection_id) + + # The load-bearing invariant: after the drain, no candidate row is 'running'. + rows = await db_session.execute( + text( + "SELECT count(*) FROM model_selection_candidate " + "WHERE selection_id = :sid AND status = 'running'" + ), + {"sid": selection_id}, + ) + assert rows.scalar() == 0 + + +async def test_cancel_terminal_run_returns_409( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + """Cancelling an already-settled run returns 409.""" + submit = await client.post("/model-selection/runs", json=_run_body(ready_pair)) + selection_id = submit.json()["selection_id"] + await _poll_until_terminal(client, selection_id) + + cancel = await client.delete(f"/model-selection/{selection_id}") + assert cancel.status_code == 409 + assert cancel.json()["status"] == 409 + + +async def test_candidate_table_has_named_indexes(db_session: AsyncSession) -> None: + rows = await db_session.execute( + text("SELECT indexname FROM pg_indexes WHERE tablename = 'model_selection_candidate'") + ) + names = {row[0] for row in rows} + assert "ix_model_selection_candidate_candidate_id" in names + assert "ix_model_selection_candidate_selection_status" in names + + +async def test_legacy_sync_run_has_no_progress_children( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + """A legacy synchronous /run row carries no async progress.""" + run = await client.post("/model-selection/run", json=_run_body(ready_pair)) + assert run.status_code == 200 + selection_id = run.json()["selection_id"] + fetched = await client.get(f"/model-selection/{selection_id}") + body = fetched.json() + assert body["progress"] is None + assert body["candidate_progress"] == [] diff --git a/app/features/model_selection/tests/test_runner.py b/app/features/model_selection/tests/test_runner.py new file mode 100644 index 00000000..9421d303 --- /dev/null +++ b/app/features/model_selection/tests/test_runner.py @@ -0,0 +1,238 @@ +"""Unit tests for the Slice B bounded-concurrency candidate runner. + +The runner's DB helpers are monkeypatched to awaitable no-ops so the asyncio +orchestration is exercised without docker-compose. The DB invariants (no +candidate left ``running`` after a cancel drain) are covered in the integration +suite. Mirrors ``app/features/batch/tests/test_runner.py``. +""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from app.features.model_selection import runner + + +@pytest.fixture(autouse=True) +def _clear_registry() -> Any: + runner._ACTIVE_SELECTIONS.clear() + yield + runner._ACTIVE_SELECTIONS.clear() + + +@pytest.fixture +def patch_db_helpers(monkeypatch: pytest.MonkeyPatch) -> dict[str, list[Any]]: + """Replace runner DB helpers with awaitable no-ops + a call tracker.""" + calls: dict[str, list[Any]] = { + "mark_cancelled_skipped": [], + "mark_cancelled_running": [], + "mark_failed_unexpected": [], + } + + async def _mark_cancelled_skipped(_session: Any, candidate_id: str) -> None: + calls["mark_cancelled_skipped"].append(candidate_id) + + async def _mark_cancelled_running(_session: Any, candidate_id: str) -> None: + calls["mark_cancelled_running"].append(candidate_id) + + async def _mark_failed_unexpected(_session: Any, candidate_id: str) -> None: + calls["mark_failed_unexpected"].append(candidate_id) + + monkeypatch.setattr(runner, "_mark_cancelled_skipped", _mark_cancelled_skipped) + monkeypatch.setattr(runner, "_mark_cancelled_running", _mark_cancelled_running) + monkeypatch.setattr(runner, "_mark_failed_unexpected", _mark_failed_unexpected) + return calls + + +def _fake_session_maker() -> Any: + @asynccontextmanager + async def _ctx() -> Any: + yield AsyncMock() + + def _maker() -> Any: + return _ctx() + + return cast(Any, _maker) + + +# ---------------------------------------------------------------- semaphore + + +async def test_runner_semaphore_caps_concurrency( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """5 candidates with max_parallel=2 — observed concurrent peak == 2.""" + in_flight = 0 + peak = 0 + + async def child(_cid: str) -> None: + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + try: + await asyncio.sleep(0.02) + finally: + in_flight -= 1 + + effective = await runner.run_selection_candidates( + selection_id="s_sem", + candidate_ids=[f"c{i}" for i in range(5)], + max_parallel=2, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_candidate=child, + ) + runner.mark_completed("s_sem") + assert effective == 2 + assert peak == 2, f"observed peak {peak}, expected exactly 2" + + +async def test_runner_global_cap_clamps_max_parallel( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """max_parallel=32 clamped by global_max_parallel=1 → sequential (peak 1).""" + in_flight = 0 + peak = 0 + + async def child(_cid: str) -> None: + nonlocal in_flight, peak + in_flight += 1 + peak = max(peak, in_flight) + try: + await asyncio.sleep(0.01) + finally: + in_flight -= 1 + + effective = await runner.run_selection_candidates( + selection_id="s_seq", + candidate_ids=[f"c{i}" for i in range(4)], + max_parallel=32, + global_max_parallel=1, + session_maker=_fake_session_maker(), + execute_candidate=child, + ) + runner.mark_completed("s_seq") + assert effective == 1 + assert peak == 1, f"global cap of 1 must serialize; observed peak {peak}" + + +# ---------------------------------------------------- per-child failure isolation + + +async def test_runner_child_failure_does_not_abort_siblings( + patch_db_helpers: dict[str, list[Any]], +) -> None: + completed: list[str] = [] + + async def child(cid: str) -> None: + if cid == "c2": + raise RuntimeError("synthetic failure") + await asyncio.sleep(0.01) + completed.append(cid) + + await runner.run_selection_candidates( + selection_id="s_fail", + candidate_ids=[f"c{i}" for i in range(5)], + max_parallel=5, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_candidate=child, + ) + runner.mark_completed("s_fail") + assert sorted(completed) == ["c0", "c1", "c3", "c4"] + assert patch_db_helpers["mark_failed_unexpected"] == ["c2"] + + +# --------------------------------------------------------------- cancel paths + + +async def test_runner_cancel_before_start_skips( + patch_db_helpers: dict[str, list[Any]], +) -> None: + """max_parallel=1, 3 candidates. Cancel after c0 starts → c1/c2 skip.""" + started: list[str] = [] + + async def child(cid: str) -> None: + started.append(cid) + await asyncio.sleep(0.5) + + task = asyncio.create_task( + runner.run_selection_candidates( + selection_id="s_pending", + candidate_ids=["c0", "c1", "c2"], + max_parallel=1, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_candidate=child, + ) + ) + await asyncio.sleep(0.05) + fired = runner.cancel_selection("s_pending") + await task + runner.mark_completed("s_pending") + + assert fired is True + assert patch_db_helpers["mark_cancelled_running"] == ["c0"] + assert set(patch_db_helpers["mark_cancelled_skipped"]) == {"c1", "c2"} + assert started == ["c0"] + + +async def test_runner_cancel_mid_flight_marks_cancelled( + patch_db_helpers: dict[str, list[Any]], +) -> None: + cancelled_in_child: list[str] = [] + + async def child(cid: str) -> None: + try: + await asyncio.sleep(1.0) + except asyncio.CancelledError: + cancelled_in_child.append(cid) + raise + + task = asyncio.create_task( + runner.run_selection_candidates( + selection_id="s_running", + candidate_ids=["c0"], + max_parallel=1, + global_max_parallel=10, + session_maker=_fake_session_maker(), + execute_candidate=child, + ) + ) + await asyncio.sleep(0.05) + runner.cancel_selection("s_running") + await task + runner.mark_completed("s_running") + assert cancelled_in_child == ["c0"] + assert patch_db_helpers["mark_cancelled_running"] == ["c0"] + + +# ------------------------------------------------------------- registry hygiene + + +async def test_mark_completed_unblocks_await_drain() -> None: + runner._ACTIVE_SELECTIONS["sx"] = runner.CancelHandle() + drain_task = asyncio.create_task(runner.await_drain("sx", timeout_seconds=1.0)) + await asyncio.sleep(0.01) + runner.mark_completed("sx") + drained = await drain_task + assert drained is True + assert "sx" not in runner._ACTIVE_SELECTIONS + + +async def test_cancel_selection_returns_false_when_unregistered() -> None: + assert runner.cancel_selection("nope") is False + + +async def test_await_drain_returns_true_when_unregistered() -> None: + assert await runner.await_drain("nope", timeout_seconds=0.0) is True + + +async def test_await_drain_times_out_on_stuck_handle() -> None: + runner._ACTIVE_SELECTIONS["s_stuck"] = runner.CancelHandle() + assert await runner.await_drain("s_stuck", timeout_seconds=0.05) is False diff --git a/app/features/model_selection/tests/test_schemas.py b/app/features/model_selection/tests/test_schemas.py index 3d34c510..87fb093d 100644 --- a/app/features/model_selection/tests/test_schemas.py +++ b/app/features/model_selection/tests/test_schemas.py @@ -2,12 +2,18 @@ from __future__ import annotations +from datetime import UTC, datetime + import pytest from pydantic import ValidationError from app.features.model_selection.schemas import ( + CandidateProgress, ModelSelectionRunRequest, + ModelSelectionRunResponse, + SelectionProgress, SelectionWindow, + SubmitRunResponse, ) @@ -79,3 +85,79 @@ def test_candidate_models_min_length_enforced() -> None: """At least one candidate is required.""" with pytest.raises(ValidationError): ModelSelectionRunRequest.model_validate(_base_request_dict(candidate_models=[])) + + +# --------------------------------------------------------------------- Slice B + + +def _base_response_dict(**overrides: object) -> dict[str, object]: + payload: dict[str, object] = { + "selection_id": "sel1", + "store_id": 1, + "product_id": 2, + "status": "running", + "selection_window": {"start_date": "2026-01-01", "end_date": "2026-05-31"}, + "forecast_horizon": 14, + "ranking_metric": "wape", + "availability": None, + "ranking": [], + "winner": None, + "recommendation_confidence": None, + "confidence_reasons": [], + "chart_data": None, + "final_model": None, + "forecast": None, + "business_summary": None, + "error_message": None, + "created_at": datetime(2026, 6, 1, 12, 0, 0, tzinfo=UTC), + "completed_at": None, + } + payload.update(overrides) + return payload + + +def test_response_progress_fields_default_safely() -> None: + """Legacy sync-run rows validate without progress fields (additive defaults).""" + resp = ModelSelectionRunResponse.model_validate(_base_response_dict()) + assert resp.started_at is None + assert resp.progress is None + assert resp.candidate_progress == [] + + +def test_status_literal_accepts_cancelled() -> None: + """The 'cancelled' status (Slice B) is accepted by the response literal.""" + resp = ModelSelectionRunResponse.model_validate(_base_response_dict(status="cancelled")) + assert resp.status == "cancelled" + + +def test_selection_and_candidate_progress_models() -> None: + progress = SelectionProgress(total=5, pending=3, running=1, completed=1, failed=0, cancelled=0) + assert progress.total == 5 + cand = CandidateProgress(candidate_id="c1", ordinal=0, model_type="naive", status="running") + assert cand.status == "running" + assert cand.error is None + + +def test_submit_run_response_carries_monitor_and_cancel_urls() -> None: + submit = SubmitRunResponse.model_validate( + _base_response_dict( + monitor_url="/model-selection/sel1", + cancel_url="/model-selection/sel1", + progress={ + "total": 1, + "pending": 1, + "running": 0, + "completed": 0, + "failed": 0, + "cancelled": 0, + }, + candidate_progress=[ + {"candidate_id": "c1", "ordinal": 0, "model_type": "naive", "status": "pending"} + ], + ) + ) + assert submit.monitor_url == "/model-selection/sel1" + assert submit.cancel_url == "/model-selection/sel1" + assert submit.progress is not None + assert submit.progress.pending == 1 + assert submit.candidate_progress[0].model_type == "naive" diff --git a/app/features/model_selection/tests/test_service.py b/app/features/model_selection/tests/test_service.py index 7d3da5f1..2fd3002e 100644 --- a/app/features/model_selection/tests/test_service.py +++ b/app/features/model_selection/tests/test_service.py @@ -5,7 +5,7 @@ from datetime import date, timedelta from types import SimpleNamespace from typing import Any -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest @@ -218,5 +218,150 @@ async def test_response_uses_recommendation_confidence_key( async def test_get_selection_missing_raises_not_found() -> None: db = AsyncMock() db.scalar = AsyncMock(return_value=None) + db.execute = AsyncMock() with pytest.raises(NotFoundError): await ModelSelectionService().get_selection(db, uuid4().hex) + + +# ----------------------------------------------------------------------------- +# Slice B — async submit / settle / cancel (worker mocked or DB-free units) +# ----------------------------------------------------------------------------- + +from datetime import UTC, datetime # noqa: E402 + +from app.core.exceptions import ConflictError # noqa: E402 +from app.features.model_selection import runner as _runner # noqa: E402 +from app.features.model_selection.models import ( # noqa: E402 + ModelSelectionCandidate, + ModelSelectionRun, + ModelSelectionStatus, +) + + +def _submit_mock_db() -> AsyncMock: + """Mock ``AsyncSession`` whose ``refresh`` stamps ``created_at`` on the run.""" + db = AsyncMock() + added: list[Any] = [] + + def _add(obj: Any) -> None: + added.append(obj) + + async def _refresh(obj: Any) -> None: + if isinstance(obj, ModelSelectionRun) and obj.created_at is None: + obj.created_at = datetime.now(UTC) + + db.add = MagicMock(side_effect=_add) + db.commit = AsyncMock() + db.refresh = AsyncMock(side_effect=_refresh) + db._added = added # expose for assertions + return db + + +async def test_submit_run_inserts_running_parent_and_pending_candidates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_availability(monkeypatch, "ready") + # Stub the detached worker so create_task schedules a harmless no-op. + monkeypatch.setattr(ModelSelectionService, "_run_in_background", AsyncMock()) + + request = _request( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ] + ) + db = _submit_mock_db() + response = await ModelSelectionService().submit_run(db, request) + + assert response.status == "running" + assert response.monitor_url == f"/model-selection/{response.selection_id}" + assert response.cancel_url == f"/model-selection/{response.selection_id}" + assert response.progress is not None + assert response.progress.total == 2 + assert response.progress.pending == 2 + assert len(response.candidate_progress) == 2 + assert {c.status for c in response.candidate_progress} == {"pending"} + + parents = [o for o in db._added if isinstance(o, ModelSelectionRun)] + children = [o for o in db._added if isinstance(o, ModelSelectionCandidate)] + assert len(parents) == 1 + assert parents[0].status == ModelSelectionStatus.RUNNING.value + assert parents[0].started_at is not None + assert parents[0].total_candidates == 2 + assert len(children) == 2 + assert {c.status for c in children} == {"pending"} + assert [c.ordinal for c in children] == [0, 1] + + +async def test_submit_run_unusable_availability_raises_400( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_availability(monkeypatch, "unusable") + monkeypatch.setattr(ModelSelectionService, "_run_in_background", AsyncMock()) + db = _submit_mock_db() + with pytest.raises(BadRequestError): + await ModelSelectionService().submit_run(db, _request()) + # The parent was persisted as failed; no children were inserted. + parents = [o for o in db._added if isinstance(o, ModelSelectionRun)] + children = [o for o in db._added if isinstance(o, ModelSelectionCandidate)] + assert parents[0].status == ModelSelectionStatus.FAILED.value + assert children == [] + + +def test_terminal_status_rule() -> None: + svc = ModelSelectionService() + f = svc._terminal_status + assert f({"completed": 3, "failed": 0, "cancelled": 0}) is ModelSelectionStatus.COMPLETED + assert f({"completed": 0, "failed": 3, "cancelled": 0}) is ModelSelectionStatus.FAILED + assert f({"completed": 0, "failed": 0, "cancelled": 3}) is ModelSelectionStatus.CANCELLED + assert f({"completed": 2, "failed": 1, "cancelled": 0}) is ModelSelectionStatus.PARTIAL + assert f({"completed": 1, "failed": 0, "cancelled": 1}) is ModelSelectionStatus.PARTIAL + + +async def test_cancel_run_404_when_missing() -> None: + db = AsyncMock() + db.scalar = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await ModelSelectionService().cancel_run(db, uuid4().hex) + + +async def test_cancel_run_409_when_terminal() -> None: + row = ModelSelectionRun( + selection_id="sel_terminal", + status=ModelSelectionStatus.COMPLETED.value, + store_id=1, + product_id=1, + start_date=date(2026, 1, 1), + end_date=date(2026, 5, 31), + forecast_horizon=14, + ranking_metric="wape", + candidate_models=[], + policy_snapshot={}, + ) + db = AsyncMock() + db.scalar = AsyncMock(return_value=row) + with pytest.raises(ConflictError): + await ModelSelectionService().cancel_run(db, "sel_terminal") + + +async def test_cancel_run_409_when_settle_races_cancel( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """If the worker settled (no handle) between load and cancel → 409.""" + row = ModelSelectionRun( + selection_id="sel_race", + status=ModelSelectionStatus.RUNNING.value, + store_id=1, + product_id=1, + start_date=date(2026, 1, 1), + end_date=date(2026, 5, 31), + forecast_horizon=14, + ranking_metric="wape", + candidate_models=[], + policy_snapshot={}, + ) + db = AsyncMock() + db.scalar = AsyncMock(return_value=row) + monkeypatch.setattr(_runner, "cancel_selection", lambda _sid: False) + with pytest.raises(ConflictError): + await ModelSelectionService().cancel_run(db, "sel_race") diff --git a/frontend/src/components/champion-selector/results/cancel-run-dialog.test.tsx b/frontend/src/components/champion-selector/results/cancel-run-dialog.test.tsx new file mode 100644 index 00000000..c5d53231 --- /dev/null +++ b/frontend/src/components/champion-selector/results/cancel-run-dialog.test.tsx @@ -0,0 +1,33 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { CancelRunDialog } from './cancel-run-dialog' + +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) + if (!Element.prototype.hasPointerCapture) { + Element.prototype.hasPointerCapture = () => false + } +}) + +afterEach(cleanup) + +describe('CancelRunDialog', () => { + it('confirms cancellation via the AlertDialog', () => { + const onConfirm = vi.fn() + render() + fireEvent.click(screen.getByTestId('cancel-run-trigger')) + fireEvent.click(screen.getByTestId('cancel-run-confirm')) + expect(onConfirm).toHaveBeenCalledTimes(1) + }) + + it('disables the trigger while cancelling', () => { + render( {}} isCancelling />) + const trigger = screen.getByTestId('cancel-run-trigger') as HTMLButtonElement + expect(trigger.disabled).toBe(true) + }) +}) diff --git a/frontend/src/components/champion-selector/results/cancel-run-dialog.tsx b/frontend/src/components/champion-selector/results/cancel-run-dialog.tsx new file mode 100644 index 00000000..d85c08ca --- /dev/null +++ b/frontend/src/components/champion-selector/results/cancel-run-dialog.tsx @@ -0,0 +1,62 @@ +import { Loader2, X } from 'lucide-react' +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from '@/components/ui/alert-dialog' +import { Button } from '@/components/ui/button' + +interface CancelRunDialogProps { + onConfirm: () => void + isCancelling?: boolean + disabled?: boolean +} + +/** + * Cancel-run confirmation (Slice B). Mirrors the batch cancel dialog and reuses + * the honest pending-skip / running-yield copy. + */ +export function CancelRunDialog({ onConfirm, isCancelling, disabled }: CancelRunDialogProps) { + return ( + + + + + + + Cancel this comparison? + + Candidates that haven't started will be skipped. A candidate + already mid-fit stops at the next safe point — sklearn / LightGBM + fits are uncancellable mid-call, so an in-flight fit may finish + first. Results from candidates that already completed are kept. + + + + Keep running + + Cancel run + + + + + ) +} diff --git a/frontend/src/components/champion-selector/results/comparison-charts.test.tsx b/frontend/src/components/champion-selector/results/comparison-charts.test.tsx new file mode 100644 index 00000000..d1ea60bf --- /dev/null +++ b/frontend/src/components/champion-selector/results/comparison-charts.test.tsx @@ -0,0 +1,36 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { ComparisonCharts } from './comparison-charts' +import type { ModelSelectionChartData } from '@/types/api' + +// Recharts' ResponsiveContainer needs ResizeObserver in jsdom. +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) +}) + +afterEach(cleanup) + +const chartData: ModelSelectionChartData = { + wape_by_model: { regression: 10, naive: 14 }, + bias_by_model: { regression: -0.2, naive: 0.5 }, + fold_stability: { regression: [10, 11] }, + winner_actual_vs_predicted: [ + { dates: ['2026-01-01', '2026-01-02'], actuals: [10, 12], predictions: [9.5, 12.5] }, + ], +} + +describe('ComparisonCharts', () => { + it('renders WAPE + bias bars from chart_data', () => { + render() + expect(screen.getByTestId('comparison-charts')).toBeTruthy() + expect(screen.getByTestId('metric-bars-wape-by-model')).toBeTruthy() + expect(screen.getByTestId('metric-bars-bias-by-model')).toBeTruthy() + // Winner is starred in the bar list. + expect(screen.getAllByText('★ regression').length).toBeGreaterThan(0) + }) +}) diff --git a/frontend/src/components/champion-selector/results/comparison-charts.tsx b/frontend/src/components/champion-selector/results/comparison-charts.tsx new file mode 100644 index 00000000..5e192a22 --- /dev/null +++ b/frontend/src/components/champion-selector/results/comparison-charts.tsx @@ -0,0 +1,105 @@ +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { MultiSeriesChart } from '@/components/charts/multi-series-chart' +import { BIAS_EXPLANATION } from '@/components/champion-selector/copy' +import type { ModelSelectionChartData } from '@/types/api' + +interface ComparisonChartsProps { + chartData: ModelSelectionChartData + winnerModelType?: string +} + +/** One labelled horizontal bar (CSS — deterministic, no chart lib needed). */ +function MetricBars({ + title, + byModel, + winnerModelType, + signed = false, +}: { + title: string + byModel: Record + winnerModelType?: string + signed?: boolean +}) { + const entries = Object.entries(byModel) + const max = Math.max(1, ...entries.map(([, v]) => Math.abs(v))) + return ( +
+

{title}

+ {entries.map(([model, value]) => ( +
+ + {model === winnerModelType ? `★ ${model}` : model} + +
+
+
+ {value.toFixed(2)} +
+ ))} +
+ ) +} + +/** + * Comparison charts (Slice B): WAPE-by-model + bias-by-model bars, and the + * winner's actual-vs-predicted overlay. Reads the backend `chart_data` payload. + */ +export function ComparisonCharts({ chartData, winnerModelType }: ComparisonChartsProps) { + // Build actual-vs-predicted rows for the winner from the fold chart points. + const avpRows: Record[] = [] + for (const fold of chartData.winner_actual_vs_predicted as Array<{ + dates?: string[] + actuals?: number[] + predictions?: number[] + }>) { + const dates = fold.dates ?? [] + const actuals = fold.actuals ?? [] + const predictions = fold.predictions ?? [] + for (let i = 0; i < dates.length; i++) { + avpRows.push({ + date: dates[i] ?? String(i), + actual: actuals[i] ?? 0, + predicted: predictions[i] ?? 0, + }) + } + } + + return ( + + + Comparison + {BIAS_EXPLANATION} + + +
+ + +
+ {avpRows.length > 0 && ( + + )} +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/results/constants.ts b/frontend/src/components/champion-selector/results/constants.ts new file mode 100644 index 00000000..41aa3bb2 --- /dev/null +++ b/frontend/src/components/champion-selector/results/constants.ts @@ -0,0 +1,17 @@ +import type { ModelSelectionStatus } from '@/types/api' + +/** + * Terminal selection-run statuses (Slice B). Polling stops once a run reaches + * one of these. Kept in a `.ts` module so the + * `react-refresh/only-export-components` lint rule never trips. + */ +export const TERMINAL_SELECTION_STATES: ReadonlySet = new Set([ + 'completed', + 'partial', + 'failed', + 'cancelled', +]) + +export function isTerminalSelectionStatus(status: ModelSelectionStatus): boolean { + return TERMINAL_SELECTION_STATES.has(status) +} diff --git a/frontend/src/components/champion-selector/results/model-detail-drawer.test.tsx b/frontend/src/components/champion-selector/results/model-detail-drawer.test.tsx new file mode 100644 index 00000000..83d90d1b --- /dev/null +++ b/frontend/src/components/champion-selector/results/model-detail-drawer.test.tsx @@ -0,0 +1,43 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { ModelDetailDrawer } from './model-detail-drawer' +import type { ModelRankEntry } from '@/types/api' + +// Radix Dialog (Sheet) needs these layout APIs in jsdom. +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) + if (!Element.prototype.hasPointerCapture) { + Element.prototype.hasPointerCapture = () => false + } +}) + +afterEach(cleanup) + +const entry: ModelRankEntry = { + rank: 1, + model_type: 'regression', + params: { max_depth: 6 }, + included: true, + exclusion_reason: null, + metrics: { wape: 10, smape: 8, mae: 4, rmse: 5, bias: 0.1 }, +} + +describe('ModelDetailDrawer', () => { + it('renders the candidate metrics + params when open', () => { + render( {}} />) + const drawer = screen.getByTestId('model-detail-drawer') + expect(drawer.textContent).toContain('regression') + expect(drawer.textContent).toContain('WAPE') + expect(drawer.textContent).toContain('max_depth') + }) + + it('renders nothing meaningful when closed', () => { + render( {}} />) + expect(screen.queryByTestId('model-detail-drawer')).toBeNull() + }) +}) diff --git a/frontend/src/components/champion-selector/results/model-detail-drawer.tsx b/frontend/src/components/champion-selector/results/model-detail-drawer.tsx new file mode 100644 index 00000000..f7ac0148 --- /dev/null +++ b/frontend/src/components/champion-selector/results/model-detail-drawer.tsx @@ -0,0 +1,79 @@ +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetTitle, +} from '@/components/ui/sheet' +import { Badge } from '@/components/ui/badge' +import type { ModelRankEntry } from '@/types/api' + +interface ModelDetailDrawerProps { + entry: ModelRankEntry | null + open: boolean + onOpenChange: (open: boolean) => void +} + +function fmt(value: number | undefined): string { + if (typeof value !== 'number' || !Number.isFinite(value)) return '—' + return value.toFixed(3) +} + +const METRIC_KEYS: { key: string; label: string }[] = [ + { key: 'wape', label: 'WAPE' }, + { key: 'smape', label: 'sMAPE' }, + { key: 'mae', label: 'MAE' }, + { key: 'rmse', label: 'RMSE' }, + { key: 'bias', label: 'Bias' }, +] + +/** + * Per-model detail drawer (Slice B). Opens from a ranking-row click; shows one + * candidate's metrics, params, and exclusion reason (read-only). + */ +export function ModelDetailDrawer({ entry, open, onOpenChange }: ModelDetailDrawerProps) { + return ( + + + {entry && ( + <> + + + {entry.model_type} + {!entry.included && ( + {entry.exclusion_reason ?? 'excluded'} + )} + + + {entry.rank !== null ? `Ranked #${entry.rank}` : 'Not ranked'} + + +
+
+

Metrics

+ + + {METRIC_KEYS.map((m) => ( + + + + + ))} + +
{m.label} + {fmt(entry.metrics?.[m.key])} +
+
+
+

Parameters

+
+                  {JSON.stringify(entry.params, null, 2)}
+                
+
+
+ + )} +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/results/ranking-table.test.tsx b/frontend/src/components/champion-selector/results/ranking-table.test.tsx new file mode 100644 index 00000000..9943ff6b --- /dev/null +++ b/frontend/src/components/champion-selector/results/ranking-table.test.tsx @@ -0,0 +1,50 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { RankingTable } from './ranking-table' +import type { ModelRankEntry } from '@/types/api' + +afterEach(cleanup) + +const ranking: ModelRankEntry[] = [ + { + rank: 1, + model_type: 'regression', + params: {}, + included: true, + exclusion_reason: null, + metrics: { wape: 10, smape: 8, mae: 4, bias: 0.1 }, + }, + { + rank: 2, + model_type: 'naive', + params: {}, + included: true, + exclusion_reason: null, + metrics: { wape: 14, smape: 12, mae: 6, bias: 0.5 }, + }, + { + rank: null, + model_type: 'moving_average', + params: { window_size: 0 }, + included: false, + exclusion_reason: 'failed', + metrics: null, + }, +] + +describe('RankingTable', () => { + it('renders a row per entry; excluded rows show their reason', () => { + render( {}} />) + expect(screen.getByTestId('ranking-row-regression')).toBeTruthy() + expect(screen.getByTestId('ranking-row-naive')).toBeTruthy() + const excluded = screen.getByTestId('ranking-row-moving_average') + expect(excluded.textContent).toContain('failed') + }) + + it('calls onSelectModel with the clicked entry', () => { + const onSelect = vi.fn() + render() + fireEvent.click(screen.getByTestId('ranking-row-naive')) + expect(onSelect).toHaveBeenCalledWith(ranking[1]) + }) +}) diff --git a/frontend/src/components/champion-selector/results/ranking-table.tsx b/frontend/src/components/champion-selector/results/ranking-table.tsx new file mode 100644 index 00000000..a8c0515a --- /dev/null +++ b/frontend/src/components/champion-selector/results/ranking-table.tsx @@ -0,0 +1,90 @@ +import { Trophy } from 'lucide-react' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Badge } from '@/components/ui/badge' +import { cn } from '@/lib/utils' +import { RANKING_TIE_BREAK } from '@/components/champion-selector/copy' +import type { ModelRankEntry } from '@/types/api' + +interface RankingTableProps { + ranking: ModelRankEntry[] + onSelectModel: (entry: ModelRankEntry) => void +} + +function fmt(value: number | undefined): string { + if (typeof value !== 'number' || !Number.isFinite(value)) return '—' + return value.toFixed(2) +} + +/** + * Candidate ranking table (Slice B). Winner row highlighted; excluded + * (failed/cancelled/filtered) rows show their reason and stay visible. Clicking + * a row opens the model-detail drawer. + */ +export function RankingTable({ ranking, onSelectModel }: RankingTableProps) { + return ( + + + Ranking + {RANKING_TIE_BREAK} + + + + + + + + + + + + + + + {ranking.map((entry) => ( + onSelectModel(entry)} + className={cn( + 'cursor-pointer border-t hover:bg-accent/50', + entry.rank === 1 && 'bg-primary/5 font-medium', + !entry.included && 'text-muted-foreground', + )} + > + + + + + + + + ))} + +
RankModelWAPEsMAPEMAEBias
+ {entry.rank === 1 ? ( + + 1 + + ) : ( + (entry.rank ?? '—') + )} + + {entry.model_type} + {!entry.included && ( + + {entry.exclusion_reason ?? 'excluded'} + + )} + + {fmt(entry.metrics?.['wape'])} + + {fmt(entry.metrics?.['smape'])} + + {fmt(entry.metrics?.['mae'])} + + {fmt(entry.metrics?.['bias'])} +
+
+
+ ) +} diff --git a/frontend/src/components/champion-selector/results/run-progress-panel.test.tsx b/frontend/src/components/champion-selector/results/run-progress-panel.test.tsx new file mode 100644 index 00000000..13c4ef54 --- /dev/null +++ b/frontend/src/components/champion-selector/results/run-progress-panel.test.tsx @@ -0,0 +1,57 @@ +import { afterEach, describe, expect, it } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { RunProgressPanel } from './run-progress-panel' +import type { CandidateProgress, SelectionProgress } from '@/types/api' + +afterEach(cleanup) + +const progress: SelectionProgress = { + total: 3, + pending: 1, + running: 1, + completed: 1, + failed: 0, + cancelled: 0, +} + +function cand(model_type: string, status: CandidateProgress['status']): CandidateProgress { + return { + candidate_id: `id-${model_type}`, + ordinal: 0, + model_type, + status, + error: status === 'failed' ? 'boom' : null, + started_at: null, + completed_at: null, + duration_ms: status === 'completed' ? 1500 : null, + } +} + +describe('RunProgressPanel', () => { + it('renders status badge, counts, and a per-candidate row', () => { + render( + , + ) + expect(screen.getByTestId('run-status-badge').textContent).toContain('running') + expect(screen.getByText('Total')).toBeTruthy() + expect(screen.getByTestId('candidate-row-naive')).toBeTruthy() + expect(screen.getByTestId('candidate-row-regression')).toBeTruthy() + }) + + it('keeps a failed candidate visible with its error', () => { + render( + , + ) + const row = screen.getByTestId('candidate-row-xgboost') + expect(row.textContent).toContain('failed') + expect(row.textContent).toContain('boom') + }) +}) diff --git a/frontend/src/components/champion-selector/results/run-progress-panel.tsx b/frontend/src/components/champion-selector/results/run-progress-panel.tsx new file mode 100644 index 00000000..4c5699a3 --- /dev/null +++ b/frontend/src/components/champion-selector/results/run-progress-panel.tsx @@ -0,0 +1,87 @@ +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' +import { StatusBadge } from '@/components/common/status-badge' +import { getStatusVariant } from '@/lib/status-utils' +import type { + CandidateProgress, + ModelSelectionStatus, + SelectionProgress, +} from '@/types/api' + +interface RunProgressPanelProps { + status: ModelSelectionStatus + progress: SelectionProgress | null + candidates: CandidateProgress[] +} + +function Count({ label, value }: { label: string; value: number }) { + return ( +
+

{label}

+

{value}

+
+ ) +} + +/** + * Live async-run progress (Slice B): the run status, per-status counts, and a + * per-candidate table. Failed/cancelled candidates stay visible. + */ +export function RunProgressPanel({ status, progress, candidates }: RunProgressPanelProps) { + return ( + + +
+ Comparison progress + + {status} + +
+
+ + {progress && ( +
+ + + + + + +
+ )} + {candidates.length > 0 && ( + + + + + + + + + + {candidates.map((c) => ( + + + + + + ))} + +
ModelStatusDuration
{c.model_type} + + {c.status} + + {c.error && ( + {c.error} + )} + + {c.duration_ms === null ? '—' : `${(c.duration_ms / 1000).toFixed(1)}s`} +
+ )} +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/results/winner-card.test.tsx b/frontend/src/components/champion-selector/results/winner-card.test.tsx new file mode 100644 index 00000000..54054253 --- /dev/null +++ b/frontend/src/components/champion-selector/results/winner-card.test.tsx @@ -0,0 +1,40 @@ +import { afterEach, describe, expect, it } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { WinnerCard } from './winner-card' +import type { WinnerSummary } from '@/types/api' + +afterEach(cleanup) + +const winner: WinnerSummary = { + model_type: 'regression', + params: {}, + metrics: { wape: 10, smape: 8, mae: 4, bias: 0.1 }, + rank: 1, +} + +describe('WinnerCard', () => { + it('renders the winner, confidence, metrics, and bias copy', () => { + render() + expect(screen.getByTestId('winner-card').textContent).toContain('regression') + expect(screen.getByTestId('winner-confidence-badge').textContent).toContain('high') + expect(screen.getByText('clear lead')).toBeTruthy() + expect(screen.getByText(/Positive bias means the model under-forecasts/)).toBeTruthy() + }) + + it('renders a no-winner state when winner is null', () => { + render() + expect(screen.getByText('No champion selected')).toBeTruthy() + }) + + it('surfaces the deterministic business_summary headline read-only', () => { + render( + , + ) + expect(screen.getByText('regression wins by 28% WAPE')).toBeTruthy() + }) +}) diff --git a/frontend/src/components/champion-selector/results/winner-card.tsx b/frontend/src/components/champion-selector/results/winner-card.tsx new file mode 100644 index 00000000..c5fa0b8a --- /dev/null +++ b/frontend/src/components/champion-selector/results/winner-card.tsx @@ -0,0 +1,100 @@ +import { Trophy } from 'lucide-react' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Badge } from '@/components/ui/badge' +import { StatusBadge } from '@/components/common/status-badge' +import { BIAS_EXPLANATION } from '@/components/champion-selector/copy' +import type { ConfidenceLevel, WinnerSummary } from '@/types/api' + +interface WinnerCardProps { + winner: WinnerSummary | null + confidence: ConfidenceLevel | null + reasons: string[] + /** The deterministic backend `business_summary` (read-only; Slice C extends). */ + businessSummary?: Record | null +} + +const CONFIDENCE_VARIANT: Record = { + high: 'success', + medium: 'info', + low: 'warning', +} + +function Metric({ label, value }: { label: string; value: number | undefined }) { + return ( +
+

{label}

+

+ {typeof value === 'number' && Number.isFinite(value) ? value.toFixed(2) : '—'} +

+
+ ) +} + +/** + * Winner summary card (Slice B). Null-safe — renders a "no winner" state for a + * failed/cancelled run. Renders the deterministic `business_summary` headline + * READ-ONLY (Slice C adds the decision-layer interpretation on top). + */ +export function WinnerCard({ winner, confidence, reasons, businessSummary }: WinnerCardProps) { + if (winner === null) { + return ( + + + No champion selected + + No candidate produced a valid backtest. Review the failed candidates + below or adjust the selection. + + + + ) + } + + const headline = + typeof businessSummary?.['headline'] === 'string' + ? (businessSummary['headline'] as string) + : null + + return ( + + +
+ + + {winner.model_type} + + {confidence && ( + + {confidence} confidence + + )} +
+ {headline && {headline}} +
+ +
+ + + + +
+ {reasons.length > 0 && ( +
+ {reasons.map((reason, i) => ( +
+ + why + + {reason} +
+ ))} +
+ )} +

{BIAS_EXPLANATION}

+
+
+ ) +} diff --git a/frontend/src/hooks/use-model-selection.test.ts b/frontend/src/hooks/use-model-selection.test.ts index a1187321..4209a072 100644 --- a/frontend/src/hooks/use-model-selection.test.ts +++ b/frontend/src/hooks/use-model-selection.test.ts @@ -5,12 +5,23 @@ * availability `enabled` gating. No real backend is exercised. */ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { renderHook, waitFor } from '@testing-library/react' +import { act, renderHook, waitFor } from '@testing-library/react' import { afterEach, describe, expect, it, vi } from 'vitest' import { createElement, type ReactNode } from 'react' -import { useModelCatalog, usePairAvailability } from './use-model-selection' -import type { ModelCatalogResponse, PairAvailability } from '@/types/api' +import { + useCancelSelectionRun, + useModelCatalog, + usePairAvailability, + useSelectionRun, + useSubmitSelectionRun, +} from './use-model-selection' +import type { + ModelCatalogResponse, + ModelSelectionRunRequest, + PairAvailability, + SubmitRunResponse, +} from '@/types/api' function makeWrapper(client: QueryClient) { return function Wrapper({ children }: { children: ReactNode }) { @@ -124,3 +135,139 @@ describe('usePairAvailability', () => { expect(fetchMock).not.toHaveBeenCalled() }) }) + +// --------------------------------------------------------------------- Slice B + +const SUBMIT_RESPONSE: SubmitRunResponse = { + selection_id: 'sel_b', + store_id: 7, + product_id: 12, + status: 'running', + selection_window: { start_date: '2026-01-01', end_date: '2026-05-31' }, + forecast_horizon: 14, + ranking_metric: 'wape', + availability: null, + ranking: [], + winner: null, + recommendation_confidence: null, + confidence_reasons: [], + chart_data: null, + final_model: null, + forecast: null, + business_summary: null, + error_message: null, + created_at: '2026-06-01T12:00:00Z', + started_at: '2026-06-01T12:00:00Z', + completed_at: null, + progress: { total: 1, pending: 1, running: 0, completed: 0, failed: 0, cancelled: 0 }, + candidate_progress: [ + { + candidate_id: 'c0', + ordinal: 0, + model_type: 'naive', + status: 'pending', + error: null, + started_at: null, + completed_at: null, + duration_ms: null, + }, + ], + monitor_url: '/model-selection/sel_b', + cancel_url: '/model-selection/sel_b', +} + +const RUN_REQUEST: ModelSelectionRunRequest = { + store_id: 7, + product_id: 12, + selection_window: { start_date: '2026-01-01', end_date: '2026-05-31' }, + forecast_horizon: 14, + ranking_metric: 'wape', + split_config: { + strategy: 'expanding', + n_splits: 5, + min_train_size: 30, + gap: 0, + horizon: 14, + }, + candidate_models: [{ model_type: 'naive', params: {} }], + feature_frame_version: 1, + feature_groups: null, + auto_train_winner: false, + auto_predict: false, +} + +describe('useSubmitSelectionRun', () => { + it('POSTs to /model-selection/runs and seeds the poll cache', async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify(SUBMIT_RESPONSE), { + status: 202, + headers: { 'content-type': 'application/json' }, + }), + ) + vi.stubGlobal('fetch', fetchMock) + const client = makeClient() + const { result } = renderHook(() => useSubmitSelectionRun(), { + wrapper: makeWrapper(client), + }) + await act(async () => { + result.current.mutate(RUN_REQUEST) + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const call = fetchMock.mock.calls[0]! + expect(String(call[0])).toContain('/model-selection/runs') + expect((call[1] as RequestInit).method).toBe('POST') + // The poll cache is seeded so useSelectionRun starts warm. + expect( + client.getQueryData(['model-selection', 'run', 'sel_b']), + ).toEqual(SUBMIT_RESPONSE) + }) +}) + +describe('useSelectionRun', () => { + it('GETs /model-selection/{id} when given a selection id', async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ ...SUBMIT_RESPONSE, status: 'completed' }), { + status: 200, + headers: { 'content-type': 'application/json' }, + }), + ) + vi.stubGlobal('fetch', fetchMock) + const { result } = renderHook(() => useSelectionRun('sel_b'), { + wrapper: makeWrapper(makeClient()), + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + expect(String(fetchMock.mock.calls[0]![0])).toContain('/model-selection/sel_b') + expect(result.current.data?.status).toBe('completed') + }) + + it('does NOT fetch without a selection id (enabled gating)', async () => { + const fetchMock = vi.fn() + vi.stubGlobal('fetch', fetchMock) + renderHook(() => useSelectionRun(null), { wrapper: makeWrapper(makeClient()) }) + await new Promise((resolve) => setTimeout(resolve, 20)) + expect(fetchMock).not.toHaveBeenCalled() + }) +}) + +describe('useCancelSelectionRun', () => { + it('DELETEs /model-selection/{id}', async () => { + const cancelled = { ...SUBMIT_RESPONSE, status: 'cancelled' as const } + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify(cancelled), { + status: 200, + headers: { 'content-type': 'application/json' }, + }), + ) + vi.stubGlobal('fetch', fetchMock) + const { result } = renderHook(() => useCancelSelectionRun(), { + wrapper: makeWrapper(makeClient()), + }) + await act(async () => { + result.current.mutate('sel_b') + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const call = fetchMock.mock.calls[0]! + expect(String(call[0])).toContain('/model-selection/sel_b') + expect((call[1] as RequestInit).method).toBe('DELETE') + }) +}) diff --git a/frontend/src/hooks/use-model-selection.ts b/frontend/src/hooks/use-model-selection.ts index 726f8072..2cf7286f 100644 --- a/frontend/src/hooks/use-model-selection.ts +++ b/frontend/src/hooks/use-model-selection.ts @@ -1,12 +1,19 @@ -import { useQuery } from '@tanstack/react-query' +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { api } from '@/lib/api' -import type { ModelCatalogResponse, PairAvailability } from '@/types/api' +import { isTerminalSelectionStatus } from '@/components/champion-selector/results/constants' +import type { + ModelCatalogResponse, + ModelSelectionRunRequest, + ModelSelectionRunResponse, + PairAvailability, + SubmitRunResponse, +} from '@/types/api' /** - * Model-selection query hooks (Champion Selector, Slice A). + * Model-selection query hooks (Champion Selector). * - * Read-only: the catalog and pair-availability GETs. The run mutation, - * progress, and results hooks are owned by Slice B; train/predict by Slice C. + * Slice A: catalog + availability GETs. Slice B: async submit / poll / cancel. + * Train/predict/promotion are owned by Slice C. */ /** @@ -55,3 +62,59 @@ export function usePairAvailability({ enabled: enabled && !!storeId && storeId > 0 && !!productId && productId > 0, }) } + +/** + * Submit an async selection run (Slice B). `POST /model-selection/runs` returns + * 202 immediately; we seed the poll cache so `useSelectionRun` starts warm. + */ +export function useSubmitSelectionRun() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (request: ModelSelectionRunRequest) => + api('/model-selection/runs', { + method: 'POST', + body: request, + }), + onSuccess: (data) => { + queryClient.setQueryData(['model-selection', 'run', data.selection_id], data) + }, + }) +} + +/** + * Poll one selection run. Refetches every 2s while pending/running, then stops + * once the run reaches a terminal status. Gated on a real selection id. + */ +export function useSelectionRun(selectionId: string | null, enabled = true) { + return useQuery({ + queryKey: ['model-selection', 'run', selectionId], + queryFn: () => + api(`/model-selection/${selectionId}`), + enabled: enabled && !!selectionId, + refetchInterval: (query) => { + const status = query.state.data?.status + return status && isTerminalSelectionStatus(status) ? false : 2000 + }, + }) +} + +/** + * Cancel an in-flight selection run (Slice B). `DELETE /model-selection/{id}` — + * 200 settled / 404 / 409 terminal / 504 drain timeout. Seeds + invalidates the + * poll query on success. + */ +export function useCancelSelectionRun() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (selectionId: string) => + api(`/model-selection/${selectionId}`, { + method: 'DELETE', + }), + onSuccess: (data) => { + queryClient.setQueryData(['model-selection', 'run', data.selection_id], data) + void queryClient.invalidateQueries({ + queryKey: ['model-selection', 'run', data.selection_id], + }) + }, + }) +} diff --git a/frontend/src/pages/visualize/champion.test.tsx b/frontend/src/pages/visualize/champion.test.tsx index 123d4862..2ae297ca 100644 --- a/frontend/src/pages/visualize/champion.test.tsx +++ b/frontend/src/pages/visualize/champion.test.tsx @@ -69,6 +69,10 @@ vi.mock('@/hooks/use-model-selection', () => ({ isLoading: false, isError: false, }), + // Slice B — inert async hooks (no run in progress for the shell test). + useSubmitSelectionRun: () => ({ mutate: vi.fn(), isPending: false }), + useCancelSelectionRun: () => ({ mutate: vi.fn(), isPending: false }), + useSelectionRun: () => ({ data: undefined, isLoading: false, isError: false }), })) import ChampionSelectorPage from './champion' diff --git a/frontend/src/pages/visualize/champion.tsx b/frontend/src/pages/visualize/champion.tsx index d3e3106f..6157148e 100644 --- a/frontend/src/pages/visualize/champion.tsx +++ b/frontend/src/pages/visualize/champion.tsx @@ -1,10 +1,16 @@ import { useMemo, useState } from 'react' import { format } from 'date-fns' import { DateRange } from 'react-day-picker' -import { Trophy } from 'lucide-react' +import { Loader2, Trophy } from 'lucide-react' import { useStores } from '@/hooks/use-stores' import { useProducts } from '@/hooks/use-products' -import { useModelCatalog, usePairAvailability } from '@/hooks/use-model-selection' +import { + useCancelSelectionRun, + useModelCatalog, + usePairAvailability, + useSelectionRun, + useSubmitSelectionRun, +} from '@/hooks/use-model-selection' import { DateRangePicker } from '@/components/common/date-range-picker' import { ErrorDisplay } from '@/components/common/error-display' import { AvailabilityPanel } from '@/components/champion-selector/availability-panel' @@ -12,12 +18,20 @@ import { BacktestSettingsForm } from '@/components/champion-selector/backtest-se import { splitConfigErrors } from '@/components/champion-selector/split-config' import { CandidateModelPicker } from '@/components/champion-selector/candidate-model-picker' import { SearchableEntitySelect } from '@/components/champion-selector/searchable-entity-select' -import { RUN_COMPARISON_PENDING } from '@/components/champion-selector/copy' import { assembleRunRequest } from '@/components/champion-selector/run-request' +import { RunProgressPanel } from '@/components/champion-selector/results/run-progress-panel' +import { RankingTable } from '@/components/champion-selector/results/ranking-table' +import { WinnerCard } from '@/components/champion-selector/results/winner-card' +import { ComparisonCharts } from '@/components/champion-selector/results/comparison-charts' +import { ModelDetailDrawer } from '@/components/champion-selector/results/model-detail-drawer' +import { CancelRunDialog } from '@/components/champion-selector/results/cancel-run-dialog' +import { isTerminalSelectionStatus } from '@/components/champion-selector/results/constants' import { Button } from '@/components/ui/button' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Input } from '@/components/ui/input' +import { getErrorMessage } from '@/lib/api' import type { + ModelRankEntry, ModelSelectionRunRequest, SplitConfig, } from '@/types/api' @@ -54,6 +68,12 @@ export default function ChampionSelectorPage() { // catalog's default candidate set (derived below, no effect needed). const [editedModels, setEditedModels] = useState(null) + // Slice B — the in-flight/terminal async run + the detail-drawer selection. + const [selectionId, setSelectionId] = useState(null) + const [submitError, setSubmitError] = useState(null) + const [drawerEntry, setDrawerEntry] = useState(null) + const [drawerOpen, setDrawerOpen] = useState(false) + // /dimensions/{stores,products} both cap page_size at 100 (client-filtered). const storesQuery = useStores({ page: 1, pageSize: 100 }) const productsQuery = useProducts({ page: 1, pageSize: 100 }) @@ -107,9 +127,8 @@ export default function ChampionSelectorPage() { selectedModels.length >= 1 && splitConfigErrors(effectiveSplit).length === 0 - // The assembled request — typed but NOT sent in Slice A (the CTA is disabled). - // `auto_train_winner`/`auto_predict` are pinned false by `assembleRunRequest`. - // Built defensively so it is valid the moment Slice B wires the mutation. + // The assembled request — `auto_train_winner`/`auto_predict` pinned false by + // `assembleRunRequest` (no-ops in the async path; Slice C owns train/predict). const runRequest: ModelSelectionRunRequest | null = formReady && dateRange?.from && dateRange?.to ? assembleRunRequest({ @@ -124,6 +143,28 @@ export default function ChampionSelectorPage() { }) : null + // Slice B — async submit → poll → cancel. + const submitRun = useSubmitSelectionRun() + const cancelRun = useCancelSelectionRun() + const runQuery = useSelectionRun(selectionId) + const run = runQuery.data + const isRunning = !!run && !isTerminalSelectionStatus(run.status) + const isTerminal = !!run && isTerminalSelectionStatus(run.status) + + function handleRunComparison() { + if (!runRequest) return + setSubmitError(null) + submitRun.mutate(runRequest, { + onSuccess: (data) => setSelectionId(data.selection_id), + onError: (err) => setSubmitError(getErrorMessage(err)), + }) + } + + function handleSelectModel(entry: ModelRankEntry) { + setDrawerEntry(entry) + setDrawerOpen(true) + } + return (
@@ -260,34 +301,75 @@ export default function ChampionSelectorPage() { - {/* Run CTA (disabled until Slice B) */} + {/* Run CTA (Slice B — submit the async comparison) */}
{formReady ? `Ready to compare ${selectedModels.length} model${ selectedModels.length === 1 ? '' : 's' - }. ${RUN_COMPARISON_PENDING}` + }.` : 'Pick a store, product, time period, horizon and at least one model to continue.'} + {submitError && ( + {submitError} + )} +
+
+ {isRunning && ( + selectionId && cancelRun.mutate(selectionId)} + isCancelling={cancelRun.isPending} + /> + )} +
-
- {/* Dev-only assurance that a valid request is assembled (not sent). */} - {runRequest && ( -

- {JSON.stringify(runRequest)} -

+ {/* Live progress + results (Slice B) */} + {run && ( + + )} + + {isTerminal && run && ( + <> + + {run.chart_data && ( + + )} + {run.ranking.length > 0 && ( + + )} + + )}
) diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index d6e0584f..63ebe3f4 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -1205,6 +1205,13 @@ export type ModelSelectionStatus = | 'completed' | 'partial' | 'failed' + | 'cancelled' // Slice B — async cancel terminal state +export type CandidateStatus = + | 'pending' + | 'running' + | 'completed' + | 'failed' + | 'cancelled' export type RankingMetric = 'wape' | 'smape' | 'mae' | 'bias' export type AvailabilityStatus = 'ready' | 'limited' | 'unusable' // `ConfidenceLevel` ('high' | 'medium' | 'low') is reused from the @@ -1325,6 +1332,27 @@ export interface ModelSelectionForecastSummary { horizon: number } +// Slice B — live async progress on a selection run. +export interface CandidateProgress { + candidate_id: string + ordinal: number + model_type: string + status: CandidateStatus + error: string | null + started_at: string | null + completed_at: string | null + duration_ms: number | null +} + +export interface SelectionProgress { + total: number + pending: number + running: number + completed: number + failed: number + cancelled: number +} + export interface ModelSelectionRunResponse { selection_id: string store_id: number @@ -1344,5 +1372,15 @@ export interface ModelSelectionRunResponse { business_summary: Record | null error_message: string | null created_at: string // ISO datetime + // Slice B — additive async fields (null/empty on a legacy sync `/run` row). + started_at?: string | null completed_at: string | null + progress?: SelectionProgress | null + candidate_progress?: CandidateProgress[] +} + +// Slice B — 202 response from `POST /model-selection/runs` (additive superset). +export interface SubmitRunResponse extends ModelSelectionRunResponse { + monitor_url: string + cancel_url: string } From 45b7a7043232c57ce7af37ac9055b6a19377b98a Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Mon, 1 Jun 2026 09:50:16 +0200 Subject: [PATCH 2/3] docs: add forecast champion selector slice A/B/C PRPs (#360) --- ...n-selector-slice-a-selection-capability.md | 716 +++++++++++ ...lector-slice-b-async-comparison-results.md | 1010 +++++++++++++++ ...-c-forecast-decision-operationalization.md | 1107 +++++++++++++++++ 3 files changed, 2833 insertions(+) create mode 100644 PRPs/forecast-champion-selector-slice-a-selection-capability.md create mode 100644 PRPs/forecast-champion-selector-slice-b-async-comparison-results.md create mode 100644 PRPs/forecast-champion-selector-slice-c-forecast-decision-operationalization.md diff --git a/PRPs/forecast-champion-selector-slice-a-selection-capability.md b/PRPs/forecast-champion-selector-slice-a-selection-capability.md new file mode 100644 index 00000000..f43c0371 --- /dev/null +++ b/PRPs/forecast-champion-selector-slice-a-selection-capability.md @@ -0,0 +1,716 @@ +name: "Forecast Champion Selector — Slice A: Selection & Capability Foundation" +description: | + First usable frontend/backend surface for the Forecast Champion Selector. Adds + one backend-owned model-capability catalog endpoint to the existing + `model_selection` slice, then builds the React selection shell — searchable + store/product selectors, pair validation, live data-availability assessment, + a simple/advanced backtest-settings form, and a candidate-model picker — under + a new `/visualize/champion` page. Slice A deliberately STOPS before running the + comparison: it does NOT call `POST /model-selection/run`, render ranking/chart + results, train, predict, or promote. Those are Slice B (async run + results) + and Slice C (train/predict/business summary/override/promotion). + +**Created:** 2026-06-01 · **Slice:** A of 3 (A → B → C) +**Current repo base observed:** `dev` @ `6c3f8d4` (Merge PR #354 — model_selection backend merged) +**Backend foundation (source of truth):** `PRPs/forecast-champion-selector-backend.md` (issue #353, MERGED) + +the live slice `app/features/model_selection/` (schemas/service/routes/ranking/explanations verified 2026-06-01). +**Working-tree caveat:** `docker-compose.lan.yml` is an untracked local dogfood override; do NOT commit it. +**Tracking issue:** create before implementation, suggested title `feat(api,ui): forecast champion selector slice A — selection & capability`. +**Suggested branch:** `feat/champion-selector-slice-a` (off `dev`, per `.claude/rules/branch-naming.md`). +**Commit scope:** `api` (new catalog endpoint + slice schemas/service/routes) and `ui` (frontend page/components/hooks/types). +No migration in Slice A — no schema change. Every commit references the tracking issue. + +--- + +## Goal + +**Feature Goal:** Ship the first interactive Forecast Champion Selector surface — a `/visualize/champion` +React page that lets a user choose a **Store → Product → Time Period → Forecast Horizon → Model Types → +Backtest Settings**, see whether the chosen pair has enough history to model (live availability assessment), +and pick candidate models from a **backend-owned** capability catalog — backed by exactly one new backend +endpoint (`GET /model-selection/models`). The page is genuinely usable for *configuration + availability +triage* even though the comparison **run** itself lands in Slice B. + +**Deliverable:** +- **Backend:** `GET /model-selection/models` → `ModelCatalogResponse` (capability catalog), implemented via a new + pure module `app/features/model_selection/capabilities.py`, response schemas added to the slice's + `schemas.py`, a thin `ModelSelectionService.get_model_catalog()` delegate, and the route wired in the slice's + existing `routes.py`. No migration, no new mutation surface, no agent tool. +- **Frontend:** a lazy-loaded `pages/visualize/champion.tsx` page (route `ROUTES.VISUALIZE.CHAMPION`, + nav entry under **Visualize**), a `components/champion-selector/` component family (searchable store/product + selects, availability panel, backtest-settings form, candidate-model picker), a `hooks/use-model-selection.ts` + query-hook module (catalog + availability reads), and a `types/api.ts` "Model Selection" section that declares + the FULL workflow contract (so Slices B/C inherit, not redefine, the types). + +**Success Definition:** +1. `GET /model-selection/models` returns HTTP 200 with a non-empty `models` array — each entry carrying + `model_type`, `label`, `family ∈ {baseline,tree,additive}`, `feature_aware`, `requires_extra`, + `default_params`, `supports_auto_predict`, `description` — plus a `default_candidate_model_types` list. +2. The `/visualize/champion` page renders: a searchable store select, a searchable product select (each with a + secondary line — store `code · name`, product `sku · category`), a date-range picker, a horizon input, a + candidate-model picker fed by `GET /model-selection/models`, and a simple/advanced backtest-settings form. +3. Selecting a valid `(store, product, horizon)` triggers `GET /model-selection/availability` and renders a + `ready | limited | unusable` status block with coverage/observed-days/zero-sale/promotion/avg-demand and the + recommended split config; an unusable/empty pair shows a clear not-enough-data state. +4. The "Run comparison" primary CTA is present but **disabled** with explanatory copy (Slice B turns it on). +5. All Slice A validation gates pass (backend Level-1..4 + frontend `tsc`/`lint`/`test`). + +## Why + +- Business users want to ask "which model should I use for this store/product?" through a UI, not curl. Slice A + gives them the **configuration + triage** half of that workflow immediately, and a stable shell Slice B/C bolt + onto with minimal churn. +- The capability catalog must be **backend-owned** (coordination contract): the model union, families, opt-in + extras, and feature-aware flags live in Python (`app/features/forecasting/`), and shipping them over an API + prevents the TypeScript `MODEL_FAMILY_MAP`/`MODEL_TYPE_LABELS` from drifting out of sync as new models land. +- Declaring the full TS contract now (consumed read-only in A) means Slices B and C add behavior, not type + definitions — cleaner slice boundaries, fewer merge conflicts. +- Preserves the single-host architecture: one new read-only GET, no queue, no new dependency, no cloud SDK. + +## What + +### New backend endpoint (added to the existing slice router `APIRouter(prefix="/model-selection")`) + +```http +GET /model-selection/models +``` + +Response `ModelCatalogResponse`: + +```json +{ + "models": [ + { + "model_type": "naive", + "label": "Naive", + "family": "baseline", + "feature_aware": false, + "requires_extra": false, + "default_params": {}, + "supports_auto_predict": true, + "description": "Repeats the last observed value." + }, + { + "model_type": "seasonal_naive", + "label": "Seasonal Naive", + "family": "baseline", + "feature_aware": false, + "requires_extra": false, + "default_params": { "season_length": 7 }, + "supports_auto_predict": true, + "description": "Repeats the value from one season ago." + } + // ... one entry per forecasting ModelConfig member (11 total) + ], + "default_candidate_model_types": ["naive", "seasonal_naive", "moving_average", "regression", "prophet_like"] +} +``` + +### LOCKED Slice-A decisions (remove every "choose-one" ambiguity) + +1. **Exactly one new backend endpoint:** `GET /model-selection/models`. It is **declared in `routes.py` + BEFORE the `GET /{selection_id}` route** (literal path must precede the path-param route, mirroring the + existing `/availability` route at `routes.py:41` which sits before `/{selection_id}` at `:94`). Status 200. + No request body, no query params. +2. **Catalog is backend-owned and derived, not hand-duplicated.** `family` comes from the forecasting + authority `app.features.forecasting.feature_metadata.model_family_for(model_type)` (imported LAZILY inside + the builder, per the slice's cross-slice discipline) mapped to the lowercase literal + (`ModelFamily.BASELINE → "baseline"`, etc.). `model_type` iteration order + `default_params` + `label` + + `description` come from a slice-local ordered map in `capabilities.py` whose keys are asserted (in a test) to + exactly equal the `ModelType` Literal in `app/features/model_selection/schemas.py`. +3. **`requires_extra`** = `model_type in {"lightgbm", "xgboost"}` (opt-in extras that may `ImportError`). + **`feature_aware`** = `model_type in {"regression", "prophet_like", "lightgbm", "xgboost", "random_forest"}` + (the set the forecasting `predict()` rejects — see Known Gotchas to verify against `forecasting/service.py`). + **`supports_auto_predict`** = `not feature_aware` (feature-aware winners cannot auto-predict — backend + `predict()` rejects them; this flag lets Slice C grey-out the auto-predict toggle). +4. **`default_candidate_model_types`** = `["naive", "seasonal_naive", "moving_average", "regression", "prophet_like"]` + — the exact default five from the backend PRP's `POST /run` example, so the UI pre-selects the same set the + contract documents. +5. **No `model_selection_run` write in Slice A.** The page consumes `GET /models` and `GET /availability` only. + It assembles a typed `ModelSelectionRunRequest` in component state and exposes it through a **disabled** + "Run comparison" CTA; Slice B wires the `POST /run` mutation + results. Slice A MUST NOT call `POST /run`, + `/{id}`, `/{id}/ranking`, `/{id}/train-winner`, or `/{id}/predict`. +6. **Searchable selects use existing primitives only** (no new npm dependency). Stores/products are fetched at + `pageSize: 100` (the dimensions cap) and filtered **client-side** inside a `Popover` + text `Input` + + scrollable button list. (If the catalog ever exceeds 100, swap to the server-side `search` param the + `useStores`/`useProducts` hooks already support — out of scope here.) +7. **Bias-explanation copy (locked, reused by B/C):** wherever bias is explained in help text/tooltips, use + exactly — *"Positive bias means the model under-forecasts (risk of stockouts); negative bias means it + over-forecasts (risk of overstock)."* Export it as a shared constant so B/C reuse the same wording. +8. **WAPE is the default ranking metric**; the advanced form's ranking-metric select offers `wape` (default), + `smape`, `mae`, `bias`, with help text stating the tie-break chain *WAPE → sMAPE → |bias| → MAE* and the + bias copy from #7. + +### Success Criteria + +- [ ] `GET /model-selection/models` returns 200 with `models` (11 entries) + `default_candidate_model_types`. +- [ ] `capabilities.build_model_catalog()` is pure (no DB/IO) and its `model_type` set equals the slice + `ModelType` Literal (asserted by a test). +- [ ] `/model-selection/models` is matched correctly (NOT captured by `/{selection_id}`) — route-order test green. +- [ ] `/visualize/champion` route + Visualize nav entry render the page; lazy-loaded like its siblings. +- [ ] Searchable store + product selects filter client-side and show the secondary descriptor line. +- [ ] Pair validation: the form's primary CTA stays disabled until a store, product, valid date window, and + horizon are all chosen; the date window + horizon respect backend bounds. +- [ ] Availability auto-fetches for a valid pair and renders `ready/limited/unusable` + metrics + recommended + split config; an empty/unusable pair renders a not-enough-data `EmptyState`. +- [ ] The candidate-model picker is fed by `GET /model-selection/models`; opt-in-extra models are visibly + flagged; the default five are pre-selected. +- [ ] The simple/advanced settings form mirrors `SplitConfig` bounds and keeps `split_config.horizon === + forecast_horizon` (matching the backend request validator). +- [ ] The "Run comparison" CTA is present but disabled with copy indicating it arrives next. +- [ ] No `POST /model-selection/run` (or any mutation) is called; no chart/ranking results UI; no train/predict/ + promotion UI; no agent tool; no migration; no new npm dependency. +- [ ] `app/core/tests/test_strict_mode_policy.py` stays green (no new strict request model with date fields). +- [ ] All backend Level-1..4 gates + frontend `pnpm tsc --noEmit && pnpm lint && pnpm test --run` pass. + +## All Needed Context + +### Documentation & References + +```yaml +# Slice / contract source of truth +- file: PRPs/forecast-champion-selector-backend.md + why: The merged backend foundation. LOCKED decisions #1-#7, the full /run + /{id} contract, the + availability semantics (ready/limited/unusable thresholds), and the default-five candidate list. + Slice A consumes this contract read-only; do not re-derive ranking/confidence in TS. +- file: PRPs/ai_docs/forecast-champion-selector-backend-research.md + why: External-lib + runtime facts (FastAPI APIRouter, Pydantic strict mode, sklearn TimeSeriesSplit). +- file: PRPs/templates/prp_base.md + why: Base PRP template structure. NOTE — the referenced "PRPs/prp-readme.md.md" does NOT exist + (`find PRPs -iname '*readme*'` empty on 2026-06-01); the backend PRP records the same finding. + +# Live backend slice to read (the contract the UI consumes) +- file: app/features/model_selection/schemas.py + why: ModelType Literal (:34, the 11 model_types), RankingMetric (:48), AvailabilityStatus (:51), + ConfidenceLevel (:50), PairAvailabilityResponse (:239), ModelSelectionRunRequest (:118), + ModelSelectionRunResponse (:267), ModelRankEntry (:195), WinnerSummary (:216), ChartData (:225). + ADD the new ModelCatalogResponse + CandidateModelInfo here (plain BaseModel — outputs need no strict). +- file: app/features/model_selection/routes.py + why: APIRouter(prefix="/model-selection") (:38); the literal `/availability` (:41) precedes `/{selection_id}` + (:94) — MIRROR that ordering for the new `/models` route. Error mapping: ValueError→BadRequestError, + SQLAlchemyError→DatabaseError. +- file: app/features/model_selection/service.py + why: Stateless service pattern; lazy cross-slice imports inside methods (:215-219). ADD + get_model_catalog() delegating to capabilities.build_model_catalog() (no DB needed; keep signature + db-free or accept db and ignore — prefer db-free since the catalog is static). +- file: app/features/model_selection/ranking.py + why: PURE-module precedent (no DB/IO, unit-tested directly). MIRROR this style for capabilities.py. +- file: app/features/model_selection/explanations.py + why: Second pure-module precedent (deterministic text). Same import/style conventions. +- file: app/features/model_selection/tests/test_routes.py + why: Route-test pattern (ASGITransport + AsyncClient + dependency_overrides[get_db]); ADD a /models 200 + test + a route-ordering test (GET /model-selection/models is NOT treated as selection_id="models"). +- file: app/features/model_selection/tests/test_ranking.py + why: Pure-unit test pattern to MIRROR for tests/test_capabilities.py. + +# Backend authority for model family / union (catalog source) +- file: app/features/forecasting/feature_metadata.py + why: model_family_for(model_type) -> ModelFamily (:57) and _MODEL_FAMILY_MAP (:42). The catalog `family` + field derives from here. ModelFamily enum is BASELINE/TREE/ADDITIVE (lowercase .value). +- file: app/features/forecasting/schemas.py + why: ModelConfig union (the 11 flat members + their default params). Use to VERIFY default_params per model + (see Known Gotchas verification one-liner). ModelFamily enum lives here too (imported by feature_metadata). +- file: app/features/backtesting/schemas.py + why: SplitConfig (:24) — strategy Literal["expanding","sliding"] (def "expanding"), n_splits 2-20 (def 5), + min_train_size >=7 (def 30), gap 0-30 (def 0), horizon 1-90 (def 14), field_validator horizon>gap (:65). + The TS SplitConfig type + advanced form bounds mirror this exactly. + +# Frontend examples to MIRROR (verified 2026-06-01) +- file: frontend/src/pages/visualize/backtest.tsx + why: Canonical analytical page: Card sections, store/product Select fed by useStores/useProducts + ({page:1,pageSize:100}), DateRangePicker, numeric Inputs, a `formReady` gate, EmptyState/LoadingState, + getErrorMessage. Slice A's champion page mirrors this density (minus the results/charts). +- file: frontend/src/components/forecast-intelligence/model-type-select.tsx + why: shadcn Select-based model picker convention + data-testid pattern. The Slice-A candidate picker mirrors + the labelling style but sources options from GET /model-selection/models (NOT the hardcoded util). +- file: frontend/src/components/forecast-intelligence/model-type-utils.ts + why: The EXISTING hardcoded MODEL_FAMILY_MAP / MODEL_TYPE_LABELS used by OTHER pages. DO NOT refactor or + delete it in Slice A — other pages depend on it; the champion page just doesn't use it. +- file: frontend/src/components/forecast-intelligence/batch-matrix-picker.tsx + why: Multi-select-of-models pattern (checkbox list, max-rows cap, data-testid scheme, Badge for state). + The candidate-model picker mirrors this (checkbox per model, opt-in-extra Badge), but rows = model_types + from the catalog, no feature-frame matrix (that's B/C). +- file: frontend/src/components/forecast-intelligence/batch-matrix-picker.test.tsx + why: Component test convention — render + fireEvent + expect(onChange).toHaveBeenCalledWith; afterEach(cleanup). +- file: frontend/src/hooks/use-stores.ts + why: useStores({page,pageSize,...,search,enabled}) query-hook shape + keyed query + keepPreviousData. +- file: frontend/src/hooks/use-products.ts + why: useProducts(...) — identical shape; the searchable selects fetch at pageSize:100. +- file: frontend/src/hooks/use-batches.test.ts + why: Hook test convention — vi.fn() fetch mock via vi.stubGlobal('fetch',...), QueryClient wrapper, + renderHook + waitFor, afterEach(vi.unstubAllGlobals()). MIRROR for use-model-selection.test.ts. +- file: frontend/src/hooks/index.ts + why: Star-export barrel; ADD `export * from './use-model-selection'`. +- file: frontend/src/lib/api.ts + why: `api(endpoint,{params})` typed fetch helper; getErrorMessage(); ApiError. All hooks call `api`. +- file: frontend/src/lib/constants.ts + why: ROUTES (VISUALIZE.* block) + NAV_ITEMS (Visualize group). ADD ROUTES.VISUALIZE.CHAMPION + + a { label:'Champion Selector', href: ROUTES.VISUALIZE.CHAMPION } nav entry under Visualize. +- file: frontend/src/App.tsx + why: Lazy-page + }> pattern. ADD the + champion route mirroring the BATCH/PLANNER entries. +- file: frontend/src/types/api.ts + why: Section-commented type file. ModelFamily (:177 = 'baseline'|'tree'|'additive'), ProblemDetail (:652), + Store/StoreListResponse (:10/:21), Product/ProductListResponse (:25/:37). ADD a new + "// === Model Selection (Champion Selector) ===" section near the Registry block. +- file: frontend/src/components/common/error-display.tsx + why: EmptyState({title,description,action?,icon?}) — used for the not-enough-data state. +- file: frontend/src/components/common/loading-state.tsx + why: LoadingState({message}) — used while availability/catalog load. +- file: frontend/src/components/common/date-range-picker.tsx + why: DateRangePicker({value:DateRange|undefined,onChange}) — the time-period selector. +- file: frontend/src/components/ui/{select,popover,input,card,button,badge,checkbox,table}.tsx + why: Available shadcn primitives. NOTE: there is NO command/combobox/cmdk primitive — build the searchable + select from Popover + Input + a filtered button list (LOCKED #6). +- file: frontend/src/components/layout/top-nav.tsx + why: Renders NAV_ITEMS (grouped via NavigationMenu). No edit needed beyond the constants.ts NAV_ITEMS entry. +- file: frontend/vitest.config.ts + why: jsdom env; include 'src/**/*.test.{ts,tsx}'; `@`→./src alias. No setup file. `pnpm test --run` runs once. + +# External official docs (with reasoning) +- url: https://fastapi.tiangolo.com/tutorial/bigger-applications/#include-an-apirouter-with-a-custom-prefix-tags-responses-and-dependencies + why: APIRouter route-registration + the literal-before-path-param ordering rule that LOCKED #1 depends on. +- url: https://www.ibm.com/design/language/ # (progressive disclosure principle) + why: Simple/advanced settings split — show the recommended split config by default, reveal n_splits/min_train/ + gap/strategy under an "Advanced" toggle so novice users aren't overwhelmed. NOTE: the originally-cited + IBM technical-content URL 404s; use the IBM Design language site / Nielsen Norman + (https://www.nngroup.com/articles/progressive-disclosure/) as the canonical reference instead. +- url: https://help.tableau.com/current/pro/desktop/en-us/dashboards_best_practices.htm + why: Analytical dashboard layout — lead with the question (which model?), group related controls, keep the + availability triage adjacent to the selection. Informs the Card grouping of the champion page. +- url: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html + why: The split semantics behind SplitConfig (expanding window, n_splits, gap, horizon) — so the advanced + form's help text describes folds correctly. +- url: https://tanstack.com/query/latest/docs/framework/react/guides/queries + why: useQuery enabled-gating (only fetch availability once a valid pair exists) + queryKey conventions. +``` + +### Current Codebase Tree (relevant) + +```bash +app/features/model_selection/ # MERGED backend slice (issue #353) +├── __init__.py +├── models.py # ModelSelectionRun ORM (NOT touched in Slice A) +├── schemas.py # request/response contract ← ADD catalog response models +├── ranking.py # pure ranking (precedent for capabilities.py) +├── explanations.py # pure explanations (precedent) +├── service.py # ModelSelectionService ← ADD get_model_catalog() +├── routes.py # APIRouter(/model-selection) ← ADD GET /models (before /{selection_id}) +└── tests/ # ← ADD test_capabilities.py; extend test_routes.py +app/features/forecasting/feature_metadata.py # model_family_for() — catalog family authority +frontend/src/ +├── App.tsx # ← ADD lazy champion route +├── lib/{api,constants}.ts # ← constants: ROUTES.VISUALIZE.CHAMPION + NAV_ITEMS entry +├── types/api.ts # ← ADD "Model Selection" section +├── hooks/{use-stores,use-products,index}.ts # ← index: export use-model-selection +├── pages/visualize/{backtest,batch,...}.tsx # page-density precedent +└── components/ + ├── common/{error-display,loading-state,date-range-picker}.tsx + ├── ui/{select,popover,input,card,button,badge,checkbox,table}.tsx + └── forecast-intelligence/{model-type-select,batch-matrix-picker}.tsx # picker precedents +``` + +### Desired Codebase Tree (Slice A additions) + +```bash +# Backend +app/features/model_selection/capabilities.py # NEW: pure build_model_catalog() +app/features/model_selection/schemas.py # MODIFIED: + CandidateModelInfo, ModelCatalogResponse +app/features/model_selection/service.py # MODIFIED: + get_model_catalog() +app/features/model_selection/routes.py # MODIFIED: + GET /models (before /{selection_id}) +app/features/model_selection/tests/test_capabilities.py # NEW: pure catalog unit tests +app/features/model_selection/tests/test_routes.py # MODIFIED: + /models route + ordering tests + +# Frontend +frontend/src/lib/constants.ts # MODIFIED: ROUTES.VISUALIZE.CHAMPION + NAV_ITEMS entry +frontend/src/App.tsx # MODIFIED: lazy ChampionSelectorPage route +frontend/src/types/api.ts # MODIFIED: Model Selection section (full contract) +frontend/src/hooks/use-model-selection.ts # NEW: useModelCatalog + usePairAvailability +frontend/src/hooks/use-model-selection.test.ts # NEW +frontend/src/hooks/index.ts # MODIFIED: + export +frontend/src/pages/visualize/champion.tsx # NEW: the page shell +frontend/src/components/champion-selector/searchable-entity-select.tsx # NEW (generic combobox) +frontend/src/components/champion-selector/searchable-entity-select.test.tsx # NEW +frontend/src/components/champion-selector/availability-panel.tsx # NEW +frontend/src/components/champion-selector/availability-panel.test.tsx # NEW +frontend/src/components/champion-selector/backtest-settings-form.tsx # NEW +frontend/src/components/champion-selector/backtest-settings-form.test.tsx # NEW +frontend/src/components/champion-selector/candidate-model-picker.tsx # NEW +frontend/src/components/champion-selector/candidate-model-picker.test.tsx # NEW +frontend/src/components/champion-selector/copy.ts # NEW: BIAS_EXPLANATION const (LOCKED #7) +``` + +### Known Gotchas & VERIFIED Contracts + +```python +# ── ROUTE ORDERING (LOCKED #1) ──────────────────────────────────────────────── +# Starlette matches routes in DECLARATION ORDER. The literal `GET /models` MUST be declared BEFORE +# `GET /{selection_id}` or a request to /model-selection/models is captured as selection_id="models" +# and 404s in the service. The existing `/availability` route (routes.py:41) already sits before +# `/{selection_id}` (:94) — place `/models` immediately after `/availability`. + +# ── CATALOG default_params — VERIFY before hardcoding ───────────────────────── +# default_params per model must match the forecasting ModelConfig member defaults. Verify with: +# uv run python -c " +# from pydantic import TypeAdapter +# from app.features.forecasting.schemas import ModelConfig +# a=TypeAdapter(ModelConfig) +# for mt in ['naive','seasonal_naive','moving_average','weighted_moving_average','seasonal_average', +# 'trend_regression_baseline','regression','prophet_like','random_forest','lightgbm','xgboost']: +# try: +# m=a.validate_python({'model_type':mt}); d=m.model_dump(); d.pop('model_type',None) +# print(mt, d) +# except Exception as e: +# print(mt, 'NEEDS-PARAMS:', e)" +# Use the printed defaults as `default_params` in capabilities.py. If a member REQUIRES a param (validation +# error with only model_type), supply the contract default (seasonal_naive→{'season_length':7}, +# moving_average→{'window_size':7}) — match the backend PRP /run example. Pin these in test_capabilities.py. + +# ── feature_aware / requires_extra — VERIFY against forecasting predict() reject ── +# LOCKED #3 sets feature_aware = {regression, prophet_like, lightgbm, xgboost, random_forest}. Confirm this +# equals the set ForecastingService.predict() rejects (the backend PRP cites forecasting/service.py:491 +# "rejects feature-aware models"). If the live reject-set differs, the live code wins — update the +# capabilities set and the test to match, and note the discrepancy in the PR description. + +# ── family literal mapping ──────────────────────────────────────────────────── +# model_family_for(mt) returns a ModelFamily enum; serialize via `.value` → "baseline"|"tree"|"additive" +# which already matches the frontend ModelFamily TS union (types/api.ts:177). Import model_family_for +# LAZILY inside build_model_catalog() (mirror service.py lazy cross-slice imports). + +# ── NO new strict request model ─────────────────────────────────────────────── +# GET /models has no body and no query params → no ConfigDict(strict=True) model, no date fields → the +# strict-mode policy linter is unaffected. Do NOT add an AvailabilityQuery-style model for /models. + +# ── catalog is static/pure ───────────────────────────────────────────────────── +# build_model_catalog() takes no args and does no I/O — it is unit-testable like ranking.py. get_model_catalog() +# on the service is a thin pass-through (no db round-trip needed); keep it sync-pure or trivially async. +``` + +```typescript +// ── FRONTEND ──────────────────────────────────────────────────────────────── +// NO combobox/cmdk primitive exists (only select/popover/input/dialog under components/ui). Build the +// searchable select from + (filter box) + a scrollable list of