Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,15 @@ BATCH_GLOBAL_MAX_PARALLEL=4
# mid-call, so a long fit can stall the drain.
BATCH_CANCEL_DRAIN_TIMEOUT_SECONDS=30

# Model selection (champion selector) async runner (Slice B)
# Hard upper bound on concurrent candidate backtests across all active selection
# runs on this host. Effective parallelism per run is min(this, candidates).
# Set to 1 for sequential execution. Requires uvicorn restart to apply.
MODEL_SELECTION_GLOBAL_MAX_PARALLEL=4
# Max seconds DELETE /model-selection/{id} waits for in-flight candidates to
# drain before returning RFC 7807 504. sklearn / LightGBM fits are uncancellable
# mid-call, so a long fit can stall the drain.
MODEL_SELECTION_CANCEL_DRAIN_TIMEOUT_SECONDS=30

# Frontend (Vite)
VITE_API_BASE_URL=http://localhost:8123
716 changes: 716 additions & 0 deletions PRPs/forecast-champion-selector-slice-a-selection-capability.md

Large diffs are not rendered by default.

1,010 changes: 1,010 additions & 0 deletions PRPs/forecast-champion-selector-slice-b-async-comparison-results.md

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""add model_selection_candidate and async progress columns

Revision ID: d3e4f5a6b7c8
Revises: b667d321603c
Create Date: 2026-06-01 09:30:00.000000

