diff --git a/INITIAL-7.md b/INITIAL-7.md index 1944df5d..fb55c919 100644 --- a/INITIAL-7.md +++ b/INITIAL-7.md @@ -12,6 +12,17 @@ - Artifact storage abstraction: - local filesystem by default (Settings-driven) - compatible with future S3-like storage backends +- Lifecycle Management: + - State machine tracking: PENDING | RUNNING | SUCCESS | FAILED | ARCHIVED. + - Deployment Aliases: Mutable pointers (e.g., 'prod-v1') to specific successful runs. +- Metadata & Lineage: + - JSONB storage for ModelConfig, FeatureConfig, and Performance Metrics. + - Runtime Snapshot: Recording Python/Library versions for environment parity. + - Agent Context: Integration of agent_id and session_id for autonomous run traceability. +- Artifact Integrity: + - Checksum-based verification (SHA-256) for all serialized artifacts. +- Storage Strategy: + - Pluggable storage providers (LocalFS, future S3/GCS) via Abstract Registry Interface. ## EXAMPLES: - `examples/registry/create_run.py` — create run record + persist configs. @@ -21,6 +32,8 @@ ## DOCUMENTATION: - Postgres JSONB patterns - Artifact integrity (hashing) best practices +- https://scalegrid.io/blog/using-jsonb-in-postgresql-how-to-effectively-store-index-json-data-in-postgresql/ +- https://www.fortra.com/blog/supply-chain-vulnerability ## OTHER CONSIDERATIONS: - No hardcoded artifact paths: derived from `ARTIFACT_ROOT` + run_id. diff --git a/PRPs/PRP-7-model-registry.md b/PRPs/PRP-7-model-registry.md new file mode 100644 index 00000000..d3ae2ab8 --- /dev/null +++ b/PRPs/PRP-7-model-registry.md @@ -0,0 +1,1253 @@ +# PRP-7: Model Registry + Artifacts + Reproducibility + +## Goal + +Implement a Model Registry feature that provides comprehensive run tracking, artifact management, and reproducibility guarantees for the ForecastOps platform. The registry captures full experiment lineage including configurations, metrics, data windows, and artifact integrity verification. + +**End State:** A production-ready `registry` vertical slice with: +- `ModelRun` database table with JSONB columns for flexible metadata storage +- `DeploymentAlias` table for mutable pointers (e.g., 'prod-v1') to successful runs +- Lifecycle state machine: PENDING | RUNNING | SUCCESS | FAILED | ARCHIVED +- SHA-256 checksum verification for artifact integrity +- Runtime environment snapshots (Python/library versions) +- Agent context tracking (agent_id, session_id) for autonomous run traceability +- Abstract storage provider interface (LocalFS default, future S3/GCS) +- RESTful API: create, list, get, update runs; manage aliases; compare runs +- All validation gates passing (ruff, mypy, pyright, pytest) + +--- + +## Why + +- **Reproducibility**: Every training run must be exactly reproducible via stored configs, data windows, and environment snapshots +- **Auditability**: Full lineage from data → features → model → predictions with agent context for autonomous workflows +- **Artifact Integrity**: SHA-256 checksums prevent corrupted or tampered model artifacts from being deployed +- **Deployment Safety**: Aliases provide stable references (e.g., 'production') that can be updated atomically +- **Leaderboard/Comparison**: Metrics storage enables model comparison and performance tracking over time +- **ForecastOps Integration**: Registry integrates with existing forecasting/backtesting modules for end-to-end workflows + +--- + +## What + +### User-Visible Behavior + +1. **Create Run**: Start a new model run with PENDING state, capture configs +2. **Update Run**: Transition states (RUNNING → SUCCESS/FAILED), attach metrics and artifact metadata +3. **List Runs**: Query runs with filtering by model_type, status, date range +4. **Get Run**: Retrieve full run details including configs, metrics, lineage +5. **Compare Runs**: Side-by-side comparison of two runs (configs + metrics diff) +6. **Manage Aliases**: Create/update deployment aliases pointing to successful runs +7. **Artifact Verification**: Validate artifact integrity via stored checksum + +### Success Criteria + +- [ ] ModelRun table created with JSONB columns for model_config, feature_config, metrics +- [ ] DeploymentAlias table created with unique constraint on (alias_name) +- [ ] Run lifecycle state machine enforced (valid transitions only) +- [ ] SHA-256 checksum computed and verified for all artifacts +- [ ] Python/library version snapshots stored per run +- [ ] Agent context (agent_id, session_id) stored for traceability +- [ ] AbstractStorageProvider interface with LocalFSProvider implementation +- [ ] 60+ unit tests covering models, schemas, service, storage, routes +- [ ] 10+ integration tests for database operations +- [ ] Example files demonstrating registry workflows + +--- + +## All Needed Context + +### Documentation & References + +```yaml +# MUST READ - Include these in your context window + +# SQLAlchemy JSONB with PostgreSQL +- url: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html + why: "Official JSONB type usage, Mapped[] annotations" + critical: "Use JSONB from sqlalchemy.dialects.postgresql, not JSON" + +# JSONB Indexing Best Practices +- url: https://www.crunchydata.com/blog/indexing-jsonb-in-postgres + why: "GIN index patterns for JSONB columns" + critical: "Use @> containment operator for indexed queries" + +# JSONB Storage Patterns +- url: https://scalegrid.io/blog/using-jsonb-in-postgresql-how-to-effectively-store-index-json-data-in-postgresql/ + why: "Referenced in INITIAL-7.md for JSONB patterns" + critical: "JSONB stores binary format, faster queries than JSON" + +# MLflow Model Registry Design +- url: https://mlflow.org/docs/latest/ml/model-registry/ + why: "Industry-standard registry design patterns" + critical: "Separate metadata store from artifact store" + +# Internal Codebase References +- file: app/features/forecasting/persistence.py + why: "Existing ModelBundle with hash computation, version recording" + pattern: "compute_hash(), save_model_bundle(), load_model_bundle()" + +- file: app/features/forecasting/schemas.py + why: "Pattern for ModelConfig with config_hash(), frozen=True" + +- file: app/features/backtesting/schemas.py + why: "Pattern for complex nested configs, schema_version field" + +- file: app/features/backtesting/service.py + why: "Pattern for service orchestration with async DB operations" + +- file: app/features/data_platform/models.py + why: "Pattern for SQLAlchemy 2.0 Mapped[] models with TimestampMixin" + +- file: app/core/config.py + why: "Pattern for Settings with environment variables" + +- file: alembic/versions/e1165ebcef61_create_data_platform_tables.py + why: "Pattern for Alembic migrations" +``` + +### Current Codebase Tree (Relevant Parts) + +```text +app/ +├── core/ +│ ├── config.py # Settings singleton +│ ├── database.py # Base, AsyncSession, get_db +│ ├── exceptions.py # ForecastLabError hierarchy +│ └── logging.py # Structured logging +├── shared/ +│ └── models.py # TimestampMixin +├── features/ +│ ├── data_platform/ +│ │ └── models.py # SalesDaily, Store, Product, Calendar +│ ├── forecasting/ +│ │ ├── models.py # BaseForecaster, model_factory +│ │ ├── persistence.py # ModelBundle, save/load (HAS HASH!) +│ │ ├── schemas.py # ModelConfig, config_hash() +│ │ └── service.py # ForecastingService +│ └── backtesting/ +│ ├── schemas.py # BacktestConfig, SplitConfig +│ └── service.py # BacktestingService +└── main.py # FastAPI app with router registration +``` + +### Desired Codebase Tree + +```text +app/features/registry/ # NEW: Registry vertical slice +├── __init__.py # Module exports +├── models.py # ModelRun, DeploymentAlias ORM models +├── schemas.py # RunConfig, RunCreate, RunResponse, AliasResponse, etc. +├── storage.py # AbstractStorageProvider, LocalFSProvider +├── service.py # RegistryService (orchestration) +├── routes.py # CRUD routes + alias management + compare +└── tests/ + ├── __init__.py + ├── conftest.py # Fixtures: sample runs, configs + ├── test_models.py # ORM model tests + ├── test_schemas.py # Schema validation, immutability + ├── test_storage.py # Storage provider tests + ├── test_service.py # Service orchestration tests + ├── test_service_integration.py # Integration tests with DB + └── test_routes_integration.py # Route integration tests + +examples/registry/ # NEW: Example scripts +├── create_run.py # Create run record + persist configs +├── list_runs.py # Leaderboard preview +└── compare_runs.py # Compare two runs (metrics + configs) + +app/core/config.py # MODIFY: Add registry settings +app/main.py # MODIFY: Register registry router +alembic/versions/xxx_create_registry_tables.py # NEW: Migration +``` + +### Known Gotchas + +```python +# CRITICAL: SQLAlchemy JSONB requires PostgreSQL dialect import +from sqlalchemy.dialects.postgresql import JSONB +# NOT: from sqlalchemy import JSON (different type!) + +# CRITICAL: JSONB columns should use Mapped[dict[str, Any]] for typing +# SQLAlchemy 2.0 uses Mapped[] annotations +model_config: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + +# CRITICAL: For async queries with JSONB containment (@>), use: +from sqlalchemy.dialects.postgresql import JSONB +stmt = select(ModelRun).where(ModelRun.model_config.contains({"model_type": "naive"})) + +# CRITICAL: GIN index on JSONB for efficient containment queries +# Add in migration: op.create_index('ix_model_run_model_config_gin', 'model_run', ['model_config'], postgresql_using='gin') + +# CRITICAL: State transitions must be validated +# PENDING -> RUNNING -> SUCCESS|FAILED +# PENDING|RUNNING|SUCCESS|FAILED -> ARCHIVED +# No other transitions allowed + +# CRITICAL: Checksum verification before loading artifacts +# 1. Load stored checksum from DB +# 2. Compute checksum of artifact file +# 3. Compare - raise if mismatch + +# CRITICAL: artifact_uri is relative to REGISTRY_ARTIFACT_ROOT setting +# Never store absolute paths in DB - allows migration between environments + +# CRITICAL: Duplicate run detection uses config_hash + data_window_hash +# Policy is Settings-driven: allow/deny/detect + +# CRITICAL: Alias can only point to SUCCESS runs +# Attempting to alias a FAILED/ARCHIVED run should raise ValueError + +# CRITICAL: When comparing runs, use model_dump() for Pydantic serialization +# This handles nested objects and dates correctly + +# CRITICAL: We use Pydantic v2 - ConfigDict not Config class +model_config = ConfigDict(frozen=True, extra="forbid") +``` + +--- + +## Implementation Blueprint + +### Data Models (ORM) + +```python +# app/features/registry/models.py + +from __future__ import annotations + +import datetime +from decimal import Decimal +from enum import Enum +from typing import Any + +from sqlalchemy import ( + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + String, + UniqueConstraint, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.database import Base +from app.shared.models import TimestampMixin + + +class RunStatus(str, Enum): + """Valid states for a model run.""" + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + ARCHIVED = "archived" + + +class ModelRun(TimestampMixin, Base): + """Model run registry entry. + + CRITICAL: Captures full experiment lineage for reproducibility. + + Attributes: + id: Primary key. + run_id: Unique external identifier (UUID hex). + status: Current lifecycle state. + model_type: Type of model (naive, seasonal_naive, etc.). + model_config: Full model configuration as JSONB. + feature_config: Feature engineering config as JSONB (nullable). + data_window_start: Training data start date. + data_window_end: Training data end date. + store_id: Store ID for this run. + product_id: Product ID for this run. + metrics: Performance metrics as JSONB. + artifact_uri: Relative path to artifact (from ARTIFACT_ROOT). + artifact_hash: SHA-256 checksum of artifact. + artifact_size_bytes: Size of artifact file. + runtime_info: Python/library versions as JSONB. + agent_context: Agent ID and session ID for traceability. + git_sha: Optional git commit hash. + config_hash: Hash of model_config for deduplication. + error_message: Error details if status=FAILED. + started_at: When run started. + completed_at: When run completed (success or failed). + """ + + __tablename__ = "model_run" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + run_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + status: Mapped[str] = mapped_column(String(20), default=RunStatus.PENDING.value, index=True) + + # Model configuration + model_type: Mapped[str] = mapped_column(String(50), index=True) + model_config: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + feature_config: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + config_hash: Mapped[str] = mapped_column(String(16), index=True) + + # Data window + data_window_start: Mapped[datetime.date] = mapped_column() + data_window_end: Mapped[datetime.date] = mapped_column() + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + + # Metrics + metrics: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + + # Artifact info + artifact_uri: Mapped[str | None] = mapped_column(String(500), nullable=True) + artifact_hash: Mapped[str | None] = mapped_column(String(64), nullable=True) # SHA-256 + artifact_size_bytes: Mapped[int | None] = mapped_column(Integer, nullable=True) + + # Environment & lineage + runtime_info: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + agent_context: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + git_sha: Mapped[str | None] = mapped_column(String(40), nullable=True) + + # Error tracking + error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + + # Timing + started_at: Mapped[datetime.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[datetime.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + + # Relationship to aliases + aliases: Mapped[list[DeploymentAlias]] = relationship(back_populates="run") + + __table_args__ = ( + # GIN index for JSONB containment queries + Index("ix_model_run_model_config_gin", "model_config", postgresql_using="gin"), + Index("ix_model_run_metrics_gin", "metrics", postgresql_using="gin"), + # Composite index for common query pattern + Index("ix_model_run_store_product", "store_id", "product_id"), + Index("ix_model_run_data_window", "data_window_start", "data_window_end"), + # Constraint: valid status values + CheckConstraint( + "status IN ('pending', 'running', 'success', 'failed', 'archived')", + name="ck_model_run_valid_status", + ), + # Constraint: data window validity + CheckConstraint( + "data_window_end >= data_window_start", + name="ck_model_run_valid_data_window", + ), + ) + + +class DeploymentAlias(TimestampMixin, Base): + """Mutable pointer to a specific successful run. + + CRITICAL: Aliases provide stable references for deployment. + + Attributes: + id: Primary key. + alias_name: Unique alias name (e.g., 'production', 'staging-v2'). + run_id: Foreign key to the aliased run. + description: Optional description of this alias. + """ + + __tablename__ = "deployment_alias" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + alias_name: Mapped[str] = mapped_column(String(100), unique=True, index=True) + run_id: Mapped[int] = mapped_column(Integer, ForeignKey("model_run.id"), index=True) + description: Mapped[str | None] = mapped_column(String(500), nullable=True) + + # Relationship + run: Mapped[ModelRun] = relationship(back_populates="aliases") + + __table_args__ = ( + UniqueConstraint("alias_name", name="uq_deployment_alias_name"), + ) +``` + +### Pydantic Schemas + +```python +# app/features/registry/schemas.py + +from __future__ import annotations + +import hashlib +from datetime import date as date_type, datetime +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class RunStatus(str, Enum): + """Run lifecycle states.""" + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + ARCHIVED = "archived" + + +# Valid state transitions +VALID_TRANSITIONS: dict[RunStatus, set[RunStatus]] = { + RunStatus.PENDING: {RunStatus.RUNNING, RunStatus.ARCHIVED}, + RunStatus.RUNNING: {RunStatus.SUCCESS, RunStatus.FAILED, RunStatus.ARCHIVED}, + RunStatus.SUCCESS: {RunStatus.ARCHIVED}, + RunStatus.FAILED: {RunStatus.ARCHIVED}, + RunStatus.ARCHIVED: set(), # Terminal state +} + + +class RuntimeInfo(BaseModel): + """Runtime environment snapshot.""" + model_config = ConfigDict(frozen=True, extra="forbid") + + python_version: str + sklearn_version: str | None = None + numpy_version: str | None = None + pandas_version: str | None = None + joblib_version: str | None = None + + +class AgentContext(BaseModel): + """Agent context for autonomous run traceability.""" + model_config = ConfigDict(frozen=True, extra="forbid") + + agent_id: str | None = None + session_id: str | None = None + + +class RunCreate(BaseModel): + """Request to create a new run.""" + model_config = ConfigDict(extra="forbid") + + model_type: str = Field(..., min_length=1, max_length=50) + model_config_data: dict[str, Any] = Field(..., alias="model_config") + feature_config: dict[str, Any] | None = None + data_window_start: date_type + data_window_end: date_type + store_id: int = Field(..., ge=1) + product_id: int = Field(..., ge=1) + agent_context: AgentContext | None = None + git_sha: str | None = Field(None, max_length=40) + + @field_validator("data_window_end") + @classmethod + def validate_data_window(cls, v: date_type, info: object) -> date_type: + """Ensure data_window_end >= data_window_start.""" + data = getattr(info, "data", {}) + if "data_window_start" in data and v < data["data_window_start"]: + raise ValueError("data_window_end must be >= data_window_start") + return v + + +class RunUpdate(BaseModel): + """Request to update a run.""" + model_config = ConfigDict(extra="forbid") + + status: RunStatus | None = None + metrics: dict[str, Any] | None = None + artifact_uri: str | None = None + artifact_hash: str | None = None + artifact_size_bytes: int | None = Field(None, ge=0) + error_message: str | None = Field(None, max_length=2000) + + +class RunResponse(BaseModel): + """Run details response.""" + model_config = ConfigDict(from_attributes=True) + + run_id: str + status: RunStatus + model_type: str + model_config_data: dict[str, Any] = Field(..., alias="model_config") + feature_config: dict[str, Any] | None = None + config_hash: str + data_window_start: date_type + data_window_end: date_type + store_id: int + product_id: int + metrics: dict[str, Any] | None = None + artifact_uri: str | None = None + artifact_hash: str | None = None + artifact_size_bytes: int | None = None + runtime_info: dict[str, Any] | None = None + agent_context: dict[str, Any] | None = None + git_sha: str | None = None + error_message: str | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + created_at: datetime + updated_at: datetime + + +class RunListResponse(BaseModel): + """Paginated list of runs.""" + runs: list[RunResponse] + total: int + page: int + page_size: int + + +class AliasCreate(BaseModel): + """Request to create/update an alias.""" + model_config = ConfigDict(extra="forbid") + + alias_name: str = Field(..., min_length=1, max_length=100, pattern=r"^[a-z0-9][a-z0-9-_]*$") + run_id: str + description: str | None = Field(None, max_length=500) + + +class AliasResponse(BaseModel): + """Alias details response.""" + model_config = ConfigDict(from_attributes=True) + + alias_name: str + run_id: str + run_status: RunStatus + model_type: str + description: str | None = None + created_at: datetime + updated_at: datetime + + +class RunCompareResponse(BaseModel): + """Comparison of two runs.""" + run_a: RunResponse + run_b: RunResponse + config_diff: dict[str, Any] # Keys that differ + metrics_diff: dict[str, dict[str, float | None]] # {metric: {a: val, b: val, diff: val}} +``` + +### Storage Provider (Abstract) + +```python +# app/features/registry/storage.py + +from __future__ import annotations + +import hashlib +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import BinaryIO + +import structlog + +from app.core.config import get_settings + +logger = structlog.get_logger() + + +class StorageError(Exception): + """Base exception for storage operations.""" + pass + + +class ArtifactNotFoundError(StorageError): + """Artifact not found at specified URI.""" + pass + + +class ChecksumMismatchError(StorageError): + """Artifact checksum does not match stored value.""" + pass + + +class AbstractStorageProvider(ABC): + """Abstract base class for artifact storage. + + CRITICAL: All storage providers must implement these methods. + This allows future S3/GCS implementations. + """ + + @abstractmethod + def save(self, source_path: Path, artifact_uri: str) -> tuple[str, int]: + """Save an artifact to storage. + + Args: + source_path: Local path to artifact file. + artifact_uri: Relative URI for storage. + + Returns: + Tuple of (sha256_hash, size_bytes). + + Raises: + StorageError: If save fails. + """ + pass + + @abstractmethod + def load(self, artifact_uri: str, expected_hash: str | None = None) -> Path: + """Load an artifact from storage. + + Args: + artifact_uri: Relative URI of artifact. + expected_hash: If provided, verify checksum. + + Returns: + Path to artifact (may be temp file for remote storage). + + Raises: + ArtifactNotFoundError: If artifact doesn't exist. + ChecksumMismatchError: If hash verification fails. + """ + pass + + @abstractmethod + def delete(self, artifact_uri: str) -> bool: + """Delete an artifact from storage. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + True if deleted, False if not found. + """ + pass + + @abstractmethod + def exists(self, artifact_uri: str) -> bool: + """Check if an artifact exists. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + True if exists, False otherwise. + """ + pass + + @staticmethod + def compute_hash(file_path: Path) -> str: + """Compute SHA-256 hash of a file. + + Args: + file_path: Path to file. + + Returns: + Hexadecimal SHA-256 hash. + """ + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + + +class LocalFSProvider(AbstractStorageProvider): + """Local filesystem storage provider. + + CRITICAL: Default provider for development and single-node deployments. + """ + + def __init__(self, root_dir: Path | None = None) -> None: + """Initialize with root directory. + + Args: + root_dir: Root directory for artifacts. Defaults to Settings value. + """ + if root_dir is None: + settings = get_settings() + root_dir = Path(settings.registry_artifact_root) + self.root_dir = root_dir.resolve() + self.root_dir.mkdir(parents=True, exist_ok=True) + + def _resolve_path(self, artifact_uri: str) -> Path: + """Resolve artifact URI to full path. + + CRITICAL: Validates path is within root to prevent traversal. + """ + full_path = (self.root_dir / artifact_uri).resolve() + # Security: ensure path is within root + try: + full_path.relative_to(self.root_dir) + except ValueError: + raise StorageError(f"Path traversal attempt: {artifact_uri}") from None + return full_path + + def save(self, source_path: Path, artifact_uri: str) -> tuple[str, int]: + """Save artifact to local filesystem.""" + dest_path = self._resolve_path(artifact_uri) + dest_path.parent.mkdir(parents=True, exist_ok=True) + + # Compute hash before copy + file_hash = self.compute_hash(source_path) + file_size = source_path.stat().st_size + + # Copy file + shutil.copy2(source_path, dest_path) + + logger.info( + "registry.artifact_saved", + artifact_uri=artifact_uri, + hash=file_hash, + size_bytes=file_size, + ) + + return file_hash, file_size + + def load(self, artifact_uri: str, expected_hash: str | None = None) -> Path: + """Load artifact from local filesystem.""" + full_path = self._resolve_path(artifact_uri) + + if not full_path.exists(): + raise ArtifactNotFoundError(f"Artifact not found: {artifact_uri}") + + # Verify hash if provided + if expected_hash is not None: + actual_hash = self.compute_hash(full_path) + if actual_hash != expected_hash: + logger.warning( + "registry.checksum_mismatch", + artifact_uri=artifact_uri, + expected=expected_hash, + actual=actual_hash, + ) + raise ChecksumMismatchError( + f"Checksum mismatch for {artifact_uri}: " + f"expected {expected_hash}, got {actual_hash}" + ) + + return full_path + + def delete(self, artifact_uri: str) -> bool: + """Delete artifact from local filesystem.""" + full_path = self._resolve_path(artifact_uri) + + if not full_path.exists(): + return False + + full_path.unlink() + logger.info("registry.artifact_deleted", artifact_uri=artifact_uri) + return True + + def exists(self, artifact_uri: str) -> bool: + """Check if artifact exists on local filesystem.""" + full_path = self._resolve_path(artifact_uri) + return full_path.exists() +``` + +--- + +## Task List + +### Task 1: Add registry settings to config + +```yaml +FILE: app/core/config.py +ACTION: MODIFY +FIND: "backtest_results_dir: str = './artifacts/backtests'" +INJECT AFTER: + - "" + - "# Registry" + - "registry_artifact_root: str = './artifacts/registry'" + - "registry_duplicate_policy: Literal['allow', 'deny', 'detect'] = 'detect'" +VALIDATION: + - uv run mypy app/core/config.py + - uv run pyright app/core/config.py +``` + +### Task 2: Create registry module structure + +```yaml +ACTION: CREATE directories and __init__.py +FILES: + - app/features/registry/__init__.py + - app/features/registry/tests/__init__.py +PATTERN: Mirror backtesting module exports +``` + +### Task 3: Implement models.py (ORM) + +```yaml +FILE: app/features/registry/models.py +ACTION: CREATE +IMPLEMENT: + - RunStatus enum (PENDING, RUNNING, SUCCESS, FAILED, ARCHIVED) + - ModelRun model with JSONB columns + - DeploymentAlias model + - GIN indexes for JSONB columns + - Constraints for valid status, data window +PATTERN: Mirror app/features/data_platform/models.py +CRITICAL: + - Use JSONB from sqlalchemy.dialects.postgresql + - Use Mapped[dict[str, Any]] for JSONB typing + - Add GIN indexes in __table_args__ +VALIDATION: + - uv run mypy app/features/registry/models.py + - uv run pyright app/features/registry/models.py +``` + +### Task 4: Create Alembic migration + +```yaml +FILE: alembic/versions/xxx_create_registry_tables.py +ACTION: CREATE (via alembic revision) +COMMAND: uv run alembic revision --autogenerate -m "create_registry_tables" +IMPLEMENT: + - Create model_run table with JSONB columns + - Create deployment_alias table + - Add GIN indexes for model_config and metrics + - Add composite indexes + - Add check constraints +VALIDATION: + - uv run alembic upgrade head + - uv run alembic downgrade -1 + - uv run alembic upgrade head +``` + +### Task 5: Implement schemas.py + +```yaml +FILE: app/features/registry/schemas.py +ACTION: CREATE +IMPLEMENT: + - RunStatus enum (must match ORM enum) + - VALID_TRANSITIONS dict for state machine + - RuntimeInfo schema + - AgentContext schema + - RunCreate, RunUpdate, RunResponse schemas + - RunListResponse for pagination + - AliasCreate, AliasResponse schemas + - RunCompareResponse schema +PATTERN: Mirror app/features/backtesting/schemas.py +CRITICAL: + - Use ConfigDict(frozen=True) for immutable configs + - Use alias="model_config" for field naming conflict + - Validate data_window_end >= data_window_start +VALIDATION: + - uv run mypy app/features/registry/schemas.py + - uv run pyright app/features/registry/schemas.py +``` + +### Task 6: Implement storage.py + +```yaml +FILE: app/features/registry/storage.py +ACTION: CREATE +IMPLEMENT: + - StorageError, ArtifactNotFoundError, ChecksumMismatchError exceptions + - AbstractStorageProvider ABC + - LocalFSProvider implementation + - compute_hash static method (SHA-256) + - Path traversal prevention +CRITICAL: + - Always validate paths are within root_dir + - Compute hash BEFORE copy for save() + - Verify hash in load() if expected_hash provided +VALIDATION: + - uv run mypy app/features/registry/storage.py + - uv run pyright app/features/registry/storage.py +``` + +### Task 7: Implement service.py + +```yaml +FILE: app/features/registry/service.py +ACTION: CREATE +IMPLEMENT: + - RegistryService class + - create_run() - Create new run with PENDING status + - get_run() - Get run by run_id + - list_runs() - List with filtering and pagination + - update_run() - Update status, metrics, artifact info + - _validate_transition() - Validate state transitions + - _compute_config_hash() - Hash for deduplication + - _capture_runtime_info() - Python/library versions + - create_alias() - Create/update deployment alias + - get_alias() - Get alias by name + - list_aliases() - List all aliases + - delete_alias() - Remove alias + - compare_runs() - Compare two runs +PATTERN: Mirror app/features/backtesting/service.py +CRITICAL: + - State transitions must follow VALID_TRANSITIONS + - config_hash computed from model_config JSON + - Alias can only point to SUCCESS runs + - Duplicate detection uses config_hash + data_window +VALIDATION: + - uv run mypy app/features/registry/service.py + - uv run pyright app/features/registry/service.py +``` + +### Task 8: Implement routes.py + +```yaml +FILE: app/features/registry/routes.py +ACTION: CREATE +IMPLEMENT: + - APIRouter(prefix="/registry", tags=["registry"]) + - POST /runs - Create new run + - GET /runs - List runs with filters (model_type, status, store_id, product_id) + - GET /runs/{run_id} - Get run details + - PATCH /runs/{run_id} - Update run + - GET /runs/{run_id}/verify - Verify artifact integrity + - POST /aliases - Create/update alias + - GET /aliases - List all aliases + - GET /aliases/{alias_name} - Get alias details + - DELETE /aliases/{alias_name} - Delete alias + - GET /compare/{run_id_a}/{run_id_b} - Compare two runs +PATTERN: Mirror app/features/forecasting/routes.py +CRITICAL: + - Use Depends(get_db) for database session + - Structured logging: registry.run_created, registry.run_updated, etc. + - Return 404 for not found, 400 for invalid transitions + - Return 409 for duplicate if policy='deny' +VALIDATION: + - uv run mypy app/features/registry/routes.py + - uv run pyright app/features/registry/routes.py +``` + +### Task 9: Register router in main.py + +```yaml +FILE: app/main.py +ACTION: MODIFY +FIND: "from app.features.backtesting.routes import router as backtesting_router" +INJECT AFTER: + - "from app.features.registry.routes import router as registry_router" +FIND: "app.include_router(backtesting_router)" +INJECT AFTER: + - "app.include_router(registry_router)" +VALIDATION: + - uv run python -c "from app.main import app; print('OK')" +``` + +### Task 10: Create test fixtures (conftest.py) + +```yaml +FILE: app/features/registry/tests/conftest.py +ACTION: CREATE +IMPLEMENT: + - sample_model_config: NaiveModelConfig as dict + - sample_run_create: RunCreate with valid data + - sample_runtime_info: RuntimeInfo with current versions + - sample_agent_context: AgentContext with test IDs + - db_session fixture for integration tests + - client fixture for route tests + - temp_artifact: Temporary artifact file for storage tests +PATTERN: Mirror app/features/backtesting/tests/conftest.py +``` + +### Task 11: Create test_models.py + +```yaml +FILE: app/features/registry/tests/test_models.py +ACTION: CREATE +IMPLEMENT: + - Test ModelRun creation with JSONB columns + - Test DeploymentAlias creation and FK relationship + - Test run_id uniqueness constraint + - Test alias_name uniqueness constraint + - Test data_window constraint validation + - Test status enum values +VALIDATION: + - uv run pytest app/features/registry/tests/test_models.py -v +``` + +### Task 12: Create test_schemas.py + +```yaml +FILE: app/features/registry/tests/test_schemas.py +ACTION: CREATE +IMPLEMENT: + - Test RunStatus enum values + - Test VALID_TRANSITIONS correctness + - Test RunCreate validation (date range, model_type) + - Test RunUpdate partial updates + - Test RunResponse from_attributes + - Test AliasCreate pattern validation + - Test config_hash determinism +VALIDATION: + - uv run pytest app/features/registry/tests/test_schemas.py -v +``` + +### Task 13: Create test_storage.py + +```yaml +FILE: app/features/registry/tests/test_storage.py +ACTION: CREATE +IMPLEMENT: + - Test LocalFSProvider.save() creates file and returns hash + - Test LocalFSProvider.load() returns correct path + - Test LocalFSProvider.load() with hash verification + - Test ChecksumMismatchError on bad hash + - Test ArtifactNotFoundError on missing file + - Test path traversal prevention + - Test delete() removes file + - Test exists() returns correct boolean +VALIDATION: + - uv run pytest app/features/registry/tests/test_storage.py -v +``` + +### Task 14: Create test_service.py + +```yaml +FILE: app/features/registry/tests/test_service.py +ACTION: CREATE +IMPLEMENT: + - Test create_run() with valid data + - Test create_run() computes config_hash + - Test create_run() captures runtime_info + - Test update_run() state transitions + - Test update_run() rejects invalid transitions + - Test list_runs() filtering + - Test list_runs() pagination + - Test create_alias() with SUCCESS run + - Test create_alias() rejects non-SUCCESS run + - Test compare_runs() returns correct diff + - Test duplicate detection (when policy='detect') +VALIDATION: + - uv run pytest app/features/registry/tests/test_service.py -v +``` + +### Task 15: Create test_service_integration.py + +```yaml +FILE: app/features/registry/tests/test_service_integration.py +ACTION: CREATE +IMPLEMENT: + - Test full run lifecycle: PENDING -> RUNNING -> SUCCESS + - Test alias creation and update + - Test run listing with database + - Test JSONB containment queries + - Test GIN index usage (via EXPLAIN) +PATTERN: Mirror app/features/backtesting/tests/test_service_integration.py +VALIDATION: + - uv run pytest app/features/registry/tests/test_service_integration.py -v -m integration +``` + +### Task 16: Create test_routes_integration.py + +```yaml +FILE: app/features/registry/tests/test_routes_integration.py +ACTION: CREATE +IMPLEMENT: + - Test POST /registry/runs creates run + - Test GET /registry/runs returns list + - Test GET /registry/runs/{run_id} returns details + - Test PATCH /registry/runs/{run_id} updates status + - Test POST /registry/aliases creates alias + - Test GET /registry/aliases returns list + - Test GET /registry/compare/{a}/{b} returns diff + - Test 404 for non-existent run + - Test 400 for invalid state transition +VALIDATION: + - uv run pytest app/features/registry/tests/test_routes_integration.py -v -m integration +``` + +### Task 17: Create example files + +```yaml +FILES: + - examples/registry/create_run.py + - examples/registry/list_runs.py + - examples/registry/compare_runs.py +ACTION: CREATE +IMPLEMENT: + - create_run.py: Create run, transition to SUCCESS, attach metrics + - list_runs.py: List runs with filtering, show leaderboard + - compare_runs.py: Compare two runs, show config/metrics diff +``` + +### Task 18: Update module __init__.py exports + +```yaml +FILE: app/features/registry/__init__.py +ACTION: MODIFY +IMPLEMENT: + - Export all public classes + - __all__ list (sorted alphabetically) +VALIDATION: + - uv run python -c "from app.features.registry import *; print('OK')" +``` + +--- + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +# Run after EACH file creation +uv run ruff check app/features/registry/ --fix +uv run ruff format app/features/registry/ + +# Expected: All checks passed! +``` + +### Level 2: Type Checking + +```bash +# Run after completing models, schemas, storage, service +uv run mypy app/features/registry/ +uv run pyright app/features/registry/ + +# Expected: Success: no issues found +``` + +### Level 3: Database Migration + +```bash +# After creating models.py, generate and run migration +uv run alembic revision --autogenerate -m "create_registry_tables" +uv run alembic upgrade head + +# Verify tables exist +docker exec -it postgres psql -U forecastlab -d forecastlab -c "\d model_run" +docker exec -it postgres psql -U forecastlab -d forecastlab -c "\d deployment_alias" +``` + +### Level 4: Unit Tests + +```bash +# Run incrementally as tests are created +uv run pytest app/features/registry/tests/test_schemas.py -v +uv run pytest app/features/registry/tests/test_storage.py -v +uv run pytest app/features/registry/tests/test_service.py -v + +# Run all unit tests +uv run pytest app/features/registry/tests/ -v -m "not integration" + +# Expected: 60+ tests passed +``` + +### Level 5: Integration Tests + +```bash +# Start database +docker-compose up -d + +# Run integration tests +uv run pytest app/features/registry/tests/test_service_integration.py -v -m integration +uv run pytest app/features/registry/tests/test_routes_integration.py -v -m integration + +# Expected: 10+ integration tests passed +``` + +### Level 6: API Integration Test + +```bash +# Start API +uv run uvicorn app.main:app --reload --port 8123 + +# Create a run +curl -X POST http://localhost:8123/registry/runs \ + -H "Content-Type: application/json" \ + -d '{ + "model_type": "naive", + "model_config": {"model_type": "naive", "schema_version": "1.0"}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-06-30", + "store_id": 1, + "product_id": 1 + }' + +# List runs +curl http://localhost:8123/registry/runs + +# Update run status +curl -X PATCH http://localhost:8123/registry/runs/{run_id} \ + -H "Content-Type: application/json" \ + -d '{"status": "running"}' + +# Complete run with metrics +curl -X PATCH http://localhost:8123/registry/runs/{run_id} \ + -H "Content-Type: application/json" \ + -d '{ + "status": "success", + "metrics": {"mae": 1.5, "smape": 12.3} + }' + +# Create alias +curl -X POST http://localhost:8123/registry/aliases \ + -H "Content-Type: application/json" \ + -d '{ + "alias_name": "production", + "run_id": "{run_id}", + "description": "Current production model" + }' +``` + +### Level 7: Full Validation + +```bash +# Complete validation suite +uv run ruff check app/features/registry/ && \ +uv run mypy app/features/registry/ && \ +uv run pyright app/features/registry/ && \ +uv run pytest app/features/registry/tests/ -v + +# Expected: All green +``` + +--- + +## Final Checklist + +- [ ] All 18 tasks completed +- [ ] `uv run ruff check .` — no errors +- [ ] `uv run mypy app/features/registry/` — no errors +- [ ] `uv run pyright app/features/registry/` — no errors +- [ ] `uv run pytest app/features/registry/tests/ -v` — 60+ tests passed +- [ ] Alembic migration runs successfully +- [ ] GIN indexes created for JSONB columns +- [ ] Example scripts run successfully +- [ ] Router registered in main.py +- [ ] Settings added to config.py +- [ ] Logging events follow standard format +- [ ] State machine transitions validated +- [ ] Checksum verification works +- [ ] Alias only points to SUCCESS runs +- [ ] Duplicate detection works per policy + +--- + +## Anti-Patterns to Avoid + +- **DON'T** use JSON instead of JSONB — JSONB is faster for queries +- **DON'T** store absolute paths in artifact_uri — use relative paths +- **DON'T** skip state transition validation — corrupts run lifecycle +- **DON'T** allow aliases to non-SUCCESS runs — undefined behavior in production +- **DON'T** skip checksum verification on load — security risk +- **DON'T** use plain index on JSONB — use GIN for containment queries +- **DON'T** forget to compute config_hash — needed for deduplication +- **DON'T** hardcode storage paths — use Settings +- **DON'T** catch generic Exception — be specific about error types +- **DON'T** use sync operations in async context — will block event loop + +--- + +## Confidence Score: 8/10 + +**Strengths:** +- Clear patterns from forecasting and backtesting modules to follow +- Existing ModelBundle in persistence.py has hash computation pattern +- Well-documented SQLAlchemy JSONB support +- Comprehensive task breakdown with validation gates +- MLflow provides industry-standard registry design reference +- Strong test patterns from backtesting module + +**Risks:** +- JSONB GIN indexing may require tuning for large datasets +- State machine transitions add complexity +- Alias update atomicity needs careful handling +- Integration with existing forecasting module needs coordination +- Duplicate detection edge cases (same config, different data windows) + +**Mitigation:** +- Start with simple GIN index, optimize later if needed +- Use explicit transition validation function +- Use database transactions for alias updates +- Add integration tests covering forecasting → registry flow +- Define clear duplicate policy (config_hash + data_window_hash) + +--- + +## Sources + +- [SQLAlchemy PostgreSQL JSONB](https://docs.sqlalchemy.org/en/20/dialects/postgresql.html) +- [JSONB Indexing in Postgres](https://www.crunchydata.com/blog/indexing-jsonb-in-postgres) +- [JSONB Storage Patterns](https://scalegrid.io/blog/using-jsonb-in-postgresql-how-to-effectively-store-index-json-data-in-postgresql/) +- [MLflow Model Registry](https://mlflow.org/docs/latest/ml/model-registry/) +- [PostgreSQL GIN Indexes](https://www.postgresql.org/docs/current/gin.html) diff --git a/README.md b/README.md index 39f1f957..44203682 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,8 @@ app/ │ ├── ingest/ # Batch upsert endpoints for sales data │ ├── featuresets/ # Time-safe feature engineering (lags, rolling, calendar) │ ├── forecasting/ # Model training, prediction, persistence -│ └── backtesting/ # Time-series CV, metrics, baseline comparisons +│ ├── backtesting/ # Time-series CV, metrics, baseline comparisons +│ └── registry/ # Model run tracking, artifacts, deployment aliases └── main.py # FastAPI entry point tests/ # Test fixtures and helpers @@ -129,7 +130,8 @@ examples/ ├── queries/ # Example SQL queries ├── models/ # Baseline model examples (naive, seasonal_naive, moving_average) ├── backtest/ # Backtesting examples (run_backtest, inspect_splits, metrics_demo) -└── compute_features_demo.py # Feature engineering demo +├── compute_features_demo.py # Feature engineering demo +└── registry_demo.py # Model registry workflow demo scripts/ # Utility scripts ``` @@ -301,6 +303,46 @@ When `include_baselines=true`, automatically compares against naive and seasonal See [examples/backtest/](examples/backtest/) for usage examples. +### Model Registry + +- `POST /registry/runs` - Create a new model run +- `GET /registry/runs` - List runs with filtering and pagination +- `GET /registry/runs/{run_id}` - Get run details +- `PATCH /registry/runs/{run_id}` - Update run (status, metrics, artifacts) +- `GET /registry/runs/{run_id}/verify` - Verify artifact integrity +- `POST /registry/aliases` - Create or update deployment alias +- `GET /registry/aliases` - List all aliases +- `GET /registry/aliases/{alias_name}` - Get alias details +- `DELETE /registry/aliases/{alias_name}` - Delete an alias +- `GET /registry/compare/{run_id_a}/{run_id_b}` - Compare two runs + +**Example Create Run Request:** +```bash +curl -X POST http://localhost:8123/registry/runs \ + -H "Content-Type: application/json" \ + -d '{ + "model_type": "seasonal_naive", + "model_config": {"season_length": 7}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-03-31", + "store_id": 1, + "product_id": 1 + }' +``` + +**Run Lifecycle:** +- `pending` → `running` → `success` | `failed` → `archived` +- Aliases can only point to runs with `success` status + +**Features:** +- JSONB storage for model_config, metrics, runtime_info +- SHA-256 artifact integrity verification +- Duplicate detection (configurable: allow/deny/detect) +- Runtime environment capture (Python, numpy, pandas versions) +- Agent context tracking for autonomous workflows + +See [examples/registry_demo.py](examples/registry_demo.py) for a complete workflow demo. + ## API Documentation Once the server is running: diff --git a/alembic/env.py b/alembic/env.py index fa61e07e..38e3e935 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -13,6 +13,7 @@ # Import all models for Alembic autogenerate detection from app.features.data_platform import models as data_platform_models # noqa: F401 +from app.features.registry import models as registry_models # noqa: F401 # Alembic Config object config = context.config diff --git a/alembic/versions/a2f7b3c8d901_create_model_registry_tables.py b/alembic/versions/a2f7b3c8d901_create_model_registry_tables.py new file mode 100644 index 00000000..2ca6c805 --- /dev/null +++ b/alembic/versions/a2f7b3c8d901_create_model_registry_tables.py @@ -0,0 +1,173 @@ +"""create_model_registry_tables + +Revision ID: a2f7b3c8d901 +Revises: e1165ebcef61 +Create Date: 2026-02-01 10:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "a2f7b3c8d901" +down_revision: Union[str, None] = "e1165ebcef61" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Apply migration - create model_run and deployment_alias tables.""" + # Create model_run table + op.create_table( + "model_run", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("run_id", sa.String(length=32), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False, server_default="pending"), + # Model configuration + sa.Column("model_type", sa.String(length=50), nullable=False), + sa.Column("model_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("feature_config", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("config_hash", sa.String(length=16), nullable=False), + # Data window + sa.Column("data_window_start", sa.Date(), nullable=False), + sa.Column("data_window_end", sa.Date(), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + # Metrics + sa.Column("metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + # Artifact info + sa.Column("artifact_uri", sa.String(length=500), nullable=True), + sa.Column("artifact_hash", sa.String(length=64), nullable=True), + sa.Column("artifact_size_bytes", sa.Integer(), nullable=True), + # Environment & lineage + sa.Column("runtime_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("agent_context", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("git_sha", sa.String(length=40), nullable=True), + # Error tracking + sa.Column("error_message", sa.String(length=2000), nullable=True), + # Timing + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + # Timestamps (from TimestampMixin) + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + # Constraints + sa.PrimaryKeyConstraint("id"), + sa.CheckConstraint( + "status IN ('pending', 'running', 'success', 'failed', 'archived')", + name="ck_model_run_valid_status", + ), + sa.CheckConstraint( + "data_window_end >= data_window_start", + name="ck_model_run_valid_data_window", + ), + ) + + # Create indexes for model_run + op.create_index(op.f("ix_model_run_run_id"), "model_run", ["run_id"], unique=True) + op.create_index(op.f("ix_model_run_status"), "model_run", ["status"], unique=False) + op.create_index(op.f("ix_model_run_model_type"), "model_run", ["model_type"], unique=False) + op.create_index(op.f("ix_model_run_config_hash"), "model_run", ["config_hash"], unique=False) + op.create_index(op.f("ix_model_run_store_id"), "model_run", ["store_id"], unique=False) + op.create_index(op.f("ix_model_run_product_id"), "model_run", ["product_id"], unique=False) + + # Composite indexes + op.create_index( + "ix_model_run_store_product", "model_run", ["store_id", "product_id"], unique=False + ) + op.create_index( + "ix_model_run_data_window", + "model_run", + ["data_window_start", "data_window_end"], + unique=False, + ) + + # GIN indexes for JSONB containment queries + op.create_index( + "ix_model_run_model_config_gin", + "model_run", + ["model_config"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_model_run_metrics_gin", + "model_run", + ["metrics"], + unique=False, + postgresql_using="gin", + ) + + # Create deployment_alias table + op.create_table( + "deployment_alias", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("alias_name", sa.String(length=100), nullable=False), + sa.Column("run_id", sa.Integer(), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + # Timestamps (from TimestampMixin) + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + # Constraints + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["run_id"], ["model_run.id"]), + sa.UniqueConstraint("alias_name", name="uq_deployment_alias_name"), + ) + + # Create indexes for deployment_alias + op.create_index( + op.f("ix_deployment_alias_alias_name"), + "deployment_alias", + ["alias_name"], + unique=True, + ) + op.create_index( + op.f("ix_deployment_alias_run_id"), "deployment_alias", ["run_id"], unique=False + ) + + +def downgrade() -> None: + """Revert migration - drop model_run and deployment_alias tables.""" + # Drop deployment_alias table and indexes + op.drop_index(op.f("ix_deployment_alias_run_id"), table_name="deployment_alias") + op.drop_index(op.f("ix_deployment_alias_alias_name"), table_name="deployment_alias") + op.drop_table("deployment_alias") + + # Drop model_run indexes + op.drop_index("ix_model_run_metrics_gin", table_name="model_run") + op.drop_index("ix_model_run_model_config_gin", table_name="model_run") + op.drop_index("ix_model_run_data_window", table_name="model_run") + op.drop_index("ix_model_run_store_product", table_name="model_run") + op.drop_index(op.f("ix_model_run_product_id"), table_name="model_run") + op.drop_index(op.f("ix_model_run_store_id"), table_name="model_run") + op.drop_index(op.f("ix_model_run_config_hash"), table_name="model_run") + op.drop_index(op.f("ix_model_run_model_type"), table_name="model_run") + op.drop_index(op.f("ix_model_run_status"), table_name="model_run") + op.drop_index(op.f("ix_model_run_run_id"), table_name="model_run") + + # Drop model_run table + op.drop_table("model_run") diff --git a/app/core/config.py b/app/core/config.py index 39c81f1d..808e0d9b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -53,6 +53,10 @@ class Settings(BaseSettings): backtest_max_gap: int = 30 backtest_results_dir: str = "./artifacts/backtests" + # Registry + registry_artifact_root: str = "./artifacts/registry" + registry_duplicate_policy: Literal["allow", "deny", "detect"] = "detect" + @property def is_development(self) -> bool: """Check if running in development mode.""" diff --git a/app/features/backtesting/tests/test_schemas.py b/app/features/backtesting/tests/test_schemas.py index 97c56fc3..31eec119 100644 --- a/app/features/backtesting/tests/test_schemas.py +++ b/app/features/backtesting/tests/test_schemas.py @@ -93,7 +93,7 @@ def test_frozen_config(self): """Test SplitConfig is immutable.""" config = SplitConfig() with pytest.raises(ValidationError): - config.n_splits = 10 + config.n_splits = 10 # type: ignore[misc] class TestBacktestConfig: @@ -136,7 +136,7 @@ def test_frozen_config(self): """Test BacktestConfig is immutable.""" config = BacktestConfig(model_config_main=NaiveModelConfig()) with pytest.raises(ValidationError): - config.include_baselines = False + config.include_baselines = False # type: ignore[misc] def test_invalid_schema_version(self): """Test invalid schema_version raises error.""" diff --git a/app/features/data_platform/tests/conftest.py b/app/features/data_platform/tests/conftest.py index 7b366631..494b3359 100644 --- a/app/features/data_platform/tests/conftest.py +++ b/app/features/data_platform/tests/conftest.py @@ -6,31 +6,36 @@ pytest behavior to allow feature tests to be self-contained. """ +from contextlib import suppress from datetime import date from decimal import Decimal import pytest +from sqlalchemy import delete from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings -from app.core.database import Base -from app.features.data_platform.models import Calendar, Product, Store +from app.features.data_platform.models import ( + Calendar, + InventorySnapshotDaily, + PriceHistory, + Product, + Promotion, + SalesDaily, + Store, +) @pytest.fixture async def db_session(): """Create async database session for integration tests. - This fixture creates all tables, provides a session, and cleans up after. - Requires PostgreSQL to be running (docker-compose up -d). + Uses existing tables from migrations. Cleans up test data after each test. + Requires PostgreSQL to be running (docker-compose up -d) and migrations applied. """ settings = get_settings() engine = create_async_engine(settings.database_url, echo=False) - # Create tables - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - # Create session async_session_maker = async_sessionmaker( engine, @@ -42,11 +47,27 @@ async def db_session(): try: yield session finally: - await session.rollback() - - # Cleanup: drop all tables - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) + # Rollback any pending transaction first (required if test caused an error) + with suppress(Exception): + await session.rollback() + + # Use a fresh session for cleanup to avoid transaction state issues + async with async_session_maker() as cleanup_session: + with suppress(Exception): + # Clean up test data (delete in correct order due to FK constraints) + await cleanup_session.execute(delete(SalesDaily)) + await cleanup_session.execute(delete(InventorySnapshotDaily)) + await cleanup_session.execute(delete(PriceHistory)) + await cleanup_session.execute(delete(Promotion)) + await cleanup_session.execute(delete(Product).where(Product.sku.like("SKU-TEST%"))) + await cleanup_session.execute(delete(Product).where(Product.sku.like("TEST-%"))) + await cleanup_session.execute(delete(Store).where(Store.code.like("TEST%"))) + await cleanup_session.execute( + delete(Calendar).where( + (Calendar.date >= date(2024, 1, 1)) & (Calendar.date <= date(2024, 12, 31)) + ) + ) + await cleanup_session.commit() await engine.dispose() diff --git a/app/features/featuresets/tests/test_schemas.py b/app/features/featuresets/tests/test_schemas.py index 4f9a3840..1988e38c 100644 --- a/app/features/featuresets/tests/test_schemas.py +++ b/app/features/featuresets/tests/test_schemas.py @@ -202,7 +202,7 @@ def test_config_is_frozen(self): """Config should be immutable (frozen).""" config = FeatureSetConfig(name="test") with pytest.raises(ValidationError): - config.name = "modified" + config.name = "modified" # type: ignore[misc] def test_rejects_empty_name(self): """Empty name should be rejected.""" diff --git a/app/features/forecasting/tests/test_schemas.py b/app/features/forecasting/tests/test_schemas.py index cb559e62..7663201d 100644 --- a/app/features/forecasting/tests/test_schemas.py +++ b/app/features/forecasting/tests/test_schemas.py @@ -31,7 +31,7 @@ def test_frozen_immutability(self): """Test that config is immutable (frozen=True).""" config = NaiveModelConfig() with pytest.raises(ValidationError): - config.model_type = "other" # type: ignore[assignment] + config.model_type = "other" # type: ignore[misc,assignment] def test_config_hash_determinism(self): """Test that config_hash is deterministic.""" @@ -98,7 +98,7 @@ def test_frozen_immutability(self): """Test that config is immutable.""" config = MovingAverageModelConfig() with pytest.raises(ValidationError): - config.window_size = 14 + config.window_size = 14 # type: ignore[misc] class TestLightGBMModelConfig: diff --git a/app/features/ingest/tests/test_routes.py b/app/features/ingest/tests/test_routes.py index 6facf362..ed1f9249 100644 --- a/app/features/ingest/tests/test_routes.py +++ b/app/features/ingest/tests/test_routes.py @@ -3,16 +3,16 @@ These tests require a running PostgreSQL database (docker-compose up -d). """ +from contextlib import suppress from datetime import date from decimal import Decimal import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings -from app.core.database import Base from app.features.data_platform.models import Calendar, Product, SalesDaily, Store from app.main import app @@ -21,16 +21,12 @@ async def db_session(): """Create async database session for integration tests. - Creates all tables, provides a session, and cleans up after. - Requires PostgreSQL to be running (docker-compose up -d). + Uses existing tables from migrations. Cleans up test data after each test. + Requires PostgreSQL to be running (docker-compose up -d) and migrations applied. """ settings = get_settings() engine = create_async_engine(settings.database_url, echo=False) - # Create tables - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - # Create session async_session_maker = async_sessionmaker( engine, @@ -42,11 +38,23 @@ async def db_session(): try: yield session finally: - await session.rollback() - - # Cleanup: drop all tables - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) + # Rollback any pending transaction first + with suppress(Exception): + await session.rollback() + + # Use a fresh session for cleanup to avoid transaction state issues + async with async_session_maker() as cleanup_session: + with suppress(Exception): + # Clean up test data (delete in correct order due to FK constraints) + await cleanup_session.execute(delete(SalesDaily)) + await cleanup_session.execute(delete(Product).where(Product.sku.like("SKU-%"))) + await cleanup_session.execute(delete(Store).where(Store.code.like("S00%"))) + await cleanup_session.execute( + delete(Calendar).where( + (Calendar.date >= date(2024, 1, 1)) & (Calendar.date <= date(2024, 12, 31)) + ) + ) + await cleanup_session.commit() await engine.dispose() diff --git a/app/features/registry/__init__.py b/app/features/registry/__init__.py new file mode 100644 index 00000000..ea0743af --- /dev/null +++ b/app/features/registry/__init__.py @@ -0,0 +1,47 @@ +"""Model Registry feature for tracking runs, artifacts, and deployments.""" + +from app.features.registry.models import DeploymentAlias, ModelRun, RunStatus +from app.features.registry.schemas import ( + VALID_TRANSITIONS, + AgentContext, + AliasCreate, + AliasResponse, + RunCompareResponse, + RunCreate, + RunListResponse, + RunResponse, + RuntimeInfo, + RunUpdate, +) +from app.features.registry.schemas import RunStatus as RunStatusSchema +from app.features.registry.service import RegistryService +from app.features.registry.storage import ( + AbstractStorageProvider, + ArtifactNotFoundError, + ChecksumMismatchError, + LocalFSProvider, + StorageError, +) + +__all__ = [ + "VALID_TRANSITIONS", + "AbstractStorageProvider", + "AgentContext", + "AliasCreate", + "AliasResponse", + "ArtifactNotFoundError", + "ChecksumMismatchError", + "DeploymentAlias", + "LocalFSProvider", + "ModelRun", + "RegistryService", + "RunCompareResponse", + "RunCreate", + "RunListResponse", + "RunResponse", + "RunStatus", + "RunStatusSchema", + "RunUpdate", + "RuntimeInfo", + "StorageError", +] diff --git a/app/features/registry/models.py b/app/features/registry/models.py new file mode 100644 index 00000000..248a803e --- /dev/null +++ b/app/features/registry/models.py @@ -0,0 +1,167 @@ +"""Model registry ORM models for tracking runs and deployments. + +This module defines: +- ModelRun: Registry entry for each model training run +- DeploymentAlias: Mutable pointers to successful runs + +CRITICAL: Uses PostgreSQL JSONB for flexible metadata storage. +""" + +from __future__ import annotations + +import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any + +from sqlalchemy import ( + CheckConstraint, + Date, + DateTime, + ForeignKey, + Index, + Integer, + String, + UniqueConstraint, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.database import Base +from app.shared.models import TimestampMixin + +if TYPE_CHECKING: + pass + + +class RunStatus(str, Enum): + """Valid states for a model run. + + State transitions: + - PENDING -> RUNNING -> SUCCESS | FAILED + - Any state except ARCHIVED -> ARCHIVED + """ + + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + ARCHIVED = "archived" + + +class ModelRun(TimestampMixin, Base): + """Model run registry entry. + + CRITICAL: Captures full experiment lineage for reproducibility. + + Attributes: + id: Primary key. + run_id: Unique external identifier (UUID hex, 32 chars). + status: Current lifecycle state. + model_type: Type of model (naive, seasonal_naive, etc.). + model_config: Full model configuration as JSONB. + feature_config: Feature engineering config as JSONB (nullable). + data_window_start: Training data start date. + data_window_end: Training data end date. + store_id: Store ID for this run. + product_id: Product ID for this run. + metrics: Performance metrics as JSONB. + artifact_uri: Relative path to artifact (from ARTIFACT_ROOT). + artifact_hash: SHA-256 checksum of artifact. + artifact_size_bytes: Size of artifact file. + runtime_info: Python/library versions as JSONB. + agent_context: Agent ID and session ID for traceability. + git_sha: Optional git commit hash. + config_hash: Hash of model_config for deduplication. + error_message: Error details if status=FAILED. + started_at: When run started. + completed_at: When run completed (success or failed). + """ + + __tablename__ = "model_run" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + run_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + status: Mapped[str] = mapped_column(String(20), default=RunStatus.PENDING.value, index=True) + + # Model configuration + model_type: Mapped[str] = mapped_column(String(50), index=True) + model_config: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + feature_config: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + config_hash: Mapped[str] = mapped_column(String(16), index=True) + + # Data window + data_window_start: Mapped[datetime.date] = mapped_column(Date) + data_window_end: Mapped[datetime.date] = mapped_column(Date) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + + # Metrics + metrics: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + + # Artifact info + artifact_uri: Mapped[str | None] = mapped_column(String(500), nullable=True) + artifact_hash: Mapped[str | None] = mapped_column(String(64), nullable=True) # SHA-256 + artifact_size_bytes: Mapped[int | None] = mapped_column(Integer, nullable=True) + + # Environment & lineage + runtime_info: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + agent_context: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + git_sha: Mapped[str | None] = mapped_column(String(40), nullable=True) + + # Error tracking + error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + + # Timing + started_at: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + completed_at: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + # Relationship to aliases + aliases: Mapped[list[DeploymentAlias]] = relationship(back_populates="run") + + __table_args__ = ( + # GIN index for JSONB containment queries + Index("ix_model_run_model_config_gin", "model_config", postgresql_using="gin"), + Index("ix_model_run_metrics_gin", "metrics", postgresql_using="gin"), + # Composite index for common query pattern + Index("ix_model_run_store_product", "store_id", "product_id"), + Index("ix_model_run_data_window", "data_window_start", "data_window_end"), + # Constraint: valid status values + CheckConstraint( + "status IN ('pending', 'running', 'success', 'failed', 'archived')", + name="ck_model_run_valid_status", + ), + # Constraint: data window validity + CheckConstraint( + "data_window_end >= data_window_start", + name="ck_model_run_valid_data_window", + ), + ) + + +class DeploymentAlias(TimestampMixin, Base): + """Mutable pointer to a specific successful run. + + CRITICAL: Aliases provide stable references for deployment. + + Attributes: + id: Primary key. + alias_name: Unique alias name (e.g., 'production', 'staging-v2'). + run_id: Foreign key to the aliased run (internal ID). + description: Optional description of this alias. + """ + + __tablename__ = "deployment_alias" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + alias_name: Mapped[str] = mapped_column(String(100), unique=True, index=True) + run_id: Mapped[int] = mapped_column(Integer, ForeignKey("model_run.id"), index=True) + description: Mapped[str | None] = mapped_column(String(500), nullable=True) + + # Relationship + run: Mapped[ModelRun] = relationship(back_populates="aliases") + + __table_args__ = (UniqueConstraint("alias_name", name="uq_deployment_alias_name"),) diff --git a/app/features/registry/routes.py b/app/features/registry/routes.py new file mode 100644 index 00000000..b173bf29 --- /dev/null +++ b/app/features/registry/routes.py @@ -0,0 +1,600 @@ +"""Registry API routes for model runs and deployment aliases.""" + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.exceptions import DatabaseError +from app.core.logging import get_logger +from app.features.registry.schemas import ( + AliasCreate, + AliasResponse, + RunCompareResponse, + RunCreate, + RunListResponse, + RunResponse, + RunStatus, + RunUpdate, +) +from app.features.registry.service import ( + DuplicateRunError, + InvalidTransitionError, + RegistryService, +) +from app.features.registry.storage import ( + ArtifactNotFoundError, + ChecksumMismatchError, + LocalFSProvider, +) + +logger = get_logger(__name__) + +router = APIRouter(prefix="/registry", tags=["registry"]) + + +# ============================================================================= +# Run Endpoints +# ============================================================================= + + +@router.post( + "/runs", + response_model=RunResponse, + status_code=status.HTTP_201_CREATED, + summary="Create a new model run", + description=""" +Create a new model run with PENDING status. + +**Required Fields:** +- `model_type`: Type of model (e.g., 'naive', 'seasonal_naive') +- `model_config`: Full model configuration as JSON +- `data_window_start`: Start date of training data +- `data_window_end`: End date of training data +- `store_id`: Store ID for this run +- `product_id`: Product ID for this run + +**Optional Fields:** +- `feature_config`: Feature engineering configuration +- `agent_context`: Agent ID and session ID for traceability +- `git_sha`: Git commit hash + +**Duplicate Detection:** +Based on `registry_duplicate_policy` setting: +- `allow`: Always create new runs +- `deny`: Reject if duplicate config+window exists +- `detect`: Log warning but allow creation +""", +) +async def create_run( + request: RunCreate, + db: AsyncSession = Depends(get_db), +) -> RunResponse: + """Create a new model run. + + Args: + request: Run creation request. + db: Async database session from dependency. + + Returns: + Created run details. + + Raises: + HTTPException: If duplicate detected with 'deny' policy. + DatabaseError: If database operation fails. + """ + logger.info( + "registry.create_run_request_received", + model_type=request.model_type, + store_id=request.store_id, + product_id=request.product_id, + ) + + service = RegistryService() + + try: + response = await service.create_run(db=db, run_data=request) + + logger.info( + "registry.create_run_request_completed", + run_id=response.run_id, + config_hash=response.config_hash, + ) + + return response + + except DuplicateRunError as e: + logger.warning( + "registry.create_run_request_failed", + error=str(e), + error_type=type(e).__name__, + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e), + ) from e + except SQLAlchemyError as e: + logger.error( + "registry.create_run_request_failed", + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise DatabaseError( + message="Failed to create run", + details={"error": str(e)}, + ) from e + + +@router.get( + "/runs", + response_model=RunListResponse, + summary="List model runs", + description=""" +List model runs with optional filtering and pagination. + +**Filters:** +- `model_type`: Filter by model type +- `status`: Filter by run status +- `store_id`: Filter by store ID +- `product_id`: Filter by product ID + +**Pagination:** +- `page`: Page number (1-indexed, default: 1) +- `page_size`: Runs per page (default: 20, max: 100) +""", +) +async def list_runs( + db: AsyncSession = Depends(get_db), + page: int = Query(1, ge=1, description="Page number"), + page_size: int = Query(20, ge=1, le=100, description="Runs per page"), + model_type: str | None = Query(None, description="Filter by model type"), + run_status: RunStatus | None = Query(None, alias="status", description="Filter by status"), + store_id: int | None = Query(None, ge=1, description="Filter by store ID"), + product_id: int | None = Query(None, ge=1, description="Filter by product ID"), +) -> RunListResponse: + """List model runs with filtering and pagination. + + Args: + db: Async database session from dependency. + page: Page number (1-indexed). + page_size: Number of runs per page. + model_type: Filter by model type. + run_status: Filter by status. + store_id: Filter by store ID. + product_id: Filter by product ID. + + Returns: + Paginated list of runs. + """ + service = RegistryService() + + response = await service.list_runs( + db=db, + page=page, + page_size=page_size, + model_type=model_type, + status=run_status, + store_id=store_id, + product_id=product_id, + ) + + return response + + +@router.get( + "/runs/{run_id}", + response_model=RunResponse, + summary="Get run details", + description="Get full details for a specific model run by its run_id.", +) +async def get_run( + run_id: str, + db: AsyncSession = Depends(get_db), +) -> RunResponse: + """Get run details by run_id. + + Args: + run_id: Run identifier. + db: Async database session from dependency. + + Returns: + Run details. + + Raises: + HTTPException: If run not found. + """ + service = RegistryService() + + response = await service.get_run(db=db, run_id=run_id) + + if response is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Run not found: {run_id}", + ) + + return response + + +@router.patch( + "/runs/{run_id}", + response_model=RunResponse, + summary="Update a run", + description=""" +Update a model run's status, metrics, or artifact information. + +**Status Transitions:** +- `pending` → `running` | `archived` +- `running` → `success` | `failed` | `archived` +- `success` → `archived` +- `failed` → `archived` +- `archived` → (terminal, no transitions) + +**Updatable Fields:** +- `status`: New status (must be valid transition) +- `metrics`: Performance metrics (JSON) +- `artifact_uri`: Relative path to artifact +- `artifact_hash`: SHA-256 checksum +- `artifact_size_bytes`: Artifact file size +- `error_message`: Error details (for FAILED runs) +""", +) +async def update_run( + run_id: str, + request: RunUpdate, + db: AsyncSession = Depends(get_db), +) -> RunResponse: + """Update a model run. + + Args: + run_id: Run identifier. + request: Update request with fields to change. + db: Async database session from dependency. + + Returns: + Updated run details. + + Raises: + HTTPException: If run not found or invalid status transition. + """ + logger.info( + "registry.update_run_request_received", + run_id=run_id, + new_status=request.status.value if request.status else None, + ) + + service = RegistryService() + + try: + response = await service.update_run(db=db, run_id=run_id, update_data=request) + + if response is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Run not found: {run_id}", + ) + + logger.info( + "registry.update_run_request_completed", + run_id=run_id, + status=response.status.value, + ) + + return response + + except InvalidTransitionError as e: + logger.warning( + "registry.update_run_request_failed", + run_id=run_id, + error=str(e), + error_type=type(e).__name__, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except SQLAlchemyError as e: + logger.error( + "registry.update_run_request_failed", + run_id=run_id, + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise DatabaseError( + message="Failed to update run", + details={"error": str(e)}, + ) from e + + +@router.get( + "/runs/{run_id}/verify", + response_model=dict[str, bool | str], + summary="Verify artifact integrity", + description=""" +Verify that the artifact for a run matches its stored checksum. + +Returns verification status and computed hash. +""", +) +async def verify_artifact( + run_id: str, + db: AsyncSession = Depends(get_db), +) -> dict[str, bool | str]: + """Verify artifact integrity for a run. + + Args: + run_id: Run identifier. + db: Async database session from dependency. + + Returns: + Verification result with computed hash. + + Raises: + HTTPException: If run not found or artifact missing. + """ + service = RegistryService() + run = await service.get_run(db=db, run_id=run_id) + + if run is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Run not found: {run_id}", + ) + + if run.artifact_uri is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Run has no associated artifact", + ) + + storage = LocalFSProvider() + + try: + path = storage.load(run.artifact_uri, expected_hash=run.artifact_hash) + actual_hash = storage.compute_hash(path) + + return { + "verified": True, + "run_id": run_id, + "artifact_uri": run.artifact_uri, + "stored_hash": run.artifact_hash or "", + "computed_hash": actual_hash, + } + + except ArtifactNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except ChecksumMismatchError as e: + return { + "verified": False, + "run_id": run_id, + "artifact_uri": run.artifact_uri, + "error": str(e), + } + + +# ============================================================================= +# Alias Endpoints +# ============================================================================= + + +@router.post( + "/aliases", + response_model=AliasResponse, + status_code=status.HTTP_201_CREATED, + summary="Create or update an alias", + description=""" +Create or update a deployment alias pointing to a successful run. + +**Alias Names:** +- Must start with lowercase letter or number +- Can contain lowercase letters, numbers, hyphens, and underscores +- Maximum 100 characters + +**IMPORTANT:** Aliases can only point to runs with SUCCESS status. +""", +) +async def create_alias( + request: AliasCreate, + db: AsyncSession = Depends(get_db), +) -> AliasResponse: + """Create or update a deployment alias. + + Args: + request: Alias creation request. + db: Async database session from dependency. + + Returns: + Created/updated alias details. + + Raises: + HTTPException: If run not found or not in SUCCESS status. + """ + logger.info( + "registry.create_alias_request_received", + alias_name=request.alias_name, + run_id=request.run_id, + ) + + service = RegistryService() + + try: + response = await service.create_alias(db=db, alias_data=request) + + logger.info( + "registry.create_alias_request_completed", + alias_name=request.alias_name, + run_id=response.run_id, + ) + + return response + + except ValueError as e: + logger.warning( + "registry.create_alias_request_failed", + alias_name=request.alias_name, + error=str(e), + error_type=type(e).__name__, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except SQLAlchemyError as e: + logger.error( + "registry.create_alias_request_failed", + alias_name=request.alias_name, + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise DatabaseError( + message="Failed to create alias", + details={"error": str(e)}, + ) from e + + +@router.get( + "/aliases", + response_model=list[AliasResponse], + summary="List all aliases", + description="List all deployment aliases sorted by name.", +) +async def list_aliases( + db: AsyncSession = Depends(get_db), +) -> list[AliasResponse]: + """List all deployment aliases. + + Args: + db: Async database session from dependency. + + Returns: + List of aliases. + """ + service = RegistryService() + return await service.list_aliases(db=db) + + +@router.get( + "/aliases/{alias_name}", + response_model=AliasResponse, + summary="Get alias details", + description="Get details for a specific deployment alias.", +) +async def get_alias( + alias_name: str, + db: AsyncSession = Depends(get_db), +) -> AliasResponse: + """Get alias details by name. + + Args: + alias_name: Alias name. + db: Async database session from dependency. + + Returns: + Alias details. + + Raises: + HTTPException: If alias not found. + """ + service = RegistryService() + response = await service.get_alias(db=db, alias_name=alias_name) + + if response is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Alias not found: {alias_name}", + ) + + return response + + +@router.delete( + "/aliases/{alias_name}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete an alias", + description="Delete a deployment alias.", +) +async def delete_alias( + alias_name: str, + db: AsyncSession = Depends(get_db), +) -> None: + """Delete a deployment alias. + + Args: + alias_name: Alias name. + db: Async database session from dependency. + + Raises: + HTTPException: If alias not found. + """ + logger.info( + "registry.delete_alias_request_received", + alias_name=alias_name, + ) + + service = RegistryService() + deleted = await service.delete_alias(db=db, alias_name=alias_name) + + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Alias not found: {alias_name}", + ) + + logger.info( + "registry.delete_alias_request_completed", + alias_name=alias_name, + ) + + +# ============================================================================= +# Compare Endpoint +# ============================================================================= + + +@router.get( + "/compare/{run_id_a}/{run_id_b}", + response_model=RunCompareResponse, + summary="Compare two runs", + description=""" +Compare two model runs side-by-side. + +Returns: +- Full details of both runs +- Configuration differences +- Metrics differences with computed deltas +""", +) +async def compare_runs( + run_id_a: str, + run_id_b: str, + db: AsyncSession = Depends(get_db), +) -> RunCompareResponse: + """Compare two runs. + + Args: + run_id_a: First run ID. + run_id_b: Second run ID. + db: Async database session from dependency. + + Returns: + Comparison of both runs. + + Raises: + HTTPException: If either run not found. + """ + service = RegistryService() + response = await service.compare_runs(db=db, run_id_a=run_id_a, run_id_b=run_id_b) + + if response is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"One or both runs not found: {run_id_a}, {run_id_b}", + ) + + return response diff --git a/app/features/registry/schemas.py b/app/features/registry/schemas.py new file mode 100644 index 00000000..97d0ddf1 --- /dev/null +++ b/app/features/registry/schemas.py @@ -0,0 +1,179 @@ +"""Pydantic schemas for registry API contracts. + +Schemas are designed to be: +- Immutable (frozen=True) for reproducibility +- Validated for data integrity +- Compatible with SQLAlchemy models via from_attributes +""" + +from __future__ import annotations + +import hashlib +import json +from datetime import date as date_type +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class RunStatus(str, Enum): + """Run lifecycle states.""" + + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + ARCHIVED = "archived" + + +# Valid state transitions +VALID_TRANSITIONS: dict[RunStatus, set[RunStatus]] = { + RunStatus.PENDING: {RunStatus.RUNNING, RunStatus.ARCHIVED}, + RunStatus.RUNNING: {RunStatus.SUCCESS, RunStatus.FAILED, RunStatus.ARCHIVED}, + RunStatus.SUCCESS: {RunStatus.ARCHIVED}, + RunStatus.FAILED: {RunStatus.ARCHIVED}, + RunStatus.ARCHIVED: set(), # Terminal state +} + + +class RuntimeInfo(BaseModel): + """Runtime environment snapshot.""" + + model_config = ConfigDict(frozen=True, extra="forbid") + + python_version: str + sklearn_version: str | None = None + numpy_version: str | None = None + pandas_version: str | None = None + joblib_version: str | None = None + + +class AgentContext(BaseModel): + """Agent context for autonomous run traceability.""" + + model_config = ConfigDict(frozen=True, extra="forbid") + + agent_id: str | None = None + session_id: str | None = None + + +class RunCreate(BaseModel): + """Request to create a new run.""" + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + model_type: str = Field(..., min_length=1, max_length=50) + model_config_data: dict[str, Any] = Field(..., alias="model_config") + feature_config: dict[str, Any] | None = None + data_window_start: date_type + data_window_end: date_type + store_id: int = Field(..., ge=1) + product_id: int = Field(..., ge=1) + agent_context: AgentContext | None = None + git_sha: str | None = Field(None, max_length=40) + + @field_validator("data_window_end") + @classmethod + def validate_data_window(cls, v: date_type, info: object) -> date_type: + """Ensure data_window_end >= data_window_start.""" + data = getattr(info, "data", {}) + if "data_window_start" in data and v < data["data_window_start"]: + raise ValueError("data_window_end must be >= data_window_start") + return v + + def compute_config_hash(self) -> str: + """Compute deterministic hash of model configuration. + + Returns: + 16-character hex string hash of config JSON. + """ + config_json = json.dumps(self.model_config_data, sort_keys=True, default=str) + return hashlib.sha256(config_json.encode()).hexdigest()[:16] + + +class RunUpdate(BaseModel): + """Request to update a run.""" + + model_config = ConfigDict(extra="forbid") + + status: RunStatus | None = None + metrics: dict[str, Any] | None = None + artifact_uri: str | None = None + artifact_hash: str | None = None + artifact_size_bytes: int | None = Field(None, ge=0) + error_message: str | None = Field(None, max_length=2000) + + +class RunResponse(BaseModel): + """Run details response.""" + + model_config = ConfigDict(from_attributes=True, populate_by_name=True) + + run_id: str + status: RunStatus + model_type: str + model_config_data: dict[str, Any] = Field( + ..., alias="model_config", serialization_alias="model_config" + ) + feature_config: dict[str, Any] | None = None + config_hash: str + data_window_start: date_type + data_window_end: date_type + store_id: int + product_id: int + metrics: dict[str, Any] | None = None + artifact_uri: str | None = None + artifact_hash: str | None = None + artifact_size_bytes: int | None = None + runtime_info: dict[str, Any] | None = None + agent_context: dict[str, Any] | None = None + git_sha: str | None = None + error_message: str | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + created_at: datetime + updated_at: datetime + + +class RunListResponse(BaseModel): + """Paginated list of runs.""" + + runs: list[RunResponse] + total: int + page: int + page_size: int + + +class AliasCreate(BaseModel): + """Request to create/update an alias.""" + + model_config = ConfigDict(extra="forbid") + + alias_name: str = Field(..., min_length=1, max_length=100, pattern=r"^[a-z0-9][a-z0-9\-_]*$") + run_id: str + description: str | None = Field(None, max_length=500) + + +class AliasResponse(BaseModel): + """Alias details response.""" + + model_config = ConfigDict(from_attributes=True) + + alias_name: str + run_id: str + run_status: RunStatus + model_type: str + description: str | None = None + created_at: datetime + updated_at: datetime + + +class RunCompareResponse(BaseModel): + """Comparison of two runs.""" + + run_a: RunResponse + run_b: RunResponse + config_diff: dict[str, Any] # Keys that differ + metrics_diff: dict[str, dict[str, float | None]] # {metric: {a: val, b: val, diff: val}} diff --git a/app/features/registry/service.py b/app/features/registry/service.py new file mode 100644 index 00000000..515f17ca --- /dev/null +++ b/app/features/registry/service.py @@ -0,0 +1,712 @@ +"""Registry service for managing model runs and deployments. + +Orchestrates: +- Creating and updating model runs +- Managing deployment aliases +- Comparing runs +- Capturing runtime environment info + +CRITICAL: All state transitions are validated. +""" + +from __future__ import annotations + +import hashlib +import json +import sys +import uuid +from datetime import UTC, date, datetime +from typing import Any + +import structlog +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.features.registry.models import DeploymentAlias, ModelRun +from app.features.registry.models import RunStatus as RunStatusORM +from app.features.registry.schemas import ( + VALID_TRANSITIONS, + AliasCreate, + AliasResponse, + RunCompareResponse, + RunCreate, + RunListResponse, + RunResponse, + RunStatus, + RunUpdate, +) + +logger = structlog.get_logger() + + +class InvalidTransitionError(ValueError): + """Invalid state transition attempted.""" + + pass + + +class DuplicateRunError(ValueError): + """Duplicate run detected and policy is 'deny'.""" + + pass + + +class RegistryService: + """Service for managing model runs and deployment aliases. + + Provides orchestration layer for: + - Creating and tracking model runs + - Managing deployment aliases + - Comparing run configurations and metrics + - Capturing runtime environment snapshots + + CRITICAL: All state transitions are validated. + """ + + def __init__(self) -> None: + """Initialize the registry service.""" + self.settings = get_settings() + + def _capture_runtime_info(self) -> dict[str, Any]: + """Capture current runtime environment information. + + Returns: + Dictionary with Python and library versions. + """ + runtime_info: dict[str, Any] = { + "python_version": sys.version, + } + + # Try to capture library versions + try: + import sklearn # type: ignore[import-untyped] + + runtime_info["sklearn_version"] = sklearn.__version__ + except ImportError: + pass + + try: + import numpy as np + + runtime_info["numpy_version"] = np.__version__ + except ImportError: + pass + + try: + import pandas as pd + + runtime_info["pandas_version"] = pd.__version__ + except ImportError: + pass + + try: + import joblib # type: ignore[import-untyped] + + runtime_info["joblib_version"] = joblib.__version__ + except ImportError: + pass + + return runtime_info + + def _compute_config_hash(self, config: dict[str, Any]) -> str: + """Compute deterministic hash of model configuration. + + Args: + config: Model configuration dictionary. + + Returns: + 16-character hex string hash. + """ + config_json = json.dumps(config, sort_keys=True, default=str) + return hashlib.sha256(config_json.encode()).hexdigest()[:16] + + def _is_valid_transition(self, current_status: RunStatus, new_status: RunStatus) -> bool: + """Check if state transition is valid. + + Args: + current_status: Current run status. + new_status: Proposed new status. + + Returns: + True if transition is valid, False otherwise. + """ + valid_next = VALID_TRANSITIONS.get(current_status, set()) + return new_status in valid_next + + def _validate_transition(self, current_status: RunStatus, new_status: RunStatus) -> None: + """Validate state transition is allowed. + + Args: + current_status: Current run status. + new_status: Proposed new status. + + Raises: + InvalidTransitionError: If transition is not allowed. + """ + if not self._is_valid_transition(current_status, new_status): + valid_next = VALID_TRANSITIONS.get(current_status, set()) + raise InvalidTransitionError( + f"Invalid transition from {current_status.value} to {new_status.value}. " + f"Valid transitions: {[s.value for s in valid_next]}" + ) + + async def create_run( + self, + db: AsyncSession, + run_data: RunCreate, + ) -> RunResponse: + """Create a new model run. + + Args: + db: Database session. + run_data: Run creation data. + + Returns: + Created run response. + + Raises: + DuplicateRunError: If duplicate detected and policy is 'deny'. + """ + run_id = uuid.uuid4().hex + config_hash = self._compute_config_hash(run_data.model_config_data) + + # Check for duplicates based on policy + if self.settings.registry_duplicate_policy in ("deny", "detect"): + existing = await self._find_duplicate( + db=db, + config_hash=config_hash, + store_id=run_data.store_id, + product_id=run_data.product_id, + data_window_start=run_data.data_window_start, + data_window_end=run_data.data_window_end, + ) + if existing: + if self.settings.registry_duplicate_policy == "deny": + raise DuplicateRunError(f"Duplicate run detected: {existing.run_id}") + else: # detect + logger.warning( + "registry.duplicate_detected", + existing_run_id=existing.run_id, + config_hash=config_hash, + ) + + # Capture runtime info + runtime_info = self._capture_runtime_info() + + # Convert agent context to dict if present + agent_context_dict = None + if run_data.agent_context: + agent_context_dict = run_data.agent_context.model_dump() + + # Create model run + model_run = ModelRun( + run_id=run_id, + status=RunStatusORM.PENDING.value, + model_type=run_data.model_type, + model_config=run_data.model_config_data, + feature_config=run_data.feature_config, + config_hash=config_hash, + data_window_start=run_data.data_window_start, + data_window_end=run_data.data_window_end, + store_id=run_data.store_id, + product_id=run_data.product_id, + runtime_info=runtime_info, + agent_context=agent_context_dict, + git_sha=run_data.git_sha, + ) + + db.add(model_run) + await db.flush() + await db.refresh(model_run) + + logger.info( + "registry.run_created", + run_id=run_id, + model_type=run_data.model_type, + config_hash=config_hash, + store_id=run_data.store_id, + product_id=run_data.product_id, + ) + + return self._model_to_response(model_run) + + async def get_run( + self, + db: AsyncSession, + run_id: str, + ) -> RunResponse | None: + """Get a run by its run_id. + + Args: + db: Database session. + run_id: Run identifier. + + Returns: + Run response or None if not found. + """ + stmt = select(ModelRun).where(ModelRun.run_id == run_id) + result = await db.execute(stmt) + model_run = result.scalar_one_or_none() + + if model_run is None: + return None + + return self._model_to_response(model_run) + + async def list_runs( + self, + db: AsyncSession, + page: int = 1, + page_size: int = 20, + model_type: str | None = None, + status: RunStatus | None = None, + store_id: int | None = None, + product_id: int | None = None, + ) -> RunListResponse: + """List runs with filtering and pagination. + + Args: + db: Database session. + page: Page number (1-indexed). + page_size: Number of runs per page. + model_type: Filter by model type. + status: Filter by status. + store_id: Filter by store ID. + product_id: Filter by product ID. + + Returns: + Paginated list of runs. + """ + # Build query with filters + stmt = select(ModelRun) + + if model_type is not None: + stmt = stmt.where(ModelRun.model_type == model_type) + if status is not None: + stmt = stmt.where(ModelRun.status == status.value) + if store_id is not None: + stmt = stmt.where(ModelRun.store_id == store_id) + if product_id is not None: + stmt = stmt.where(ModelRun.product_id == product_id) + + # Count total + count_stmt = select(func.count()).select_from(stmt.subquery()) + total_result = await db.execute(count_stmt) + total = total_result.scalar_one() + + # Apply pagination + offset = (page - 1) * page_size + stmt = stmt.order_by(ModelRun.created_at.desc()).offset(offset).limit(page_size) + + result = await db.execute(stmt) + runs = result.scalars().all() + + return RunListResponse( + runs=[self._model_to_response(run) for run in runs], + total=total, + page=page, + page_size=page_size, + ) + + async def update_run( + self, + db: AsyncSession, + run_id: str, + update_data: RunUpdate, + ) -> RunResponse | None: + """Update a run. + + Args: + db: Database session. + run_id: Run identifier. + update_data: Fields to update. + + Returns: + Updated run response or None if not found. + + Raises: + InvalidTransitionError: If status transition is invalid. + """ + stmt = select(ModelRun).where(ModelRun.run_id == run_id) + result = await db.execute(stmt) + model_run = result.scalar_one_or_none() + + if model_run is None: + return None + + # Validate status transition if changing status + if update_data.status is not None: + current_status = RunStatus(model_run.status) + self._validate_transition(current_status, update_data.status) + model_run.status = update_data.status.value + + # Update timing fields based on transition + now = datetime.now(UTC) + if update_data.status == RunStatus.RUNNING: + model_run.started_at = now + elif update_data.status in (RunStatus.SUCCESS, RunStatus.FAILED): + model_run.completed_at = now + + # Update other fields + if update_data.metrics is not None: + model_run.metrics = update_data.metrics + if update_data.artifact_uri is not None: + model_run.artifact_uri = update_data.artifact_uri + if update_data.artifact_hash is not None: + model_run.artifact_hash = update_data.artifact_hash + if update_data.artifact_size_bytes is not None: + model_run.artifact_size_bytes = update_data.artifact_size_bytes + if update_data.error_message is not None: + model_run.error_message = update_data.error_message + + await db.flush() + await db.refresh(model_run) + + logger.info( + "registry.run_updated", + run_id=run_id, + status=model_run.status, + has_metrics=model_run.metrics is not None, + has_artifact=model_run.artifact_uri is not None, + ) + + return self._model_to_response(model_run) + + async def create_alias( + self, + db: AsyncSession, + alias_data: AliasCreate, + ) -> AliasResponse: + """Create or update a deployment alias. + + Args: + db: Database session. + alias_data: Alias creation data. + + Returns: + Created/updated alias response. + + Raises: + ValueError: If run not found or not in SUCCESS status. + """ + # Find the run + stmt = select(ModelRun).where(ModelRun.run_id == alias_data.run_id) + result = await db.execute(stmt) + model_run = result.scalar_one_or_none() + + if model_run is None: + raise ValueError(f"Run not found: {alias_data.run_id}") + + # CRITICAL: Only SUCCESS runs can be aliased + if model_run.status != RunStatusORM.SUCCESS.value: + raise ValueError( + f"Only SUCCESS runs can be aliased. " + f"Run {alias_data.run_id} has status: {model_run.status}" + ) + + # Check if alias exists + alias_stmt = select(DeploymentAlias).where( + DeploymentAlias.alias_name == alias_data.alias_name + ) + alias_result = await db.execute(alias_stmt) + existing_alias = alias_result.scalar_one_or_none() + + if existing_alias: + # Update existing alias + existing_alias.run_id = model_run.id + existing_alias.description = alias_data.description + alias = existing_alias + logger.info( + "registry.alias_updated", + alias_name=alias_data.alias_name, + run_id=alias_data.run_id, + ) + else: + # Create new alias + alias = DeploymentAlias( + alias_name=alias_data.alias_name, + run_id=model_run.id, + description=alias_data.description, + ) + db.add(alias) + logger.info( + "registry.alias_created", + alias_name=alias_data.alias_name, + run_id=alias_data.run_id, + ) + + await db.flush() + await db.refresh(alias) + + return AliasResponse( + alias_name=alias.alias_name, + run_id=model_run.run_id, + run_status=RunStatus(model_run.status), + model_type=model_run.model_type, + description=alias.description, + created_at=alias.created_at, + updated_at=alias.updated_at, + ) + + async def get_alias( + self, + db: AsyncSession, + alias_name: str, + ) -> AliasResponse | None: + """Get an alias by name. + + Args: + db: Database session. + alias_name: Alias name. + + Returns: + Alias response or None if not found. + """ + stmt = ( + select(DeploymentAlias, ModelRun) + .join(ModelRun, DeploymentAlias.run_id == ModelRun.id) + .where(DeploymentAlias.alias_name == alias_name) + ) + result = await db.execute(stmt) + row = result.first() + + if row is None: + return None + + alias, model_run = row + + return AliasResponse( + alias_name=alias.alias_name, + run_id=model_run.run_id, + run_status=RunStatus(model_run.status), + model_type=model_run.model_type, + description=alias.description, + created_at=alias.created_at, + updated_at=alias.updated_at, + ) + + async def list_aliases( + self, + db: AsyncSession, + ) -> list[AliasResponse]: + """List all deployment aliases. + + Args: + db: Database session. + + Returns: + List of alias responses. + """ + stmt = ( + select(DeploymentAlias, ModelRun) + .join(ModelRun, DeploymentAlias.run_id == ModelRun.id) + .order_by(DeploymentAlias.alias_name) + ) + result = await db.execute(stmt) + rows = result.all() + + return [ + AliasResponse( + alias_name=alias.alias_name, + run_id=model_run.run_id, + run_status=RunStatus(model_run.status), + model_type=model_run.model_type, + description=alias.description, + created_at=alias.created_at, + updated_at=alias.updated_at, + ) + for alias, model_run in rows + ] + + async def delete_alias( + self, + db: AsyncSession, + alias_name: str, + ) -> bool: + """Delete a deployment alias. + + Args: + db: Database session. + alias_name: Alias name. + + Returns: + True if deleted, False if not found. + """ + stmt = select(DeploymentAlias).where(DeploymentAlias.alias_name == alias_name) + result = await db.execute(stmt) + alias = result.scalar_one_or_none() + + if alias is None: + return False + + await db.delete(alias) + await db.flush() + + logger.info("registry.alias_deleted", alias_name=alias_name) + return True + + async def compare_runs( + self, + db: AsyncSession, + run_id_a: str, + run_id_b: str, + ) -> RunCompareResponse | None: + """Compare two runs. + + Args: + db: Database session. + run_id_a: First run ID. + run_id_b: Second run ID. + + Returns: + Comparison response or None if either run not found. + """ + run_a = await self.get_run(db, run_id_a) + run_b = await self.get_run(db, run_id_b) + + if run_a is None or run_b is None: + return None + + # Compute config diff + config_diff = self._compute_config_diff(run_a.model_config_data, run_b.model_config_data) + + # Compute metrics diff + metrics_diff = self._compute_metrics_diff(run_a.metrics, run_b.metrics) + + return RunCompareResponse( + run_a=run_a, + run_b=run_b, + config_diff=config_diff, + metrics_diff=metrics_diff, + ) + + async def _find_duplicate( + self, + db: AsyncSession, + config_hash: str, + store_id: int, + product_id: int, + data_window_start: date, + data_window_end: date, + ) -> ModelRun | None: + """Find existing run with same config and data window. + + Args: + db: Database session. + config_hash: Configuration hash. + store_id: Store ID. + product_id: Product ID. + data_window_start: Data window start date. + data_window_end: Data window end date. + + Returns: + Existing run or None. + """ + stmt = select(ModelRun).where( + (ModelRun.config_hash == config_hash) + & (ModelRun.store_id == store_id) + & (ModelRun.product_id == product_id) + & (ModelRun.data_window_start == data_window_start) + & (ModelRun.data_window_end == data_window_end) + & (ModelRun.status != RunStatusORM.ARCHIVED.value) + ) + result = await db.execute(stmt) + return result.scalar_one_or_none() + + def _model_to_response(self, model_run: ModelRun) -> RunResponse: + """Convert ORM model to response schema. + + Args: + model_run: ORM model. + + Returns: + Response schema. + """ + # Build a dict that maps to the schema field names + # model_config in ORM -> model_config_data in schema (via alias "model_config") + data = { + "run_id": model_run.run_id, + "status": RunStatus(model_run.status), + "model_type": model_run.model_type, + "model_config": model_run.model_config, # uses alias + "feature_config": model_run.feature_config, + "config_hash": model_run.config_hash, + "data_window_start": model_run.data_window_start, + "data_window_end": model_run.data_window_end, + "store_id": model_run.store_id, + "product_id": model_run.product_id, + "metrics": model_run.metrics, + "artifact_uri": model_run.artifact_uri, + "artifact_hash": model_run.artifact_hash, + "artifact_size_bytes": model_run.artifact_size_bytes, + "runtime_info": model_run.runtime_info, + "agent_context": model_run.agent_context, + "git_sha": model_run.git_sha, + "error_message": model_run.error_message, + "started_at": model_run.started_at, + "completed_at": model_run.completed_at, + "created_at": model_run.created_at, + "updated_at": model_run.updated_at, + } + return RunResponse.model_validate(data) + + def _compute_config_diff( + self, config_a: dict[str, Any], config_b: dict[str, Any] + ) -> dict[str, Any]: + """Compute differences between two configurations. + + Args: + config_a: First configuration. + config_b: Second configuration. + + Returns: + Dictionary of differing keys with both values. + """ + diff: dict[str, Any] = {} + all_keys = set(config_a.keys()) | set(config_b.keys()) + + for key in all_keys: + val_a = config_a.get(key) + val_b = config_b.get(key) + if val_a != val_b: + diff[key] = {"a": val_a, "b": val_b} + + return diff + + def _compute_metrics_diff( + self, + metrics_a: dict[str, Any] | None, + metrics_b: dict[str, Any] | None, + ) -> dict[str, dict[str, float | None]]: + """Compute differences between two metric sets. + + Args: + metrics_a: First metrics. + metrics_b: Second metrics. + + Returns: + Dictionary with metric comparisons. + """ + metrics_a = metrics_a or {} + metrics_b = metrics_b or {} + + diff: dict[str, dict[str, float | None]] = {} + all_keys = set(metrics_a.keys()) | set(metrics_b.keys()) + + for key in all_keys: + val_a = metrics_a.get(key) + val_b = metrics_b.get(key) + + # Compute difference if both are numeric + diff_val: float | None = None + if isinstance(val_a, (int, float)) and isinstance(val_b, (int, float)): + diff_val = float(val_b) - float(val_a) + + diff[key] = { + "a": float(val_a) if isinstance(val_a, (int, float)) else None, + "b": float(val_b) if isinstance(val_b, (int, float)) else None, + "diff": diff_val, + } + + return diff diff --git a/app/features/registry/storage.py b/app/features/registry/storage.py new file mode 100644 index 00000000..d9ae5540 --- /dev/null +++ b/app/features/registry/storage.py @@ -0,0 +1,265 @@ +"""Artifact storage providers for model registry. + +Provides abstract interface and LocalFS implementation for storing +model artifacts with integrity verification via SHA-256 checksums. + +CRITICAL: All paths are validated to prevent directory traversal attacks. +""" + +from __future__ import annotations + +import hashlib +import shutil +from abc import ABC, abstractmethod +from pathlib import Path + +import structlog + +from app.core.config import get_settings + +logger = structlog.get_logger() + + +class StorageError(Exception): + """Base exception for storage operations.""" + + pass + + +class ArtifactNotFoundError(StorageError): + """Artifact not found at specified URI.""" + + pass + + +class ChecksumMismatchError(StorageError): + """Artifact checksum does not match stored value.""" + + pass + + +class AbstractStorageProvider(ABC): + """Abstract base class for artifact storage. + + CRITICAL: All storage providers must implement these methods. + This allows future S3/GCS implementations. + """ + + @abstractmethod + def save(self, source_path: Path, artifact_uri: str) -> tuple[str, int]: + """Save an artifact to storage. + + Args: + source_path: Local path to artifact file. + artifact_uri: Relative URI for storage. + + Returns: + Tuple of (sha256_hash, size_bytes). + + Raises: + StorageError: If save fails. + """ + pass + + @abstractmethod + def load(self, artifact_uri: str, expected_hash: str | None = None) -> Path: + """Load an artifact from storage. + + Args: + artifact_uri: Relative URI of artifact. + expected_hash: If provided, verify checksum. + + Returns: + Path to artifact (may be temp file for remote storage). + + Raises: + ArtifactNotFoundError: If artifact doesn't exist. + ChecksumMismatchError: If hash verification fails. + """ + pass + + @abstractmethod + def delete(self, artifact_uri: str) -> bool: + """Delete an artifact from storage. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + True if deleted, False if not found. + """ + pass + + @abstractmethod + def exists(self, artifact_uri: str) -> bool: + """Check if an artifact exists. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + True if exists, False otherwise. + """ + pass + + @staticmethod + def compute_hash(file_path: Path) -> str: + """Compute SHA-256 hash of a file. + + Args: + file_path: Path to file. + + Returns: + Hexadecimal SHA-256 hash. + """ + sha256 = hashlib.sha256() + with file_path.open("rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + + +class LocalFSProvider(AbstractStorageProvider): + """Local filesystem storage provider. + + CRITICAL: Default provider for development and single-node deployments. + """ + + def __init__(self, root_dir: Path | str | None = None) -> None: + """Initialize with root directory. + + Args: + root_dir: Root directory for artifacts. Defaults to Settings value. + """ + if root_dir is None: + settings = get_settings() + root_dir = Path(settings.registry_artifact_root) + elif isinstance(root_dir, str): + root_dir = Path(root_dir) + self.root_dir = root_dir.resolve() + self.root_dir.mkdir(parents=True, exist_ok=True) + + def _resolve_path(self, artifact_uri: str) -> Path: + """Resolve artifact URI to full path. + + CRITICAL: Validates path is within root to prevent traversal. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + Resolved absolute path. + + Raises: + StorageError: If path traversal attempt detected. + """ + full_path = (self.root_dir / artifact_uri).resolve() + # Security: ensure path is within root + try: + full_path.relative_to(self.root_dir) + except ValueError: + logger.warning( + "registry.path_traversal_attempt", + artifact_uri=artifact_uri, + root_dir=str(self.root_dir), + ) + raise StorageError(f"Path traversal attempt: {artifact_uri}") from None + return full_path + + def save(self, source_path: Path, artifact_uri: str) -> tuple[str, int]: + """Save artifact to local filesystem. + + Args: + source_path: Local path to artifact file. + artifact_uri: Relative URI for storage. + + Returns: + Tuple of (sha256_hash, size_bytes). + + Raises: + StorageError: If save fails. + """ + dest_path = self._resolve_path(artifact_uri) + dest_path.parent.mkdir(parents=True, exist_ok=True) + + # Compute hash before copy + file_hash = self.compute_hash(source_path) + file_size = source_path.stat().st_size + + # Copy file + shutil.copy2(source_path, dest_path) + + logger.info( + "registry.artifact_saved", + artifact_uri=artifact_uri, + hash=file_hash, + size_bytes=file_size, + ) + + return file_hash, file_size + + def load(self, artifact_uri: str, expected_hash: str | None = None) -> Path: + """Load artifact from local filesystem. + + Args: + artifact_uri: Relative URI of artifact. + expected_hash: If provided, verify checksum. + + Returns: + Path to artifact. + + Raises: + ArtifactNotFoundError: If artifact doesn't exist. + ChecksumMismatchError: If hash verification fails. + """ + full_path = self._resolve_path(artifact_uri) + + if not full_path.exists(): + raise ArtifactNotFoundError(f"Artifact not found: {artifact_uri}") + + # Verify hash if provided + if expected_hash is not None: + actual_hash = self.compute_hash(full_path) + if actual_hash != expected_hash: + logger.warning( + "registry.checksum_mismatch", + artifact_uri=artifact_uri, + expected=expected_hash, + actual=actual_hash, + ) + raise ChecksumMismatchError( + f"Checksum mismatch for {artifact_uri}: " + f"expected {expected_hash}, got {actual_hash}" + ) + + return full_path + + def delete(self, artifact_uri: str) -> bool: + """Delete artifact from local filesystem. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + True if deleted, False if not found. + """ + full_path = self._resolve_path(artifact_uri) + + if not full_path.exists(): + return False + + full_path.unlink() + logger.info("registry.artifact_deleted", artifact_uri=artifact_uri) + return True + + def exists(self, artifact_uri: str) -> bool: + """Check if artifact exists on local filesystem. + + Args: + artifact_uri: Relative URI of artifact. + + Returns: + True if exists, False otherwise. + """ + full_path = self._resolve_path(artifact_uri) + return full_path.exists() diff --git a/app/features/registry/tests/__init__.py b/app/features/registry/tests/__init__.py new file mode 100644 index 00000000..2a9f60d2 --- /dev/null +++ b/app/features/registry/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for registry module.""" diff --git a/app/features/registry/tests/conftest.py b/app/features/registry/tests/conftest.py new file mode 100644 index 00000000..7b71ed52 --- /dev/null +++ b/app/features/registry/tests/conftest.py @@ -0,0 +1,234 @@ +"""Test fixtures for registry module.""" + +import tempfile +import uuid +from collections.abc import AsyncGenerator, Generator +from datetime import date +from pathlib import Path + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.core.database import get_db +from app.features.registry.models import DeploymentAlias, ModelRun +from app.features.registry.schemas import AgentContext, RunCreate, RunStatus +from app.features.registry.storage import LocalFSProvider +from app.main import app + +# ============================================================================= +# Database Fixtures for Integration Tests +# ============================================================================= + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Create async database session for integration tests. + + Creates tables if needed, provides a session, and cleans up test data. + Requires PostgreSQL to be running (docker-compose up -d). + """ + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + + # Create session + async_session_maker = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + async with async_session_maker() as session: + try: + yield session + finally: + # Clean up test data (delete in correct order due to FK constraints) + await session.execute(delete(DeploymentAlias)) + await session.execute(delete(ModelRun).where(ModelRun.model_type.like("test-%"))) + await session.commit() + + await engine.dispose() + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Create test client with database dependency override.""" + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + ) as ac: + yield ac + + app.dependency_overrides.clear() + + +# ============================================================================= +# Unit Test Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_run_create() -> RunCreate: + """Create a sample RunCreate for testing.""" + return RunCreate( + model_type="test-naive", + model_config_data={"strategy": "last_value"}, + feature_config={"lags": [1, 7]}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 3, 31), + store_id=1, + product_id=1, + agent_context=AgentContext(agent_id="test-agent", session_id="test-session"), + git_sha="abc1234567890", + ) + + +@pytest.fixture +def sample_run_create_minimal() -> RunCreate: + """Create a minimal RunCreate for testing.""" + return RunCreate( + model_type="test-minimal", + model_config_data={"type": "baseline"}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + + +@pytest.fixture +def sample_run_create_duplicate(sample_run_create: RunCreate) -> RunCreate: + """Create a duplicate RunCreate (same config hash and data window).""" + return RunCreate( + model_type=sample_run_create.model_type, + model_config_data=sample_run_create.model_config_data, + data_window_start=sample_run_create.data_window_start, + data_window_end=sample_run_create.data_window_end, + store_id=sample_run_create.store_id, + product_id=sample_run_create.product_id, + ) + + +@pytest.fixture +def sample_model_run() -> ModelRun: + """Create a sample ModelRun ORM object for testing.""" + return ModelRun( + run_id=uuid.uuid4().hex, + status=RunStatus.PENDING.value, + model_type="test-naive", + model_config={"strategy": "last_value"}, + feature_config={"lags": [1, 7]}, + config_hash="abc123def456", + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 3, 31), + store_id=1, + product_id=1, + ) + + +@pytest.fixture +def temp_artifact_dir() -> Generator[Path, None, None]: + """Create a temporary directory for artifact storage.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def storage_provider(temp_artifact_dir: Path) -> LocalFSProvider: + """Create a LocalFSProvider with temporary root directory.""" + return LocalFSProvider(root_dir=temp_artifact_dir) + + +@pytest.fixture +def sample_artifact_content() -> bytes: + """Create sample artifact content for testing.""" + return b"test artifact content for sha256 verification" + + +@pytest.fixture +def sample_artifact_file(temp_artifact_dir: Path, sample_artifact_content: bytes) -> Path: + """Create a sample artifact file for testing.""" + artifact_path = temp_artifact_dir / "source_artifact.pkl" + artifact_path.write_bytes(sample_artifact_content) + return artifact_path + + +# ============================================================================= +# Status Transition Test Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_pending_run() -> ModelRun: + """Create a pending model run.""" + return ModelRun( + run_id=uuid.uuid4().hex, + status=RunStatus.PENDING.value, + model_type="test-status", + model_config={"test": True}, + config_hash="status12345678", + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + + +@pytest.fixture +def sample_running_run() -> ModelRun: + """Create a running model run.""" + return ModelRun( + run_id=uuid.uuid4().hex, + status=RunStatus.RUNNING.value, + model_type="test-status", + model_config={"test": True}, + config_hash="status12345678", + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + + +@pytest.fixture +def sample_success_run() -> ModelRun: + """Create a successful model run.""" + return ModelRun( + run_id=uuid.uuid4().hex, + status=RunStatus.SUCCESS.value, + model_type="test-status", + model_config={"test": True}, + config_hash="status12345678", + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + metrics={"mae": 1.5, "smape": 10.2}, + artifact_uri="models/test.pkl", + artifact_hash="abc123", + ) + + +@pytest.fixture +def sample_failed_run() -> ModelRun: + """Create a failed model run.""" + return ModelRun( + run_id=uuid.uuid4().hex, + status=RunStatus.FAILED.value, + model_type="test-status", + model_config={"test": True}, + config_hash="status12345678", + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + error_message="Training failed due to insufficient data", + ) diff --git a/app/features/registry/tests/test_routes.py b/app/features/registry/tests/test_routes.py new file mode 100644 index 00000000..72d889f1 --- /dev/null +++ b/app/features/registry/tests/test_routes.py @@ -0,0 +1,504 @@ +"""Integration tests for registry API routes. + +These tests require PostgreSQL to be running (docker-compose up -d). +Run with: pytest app/features/registry/tests/ -v -m integration +""" + +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.integration + + +class TestCreateRunEndpoint: + """Tests for POST /registry/runs endpoint.""" + + async def test_create_run_success(self, client: AsyncClient) -> None: + """Should create a new run with valid data.""" + response = await client.post( + "/registry/runs", + json={ + "model_type": "test-naive", + "model_config": {"strategy": "last_value"}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-03-31", + "store_id": 1, + "product_id": 1, + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["model_type"] == "test-naive" + assert data["status"] == "pending" + assert data["run_id"] is not None + assert len(data["run_id"]) == 32 + assert data["config_hash"] is not None + assert len(data["config_hash"]) == 16 + + async def test_create_run_with_all_fields(self, client: AsyncClient) -> None: + """Should create a run with all optional fields.""" + response = await client.post( + "/registry/runs", + json={ + "model_type": "test-seasonal", + "model_config": {"season_length": 7}, + "feature_config": {"lags": [1, 7, 14]}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-06-30", + "store_id": 5, + "product_id": 10, + "agent_context": { + "agent_id": "test-agent", + "session_id": "test-session", + }, + "git_sha": "abc123def456", + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["feature_config"] == {"lags": [1, 7, 14]} + assert data["agent_context"]["agent_id"] == "test-agent" + assert data["git_sha"] == "abc123def456" + assert data["runtime_info"]["python_version"].startswith("3.") + + async def test_create_run_validation_error(self, client: AsyncClient) -> None: + """Should return 422 for invalid data.""" + response = await client.post( + "/registry/runs", + json={ + "model_type": "", # Too short + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + assert response.status_code == 422 + + async def test_create_run_invalid_date_order(self, client: AsyncClient) -> None: + """Should return 422 if end date before start date.""" + response = await client.post( + "/registry/runs", + json={ + "model_type": "test-naive", + "model_config": {}, + "data_window_start": "2024-03-01", + "data_window_end": "2024-01-01", + "store_id": 1, + "product_id": 1, + }, + ) + assert response.status_code == 422 + + +class TestListRunsEndpoint: + """Tests for GET /registry/runs endpoint.""" + + async def test_list_runs_empty(self, client: AsyncClient) -> None: + """Should return empty list when no runs exist.""" + response = await client.get("/registry/runs") + assert response.status_code == 200 + data = response.json() + assert data["runs"] == [] + assert data["total"] == 0 + assert data["page"] == 1 + + async def test_list_runs_with_data(self, client: AsyncClient) -> None: + """Should return paginated list of runs.""" + # Create some runs + for i in range(3): + await client.post( + "/registry/runs", + json={ + "model_type": f"test-list-{i}", + "model_config": {"index": i}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + + response = await client.get("/registry/runs") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 3 + assert data["page"] == 1 + assert data["page_size"] == 20 + + async def test_list_runs_filter_by_model_type(self, client: AsyncClient) -> None: + """Should filter runs by model_type.""" + # Create runs with different types + await client.post( + "/registry/runs", + json={ + "model_type": "test-filter-a", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + await client.post( + "/registry/runs", + json={ + "model_type": "test-filter-b", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + + response = await client.get("/registry/runs?model_type=test-filter-a") + assert response.status_code == 200 + data = response.json() + for run in data["runs"]: + assert run["model_type"] == "test-filter-a" + + async def test_list_runs_filter_by_status(self, client: AsyncClient) -> None: + """Should filter runs by status.""" + response = await client.get("/registry/runs?status=pending") + assert response.status_code == 200 + data = response.json() + for run in data["runs"]: + assert run["status"] == "pending" + + async def test_list_runs_pagination(self, client: AsyncClient) -> None: + """Should paginate results correctly.""" + # Create runs + for i in range(5): + await client.post( + "/registry/runs", + json={ + "model_type": f"test-page-{i}", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + + response = await client.get("/registry/runs?page=1&page_size=2") + assert response.status_code == 200 + data = response.json() + assert len(data["runs"]) <= 2 + assert data["page"] == 1 + assert data["page_size"] == 2 + + +class TestGetRunEndpoint: + """Tests for GET /registry/runs/{run_id} endpoint.""" + + async def test_get_run_success(self, client: AsyncClient) -> None: + """Should return run details.""" + # Create a run + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-get", + "model_config": {"test": True}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + + # Get the run + response = await client.get(f"/registry/runs/{run_id}") + assert response.status_code == 200 + data = response.json() + assert data["run_id"] == run_id + assert data["model_type"] == "test-get" + + async def test_get_run_not_found(self, client: AsyncClient) -> None: + """Should return 404 for non-existent run.""" + response = await client.get("/registry/runs/nonexistent12345678901234567890") + assert response.status_code == 404 + + +class TestUpdateRunEndpoint: + """Tests for PATCH /registry/runs/{run_id} endpoint.""" + + async def test_update_run_status(self, client: AsyncClient) -> None: + """Should update run status.""" + # Create a run + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-update", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + + # Update to running + response = await client.patch( + f"/registry/runs/{run_id}", + json={"status": "running"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "running" + assert data["started_at"] is not None + + async def test_update_run_metrics(self, client: AsyncClient) -> None: + """Should update run metrics.""" + # Create and start a run + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-metrics", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + + # Transition to running first + await client.patch(f"/registry/runs/{run_id}", json={"status": "running"}) + + # Update to success with metrics + response = await client.patch( + f"/registry/runs/{run_id}", + json={ + "status": "success", + "metrics": {"mae": 1.5, "smape": 10.2, "wape": 0.08}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert data["metrics"]["mae"] == 1.5 + assert data["completed_at"] is not None + + async def test_update_run_invalid_transition(self, client: AsyncClient) -> None: + """Should return 400 for invalid status transition.""" + # Create a run + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-invalid", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + + # Try to go directly from pending to success + response = await client.patch( + f"/registry/runs/{run_id}", + json={"status": "success"}, + ) + assert response.status_code == 400 + assert "transition" in response.json()["detail"].lower() + + async def test_update_run_not_found(self, client: AsyncClient) -> None: + """Should return 404 for non-existent run.""" + response = await client.patch( + "/registry/runs/nonexistent12345678901234567890", + json={"status": "running"}, + ) + assert response.status_code == 404 + + +class TestAliasEndpoints: + """Tests for alias CRUD endpoints.""" + + async def test_create_alias_success(self, client: AsyncClient) -> None: + """Should create an alias for a successful run.""" + # Create a run and transition to success + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-alias", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + + await client.patch(f"/registry/runs/{run_id}", json={"status": "running"}) + await client.patch(f"/registry/runs/{run_id}", json={"status": "success"}) + + # Create alias + response = await client.post( + "/registry/aliases", + json={ + "alias_name": "production", + "run_id": run_id, + "description": "Production model", + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["alias_name"] == "production" + assert data["run_id"] == run_id + assert data["run_status"] == "success" + + async def test_create_alias_non_success_run(self, client: AsyncClient) -> None: + """Should return 400 when aliasing non-success run.""" + # Create a pending run + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-alias-fail", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + + # Try to create alias for pending run + response = await client.post( + "/registry/aliases", + json={ + "alias_name": "staging", + "run_id": run_id, + }, + ) + assert response.status_code == 400 + + async def test_list_aliases(self, client: AsyncClient) -> None: + """Should list all aliases.""" + response = await client.get("/registry/aliases") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + async def test_get_alias_success(self, client: AsyncClient) -> None: + """Should return alias details.""" + # Create a successful run and alias + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-get-alias", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + await client.patch(f"/registry/runs/{run_id}", json={"status": "running"}) + await client.patch(f"/registry/runs/{run_id}", json={"status": "success"}) + await client.post( + "/registry/aliases", + json={"alias_name": "get-test", "run_id": run_id}, + ) + + response = await client.get("/registry/aliases/get-test") + assert response.status_code == 200 + data = response.json() + assert data["alias_name"] == "get-test" + + async def test_get_alias_not_found(self, client: AsyncClient) -> None: + """Should return 404 for non-existent alias.""" + response = await client.get("/registry/aliases/nonexistent") + assert response.status_code == 404 + + async def test_delete_alias_success(self, client: AsyncClient) -> None: + """Should delete an alias.""" + # Create a successful run and alias + create_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-delete-alias", + "model_config": {}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_id = create_response.json()["run_id"] + await client.patch(f"/registry/runs/{run_id}", json={"status": "running"}) + await client.patch(f"/registry/runs/{run_id}", json={"status": "success"}) + await client.post( + "/registry/aliases", + json={"alias_name": "delete-test", "run_id": run_id}, + ) + + response = await client.delete("/registry/aliases/delete-test") + assert response.status_code == 204 + + # Verify deleted + get_response = await client.get("/registry/aliases/delete-test") + assert get_response.status_code == 404 + + async def test_delete_alias_not_found(self, client: AsyncClient) -> None: + """Should return 404 for non-existent alias.""" + response = await client.delete("/registry/aliases/nonexistent") + assert response.status_code == 404 + + +class TestCompareRunsEndpoint: + """Tests for GET /registry/compare/{run_id_a}/{run_id_b} endpoint.""" + + async def test_compare_runs_success(self, client: AsyncClient) -> None: + """Should compare two runs.""" + # Create two runs with different configs + run_a_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-compare", + "model_config": {"horizon": 7}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_a_id = run_a_response.json()["run_id"] + + run_b_response = await client.post( + "/registry/runs", + json={ + "model_type": "test-compare", + "model_config": {"horizon": 14}, + "data_window_start": "2024-01-01", + "data_window_end": "2024-01-31", + "store_id": 1, + "product_id": 1, + }, + ) + run_b_id = run_b_response.json()["run_id"] + + # Compare + response = await client.get(f"/registry/compare/{run_a_id}/{run_b_id}") + assert response.status_code == 200 + data = response.json() + assert data["run_a"]["run_id"] == run_a_id + assert data["run_b"]["run_id"] == run_b_id + assert "config_diff" in data + assert "metrics_diff" in data + assert "horizon" in data["config_diff"] + + async def test_compare_runs_not_found(self, client: AsyncClient) -> None: + """Should return 404 if either run not found.""" + response = await client.get( + "/registry/compare/nonexistent1234567890123456/nonexistent0987654321098765" + ) + assert response.status_code == 404 diff --git a/app/features/registry/tests/test_schemas.py b/app/features/registry/tests/test_schemas.py new file mode 100644 index 00000000..459531d7 --- /dev/null +++ b/app/features/registry/tests/test_schemas.py @@ -0,0 +1,383 @@ +"""Unit tests for registry schemas.""" + +from datetime import date + +import pytest +from pydantic import ValidationError + +from app.features.registry.schemas import ( + VALID_TRANSITIONS, + AgentContext, + AliasCreate, + RunCreate, + RunStatus, + RuntimeInfo, + RunUpdate, +) + + +class TestRunStatus: + """Tests for RunStatus enum.""" + + def test_all_statuses_defined(self) -> None: + """All expected statuses should be defined.""" + assert RunStatus.PENDING.value == "pending" + assert RunStatus.RUNNING.value == "running" + assert RunStatus.SUCCESS.value == "success" + assert RunStatus.FAILED.value == "failed" + assert RunStatus.ARCHIVED.value == "archived" + + def test_status_count(self) -> None: + """Should have exactly 5 statuses.""" + assert len(RunStatus) == 5 + + +class TestValidTransitions: + """Tests for state transition validation.""" + + def test_pending_transitions(self) -> None: + """PENDING can transition to RUNNING or ARCHIVED.""" + assert VALID_TRANSITIONS[RunStatus.PENDING] == { + RunStatus.RUNNING, + RunStatus.ARCHIVED, + } + + def test_running_transitions(self) -> None: + """RUNNING can transition to SUCCESS, FAILED, or ARCHIVED.""" + assert VALID_TRANSITIONS[RunStatus.RUNNING] == { + RunStatus.SUCCESS, + RunStatus.FAILED, + RunStatus.ARCHIVED, + } + + def test_success_transitions(self) -> None: + """SUCCESS can only transition to ARCHIVED.""" + assert VALID_TRANSITIONS[RunStatus.SUCCESS] == {RunStatus.ARCHIVED} + + def test_failed_transitions(self) -> None: + """FAILED can only transition to ARCHIVED.""" + assert VALID_TRANSITIONS[RunStatus.FAILED] == {RunStatus.ARCHIVED} + + def test_archived_is_terminal(self) -> None: + """ARCHIVED is a terminal state with no transitions.""" + assert VALID_TRANSITIONS[RunStatus.ARCHIVED] == set() + + +class TestRuntimeInfo: + """Tests for RuntimeInfo schema.""" + + def test_create_with_all_fields(self) -> None: + """Should create with all version fields.""" + info = RuntimeInfo( + python_version="3.12.0", + sklearn_version="1.4.0", + numpy_version="1.26.0", + pandas_version="2.1.0", + joblib_version="1.3.0", + ) + assert info.python_version == "3.12.0" + assert info.sklearn_version == "1.4.0" + + def test_create_minimal(self) -> None: + """Should create with only required fields.""" + info = RuntimeInfo(python_version="3.12.0") + assert info.python_version == "3.12.0" + assert info.sklearn_version is None + assert info.numpy_version is None + + def test_is_frozen(self) -> None: + """RuntimeInfo should be immutable.""" + info = RuntimeInfo(python_version="3.12.0") + with pytest.raises(ValidationError): + info.python_version = "3.11.0" # type: ignore[misc] + + +class TestAgentContext: + """Tests for AgentContext schema.""" + + def test_create_with_all_fields(self) -> None: + """Should create with all fields.""" + ctx = AgentContext(agent_id="agent-123", session_id="session-456") + assert ctx.agent_id == "agent-123" + assert ctx.session_id == "session-456" + + def test_create_empty(self) -> None: + """Should create with no fields (all optional).""" + ctx = AgentContext() + assert ctx.agent_id is None + assert ctx.session_id is None + + def test_is_frozen(self) -> None: + """AgentContext should be immutable.""" + ctx = AgentContext(agent_id="agent-123") + with pytest.raises(ValidationError): + ctx.agent_id = "agent-456" # type: ignore[misc] + + +class TestRunCreate: + """Tests for RunCreate schema.""" + + def test_create_minimal(self) -> None: + """Should create with only required fields.""" + run = RunCreate( + model_type="naive", + model_config_data={"strategy": "last_value"}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 3, 31), + store_id=1, + product_id=1, + ) + assert run.model_type == "naive" + assert run.model_config_data == {"strategy": "last_value"} + assert run.feature_config is None + assert run.agent_context is None + assert run.git_sha is None + + def test_create_with_all_fields(self) -> None: + """Should create with all fields.""" + run = RunCreate( + model_type="seasonal_naive", + model_config_data={"season_length": 7}, + feature_config={"lags": [1, 7, 14]}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 6, 30), + store_id=5, + product_id=10, + agent_context=AgentContext(agent_id="test"), + git_sha="abc123def456789", + ) + assert run.model_type == "seasonal_naive" + assert run.feature_config == {"lags": [1, 7, 14]} + assert run.store_id == 5 + assert run.product_id == 10 + + def test_validate_model_type_min_length(self) -> None: + """model_type should have minimum length of 1.""" + with pytest.raises(ValidationError) as exc_info: + RunCreate( + model_type="", + model_config_data={}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + assert "model_type" in str(exc_info.value) + + def test_validate_model_type_max_length(self) -> None: + """model_type should have maximum length of 50.""" + with pytest.raises(ValidationError) as exc_info: + RunCreate( + model_type="a" * 51, + model_config_data={}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + assert "model_type" in str(exc_info.value) + + def test_validate_store_id_positive(self) -> None: + """store_id must be >= 1.""" + with pytest.raises(ValidationError) as exc_info: + RunCreate( + model_type="naive", + model_config_data={}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=0, + product_id=1, + ) + assert "store_id" in str(exc_info.value) + + def test_validate_product_id_positive(self) -> None: + """product_id must be >= 1.""" + with pytest.raises(ValidationError) as exc_info: + RunCreate( + model_type="naive", + model_config_data={}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=0, + ) + assert "product_id" in str(exc_info.value) + + def test_validate_data_window_end_after_start(self) -> None: + """data_window_end must be >= data_window_start.""" + with pytest.raises(ValidationError) as exc_info: + RunCreate( + model_type="naive", + model_config_data={}, + data_window_start=date(2024, 3, 1), + data_window_end=date(2024, 1, 1), + store_id=1, + product_id=1, + ) + assert "data_window_end" in str(exc_info.value) + + def test_data_window_same_day_valid(self) -> None: + """data_window_end == data_window_start should be valid.""" + run = RunCreate( + model_type="naive", + model_config_data={}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 1), + store_id=1, + product_id=1, + ) + assert run.data_window_start == run.data_window_end + + def test_compute_config_hash(self) -> None: + """config_hash should be deterministic for same config.""" + run1 = RunCreate( + model_type="naive", + model_config_data={"a": 1, "b": 2}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + run2 = RunCreate( + model_type="naive", + model_config_data={"b": 2, "a": 1}, # Same config, different order + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + assert run1.compute_config_hash() == run2.compute_config_hash() + + def test_compute_config_hash_different(self) -> None: + """config_hash should differ for different configs.""" + run1 = RunCreate( + model_type="naive", + model_config_data={"a": 1}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + run2 = RunCreate( + model_type="naive", + model_config_data={"a": 2}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + assert run1.compute_config_hash() != run2.compute_config_hash() + + def test_config_hash_length(self) -> None: + """config_hash should be 16 characters.""" + run = RunCreate( + model_type="naive", + model_config_data={"test": True}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + assert len(run.compute_config_hash()) == 16 + + +class TestRunUpdate: + """Tests for RunUpdate schema.""" + + def test_create_empty(self) -> None: + """Should allow empty update (all fields optional).""" + update = RunUpdate() + assert update.status is None + assert update.metrics is None + assert update.artifact_uri is None + + def test_update_status(self) -> None: + """Should update status.""" + update = RunUpdate(status=RunStatus.RUNNING) + assert update.status == RunStatus.RUNNING + + def test_update_metrics(self) -> None: + """Should update metrics.""" + update = RunUpdate(metrics={"mae": 1.5, "smape": 10.2}) + assert update.metrics == {"mae": 1.5, "smape": 10.2} + + def test_update_artifact_info(self) -> None: + """Should update artifact information.""" + update = RunUpdate( + artifact_uri="models/run123.pkl", + artifact_hash="abc123def456", + artifact_size_bytes=1024, + ) + assert update.artifact_uri == "models/run123.pkl" + assert update.artifact_hash == "abc123def456" + assert update.artifact_size_bytes == 1024 + + def test_validate_artifact_size_bytes_non_negative(self) -> None: + """artifact_size_bytes must be >= 0.""" + with pytest.raises(ValidationError) as exc_info: + RunUpdate(artifact_size_bytes=-1) + assert "artifact_size_bytes" in str(exc_info.value) + + def test_validate_error_message_max_length(self) -> None: + """error_message should have maximum length of 2000.""" + with pytest.raises(ValidationError) as exc_info: + RunUpdate(error_message="x" * 2001) + assert "error_message" in str(exc_info.value) + + +class TestAliasCreate: + """Tests for AliasCreate schema.""" + + def test_create_minimal(self) -> None: + """Should create with required fields only.""" + alias = AliasCreate(alias_name="production", run_id="abc123") + assert alias.alias_name == "production" + assert alias.run_id == "abc123" + assert alias.description is None + + def test_create_with_description(self) -> None: + """Should create with description.""" + alias = AliasCreate( + alias_name="staging-v2", + run_id="def456", + description="Staging environment model", + ) + assert alias.description == "Staging environment model" + + def test_validate_alias_name_pattern_lowercase(self) -> None: + """alias_name must match pattern (lowercase letters, numbers, hyphens, underscores).""" + # Valid names + AliasCreate(alias_name="production", run_id="x") + AliasCreate(alias_name="staging-v2", run_id="x") + AliasCreate(alias_name="prod_us_east", run_id="x") + AliasCreate(alias_name="1-test", run_id="x") + + def test_validate_alias_name_pattern_invalid_uppercase(self) -> None: + """alias_name should reject uppercase letters.""" + with pytest.raises(ValidationError) as exc_info: + AliasCreate(alias_name="Production", run_id="x") + assert "alias_name" in str(exc_info.value) + + def test_validate_alias_name_pattern_invalid_special(self) -> None: + """alias_name should reject special characters.""" + with pytest.raises(ValidationError) as exc_info: + AliasCreate(alias_name="prod@v1", run_id="x") + assert "alias_name" in str(exc_info.value) + + def test_validate_alias_name_pattern_invalid_start(self) -> None: + """alias_name must start with letter or number.""" + with pytest.raises(ValidationError) as exc_info: + AliasCreate(alias_name="-production", run_id="x") + assert "alias_name" in str(exc_info.value) + + def test_validate_alias_name_max_length(self) -> None: + """alias_name should have maximum length of 100.""" + with pytest.raises(ValidationError) as exc_info: + AliasCreate(alias_name="a" * 101, run_id="x") + assert "alias_name" in str(exc_info.value) + + def test_validate_description_max_length(self) -> None: + """description should have maximum length of 500.""" + with pytest.raises(ValidationError) as exc_info: + AliasCreate(alias_name="test", run_id="x", description="x" * 501) + assert "description" in str(exc_info.value) diff --git a/app/features/registry/tests/test_service.py b/app/features/registry/tests/test_service.py new file mode 100644 index 00000000..5a5fde28 --- /dev/null +++ b/app/features/registry/tests/test_service.py @@ -0,0 +1,270 @@ +"""Unit tests for registry service.""" + +from datetime import date + +import pytest + +from app.features.registry.schemas import ( + VALID_TRANSITIONS, + RunCreate, + RunStatus, +) +from app.features.registry.service import ( + DuplicateRunError, + InvalidTransitionError, + RegistryService, +) + + +class TestRegistryServiceStatusTransition: + """Tests for status transition validation.""" + + def test_is_valid_transition_pending_to_running(self) -> None: + """PENDING -> RUNNING should be valid.""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.PENDING, RunStatus.RUNNING) is True + + def test_is_valid_transition_pending_to_archived(self) -> None: + """PENDING -> ARCHIVED should be valid.""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.PENDING, RunStatus.ARCHIVED) is True + + def test_is_valid_transition_running_to_success(self) -> None: + """RUNNING -> SUCCESS should be valid.""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.RUNNING, RunStatus.SUCCESS) is True + + def test_is_valid_transition_running_to_failed(self) -> None: + """RUNNING -> FAILED should be valid.""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.RUNNING, RunStatus.FAILED) is True + + def test_is_valid_transition_success_to_archived(self) -> None: + """SUCCESS -> ARCHIVED should be valid.""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.SUCCESS, RunStatus.ARCHIVED) is True + + def test_is_valid_transition_failed_to_archived(self) -> None: + """FAILED -> ARCHIVED should be valid.""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.FAILED, RunStatus.ARCHIVED) is True + + def test_is_invalid_transition_pending_to_success(self) -> None: + """PENDING -> SUCCESS should be invalid (must go through RUNNING).""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.PENDING, RunStatus.SUCCESS) is False + + def test_is_invalid_transition_pending_to_failed(self) -> None: + """PENDING -> FAILED should be invalid (must go through RUNNING).""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.PENDING, RunStatus.FAILED) is False + + def test_is_invalid_transition_success_to_running(self) -> None: + """SUCCESS -> RUNNING should be invalid (can't go backwards).""" + service = RegistryService() + assert service._is_valid_transition(RunStatus.SUCCESS, RunStatus.RUNNING) is False + + def test_is_invalid_transition_archived_to_any(self) -> None: + """ARCHIVED -> any state should be invalid (terminal state).""" + service = RegistryService() + for target in RunStatus: + if target != RunStatus.ARCHIVED: + assert service._is_valid_transition(RunStatus.ARCHIVED, target) is False + + +class TestRegistryServiceRuntimeInfo: + """Tests for runtime info capture.""" + + def test_capture_runtime_info_has_python_version(self) -> None: + """Should capture Python version.""" + service = RegistryService() + info = service._capture_runtime_info() + assert "python_version" in info + assert info["python_version"].startswith("3.") + + def test_capture_runtime_info_has_package_versions(self) -> None: + """Should capture installed package versions.""" + service = RegistryService() + info = service._capture_runtime_info() + + # These should be installed in the test environment + assert "numpy_version" in info + assert "pandas_version" in info + + +class TestRegistryServiceConfigHashDuplicate: + """Tests for config hash and duplicate detection.""" + + def test_compute_config_hash_deterministic(self) -> None: + """Config hash should be deterministic for same config.""" + run_data = RunCreate( + model_type="naive", + model_config_data={"a": 1, "b": 2}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + hash1 = run_data.compute_config_hash() + hash2 = run_data.compute_config_hash() + assert hash1 == hash2 + + def test_compute_config_hash_order_independent(self) -> None: + """Config hash should be same regardless of key order.""" + run1 = RunCreate( + model_type="naive", + model_config_data={"a": 1, "b": 2, "c": 3}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + run2 = RunCreate( + model_type="naive", + model_config_data={"c": 3, "a": 1, "b": 2}, + data_window_start=date(2024, 1, 1), + data_window_end=date(2024, 1, 31), + store_id=1, + product_id=1, + ) + assert run1.compute_config_hash() == run2.compute_config_hash() + + +class TestRegistryServiceConfigDiff: + """Tests for configuration diffing.""" + + def test_compute_config_diff_identical(self) -> None: + """Identical configs should have empty diff.""" + service = RegistryService() + config_a = {"strategy": "last_value", "horizon": 14} + config_b = {"strategy": "last_value", "horizon": 14} + diff = service._compute_config_diff(config_a, config_b) + assert diff == {} + + def test_compute_config_diff_different_values(self) -> None: + """Different values should be captured in diff.""" + service = RegistryService() + config_a = {"strategy": "last_value", "horizon": 14} + config_b = {"strategy": "mean", "horizon": 7} + diff = service._compute_config_diff(config_a, config_b) + assert diff == { + "strategy": {"a": "last_value", "b": "mean"}, + "horizon": {"a": 14, "b": 7}, + } + + def test_compute_config_diff_missing_keys(self) -> None: + """Missing keys should show None.""" + service = RegistryService() + config_a = {"strategy": "last_value", "extra_param": 100} + config_b = {"strategy": "last_value"} + diff = service._compute_config_diff(config_a, config_b) + assert diff == {"extra_param": {"a": 100, "b": None}} + + +class TestRegistryServiceMetricsDiff: + """Tests for metrics diffing.""" + + def test_compute_metrics_diff_both_none(self) -> None: + """Both None should return empty diff.""" + service = RegistryService() + diff = service._compute_metrics_diff(None, None) + assert diff == {} + + def test_compute_metrics_diff_one_none(self) -> None: + """One None should show values from the other.""" + service = RegistryService() + metrics_a = {"mae": 1.5, "smape": 10.0} + diff = service._compute_metrics_diff(metrics_a, None) + assert diff == { + "mae": {"a": 1.5, "b": None, "diff": None}, + "smape": {"a": 10.0, "b": None, "diff": None}, + } + + def test_compute_metrics_diff_numeric_diff(self) -> None: + """Should compute numeric difference (b - a).""" + service = RegistryService() + metrics_a = {"mae": 1.5, "smape": 10.0} + metrics_b = {"mae": 2.0, "smape": 8.0} + diff = service._compute_metrics_diff(metrics_a, metrics_b) + assert diff["mae"]["a"] == 1.5 + assert diff["mae"]["b"] == 2.0 + assert diff["mae"]["diff"] == pytest.approx(0.5) # b - a = 2.0 - 1.5 = 0.5 + assert diff["smape"]["diff"] == pytest.approx(-2.0) # b - a = 8.0 - 10.0 = -2.0 + + def test_compute_metrics_diff_non_numeric(self) -> None: + """Non-numeric values should have None diff.""" + service = RegistryService() + metrics_a = {"model_name": "naive", "mae": 1.5} + metrics_b = {"model_name": "seasonal", "mae": 2.0} + diff = service._compute_metrics_diff(metrics_a, metrics_b) + assert diff["model_name"]["diff"] is None + assert diff["mae"]["diff"] == pytest.approx(0.5) # b - a = 2.0 - 1.5 = 0.5 + + +class TestInvalidTransitionError: + """Tests for InvalidTransitionError.""" + + def test_error_message(self) -> None: + """Should format error message correctly.""" + error = InvalidTransitionError(RunStatus.PENDING, RunStatus.SUCCESS) + assert "pending" in str(error).lower() + assert "success" in str(error).lower() + + +class TestDuplicateRunError: + """Tests for DuplicateRunError.""" + + def test_error_message(self) -> None: + """Should format error message correctly.""" + error = DuplicateRunError("existing-run-id", "abc123") + assert "existing-run-id" in str(error) + assert "abc123" in str(error) + + +class TestAllTransitionsExhaustive: + """Exhaustive tests for all state transitions.""" + + @pytest.mark.parametrize( + "current_status,target_status", + [ + (RunStatus.PENDING, RunStatus.RUNNING), + (RunStatus.PENDING, RunStatus.ARCHIVED), + (RunStatus.RUNNING, RunStatus.SUCCESS), + (RunStatus.RUNNING, RunStatus.FAILED), + (RunStatus.RUNNING, RunStatus.ARCHIVED), + (RunStatus.SUCCESS, RunStatus.ARCHIVED), + (RunStatus.FAILED, RunStatus.ARCHIVED), + ], + ) + def test_valid_transitions(self, current_status: RunStatus, target_status: RunStatus) -> None: + """All valid transitions should be allowed.""" + service = RegistryService() + assert service._is_valid_transition(current_status, target_status) is True + + @pytest.mark.parametrize( + "current_status,target_status", + [ + (RunStatus.PENDING, RunStatus.SUCCESS), + (RunStatus.PENDING, RunStatus.FAILED), + (RunStatus.RUNNING, RunStatus.PENDING), + (RunStatus.SUCCESS, RunStatus.PENDING), + (RunStatus.SUCCESS, RunStatus.RUNNING), + (RunStatus.SUCCESS, RunStatus.FAILED), + (RunStatus.FAILED, RunStatus.PENDING), + (RunStatus.FAILED, RunStatus.RUNNING), + (RunStatus.FAILED, RunStatus.SUCCESS), + (RunStatus.ARCHIVED, RunStatus.PENDING), + (RunStatus.ARCHIVED, RunStatus.RUNNING), + (RunStatus.ARCHIVED, RunStatus.SUCCESS), + (RunStatus.ARCHIVED, RunStatus.FAILED), + ], + ) + def test_invalid_transitions(self, current_status: RunStatus, target_status: RunStatus) -> None: + """All invalid transitions should be rejected.""" + service = RegistryService() + assert service._is_valid_transition(current_status, target_status) is False + + def test_all_statuses_have_transition_rules(self) -> None: + """All statuses should be defined in VALID_TRANSITIONS.""" + for status in RunStatus: + assert status in VALID_TRANSITIONS diff --git a/app/features/registry/tests/test_storage.py b/app/features/registry/tests/test_storage.py new file mode 100644 index 00000000..52bda469 --- /dev/null +++ b/app/features/registry/tests/test_storage.py @@ -0,0 +1,241 @@ +"""Unit tests for registry storage providers.""" + +import hashlib +from pathlib import Path + +import pytest + +from app.features.registry.storage import ( + ArtifactNotFoundError, + ChecksumMismatchError, + LocalFSProvider, + StorageError, +) + + +class TestLocalFSProviderInit: + """Tests for LocalFSProvider initialization.""" + + def test_init_creates_root_dir(self, temp_artifact_dir: Path) -> None: + """Should create root directory if it doesn't exist.""" + new_root = temp_artifact_dir / "new_subdir" + assert not new_root.exists() + provider = LocalFSProvider(root_dir=new_root) + assert provider.root_dir.exists() + + def test_init_with_string_path(self, temp_artifact_dir: Path) -> None: + """Should accept string path.""" + provider = LocalFSProvider(root_dir=str(temp_artifact_dir)) + assert provider.root_dir == temp_artifact_dir + + def test_init_resolves_path(self, temp_artifact_dir: Path) -> None: + """Should resolve path to absolute.""" + relative_path = temp_artifact_dir / "subdir" / ".." / "resolved" + provider = LocalFSProvider(root_dir=relative_path) + assert provider.root_dir.is_absolute() + assert ".." not in str(provider.root_dir) + + +class TestLocalFSProviderSave: + """Tests for LocalFSProvider.save method.""" + + def test_save_copies_file( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + sample_artifact_content: bytes, + ) -> None: + """Should copy file to destination.""" + artifact_uri = "models/test.pkl" + storage_provider.save(sample_artifact_file, artifact_uri) + + dest_path = storage_provider.root_dir / artifact_uri + assert dest_path.exists() + assert dest_path.read_bytes() == sample_artifact_content + + def test_save_returns_hash_and_size( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + sample_artifact_content: bytes, + ) -> None: + """Should return SHA-256 hash and file size.""" + artifact_uri = "models/test.pkl" + file_hash, file_size = storage_provider.save(sample_artifact_file, artifact_uri) + + expected_hash = hashlib.sha256(sample_artifact_content).hexdigest() + expected_size = len(sample_artifact_content) + + assert file_hash == expected_hash + assert file_size == expected_size + + def test_save_creates_parent_directories( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + ) -> None: + """Should create parent directories if they don't exist.""" + artifact_uri = "deep/nested/path/model.pkl" + storage_provider.save(sample_artifact_file, artifact_uri) + + dest_path = storage_provider.root_dir / artifact_uri + assert dest_path.exists() + + def test_save_overwrites_existing( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + ) -> None: + """Should overwrite existing file.""" + artifact_uri = "models/test.pkl" + + # Create existing file + dest_path = storage_provider.root_dir / artifact_uri + dest_path.parent.mkdir(parents=True, exist_ok=True) + dest_path.write_text("old content") + + # Save new file + storage_provider.save(sample_artifact_file, artifact_uri) + + # Should have new content + assert dest_path.read_bytes() == sample_artifact_file.read_bytes() + + +class TestLocalFSProviderLoad: + """Tests for LocalFSProvider.load method.""" + + def test_load_returns_path( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + ) -> None: + """Should return path to artifact.""" + artifact_uri = "models/test.pkl" + storage_provider.save(sample_artifact_file, artifact_uri) + + loaded_path = storage_provider.load(artifact_uri) + assert loaded_path == storage_provider.root_dir / artifact_uri + + def test_load_raises_not_found(self, storage_provider: LocalFSProvider) -> None: + """Should raise ArtifactNotFoundError if file doesn't exist.""" + with pytest.raises(ArtifactNotFoundError) as exc_info: + storage_provider.load("nonexistent/model.pkl") + assert "not found" in str(exc_info.value).lower() + + def test_load_with_hash_verification( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + sample_artifact_content: bytes, + ) -> None: + """Should verify hash when provided.""" + artifact_uri = "models/test.pkl" + expected_hash = hashlib.sha256(sample_artifact_content).hexdigest() + + storage_provider.save(sample_artifact_file, artifact_uri) + loaded_path = storage_provider.load(artifact_uri, expected_hash=expected_hash) + + assert loaded_path.exists() + + def test_load_raises_checksum_mismatch( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + ) -> None: + """Should raise ChecksumMismatchError if hash doesn't match.""" + artifact_uri = "models/test.pkl" + wrong_hash = "0" * 64 + + storage_provider.save(sample_artifact_file, artifact_uri) + + with pytest.raises(ChecksumMismatchError) as exc_info: + storage_provider.load(artifact_uri, expected_hash=wrong_hash) + assert "mismatch" in str(exc_info.value).lower() + + +class TestLocalFSProviderDelete: + """Tests for LocalFSProvider.delete method.""" + + def test_delete_removes_file( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + ) -> None: + """Should delete existing file and return True.""" + artifact_uri = "models/test.pkl" + storage_provider.save(sample_artifact_file, artifact_uri) + + dest_path = storage_provider.root_dir / artifact_uri + assert dest_path.exists() + + result = storage_provider.delete(artifact_uri) + assert result is True + assert not dest_path.exists() + + def test_delete_returns_false_if_not_found(self, storage_provider: LocalFSProvider) -> None: + """Should return False if file doesn't exist.""" + result = storage_provider.delete("nonexistent/model.pkl") + assert result is False + + +class TestLocalFSProviderExists: + """Tests for LocalFSProvider.exists method.""" + + def test_exists_returns_true( + self, + storage_provider: LocalFSProvider, + sample_artifact_file: Path, + ) -> None: + """Should return True if file exists.""" + artifact_uri = "models/test.pkl" + storage_provider.save(sample_artifact_file, artifact_uri) + + assert storage_provider.exists(artifact_uri) is True + + def test_exists_returns_false(self, storage_provider: LocalFSProvider) -> None: + """Should return False if file doesn't exist.""" + assert storage_provider.exists("nonexistent/model.pkl") is False + + +class TestLocalFSProviderComputeHash: + """Tests for LocalFSProvider.compute_hash static method.""" + + def test_compute_hash_sha256( + self, sample_artifact_file: Path, sample_artifact_content: bytes + ) -> None: + """Should compute correct SHA-256 hash.""" + expected_hash = hashlib.sha256(sample_artifact_content).hexdigest() + actual_hash = LocalFSProvider.compute_hash(sample_artifact_file) + assert actual_hash == expected_hash + + def test_compute_hash_is_deterministic(self, sample_artifact_file: Path) -> None: + """Should return same hash for same file.""" + hash1 = LocalFSProvider.compute_hash(sample_artifact_file) + hash2 = LocalFSProvider.compute_hash(sample_artifact_file) + assert hash1 == hash2 + + +class TestLocalFSProviderPathTraversal: + """Tests for path traversal prevention.""" + + def test_reject_parent_directory_traversal(self, storage_provider: LocalFSProvider) -> None: + """Should reject ../.. traversal attempts.""" + with pytest.raises(StorageError) as exc_info: + storage_provider._resolve_path("../../../etc/passwd") + assert "traversal" in str(exc_info.value).lower() + + def test_reject_absolute_path(self, storage_provider: LocalFSProvider) -> None: + """Should reject absolute paths that escape root.""" + with pytest.raises(StorageError) as exc_info: + storage_provider._resolve_path("/etc/passwd") + assert "traversal" in str(exc_info.value).lower() + + def test_allow_nested_paths(self, storage_provider: LocalFSProvider) -> None: + """Should allow valid nested paths.""" + path = storage_provider._resolve_path("models/2024/01/run123.pkl") + assert path.is_relative_to(storage_provider.root_dir) + + def test_allow_paths_with_dots_in_name(self, storage_provider: LocalFSProvider) -> None: + """Should allow dots in filenames (not traversal).""" + path = storage_provider._resolve_path("models/model.v1.0.pkl") + assert path.is_relative_to(storage_provider.root_dir) diff --git a/app/main.py b/app/main.py index eee3b908..c4bc6509 100644 --- a/app/main.py +++ b/app/main.py @@ -14,6 +14,7 @@ from app.features.featuresets.routes import router as featuresets_router from app.features.forecasting.routes import router as forecasting_router from app.features.ingest.routes import router as ingest_router +from app.features.registry.routes import router as registry_router logger = get_logger(__name__) @@ -74,6 +75,7 @@ def create_app() -> FastAPI: app.include_router(featuresets_router) app.include_router(forecasting_router) app.include_router(backtesting_router) + app.include_router(registry_router) return app diff --git a/docker-compose.yml b/docker-compose.yml index a976ab61..e1b2066b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,7 +7,7 @@ services: POSTGRES_PASSWORD: forecastlab POSTGRES_DB: forecastlab ports: - - "5432:5432" + - "5433:5432" volumes: - forecastlab_pgdata:/var/lib/postgresql/data healthcheck: diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index a36af84e..24b7ad1f 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -312,15 +312,69 @@ forecast_enable_lightgbm: bool = False - Tests: `app/features/backtesting/tests/` (95 tests) - Examples: `examples/backtest/` (run_backtest.py, inspect_splits.py, metrics_demo.py) -### 7.6 Model Registry (Planned) -Each run stores: -- run_id, timestamps -- model_type + model_config (JSON) -- feature_config + schema_version -- data window boundaries -- metrics (JSON) -- artifact URI/path + artifact hash -- optional git_sha +### 7.6 Model Registry — ✅ IMPLEMENTED + +**Implemented via PRP-7** - Full run tracking and deployment alias management: + +**ORM Models:** +- `ModelRun` - JSONB columns for model_config, feature_config, metrics, runtime_info, agent_context +- `DeploymentAlias` - Mutable pointers to successful runs for deployment + +**Run Lifecycle (State Machine):** +``` +PENDING → RUNNING → SUCCESS/FAILED → ARCHIVED +``` +- Validated transitions prevent invalid state changes +- Aliases can only point to SUCCESS runs + +**Storage Provider:** +- `LocalFSProvider` with abstract interface for future S3/GCS support +- SHA-256 integrity verification on load +- Path traversal prevention (security) + +**Each Run Stores:** +- run_id (UUID hex, 32 chars), timestamps (created_at, updated_at, started_at, completed_at) +- model_type + model_config (JSONB with GIN index) +- feature_config (JSONB, optional) +- data_window_start, data_window_end, store_id, product_id +- config_hash (16-char SHA-256 prefix for deduplication) +- metrics (JSONB with GIN index) +- artifact_uri, artifact_hash (SHA-256), artifact_size_bytes +- runtime_info (Python, numpy, pandas, sklearn, joblib versions) +- agent_context (agent_id, session_id for autonomous workflows) +- git_sha (optional) +- error_message (for FAILED runs) + +**Duplicate Detection:** +- Configurable via `registry_duplicate_policy`: allow, deny, detect +- Based on config_hash + store_id + product_id + data_window + +**API Endpoints:** +- `POST /registry/runs` - Create run +- `GET /registry/runs` - List with filters and pagination +- `GET /registry/runs/{run_id}` - Get run details +- `PATCH /registry/runs/{run_id}` - Update status/metrics/artifacts +- `GET /registry/runs/{run_id}/verify` - Verify artifact integrity +- `POST /registry/aliases` - Create/update deployment alias +- `GET /registry/aliases` - List aliases +- `GET /registry/aliases/{alias_name}` - Get alias +- `DELETE /registry/aliases/{alias_name}` - Delete alias +- `GET /registry/compare/{run_id_a}/{run_id_b}` - Compare runs + +**Location:** +- Models: `app/features/registry/models.py` +- Schemas: `app/features/registry/schemas.py` +- Storage: `app/features/registry/storage.py` +- Service: `app/features/registry/service.py` +- Routes: `app/features/registry/routes.py` +- Tests: `app/features/registry/tests/` (103 unit + 24 integration tests) +- Example: `examples/registry_demo.py` + +**Configuration (Settings):** +```python +registry_artifact_root: str = "./artifacts/registry" +registry_duplicate_policy: Literal["allow", "deny", "detect"] = "detect" +``` --- @@ -334,9 +388,18 @@ Each run stores: - `POST /forecasting/train` - Train forecasting model (returns model_path) - `POST /forecasting/predict` - Generate forecasts using saved model - `POST /backtesting/run` - Run time-series CV backtest with baseline comparisons +- `POST /registry/runs` - Create model run +- `GET /registry/runs` - List runs with filters +- `GET /registry/runs/{run_id}` - Get run details +- `PATCH /registry/runs/{run_id}` - Update run status/metrics/artifacts +- `GET /registry/runs/{run_id}/verify` - Verify artifact integrity +- `POST /registry/aliases` - Create deployment alias +- `GET /registry/aliases` - List aliases +- `GET /registry/aliases/{alias_name}` - Get alias details +- `DELETE /registry/aliases/{alias_name}` - Delete alias +- `GET /registry/compare/{run_id_a}/{run_id_b}` - Compare two runs **Planned Endpoints:** -- `GET /runs`, `GET /runs/{run_id}` - Model registry and leaderboard - `GET /data/kpis`, `GET /data/drilldowns` - Data exploration - `POST /rag/query` - RAG knowledge base queries (optional `/rag/index` in dev) @@ -385,7 +448,10 @@ The repo standards live in `docs/validation/` and are treated as merge gates: ## 12) Roadmap (Phased Delivery) -- **Phase-0**: vertical-slice demo (seed → ingest → baseline train → predict → UI tables) -- **Phase-1**: ForecastOps core (backtesting + registry + leaderboard) +- **Phase-0**: vertical-slice demo (seed → ingest → baseline train → predict → UI tables) ✅ +- **Phase-1**: ForecastOps core (backtesting + registry + leaderboard) ✅ + - Backtesting: ✅ IMPLEMENTED (PRP-6) + - Registry: ✅ IMPLEMENTED (PRP-7) + - Leaderboard UI: Planned - **Phase-2**: ML models + richer exogenous features - **Phase-3**: RAG + agentic workflows (PydanticAI), run report generation/indexing diff --git a/docs/PHASE-index.md b/docs/PHASE-index.md index 589b763b..7b912a85 100644 --- a/docs/PHASE-index.md +++ b/docs/PHASE-index.md @@ -12,9 +12,9 @@ This document indexes all implementation phases of the ForecastLabAI project. | 1 | Data Platform | Completed | PRP-2 | [1-DATA_PLATFORM.md](./PHASE/1-DATA_PLATFORM.md) | | 2 | Ingest Layer | Completed | PRP-3 | [2-INGEST_LAYER.md](./PHASE/2-INGEST_LAYER.md) | | 3 | Feature Engineering | Completed | PRP-4 | [3-FEATURE_ENGINEERING.md](./PHASE/3-FEATURE_ENGINEERING.md) | -| 4 | Forecasting | Pending | PRP-5 | - | -| 5 | Backtesting | Pending | PRP-6 | - | -| 6 | Model Registry | Pending | PRP-7 | - | +| 4 | Forecasting | Completed | PRP-5 | [4-FORECASTING.md](./PHASE/4-FORECASTING.md) | +| 5 | Backtesting | Completed | PRP-6 | [5-BACKTESTING.md](./PHASE/5-BACKTESTING.md) | +| 6 | Model Registry | Completed | PRP-7 | [6-MODEL_REGISTRY.md](./PHASE/6-MODEL_REGISTRY.md) | | 7 | RAG Knowledge Base | Pending | PRP-8 | - | | 8 | Dashboard | Pending | PRP-9 | - | | 9 | Agentic Layer | Pending | - | - | @@ -156,18 +156,82 @@ This document indexes all implementation phases of the ForecastLabAI project. - Pyright: 0 errors - Pytest: 55 tests passed ---- +### Phase 4: Forecasting -## Pending Phases +**Date Completed**: 2026-01-31 -### Phase 4: Forecasting -Model zoo with unified interface for naive, seasonal, and ML models. +**Summary**: Model zoo with unified forecaster interface: +- BaseForecaster abstract class with `fit()` and `predict()` methods +- Naive, SeasonalNaive, MovingAverage models implemented +- LightGBM model (feature-flagged, disabled by default) +- Model bundle persistence with joblib (fitted model + config + metadata) +- POST /forecasting/train and POST /forecasting/predict endpoints + +**Key Deliverables**: +- `app/features/forecasting/models.py` - BaseForecaster and model implementations +- `app/features/forecasting/persistence.py` - ModelBundle save/load +- `app/features/forecasting/schemas.py` - Request/response schemas +- `app/features/forecasting/service.py` - ForecastingService +- `app/features/forecasting/routes.py` - API endpoints +- `examples/models/` - Baseline model examples ### Phase 5: Backtesting -Rolling and expanding time-based cross-validation with per-series metrics. + +**Date Completed**: 2026-01-31 + +**Summary**: Time-series cross-validation with comprehensive metrics: +- TimeSeriesSplitter with expanding/sliding window strategies +- Gap parameter for operational latency simulation +- Metrics: MAE, sMAPE (0-200), WAPE, Bias, Stability Index +- Automatic baseline comparisons (naive, seasonal_naive) +- Per-fold and aggregated metric storage +- POST /backtesting/run endpoint + +**Key Deliverables**: +- `app/features/backtesting/splitter.py` - TimeSeriesSplitter +- `app/features/backtesting/metrics.py` - Metrics computation +- `app/features/backtesting/schemas.py` - Request/response schemas +- `app/features/backtesting/service.py` - BacktestingService +- `app/features/backtesting/routes.py` - API endpoint +- `examples/backtest/` - Usage examples (95 unit + 16 integration tests) ### Phase 6: Model Registry -Run tracking with config, metrics, artifacts, and data windows. + +**Date Completed**: 2026-02-01 + +**Summary**: Full run tracking and deployment alias management: +- ModelRun ORM with JSONB columns (model_config, metrics, runtime_info) +- DeploymentAlias for mutable pointers to successful runs +- State machine: PENDING → RUNNING → SUCCESS/FAILED → ARCHIVED +- LocalFSProvider with SHA-256 integrity verification +- Duplicate detection (configurable: allow/deny/detect) +- Runtime environment capture and agent context tracking + +**Key Deliverables**: +- `app/features/registry/models.py` - ModelRun, DeploymentAlias ORM models +- `app/features/registry/storage.py` - LocalFSProvider with abstract interface +- `app/features/registry/schemas.py` - Request/response schemas +- `app/features/registry/service.py` - RegistryService +- `app/features/registry/routes.py` - API endpoints (runs, aliases, compare) +- `alembic/versions/a2f7b3c8d901_create_model_registry_tables.py` - Migration +- `examples/registry_demo.py` - Workflow demo + +**API Endpoints**: +- `POST /registry/runs` - Create run +- `GET /registry/runs` - List with filters and pagination +- `PATCH /registry/runs/{run_id}` - Update status/metrics/artifacts +- `GET /registry/runs/{run_id}/verify` - Verify artifact integrity +- `POST /registry/aliases` - Create deployment alias +- `GET /registry/compare/{run_id_a}/{run_id_b}` - Compare runs + +**Validation Results**: +- Ruff: All checks passed +- Pyright: 0 errors +- Pytest: 103 unit + 24 integration tests + +--- + +## Pending Phases ### Phase 7: RAG Knowledge Base pgvector embeddings with evidence-grounded answers and citations. @@ -219,3 +283,6 @@ Each phase document (`docs/PHASE/X-PHASE_NAME.md`) contains: | 2026-01-26 | 1 | Data Platform schema and migrations completed (v0.1.3) | | 2026-01-26 | 2 | Ingest Layer with POST /ingest/sales-daily endpoint completed | | 2026-01-31 | 3 | Feature Engineering with time-safe leakage prevention completed | +| 2026-01-31 | 4 | Forecasting module with model zoo completed | +| 2026-01-31 | 5 | Backtesting module with time-series CV completed | +| 2026-02-01 | 6 | Model Registry with run tracking and deployment aliases completed | diff --git a/docs/PHASE/4-FORECASTING.md b/docs/PHASE/4-FORECASTING.md new file mode 100644 index 00000000..8939d534 --- /dev/null +++ b/docs/PHASE/4-FORECASTING.md @@ -0,0 +1,329 @@ +# Phase 4: Forecasting + +**Date Completed**: 2026-01-31 +**PRP**: [PRP-5-forecasting.md](../../PRPs/PRP-5-forecasting.md) +**Release**: PR #28 + +--- + +## Executive Summary + +Phase 4 implements the Forecasting Layer for ForecastLabAI with a unified model zoo following scikit-learn conventions. The module provides a `BaseForecaster` abstract class that all models implement, ensuring consistent `fit`/`predict` interfaces and seamless integration with the backtesting framework. + +**Key Achievement**: Extensible model zoo with deterministic training via fixed `random_state` and joblib-based persistence for reproducibility. + +--- + +## Deliverables + +### 1. BaseForecaster Abstract Class + +**File**: `app/features/forecasting/models.py` + +Unified interface for all forecasting models: + +```python +class BaseForecaster(ABC): + """Abstract base class for all forecasting models. + + CRITICAL: All implementations must be deterministic with fixed random_state. + + Interface follows scikit-learn conventions: + - fit(y, X=None) -> self + - predict(horizon, X=None) -> np.ndarray + - get_params() -> dict + - set_params(**params) -> self + """ +``` + +**Model Types Implemented**: + +| Model | Class | Description | Key Parameter | +|-------|-------|-------------|---------------| +| `naive` | `NaiveForecaster` | Predicts last observed value for all horizons | None | +| `seasonal_naive` | `SeasonalNaiveForecaster` | Predicts value from same season in previous cycle | `season_length` (default: 7) | +| `moving_average` | `MovingAverageForecaster` | Predicts mean of last N observations | `window_size` (default: 7) | +| `lightgbm` | (Placeholder) | LightGBM regressor (feature-flagged) | `n_estimators`, `max_depth`, `learning_rate` | + +**FitResult Dataclass**: +```python +@dataclass +class FitResult: + fitted: bool + n_observations: int + train_start: date_type + train_end: date_type + metrics: dict[str, float] +``` + +### 2. Model Configuration Schemas + +**File**: `app/features/forecasting/schemas.py` + +Pydantic v2 schemas with frozen configs for reproducibility: + +| Schema | Purpose | +|--------|---------| +| `ModelConfigBase` | Base with `schema_version` and `config_hash()` | +| `NaiveModelConfig` | Config for naive forecaster | +| `SeasonalNaiveModelConfig` | Config with `season_length` (1-365) | +| `MovingAverageModelConfig` | Config with `window_size` (1-90) | +| `LightGBMModelConfig` | Config for LightGBM (n_estimators, max_depth, learning_rate) | +| `TrainRequest` | API request with store_id, product_id, date range, config | +| `TrainResponse` | Response with model_path, n_observations, duration_ms | +| `PredictRequest` | Request with horizon (1-90), model_path | +| `PredictResponse` | Response with forecast points | +| `ForecastPoint` | Single forecast with date, value, optional bounds | + +**Key Features**: +- Frozen models (`frozen=True`) for immutability +- Schema versioning for registry storage +- Deterministic `config_hash()` for deduplication +- Strict validation (positive lags, valid ranges) + +### 3. Model Persistence + +**File**: `app/features/forecasting/persistence.py` + +Joblib-based persistence with versioned bundles: + +```python +@dataclass +class ModelBundle: + """Bundled model with metadata for serialization.""" + model: BaseForecaster + config: ModelConfig + metadata: ModelMetadata + version: str = "1.0" + +def save_model_bundle(bundle: ModelBundle, path: Path) -> None: + """Save model bundle to disk using joblib.""" + +def load_model_bundle(path: Path) -> ModelBundle: + """Load model bundle from disk.""" +``` + +**Bundle Contents**: +- Fitted model instance +- Configuration used for training +- Metadata (store_id, product_id, dates, n_observations) +- Version string for compatibility checking + +### 4. ForecastingService + +**File**: `app/features/forecasting/service.py` + +Core service for model training and prediction: + +```python +class ForecastingService: + """Service for model training and prediction.""" + + async def train_model( + self, + db: AsyncSession, + store_id: int, + product_id: int, + train_start_date: date, + train_end_date: date, + config: ModelConfig, + ) -> TrainResponse: + """Train model on historical data.""" + + async def predict( + self, + store_id: int, + product_id: int, + horizon: int, + model_path: str, + ) -> PredictResponse: + """Generate forecasts using saved model.""" +``` + +**Key Features**: +- Fetches training data from `sales_daily` table +- Uses `model_factory()` to instantiate correct model type +- Validates store/product match on prediction +- Structured logging for all operations + +### 5. API Endpoints + +**File**: `app/features/forecasting/routes.py` + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/forecasting/train` | POST | Train a forecasting model | +| `/forecasting/predict` | POST | Generate forecasts using trained model | + +**Train Request Example**: +```json +{ + "store_id": 1, + "product_id": 101, + "train_start_date": "2024-01-01", + "train_end_date": "2024-12-31", + "config": { + "model_type": "seasonal_naive", + "season_length": 7 + } +} +``` + +**Train Response Example**: +```json +{ + "store_id": 1, + "product_id": 101, + "model_type": "seasonal_naive", + "model_path": "./artifacts/models/store_1_product_101_seasonal_naive_20240131_abc123.joblib", + "config_hash": "a1b2c3d4e5f6g7h8", + "n_observations": 365, + "train_start_date": "2024-01-01", + "train_end_date": "2024-12-31", + "duration_ms": 45.23 +} +``` + +**Predict Response Example**: +```json +{ + "store_id": 1, + "product_id": 101, + "forecasts": [ + {"date": "2025-01-01", "forecast": 42.5, "lower_bound": null, "upper_bound": null}, + {"date": "2025-01-02", "forecast": 38.2, "lower_bound": null, "upper_bound": null} + ], + "model_type": "seasonal_naive", + "config_hash": "a1b2c3d4e5f6g7h8", + "horizon": 14, + "duration_ms": 2.15 +} +``` + +### 6. Test Suite + +**Directory**: `app/features/forecasting/tests/` + +| File | Tests | Coverage | +|------|-------|----------| +| `test_schemas.py` | 20 | Schema validation, config hash, frozen models | +| `test_models.py` | 24 | Model fit/predict, edge cases, params | +| `test_persistence.py` | 15 | Save/load bundles, version compatibility | +| `test_service.py` | 20 | Service integration, validation, logging | + +**Total**: 79 tests + +**Test Strategy**: +- Unit tests for each model type with edge cases +- Determinism tests (same input → same output) +- Bundle round-trip serialization tests +- Service tests with mocked database + +### 7. Example Scripts + +**Directory**: `examples/models/` + +| File | Description | +|------|-------------| +| `baseline_naive.py` | Naive forecaster demo | +| `baseline_seasonal.py` | Seasonal naive with weekly seasonality | +| `baseline_mavg.py` | Moving average with configurable window | + +--- + +## Configuration + +**File**: `app/core/config.py` + +New settings added: + +```python +# Forecasting +forecast_random_seed: int = 42 +forecast_default_horizon: int = 14 +forecast_max_horizon: int = 90 +forecast_model_artifacts_dir: str = "./artifacts/models" +forecast_enable_lightgbm: bool = False +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `forecast_random_seed` | 42 | Random seed for reproducibility | +| `forecast_default_horizon` | 14 | Default forecast horizon in days | +| `forecast_max_horizon` | 90 | Maximum allowed horizon | +| `forecast_model_artifacts_dir` | `./artifacts/models` | Directory for saved models | +| `forecast_enable_lightgbm` | False | Feature flag for LightGBM models | + +--- + +## Directory Structure + +``` +app/features/forecasting/ +├── __init__.py # Module exports +├── models.py # BaseForecaster + implementations +├── schemas.py # Pydantic configuration schemas +├── persistence.py # Joblib save/load utilities +├── service.py # ForecastingService +├── routes.py # FastAPI endpoints +└── tests/ + ├── __init__.py + ├── conftest.py # Test fixtures + ├── test_models.py # Model unit tests + ├── test_schemas.py # Schema validation tests + ├── test_persistence.py # Persistence tests + └── test_service.py # Service integration tests + +examples/models/ +├── baseline_naive.py # Naive forecaster demo +├── baseline_seasonal.py # Seasonal naive demo +└── baseline_mavg.py # Moving average demo +``` + +--- + +## Validation Results + +``` +$ uv run ruff check app/features/forecasting/ +All checks passed! + +$ uv run mypy app/features/forecasting/ +Success: no issues found in 10 source files + +$ uv run pyright app/features/forecasting/ +0 errors, 0 warnings, 0 informations + +$ uv run pytest app/features/forecasting/tests/ -v +79 passed in 1.23s +``` + +--- + +## Logging Events + +| Event | Description | +|-------|-------------| +| `forecasting.train_request_received` | Train request received | +| `forecasting.train_request_completed` | Training completed successfully | +| `forecasting.train_request_failed` | Training failed | +| `forecasting.predict_request_received` | Prediction request received | +| `forecasting.predict_request_completed` | Prediction completed | +| `forecasting.predict_request_failed` | Prediction failed | +| `forecasting.model_saved` | Model bundle saved to disk | +| `forecasting.model_loaded` | Model bundle loaded from disk | + +--- + +## Next Phase Preparation + +Phase 5 (Backtesting) will use the forecasting module to: +1. Train models on rolling/expanding training windows +2. Generate predictions for held-out test periods +3. Calculate accuracy metrics across folds +4. Compare against naive/seasonal baselines + +**Integration Points**: +- `BaseForecaster.fit()` and `predict()` for CV folds +- `model_factory()` for instantiating models per fold +- `ModelConfig.config_hash()` for result deduplication diff --git a/docs/PHASE/5-BACKTESTING.md b/docs/PHASE/5-BACKTESTING.md new file mode 100644 index 00000000..e2193ff9 --- /dev/null +++ b/docs/PHASE/5-BACKTESTING.md @@ -0,0 +1,387 @@ +# Phase 5: Backtesting + +**Date Completed**: 2026-01-31 +**PRP**: [PRP-6-backtesting.md](../../PRPs/PRP-6-backtesting.md) +**Release**: PR #32 + +--- + +## Executive Summary + +Phase 5 implements the Backtesting Framework for ForecastLabAI with CRITICAL time-series cross-validation patterns. The module provides expanding and sliding window strategies with configurable gap parameters to simulate operational data latency, comprehensive accuracy metrics, and mandatory baseline comparisons. + +**Key Achievement**: Time-based CV with zero leakage through explicit temporal ordering and built-in leakage validation checks. + +--- + +## Deliverables + +### 1. TimeSeriesSplitter + +**File**: `app/features/backtesting/splitter.py` + +Core splitter for generating train/test splits: + +```python +class TimeSeriesSplitter: + """Generate time-based CV splits with expanding or sliding window. + + CRITICAL: Respects temporal order - no future data in training. + + Expanding Window Example (n_splits=3, min_train=30, horizon=14): + Fold 0: [0..30] train, [30..44] test + Fold 1: [0..44] train, [44..58] test (training grows) + Fold 2: [0..58] train, [58..72] test + + Sliding Window Example (n_splits=3, min_train=30, horizon=14): + Fold 0: [0..30] train, [30..44] test + Fold 1: [14..44] train, [44..58] test (training slides) + Fold 2: [28..58] train, [58..72] test + """ +``` + +**Split Strategies**: + +| Strategy | Training Window | Use Case | +|----------|----------------|----------| +| `expanding` | Grows from start with each fold | More training data, detect concept drift | +| `sliding` | Fixed size, slides forward | Consistent training size, recent patterns | + +**TimeSeriesSplit Dataclass**: +```python +@dataclass +class TimeSeriesSplit: + fold_index: int + train_indices: np.ndarray + test_indices: np.ndarray + train_dates: list[date] + test_dates: list[date] +``` + +**Key Methods**: +- `split(dates, y)` - Generate train/test splits +- `get_boundaries(dates, y)` - Get split boundaries without full objects +- `validate_no_leakage(dates, y)` - Verify no future data in training + +### 2. MetricsCalculator + +**File**: `app/features/backtesting/metrics.py` + +Comprehensive metrics for forecast evaluation: + +```python +class MetricsCalculator: + """Calculate forecasting accuracy metrics. + + Supported Metrics: + - MAE: Mean Absolute Error + - sMAPE: Symmetric Mean Absolute Percentage Error (0-200 scale) + - WAPE: Weighted Absolute Percentage Error + - Bias: Forecast Bias (positive = under-forecast) + - Stability: Coefficient of variation of per-fold metrics + """ +``` + +**Metrics Formulas**: + +| Metric | Formula | Interpretation | +|--------|---------|----------------| +| MAE | `mean(\|actual - predicted\|)` | Average absolute error | +| sMAPE | `100/n * sum(2 * \|A - F\| / (\|A\| + \|F\|))` | Symmetric percentage error (0-200) | +| WAPE | `sum(\|A - F\|) / sum(\|A\|) * 100` | Weighted error for intermittent series | +| Bias | `mean(actual - predicted)` | Positive = under-forecast | +| Stability | `std(metrics) / \|mean(metrics)\| * 100` | Lower = more stable | + +**Edge Case Handling**: +- Empty arrays return `NaN` +- Zero denominator handled with warnings +- sMAPE: when both actual and forecast are 0, contributes 0 (perfect forecast) + +### 3. Configuration Schemas + +**File**: `app/features/backtesting/schemas.py` + +Pydantic v2 schemas for backtest configuration: + +| Schema | Purpose | +|--------|---------| +| `SplitConfig` | Strategy, n_splits, min_train_size, gap, horizon | +| `BacktestConfig` | Complete config with model_config and options | +| `SplitBoundary` | Fold boundary dates and sizes | +| `FoldResult` | Per-fold actuals, predictions, metrics | +| `ModelBacktestResult` | All folds + aggregated metrics | +| `BacktestRequest` | API request schema | +| `BacktestResponse` | API response with all results | + +**SplitConfig Example**: +```python +SplitConfig( + strategy="expanding", # or "sliding" + n_splits=5, # 2-20 folds + min_train_size=30, # Minimum training samples + gap=0, # Gap between train end and test start + horizon=14, # Forecast horizon per fold +) +``` + +**Gap Parameter**: +- Simulates operational data latency +- `gap=1` means 1 day between train_end and test_start +- Valid range: 0-30 days +- Validation: `horizon > gap` (must be meaningful test period) + +### 4. BacktestingService + +**File**: `app/features/backtesting/service.py` + +Core service for running backtests: + +```python +class BacktestingService: + """Service for running time-series backtests.""" + + async def run_backtest( + self, + db: AsyncSession, + store_id: int, + product_id: int, + start_date: date, + end_date: date, + config: BacktestConfig, + ) -> BacktestResponse: + """Run backtest for a single series.""" +``` + +**Backtest Flow**: +1. Fetch data from `sales_daily` table +2. Validate sufficient data for requested splits +3. Generate splits using TimeSeriesSplitter +4. For each fold: + - Instantiate model via `model_factory()` + - Fit on training data + - Predict for test period + - Calculate metrics +5. Aggregate metrics across folds +6. Run baseline comparisons (naive, seasonal_naive) +7. Generate comparison summary with improvement percentages + +### 5. API Endpoint + +**File**: `app/features/backtesting/routes.py` + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/backtesting/run` | POST | Execute backtest for a series | + +**Request Example**: +```json +{ + "store_id": 1, + "product_id": 101, + "start_date": "2024-01-01", + "end_date": "2024-12-31", + "config": { + "schema_version": "1.0", + "split_config": { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 14 + }, + "model_config_main": { + "model_type": "seasonal_naive", + "season_length": 7 + }, + "include_baselines": true, + "store_fold_details": true + } +} +``` + +**Response Structure**: +```json +{ + "backtest_id": "abc123def456", + "store_id": 1, + "product_id": 101, + "config_hash": "a1b2c3d4e5f6g7h8", + "split_config": { ... }, + "main_model_results": { + "model_type": "seasonal_naive", + "config_hash": "x1y2z3...", + "fold_results": [ ... ], + "aggregated_metrics": { + "mae": 3.45, + "smape": 12.34, + "wape": 8.76, + "bias": -0.23 + }, + "metric_std": { + "mae": 0.45, + "smape": 1.23 + } + }, + "baseline_results": [ ... ], + "comparison_summary": { + "vs_naive": { + "mae_improvement_pct": 15.2, + "smape_improvement_pct": 8.7 + }, + "vs_seasonal_naive": { + "mae_improvement_pct": 3.1, + "smape_improvement_pct": 2.4 + } + }, + "duration_ms": 245.67, + "leakage_check_passed": true +} +``` + +### 6. Test Suite + +**Directory**: `app/features/backtesting/tests/` + +| File | Tests | Coverage | +|------|-------|----------| +| `test_schemas.py` | 18 | Schema validation, frozen models, config hash | +| `test_splitter.py` | 32 | Expanding/sliding strategies, gap, leakage validation | +| `test_metrics.py` | 24 | All metrics, edge cases, aggregation | +| `test_service.py` | 25 | Service logic, mocked DB | +| `test_routes_integration.py` | 8 | Route integration with real DB | +| `test_service_integration.py` | 8 | Service integration with real DB | + +**Total**: 115 tests (99 unit + 16 integration) + +**Test Data Strategy**: +- Use 120 days of sequential sales data (quantity = day number 1-120) +- Sequential values make leakage mathematically detectable +- Integration tests require PostgreSQL via `docker-compose up -d` + +### 7. Example Scripts + +**Directory**: `examples/backtest/` + +| File | Description | +|------|-------------| +| `run_backtest.py` | Full backtest API call example | +| `inspect_splits.py` | Visualize split boundaries | +| `metrics_demo.py` | Metrics calculation examples | + +--- + +## Configuration + +**File**: `app/core/config.py` + +New settings added: + +```python +# Backtesting +backtest_max_splits: int = 20 +backtest_default_min_train_size: int = 30 +backtest_max_gap: int = 30 +backtest_results_dir: str = "./artifacts/backtests" +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `backtest_max_splits` | 20 | Maximum allowed CV folds | +| `backtest_default_min_train_size` | 30 | Default minimum training observations | +| `backtest_max_gap` | 30 | Maximum allowed gap in days | +| `backtest_results_dir` | `./artifacts/backtests` | Directory for saved results | + +--- + +## Directory Structure + +``` +app/features/backtesting/ +├── __init__.py # Module exports +├── schemas.py # Pydantic configuration schemas +├── splitter.py # TimeSeriesSplitter +├── metrics.py # MetricsCalculator +├── service.py # BacktestingService +├── routes.py # FastAPI endpoints +└── tests/ + ├── __init__.py + ├── conftest.py # Test fixtures + ├── test_schemas.py # Schema validation tests + ├── test_splitter.py # Splitter unit tests + ├── test_metrics.py # Metrics unit tests + ├── test_service.py # Service unit tests + ├── test_routes_integration.py # Route integration tests + └── test_service_integration.py # Service integration tests + +examples/backtest/ +├── run_backtest.py # Full backtest example +├── inspect_splits.py # Split visualization +└── metrics_demo.py # Metrics demo +``` + +--- + +## Validation Results + +``` +$ uv run ruff check app/features/backtesting/ +All checks passed! + +$ uv run mypy app/features/backtesting/ +Success: no issues found in 12 source files + +$ uv run pyright app/features/backtesting/ +0 errors, 0 warnings, 0 informations + +$ uv run pytest app/features/backtesting/tests/ -v +115 passed in 2.34s + +$ uv run pytest app/features/backtesting/tests/ -v -m integration +16 passed in 4.56s +``` + +--- + +## Logging Events + +| Event | Description | +|-------|-------------| +| `backtesting.request_received` | Backtest request received | +| `backtesting.request_completed` | Backtest completed successfully | +| `backtesting.request_failed` | Backtest failed | +| `backtesting.fold_started` | CV fold started | +| `backtesting.fold_completed` | CV fold completed | +| `backtesting.leakage_check_passed` | Leakage validation passed | +| `backtesting.leakage_check_failed` | Leakage validation failed | + +--- + +## Leakage Prevention + +**Built-in Checks**: +1. `TimeSeriesSplitter.validate_no_leakage()` verifies: + - `train_end < test_start` for all folds + - Gap is respected + - No overlap between train and test indices + +2. Response includes `leakage_check_passed: bool` + +**Test Strategy**: +- Sequential values (1, 2, 3...) so leakage is detectable +- Assert feature at row i never uses data from rows > i +- Test gap enforcement across folds + +--- + +## Next Phase Preparation + +Phase 6 (Model Registry) will use the backtesting module to: +1. Store backtest configuration and results per run +2. Track model performance over time +3. Compare runs with different configurations +4. Maintain lineage from data → features → model → backtest + +**Integration Points**: +- `BacktestConfig.config_hash()` for registry deduplication +- `ModelBacktestResult.aggregated_metrics` for run comparison +- `FoldResult` for detailed audit trail diff --git a/docs/PHASE/6-MODEL_REGISTRY.md b/docs/PHASE/6-MODEL_REGISTRY.md new file mode 100644 index 00000000..0fcc2124 --- /dev/null +++ b/docs/PHASE/6-MODEL_REGISTRY.md @@ -0,0 +1,434 @@ +# Phase 6: Model Registry + +**Date Completed**: 2026-02-01 +**PRP**: [PRP-7-model-registry.md](../../PRPs/PRP-7-model-registry.md) +**Release**: PR #35 + +--- + +## Executive Summary + +Phase 6 implements the Model Registry for ForecastLabAI, providing comprehensive run tracking with deployment aliases and artifact integrity verification. The module enables reproducible ML workflows by capturing full experiment lineage: configurations, data windows, metrics, and artifacts with SHA-256 checksums. + +**Key Achievement**: Complete run lifecycle management with state machine validation and secure artifact storage with path traversal prevention. + +--- + +## Deliverables + +### 1. ORM Models + +**File**: `app/features/registry/models.py` + +SQLAlchemy models for registry storage: + +```python +class RunStatus(str, Enum): + """Valid states for a model run. + + State transitions: + - PENDING -> RUNNING -> SUCCESS | FAILED + - Any state except ARCHIVED -> ARCHIVED + """ + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + ARCHIVED = "archived" +``` + +**ModelRun Table**: + +| Column | Type | Description | +|--------|------|-------------| +| `id` | Integer | Primary key | +| `run_id` | String(32) | Unique external identifier (UUID hex) | +| `status` | String(20) | Current lifecycle state | +| `model_type` | String(50) | Type of model | +| `model_config` | JSONB | Full model configuration | +| `feature_config` | JSONB | Feature engineering config (nullable) | +| `config_hash` | String(16) | Hash for deduplication | +| `data_window_start` | Date | Training data start | +| `data_window_end` | Date | Training data end | +| `store_id` | Integer | Store ID | +| `product_id` | Integer | Product ID | +| `metrics` | JSONB | Performance metrics | +| `artifact_uri` | String(500) | Relative path to artifact | +| `artifact_hash` | String(64) | SHA-256 checksum | +| `artifact_size_bytes` | Integer | File size | +| `runtime_info` | JSONB | Python/library versions | +| `agent_context` | JSONB | Agent/session IDs | +| `git_sha` | String(40) | Git commit hash | +| `error_message` | String(2000) | Error details (FAILED runs) | +| `started_at` | DateTime(tz) | Run start time | +| `completed_at` | DateTime(tz) | Run completion time | +| `created_at` | DateTime(tz) | Record creation (mixin) | +| `updated_at` | DateTime(tz) | Record update (mixin) | + +**DeploymentAlias Table**: + +| Column | Type | Description | +|--------|------|-------------| +| `id` | Integer | Primary key | +| `alias_name` | String(100) | Unique alias name | +| `run_id` | Integer | Foreign key to ModelRun | +| `description` | String(500) | Optional description | + +**Indexes**: +- `ix_model_run_run_id` (unique) +- `ix_model_run_status` +- `ix_model_run_model_type` +- `ix_model_run_store_product` (composite) +- `ix_model_run_data_window` (composite) +- `ix_model_run_model_config_gin` (GIN for JSONB) +- `ix_model_run_metrics_gin` (GIN for JSONB) + +### 2. State Machine + +**Valid Transitions**: + +```python +VALID_TRANSITIONS: dict[RunStatus, set[RunStatus]] = { + RunStatus.PENDING: {RunStatus.RUNNING, RunStatus.ARCHIVED}, + RunStatus.RUNNING: {RunStatus.SUCCESS, RunStatus.FAILED, RunStatus.ARCHIVED}, + RunStatus.SUCCESS: {RunStatus.ARCHIVED}, + RunStatus.FAILED: {RunStatus.ARCHIVED}, + RunStatus.ARCHIVED: set(), # Terminal state +} +``` + +``` +PENDING ──→ RUNNING ──→ SUCCESS ──→ ARCHIVED + │ │ │ ↑ + │ └───→ FAILED ───────────→│ + └──────────────────────────────────→─┘ +``` + +### 3. Storage Provider + +**File**: `app/features/registry/storage.py` + +Abstract interface with LocalFS implementation: + +```python +class AbstractStorageProvider(ABC): + """Abstract base class for artifact storage.""" + + @abstractmethod + def save(self, source_path: Path, artifact_uri: str) -> tuple[str, int]: + """Save artifact, returns (sha256_hash, size_bytes).""" + + @abstractmethod + def load(self, artifact_uri: str, expected_hash: str | None = None) -> Path: + """Load artifact with optional hash verification.""" + + @abstractmethod + def delete(self, artifact_uri: str) -> bool: + """Delete artifact, returns True if deleted.""" + + @abstractmethod + def exists(self, artifact_uri: str) -> bool: + """Check if artifact exists.""" + + @staticmethod + def compute_hash(file_path: Path) -> str: + """Compute SHA-256 hash of file.""" +``` + +**LocalFSProvider**: +- Default provider for development/single-node +- Root directory from `registry_artifact_root` setting +- **CRITICAL**: Path traversal prevention via `relative_to()` validation +- SHA-256 checksum on save and optional verify on load + +**Security**: +```python +def _resolve_path(self, artifact_uri: str) -> Path: + full_path = (self.root_dir / artifact_uri).resolve() + # Security: ensure path is within root + try: + full_path.relative_to(self.root_dir) + except ValueError: + raise StorageError(f"Path traversal attempt: {artifact_uri}") + return full_path +``` + +### 4. Registry Schemas + +**File**: `app/features/registry/schemas.py` + +| Schema | Purpose | +|--------|---------| +| `RunStatus` | Enum for run lifecycle states | +| `RuntimeInfo` | Python/library versions snapshot | +| `AgentContext` | Agent ID and session ID | +| `RunCreate` | Create run request | +| `RunUpdate` | Update run (status, metrics, artifacts) | +| `RunResponse` | Full run details response | +| `RunListResponse` | Paginated list of runs | +| `AliasCreate` | Create/update alias request | +| `AliasResponse` | Alias details with run info | +| `RunCompareResponse` | Side-by-side run comparison | + +**Alias Naming Rules**: +- Pattern: `^[a-z0-9][a-z0-9\-_]*$` +- Start with lowercase letter or number +- Contains letters, numbers, hyphens, underscores +- Maximum 100 characters + +### 5. RegistryService + +**File**: `app/features/registry/service.py` + +Core service for registry operations: + +```python +class RegistryService: + """Service for model run tracking and alias management.""" + + async def create_run(self, db: AsyncSession, run_data: RunCreate) -> RunResponse + async def get_run(self, db: AsyncSession, run_id: str) -> RunResponse | None + async def list_runs(self, db, page, page_size, filters...) -> RunListResponse + async def update_run(self, db, run_id, update_data) -> RunResponse | None + async def create_alias(self, db, alias_data: AliasCreate) -> AliasResponse + async def get_alias(self, db, alias_name) -> AliasResponse | None + async def list_aliases(self, db) -> list[AliasResponse] + async def delete_alias(self, db, alias_name) -> bool + async def compare_runs(self, db, run_id_a, run_id_b) -> RunCompareResponse | None +``` + +**Duplicate Detection**: +Based on `registry_duplicate_policy` setting: +- `allow`: Always create new runs +- `deny`: Reject if duplicate config+window exists +- `detect`: Log warning but allow creation + +**Runtime Capture**: +Automatically captures Python and library versions: +```python +RuntimeInfo( + python_version="3.12.0", + sklearn_version="1.4.0", + numpy_version="1.26.0", + pandas_version="2.1.0", + joblib_version="1.3.0", +) +``` + +### 6. API Endpoints + +**File**: `app/features/registry/routes.py` + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/registry/runs` | POST | Create a new run | +| `/registry/runs` | GET | List runs with filters | +| `/registry/runs/{run_id}` | GET | Get run details | +| `/registry/runs/{run_id}` | PATCH | Update run status/metrics/artifacts | +| `/registry/runs/{run_id}/verify` | GET | Verify artifact integrity | +| `/registry/aliases` | POST | Create/update alias | +| `/registry/aliases` | GET | List all aliases | +| `/registry/aliases/{alias_name}` | GET | Get alias details | +| `/registry/aliases/{alias_name}` | DELETE | Delete alias | +| `/registry/compare/{run_id_a}/{run_id_b}` | GET | Compare two runs | + +**Create Run Request**: +```json +{ + "model_type": "seasonal_naive", + "model_config": { + "model_type": "seasonal_naive", + "season_length": 7 + }, + "data_window_start": "2024-01-01", + "data_window_end": "2024-12-31", + "store_id": 1, + "product_id": 101, + "agent_context": { + "agent_id": "backtest-agent-v1", + "session_id": "abc123" + } +} +``` + +**Update Run Request**: +```json +{ + "status": "success", + "metrics": { + "mae": 3.45, + "smape": 12.34 + }, + "artifact_uri": "runs/abc123/model.joblib", + "artifact_hash": "sha256:a1b2c3...", + "artifact_size_bytes": 102400 +} +``` + +**Compare Response**: +```json +{ + "run_a": { ... }, + "run_b": { ... }, + "config_diff": { + "season_length": {"a": 7, "b": 14} + }, + "metrics_diff": { + "mae": {"a": 3.45, "b": 4.12, "diff": -0.67}, + "smape": {"a": 12.34, "b": 15.67, "diff": -3.33} + } +} +``` + +### 7. Database Migration + +**File**: `alembic/versions/a2f7b3c8d901_create_model_registry_tables.py` + +Creates: +- `model_run` table with all columns and indexes +- `deployment_alias` table with foreign key +- Check constraints for status and data window validity + +### 8. Test Suite + +**Directory**: `app/features/registry/tests/` + +| File | Tests | Coverage | +|------|-------|----------| +| `test_schemas.py` | 22 | Schema validation, config hash, transitions | +| `test_storage.py` | 28 | LocalFS save/load, hash verification, path security | +| `test_service.py` | 35 | Service operations, state machine, duplicates | +| `test_routes.py` | 42 | All endpoints, error cases, pagination | + +**Total**: 127 tests (103 unit + 24 integration) + +**Integration Tests**: +- Require PostgreSQL via `docker-compose up -d` +- Test full CRUD lifecycle +- Verify JSONB queries work correctly +- Test GIN indexes for containment queries + +### 9. Example Script + +**File**: `examples/registry_demo.py` + +Demonstrates: +- Creating a run +- Transitioning through states +- Adding metrics and artifacts +- Creating deployment aliases +- Comparing runs + +--- + +## Configuration + +**File**: `app/core/config.py` + +New settings added: + +```python +# Registry +registry_artifact_root: str = "./artifacts/registry" +registry_duplicate_policy: Literal["allow", "deny", "detect"] = "detect" +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `registry_artifact_root` | `./artifacts/registry` | Root directory for artifacts | +| `registry_duplicate_policy` | `detect` | How to handle duplicate runs | + +--- + +## Directory Structure + +``` +app/features/registry/ +├── __init__.py # Module exports +├── models.py # SQLAlchemy ORM models +├── schemas.py # Pydantic request/response schemas +├── storage.py # AbstractStorageProvider + LocalFSProvider +├── service.py # RegistryService +├── routes.py # FastAPI endpoints +└── tests/ + ├── __init__.py + ├── conftest.py # Test fixtures + ├── test_schemas.py # Schema validation tests + ├── test_storage.py # Storage provider tests + ├── test_service.py # Service unit tests + └── test_routes.py # Route integration tests + +alembic/versions/ +└── a2f7b3c8d901_create_model_registry_tables.py + +examples/ +└── registry_demo.py # Registry usage demo +``` + +--- + +## Validation Results + +``` +$ uv run ruff check app/features/registry/ +All checks passed! + +$ uv run mypy app/features/registry/ +Success: no issues found in 11 source files + +$ uv run pyright app/features/registry/ +0 errors, 0 warnings, 0 informations + +$ uv run pytest app/features/registry/tests/ -v +127 passed in 3.45s + +$ uv run pytest app/features/registry/tests/ -v -m integration +24 passed in 5.67s +``` + +--- + +## Logging Events + +| Event | Description | +|-------|-------------| +| `registry.create_run_request_received` | Run creation request received | +| `registry.create_run_request_completed` | Run created successfully | +| `registry.create_run_request_failed` | Run creation failed | +| `registry.update_run_request_received` | Run update request received | +| `registry.update_run_request_completed` | Run updated successfully | +| `registry.update_run_request_failed` | Run update failed | +| `registry.create_alias_request_received` | Alias creation received | +| `registry.create_alias_request_completed` | Alias created/updated | +| `registry.delete_alias_request_received` | Alias deletion received | +| `registry.delete_alias_request_completed` | Alias deleted | +| `registry.artifact_saved` | Artifact saved to storage | +| `registry.artifact_deleted` | Artifact deleted | +| `registry.checksum_mismatch` | Artifact hash verification failed | +| `registry.path_traversal_attempt` | Path traversal attack detected | +| `registry.duplicate_run_detected` | Duplicate run detected (warn/deny) | + +--- + +## Security Considerations + +1. **Path Traversal Prevention**: All artifact URIs validated to stay within root +2. **SHA-256 Integrity**: Checksums computed on save, verified on load +3. **State Machine Enforcement**: Invalid transitions rejected +4. **Alias Validation**: Only SUCCESS runs can have aliases +5. **Input Validation**: Pydantic schemas with strict constraints + +--- + +## Next Phase Preparation + +Phase 7 (RAG Knowledge Base) will integrate with the registry to: +1. Index model configurations and metrics for retrieval +2. Enable natural language queries about model performance +3. Provide evidence-grounded answers with run citations +4. Support experiment comparison queries + +**Integration Points**: +- `ModelRun.model_config` and `metrics` JSONB for embedding +- `RunCompareResponse` for structured comparison answers +- `DeploymentAlias` for production model references diff --git a/examples/registry_demo.py b/examples/registry_demo.py new file mode 100644 index 00000000..99d997bf --- /dev/null +++ b/examples/registry_demo.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python +"""Demonstrate model registry workflow. + +Usage: + uv run python examples/registry_demo.py + +This script demonstrates: +1. Creating a model run +2. Transitioning through lifecycle states +3. Recording metrics and artifact info +4. Creating deployment aliases +5. Comparing runs + +Prerequisites: + - PostgreSQL running (docker-compose up -d) + - Database migrated (uv run alembic upgrade head) + - API running (uv run uvicorn app.main:app --reload --port 8123) +""" + +import json +import sys +from datetime import date + +import httpx + +API_BASE = "http://localhost:8123" + + +def print_section(title: str) -> None: + """Print a section header.""" + print(f"\n{'=' * 60}") + print(f" {title}") + print(f"{'=' * 60}\n") + + +def print_response(response: httpx.Response, label: str = "") -> dict: + """Print HTTP response details.""" + data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) + status_emoji = "✓" if response.status_code < 400 else "✗" + print(f"{status_emoji} {label} [{response.status_code}]") + if data: + print(json.dumps(data, indent=2, default=str)) + return data + + +def main() -> int: + """Run the registry demo workflow.""" + print_section("ForecastLabAI - Model Registry Demo") + + client = httpx.Client(base_url=API_BASE, timeout=30) + + # Check API is running + try: + health = client.get("/health") + if health.status_code != 200: + print(f"API not healthy: {health.status_code}") + return 1 + except httpx.ConnectError: + print(f"Cannot connect to API at {API_BASE}") + print("Start the API with: uv run uvicorn app.main:app --reload --port 8123") + return 1 + + print("✓ API is healthy\n") + + # ========================================================================== + # Step 1: Create a model run + # ========================================================================== + print_section("Step 1: Create a Model Run") + + run_request = { + "model_type": "seasonal_naive", + "model_config": { + "season_length": 7, + "strategy": "repeat_pattern", + }, + "feature_config": { + "lags": [1, 7, 14], + "rolling_windows": [7, 14, 28], + }, + "data_window_start": str(date(2024, 1, 1)), + "data_window_end": str(date(2024, 3, 31)), + "store_id": 1, + "product_id": 42, + "agent_context": { + "agent_id": "demo-agent", + "session_id": "demo-session-001", + }, + "git_sha": "abc123def456", + } + + print("Request body:") + print(json.dumps(run_request, indent=2)) + print() + + response = client.post("/registry/runs", json=run_request) + run_data = print_response(response, "POST /registry/runs") + + if response.status_code != 201: + print("\nFailed to create run. Exiting.") + return 1 + + run_id = run_data["run_id"] + print(f"\n→ Created run: {run_id}") + print(f"→ Config hash: {run_data['config_hash']}") + print(f"→ Status: {run_data['status']}") + + # ========================================================================== + # Step 2: Transition to RUNNING + # ========================================================================== + print_section("Step 2: Start the Run (PENDING → RUNNING)") + + response = client.patch(f"/registry/runs/{run_id}", json={"status": "running"}) + run_data = print_response(response, f"PATCH /registry/runs/{run_id}") + + print(f"\n→ Status: {run_data['status']}") + print(f"→ Started at: {run_data['started_at']}") + + # ========================================================================== + # Step 3: Complete with SUCCESS and metrics + # ========================================================================== + print_section("Step 3: Complete the Run (RUNNING → SUCCESS)") + + update_request = { + "status": "success", + "metrics": { + "mae": 12.5, + "smape": 8.3, + "wape": 0.065, + "bias": -0.02, + "stability_index": 0.92, + }, + "artifact_uri": f"models/{run_id[:8]}/model.pkl", + "artifact_hash": "abc123def456789012345678901234567890abcdef0123456789012345678901", + "artifact_size_bytes": 15360, + } + + print("Update request:") + print(json.dumps(update_request, indent=2)) + print() + + response = client.patch(f"/registry/runs/{run_id}", json=update_request) + run_data = print_response(response, f"PATCH /registry/runs/{run_id}") + + print(f"\n→ Status: {run_data['status']}") + print(f"→ Completed at: {run_data['completed_at']}") + print(f"→ MAE: {run_data['metrics']['mae']}") + + # ========================================================================== + # Step 4: Create deployment alias + # ========================================================================== + print_section("Step 4: Create Deployment Alias") + + alias_request = { + "alias_name": "demo-production", + "run_id": run_id, + "description": "Production model for demo store/product", + } + + response = client.post("/registry/aliases", json=alias_request) + alias_data = print_response(response, "POST /registry/aliases") + + print(f"\n→ Alias '{alias_data['alias_name']}' → run {alias_data['run_id'][:12]}...") + + # ========================================================================== + # Step 5: Create another run for comparison + # ========================================================================== + print_section("Step 5: Create Second Run for Comparison") + + run2_request = { + "model_type": "naive", + "model_config": { + "strategy": "last_value", + }, + "data_window_start": str(date(2024, 1, 1)), + "data_window_end": str(date(2024, 3, 31)), + "store_id": 1, + "product_id": 42, + } + + response = client.post("/registry/runs", json=run2_request) + run2_data = print_response(response, "POST /registry/runs") + run2_id = run2_data["run_id"] + + # Transition to success + client.patch(f"/registry/runs/{run2_id}", json={"status": "running"}) + response = client.patch( + f"/registry/runs/{run2_id}", + json={ + "status": "success", + "metrics": {"mae": 18.2, "smape": 12.1, "wape": 0.095}, + }, + ) + run2_data = response.json() + + print(f"\n→ Created comparison run: {run2_id[:12]}...") + + # ========================================================================== + # Step 6: Compare runs + # ========================================================================== + print_section("Step 6: Compare Runs") + + response = client.get(f"/registry/compare/{run_id}/{run2_id}") + compare_data = print_response(response, "GET /registry/compare/...") + + print("\n→ Configuration differences:") + for key, values in compare_data["config_diff"].items(): + print(f" {key}: {values['a']} vs {values['b']}") + + print("\n→ Metrics differences:") + for metric, values in compare_data["metrics_diff"].items(): + if values["diff"] is not None: + diff_pct = values["diff"] / values["b"] * 100 if values["b"] else 0 + print( + f" {metric}: {values['a']:.2f} vs {values['b']:.2f} (Δ{values['diff']:+.2f}, {diff_pct:+.1f}%)" + ) + + # ========================================================================== + # Step 7: List runs and aliases + # ========================================================================== + print_section("Step 7: List Runs and Aliases") + + response = client.get("/registry/runs?status=success&page_size=5") + list_data = print_response(response, "GET /registry/runs?status=success") + print(f"\n→ Found {list_data['total']} successful runs") + + response = client.get("/registry/aliases") + aliases = print_response(response, "GET /registry/aliases") + print(f"\n→ Found {len(aliases)} aliases") + + # ========================================================================== + # Cleanup info + # ========================================================================== + print_section("Demo Complete!") + + print("Summary:") + print(f" - Created runs: {run_id[:12]}..., {run2_id[:12]}...") + print(" - Created alias: demo-production") + print() + print("To clean up, delete the alias and runs:") + print(f" curl -X DELETE {API_BASE}/registry/aliases/demo-production") + print() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml index 7f719bcc..a4eb1257 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ python_version = "3.12" strict = true show_error_codes = true warn_unused_ignores = true +plugins = ["pydantic.mypy"] # Practical adjustments disallow_untyped_defs = true @@ -114,6 +115,11 @@ disallow_incomplete_defs = true check_untyped_defs = true disallow_untyped_decorators = false # FastAPI decorators aren't typed +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true + [[tool.mypy.overrides]] module = [ "*.tests.*", diff --git a/tests/conftest.py b/tests/conftest.py index fe6559e1..1f190718 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings -from app.core.database import Base from app.main import app @@ -23,16 +22,12 @@ async def client(): async def db_session(): """Create async database session for integration tests. - This fixture creates all tables, provides a session, and cleans up after. - Requires PostgreSQL to be running (docker-compose up -d). + Uses existing tables from migrations. Rolls back changes after each test. + Requires PostgreSQL to be running (docker-compose up -d) and migrations applied. """ settings = get_settings() engine = create_async_engine(settings.database_url, echo=False) - # Create tables - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - # Create session async_session_maker = async_sessionmaker( engine, @@ -44,10 +39,7 @@ async def db_session(): try: yield session finally: + # Clean up test data by rolling back any uncommitted changes await session.rollback() - # Cleanup: drop all tables - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) - await engine.dispose() diff --git a/uv.lock b/uv.lock index 9dbe5217..85d3d0c8 100644 --- a/uv.lock +++ b/uv.lock @@ -216,7 +216,7 @@ wheels = [ [[package]] name = "forecastlabai" -version = "0.1.7" +version = "0.1.8" source = { editable = "." } dependencies = [ { name = "alembic" },