Slice B of the Forecast Champion Selector (issue #360). Converts the selection
run into a DB-backed async LRO:

- creates ``model_selection_candidate`` (one row per candidate, FK CASCADE to
``model_selection_run.selection_id``) carrying per-candidate status, result
JSONB, error, and timing — the live-progress + audit surface;
- adds ``started_at`` + the four final count columns to ``model_selection_run``;
- widens the run status CheckConstraint to include ``'cancelled'`` (forward-only
drop + recreate of the named constraint).

Mirrors ``c1d2e3f40512_create_batch_tables`` for JSONB / index / FK style.
"""

from collections.abc import Sequence

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "d3e4f5a6b7c8"
down_revision: str | None = "b667d321603c"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

_OLD_RUN_STATUS = "status IN ('pending', 'running', 'completed', 'partial', 'failed')"
_NEW_RUN_STATUS = (
"status IN ('pending', 'running', 'completed', 'partial', 'failed', 'cancelled')"
)


def upgrade() -> None:
"""Apply migration."""
# ------------------------------------------------------------------
# 1. Widen the run status CheckConstraint to include 'cancelled'.
# ------------------------------------------------------------------
op.drop_constraint(
"ck_model_selection_run_valid_status",
"model_selection_run",
type_="check",
)
op.create_check_constraint(
"ck_model_selection_run_valid_status",
"model_selection_run",
_NEW_RUN_STATUS,
)

# ------------------------------------------------------------------
# 2. Additive progress columns on the parent run.
# ------------------------------------------------------------------
op.add_column(
"model_selection_run",
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"model_selection_run",
sa.Column("total_candidates", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"model_selection_run",
sa.Column(
"completed_candidates", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"model_selection_run",
sa.Column("failed_candidates", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"model_selection_run",
sa.Column(
"cancelled_candidates", sa.Integer(), nullable=False, server_default="0"
),
)

# ------------------------------------------------------------------
# 3. Per-candidate execution child table (FK CASCADE on selection_id).
# ------------------------------------------------------------------
op.create_table(
"model_selection_candidate",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("candidate_id", sa.String(length=32), nullable=False),
sa.Column("selection_id", sa.String(length=32), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False),
sa.Column("model_type", sa.String(length=40), nullable=False),
sa.Column("params", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False),
sa.Column("result", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("error_message", sa.String(length=2000), nullable=True),
sa.Column("error_type", sa.String(length=100), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("duration_ms", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.CheckConstraint(
"status IN ('pending', 'running', 'completed', 'failed', 'cancelled')",
name="ck_model_selection_candidate_valid_status",
),
sa.ForeignKeyConstraint(
["selection_id"],
["model_selection_run.selection_id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_model_selection_candidate_candidate_id"),
"model_selection_candidate",
["candidate_id"],
unique=True,
)
op.create_index(
op.f("ix_model_selection_candidate_selection_id"),
"model_selection_candidate",
["selection_id"],
unique=False,
)
op.create_index(
op.f("ix_model_selection_candidate_status"),
"model_selection_candidate",
["status"],
unique=False,
)
op.create_index(
"ix_model_selection_candidate_selection_status",
"model_selection_candidate",
["selection_id", "status"],
unique=False,
)


def downgrade() -> None:
"""Revert migration."""
op.drop_index(
"ix_model_selection_candidate_selection_status",
table_name="model_selection_candidate",
)
op.drop_index(
op.f("ix_model_selection_candidate_status"),
table_name="model_selection_candidate",
)
op.drop_index(
op.f("ix_model_selection_candidate_selection_id"),
table_name="model_selection_candidate",
)
op.drop_index(
op.f("ix_model_selection_candidate_candidate_id"),
table_name="model_selection_candidate",
)
op.drop_table("model_selection_candidate")

op.drop_column("model_selection_run", "cancelled_candidates")
op.drop_column("model_selection_run", "failed_candidates")
op.drop_column("model_selection_run", "completed_candidates")
op.drop_column("model_selection_run", "total_candidates")
op.drop_column("model_selection_run", "started_at")

op.drop_constraint(
"ck_model_selection_run_valid_status",
"model_selection_run",
type_="check",
)
op.create_check_constraint(
"ck_model_selection_run_valid_status",
"model_selection_run",
_OLD_RUN_STATUS,
)
11 changes: 11 additions & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ class Settings(BaseSettings):
# are uncancellable mid-call, so a long fit can stall the drain.
batch_cancel_drain_timeout_seconds: int = 30

# Model selection (champion selector) async runner (Slice B) — mirrors the
# batch runner. Hard upper bound on concurrent candidate backtests across
# all active selection runs on this host; sized for the same Postgres pool
# (pool_size=5, max_overflow=10). Setting this to 1 makes the runner
# sequential. Env override: MODEL_SELECTION_GLOBAL_MAX_PARALLEL=8 (restart).
model_selection_global_max_parallel: int = 4
# Max seconds DELETE /model-selection/{id} waits for in-flight candidates to
# settle before returning RFC 7807 504. In-flight sklearn/LightGBM fits are
# uncancellable mid-call, so a long fit can stall the drain.
model_selection_cancel_drain_timeout_seconds: int = 30

# RAG Embedding Configuration
rag_embedding_provider: Literal["openai", "ollama"] = "openai"
openai_api_key: str = ""
Expand Down
9 changes: 9 additions & 0 deletions app/core/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def test_settings_has_defaults(monkeypatch):
assert settings.api_port == 8123


def test_model_selection_runner_defaults(monkeypatch):
"""Slice B async-runner settings default to the batch-mirrored values."""
monkeypatch.delenv("MODEL_SELECTION_GLOBAL_MAX_PARALLEL", raising=False)
monkeypatch.delenv("MODEL_SELECTION_CANCEL_DRAIN_TIMEOUT_SECONDS", raising=False)
settings = Settings(_env_file=None)
assert settings.model_selection_global_max_parallel == 4
assert settings.model_selection_cancel_drain_timeout_seconds == 30


def test_settings_is_development_property():
"""is_development should return True for development env."""
settings = Settings(app_env="development")
Expand Down
87 changes: 83 additions & 4 deletions app/features/model_selection/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any

from sqlalchemy import CheckConstraint, Date, DateTime, Index, Integer, String
from sqlalchemy import CheckConstraint, Date, DateTime, ForeignKey, Index, Integer, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -27,17 +27,42 @@ class ModelSelectionStatus(str, Enum):
"""Lifecycle states of a selection run.

Transitions:
- PENDING -> RUNNING -> {COMPLETED, PARTIAL, FAILED}
- PARTIAL fires when >=1 candidate succeeded AND >=1 candidate failed.
- PENDING -> RUNNING -> {COMPLETED, PARTIAL, FAILED, CANCELLED}
- PARTIAL fires when >=1 candidate succeeded AND >=1 candidate failed/cancelled.
- FAILED fires when availability is unusable (fail-fast) OR every
candidate's backtest errored (no valid winner).
- CANCELLED (Slice B) fires when a cancel drained before any candidate
reached a non-cancelled terminal state.
"""

PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
PARTIAL = "partial"
FAILED = "failed"
CANCELLED = "cancelled"


# Statuses a selection run cannot transition out of — the DELETE-route 409 set
# (Slice B). Mirrors ``batch.models.TERMINAL_BATCH_STATES``.
TERMINAL_SELECTION_STATES: frozenset[str] = frozenset(
{
ModelSelectionStatus.COMPLETED.value,
ModelSelectionStatus.PARTIAL.value,
ModelSelectionStatus.FAILED.value,
ModelSelectionStatus.CANCELLED.value,
}
)


class CandidateStatus(str, Enum):
"""Per-candidate execution states inside an async selection run (Slice B)."""

PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"


class ModelSelectionRun(TimestampMixin, Base):
Expand Down Expand Up @@ -74,13 +99,21 @@ class ModelSelectionRun(TimestampMixin, Base):
forecast_result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True)
business_summary: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True)
error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True)
# Slice B (async) — set when the run starts executing; the four count
# columns cache the FINAL per-status candidate tally written once at settle
# (live progress is derived from a GROUP BY over the child rows).
started_at: Mapped[_dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
total_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
completed_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
failed_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
cancelled_candidates: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
completed_at: Mapped[_dt.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

__table_args__ = (
CheckConstraint(
"status IN ('pending', 'running', 'completed', 'partial', 'failed')",
"status IN ('pending', 'running', 'completed', 'partial', 'failed', 'cancelled')",
name="ck_model_selection_run_valid_status",
),
Index(
Expand All @@ -91,3 +124,49 @@ class ModelSelectionRun(TimestampMixin, Base):
),
Index("ix_model_selection_run_status_created", "status", "created_at"),
)


class ModelSelectionCandidate(TimestampMixin, Base):
"""One candidate's async execution record inside a selection run (Slice B).

Concurrent candidate tasks each write their OWN row in their OWN session —
no shared-row write race. ``result`` carries the full ``CandidateResult``
JSONB (incl. folds) on success; failed/cancelled candidates keep their row
so they stay visible in the results UI. Mirrors ``batch.BatchJobItem``.
"""

__tablename__ = "model_selection_candidate"

id: Mapped[int] = mapped_column(Integer, primary_key=True)
candidate_id: Mapped[str] = mapped_column(String(32), unique=True, index=True)
selection_id: Mapped[str] = mapped_column(
String(32),
ForeignKey("model_selection_run.selection_id", ondelete="CASCADE"),
index=True,
)
ordinal: Mapped[int] = mapped_column(Integer) # submit order — stable display
model_type: Mapped[str] = mapped_column(String(40))
params: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
status: Mapped[str] = mapped_column(
String(20), default=CandidateStatus.PENDING.value, index=True
)
result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True)
error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True)
error_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
started_at: Mapped[_dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[_dt.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)

__table_args__ = (
CheckConstraint(
"status IN ('pending', 'running', 'completed', 'failed', 'cancelled')",
name="ck_model_selection_candidate_valid_status",
),
Index(
"ix_model_selection_candidate_selection_status",
"selection_id",
"status",
),
)
Loading