Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions INITIAL-7.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
- Artifact storage abstraction:
- local filesystem by default (Settings-driven)
- compatible with future S3-like storage backends
- Lifecycle Management:
- State machine tracking: PENDING | RUNNING | SUCCESS | FAILED | ARCHIVED.
- Deployment Aliases: Mutable pointers (e.g., 'prod-v1') to specific successful runs.
- Metadata & Lineage:
- JSONB storage for ModelConfig, FeatureConfig, and Performance Metrics.
- Runtime Snapshot: Recording Python/Library versions for environment parity.
- Agent Context: Integration of agent_id and session_id for autonomous run traceability.
- Artifact Integrity:
- Checksum-based verification (SHA-256) for all serialized artifacts.
- Storage Strategy:
- Pluggable storage providers (LocalFS, future S3/GCS) via Abstract Registry Interface.

## EXAMPLES:
- `examples/registry/create_run.py` — create run record + persist configs.
Expand All @@ -21,6 +32,8 @@
## DOCUMENTATION:
- Postgres JSONB patterns
- Artifact integrity (hashing) best practices
- https://scalegrid.io/blog/using-jsonb-in-postgresql-how-to-effectively-store-index-json-data-in-postgresql/
- https://www.fortra.com/blog/supply-chain-vulnerability

## OTHER CONSIDERATIONS:
- No hardcoded artifact paths: derived from `ARTIFACT_ROOT` + run_id.
Expand Down
1,253 changes: 1,253 additions & 0 deletions PRPs/PRP-7-model-registry.md

Large diffs are not rendered by default.

46 changes: 44 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ app/
│ ├── ingest/ # Batch upsert endpoints for sales data
│ ├── featuresets/ # Time-safe feature engineering (lags, rolling, calendar)
│ ├── forecasting/ # Model training, prediction, persistence
│ └── backtesting/ # Time-series CV, metrics, baseline comparisons
│ ├── backtesting/ # Time-series CV, metrics, baseline comparisons
│ └── registry/ # Model run tracking, artifacts, deployment aliases
└── main.py # FastAPI entry point

tests/ # Test fixtures and helpers
Expand All @@ -129,7 +130,8 @@ examples/
├── queries/ # Example SQL queries
├── models/ # Baseline model examples (naive, seasonal_naive, moving_average)
├── backtest/ # Backtesting examples (run_backtest, inspect_splits, metrics_demo)
└── compute_features_demo.py # Feature engineering demo
├── compute_features_demo.py # Feature engineering demo
└── registry_demo.py # Model registry workflow demo
scripts/ # Utility scripts
```

Expand Down Expand Up @@ -301,6 +303,46 @@ When `include_baselines=true`, automatically compares against naive and seasonal

See [examples/backtest/](examples/backtest/) for usage examples.

### Model Registry

- `POST /registry/runs` - Create a new model run
- `GET /registry/runs` - List runs with filtering and pagination
- `GET /registry/runs/{run_id}` - Get run details
- `PATCH /registry/runs/{run_id}` - Update run (status, metrics, artifacts)
- `GET /registry/runs/{run_id}/verify` - Verify artifact integrity
- `POST /registry/aliases` - Create or update deployment alias
- `GET /registry/aliases` - List all aliases
- `GET /registry/aliases/{alias_name}` - Get alias details
- `DELETE /registry/aliases/{alias_name}` - Delete an alias
- `GET /registry/compare/{run_id_a}/{run_id_b}` - Compare two runs

**Example Create Run Request:**
```bash
curl -X POST http://localhost:8123/registry/runs \
-H "Content-Type: application/json" \
-d '{
"model_type": "seasonal_naive",
"model_config": {"season_length": 7},
"data_window_start": "2024-01-01",
"data_window_end": "2024-03-31",
"store_id": 1,
"product_id": 1
}'
```

**Run Lifecycle:**
- `pending` → `running` → `success` | `failed` → `archived`
- Aliases can only point to runs with `success` status

**Features:**
- JSONB storage for model_config, metrics, runtime_info
- SHA-256 artifact integrity verification
- Duplicate detection (configurable: allow/deny/detect)
- Runtime environment capture (Python, numpy, pandas versions)
- Agent context tracking for autonomous workflows

See [examples/registry_demo.py](examples/registry_demo.py) for a complete workflow demo.

## API Documentation

Once the server is running:
Expand Down
1 change: 1 addition & 0 deletions alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# Import all models for Alembic autogenerate detection
from app.features.data_platform import models as data_platform_models # noqa: F401
from app.features.registry import models as registry_models # noqa: F401

# Alembic Config object
config = context.config
Expand Down
173 changes: 173 additions & 0 deletions alembic/versions/a2f7b3c8d901_create_model_registry_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""create_model_registry_tables

Revision ID: a2f7b3c8d901
Revises: e1165ebcef61
Create Date: 2026-02-01 10:00:00.000000

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "a2f7b3c8d901"
down_revision: Union[str, None] = "e1165ebcef61"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Apply migration - create model_run and deployment_alias tables."""
# Create model_run table
op.create_table(
"model_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("run_id", sa.String(length=32), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False, server_default="pending"),
# Model configuration
sa.Column("model_type", sa.String(length=50), nullable=False),
sa.Column("model_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("feature_config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("config_hash", sa.String(length=16), nullable=False),
# Data window
sa.Column("data_window_start", sa.Date(), nullable=False),
sa.Column("data_window_end", sa.Date(), nullable=False),
sa.Column("store_id", sa.Integer(), nullable=False),
sa.Column("product_id", sa.Integer(), nullable=False),
# Metrics
sa.Column("metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
# Artifact info
sa.Column("artifact_uri", sa.String(length=500), nullable=True),
sa.Column("artifact_hash", sa.String(length=64), nullable=True),
sa.Column("artifact_size_bytes", sa.Integer(), nullable=True),
# Environment & lineage
sa.Column("runtime_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("agent_context", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("git_sha", sa.String(length=40), nullable=True),
# Error tracking
sa.Column("error_message", sa.String(length=2000), nullable=True),
# Timing
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
# Timestamps (from TimestampMixin)
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
# Constraints
sa.PrimaryKeyConstraint("id"),
sa.CheckConstraint(
"status IN ('pending', 'running', 'success', 'failed', 'archived')",
name="ck_model_run_valid_status",
),
sa.CheckConstraint(
"data_window_end >= data_window_start",
name="ck_model_run_valid_data_window",
),
)

# Create indexes for model_run
op.create_index(op.f("ix_model_run_run_id"), "model_run", ["run_id"], unique=True)
op.create_index(op.f("ix_model_run_status"), "model_run", ["status"], unique=False)
op.create_index(op.f("ix_model_run_model_type"), "model_run", ["model_type"], unique=False)
op.create_index(op.f("ix_model_run_config_hash"), "model_run", ["config_hash"], unique=False)
op.create_index(op.f("ix_model_run_store_id"), "model_run", ["store_id"], unique=False)
op.create_index(op.f("ix_model_run_product_id"), "model_run", ["product_id"], unique=False)

# Composite indexes
op.create_index(
"ix_model_run_store_product", "model_run", ["store_id", "product_id"], unique=False
)
op.create_index(
"ix_model_run_data_window",
"model_run",
["data_window_start", "data_window_end"],
unique=False,
)

# GIN indexes for JSONB containment queries
op.create_index(
"ix_model_run_model_config_gin",
"model_run",
["model_config"],
unique=False,
postgresql_using="gin",
)
op.create_index(
"ix_model_run_metrics_gin",
"model_run",
["metrics"],
unique=False,
postgresql_using="gin",
)

# Create deployment_alias table
op.create_table(
"deployment_alias",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("alias_name", sa.String(length=100), nullable=False),
sa.Column("run_id", sa.Integer(), nullable=False),
sa.Column("description", sa.String(length=500), nullable=True),
# Timestamps (from TimestampMixin)
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
# Constraints
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["run_id"], ["model_run.id"]),
sa.UniqueConstraint("alias_name", name="uq_deployment_alias_name"),
)

# Create indexes for deployment_alias
op.create_index(
op.f("ix_deployment_alias_alias_name"),
"deployment_alias",
["alias_name"],
unique=True,
)
op.create_index(
op.f("ix_deployment_alias_run_id"), "deployment_alias", ["run_id"], unique=False
)


def downgrade() -> None:
"""Revert migration - drop model_run and deployment_alias tables."""
# Drop deployment_alias table and indexes
op.drop_index(op.f("ix_deployment_alias_run_id"), table_name="deployment_alias")
op.drop_index(op.f("ix_deployment_alias_alias_name"), table_name="deployment_alias")
op.drop_table("deployment_alias")

# Drop model_run indexes
op.drop_index("ix_model_run_metrics_gin", table_name="model_run")
op.drop_index("ix_model_run_model_config_gin", table_name="model_run")
op.drop_index("ix_model_run_data_window", table_name="model_run")
op.drop_index("ix_model_run_store_product", table_name="model_run")
op.drop_index(op.f("ix_model_run_product_id"), table_name="model_run")
op.drop_index(op.f("ix_model_run_store_id"), table_name="model_run")
op.drop_index(op.f("ix_model_run_config_hash"), table_name="model_run")
op.drop_index(op.f("ix_model_run_model_type"), table_name="model_run")
op.drop_index(op.f("ix_model_run_status"), table_name="model_run")
op.drop_index(op.f("ix_model_run_run_id"), table_name="model_run")

# Drop model_run table
op.drop_table("model_run")
4 changes: 4 additions & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class Settings(BaseSettings):
backtest_max_gap: int = 30
backtest_results_dir: str = "./artifacts/backtests"

# Registry
registry_artifact_root: str = "./artifacts/registry"
registry_duplicate_policy: Literal["allow", "deny", "detect"] = "detect"

@property
def is_development(self) -> bool:
"""Check if running in development mode."""
Expand Down
4 changes: 2 additions & 2 deletions app/features/backtesting/tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_frozen_config(self):
"""Test SplitConfig is immutable."""
config = SplitConfig()
with pytest.raises(ValidationError):
config.n_splits = 10
config.n_splits = 10 # type: ignore[misc]


class TestBacktestConfig:
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_frozen_config(self):
"""Test BacktestConfig is immutable."""
config = BacktestConfig(model_config_main=NaiveModelConfig())
with pytest.raises(ValidationError):
config.include_baselines = False
config.include_baselines = False # type: ignore[misc]

def test_invalid_schema_version(self):
"""Test invalid schema_version raises error."""
Expand Down
47 changes: 34 additions & 13 deletions app/features/data_platform/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,36 @@
pytest behavior to allow feature tests to be self-contained.
"""

from contextlib import suppress
from datetime import date
from decimal import Decimal

import pytest
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from app.core.config import get_settings
from app.core.database import Base
from app.features.data_platform.models import Calendar, Product, Store
from app.features.data_platform.models import (
Calendar,
InventorySnapshotDaily,
PriceHistory,
Product,
Promotion,
SalesDaily,
Store,
)


@pytest.fixture
async def db_session():
"""Create async database session for integration tests.

This fixture creates all tables, provides a session, and cleans up after.
Requires PostgreSQL to be running (docker-compose up -d).
Uses existing tables from migrations. Cleans up test data after each test.
Requires PostgreSQL to be running (docker-compose up -d) and migrations applied.
"""
settings = get_settings()
engine = create_async_engine(settings.database_url, echo=False)

# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

# Create session
async_session_maker = async_sessionmaker(
engine,
Expand All @@ -42,11 +47,27 @@ async def db_session():
try:
yield session
finally:
await session.rollback()

# Cleanup: drop all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
# Rollback any pending transaction first (required if test caused an error)
with suppress(Exception):
await session.rollback()

# Use a fresh session for cleanup to avoid transaction state issues
async with async_session_maker() as cleanup_session:
with suppress(Exception):
# Clean up test data (delete in correct order due to FK constraints)
await cleanup_session.execute(delete(SalesDaily))
await cleanup_session.execute(delete(InventorySnapshotDaily))
await cleanup_session.execute(delete(PriceHistory))
await cleanup_session.execute(delete(Promotion))
await cleanup_session.execute(delete(Product).where(Product.sku.like("SKU-TEST%")))
await cleanup_session.execute(delete(Product).where(Product.sku.like("TEST-%")))
await cleanup_session.execute(delete(Store).where(Store.code.like("TEST%")))
await cleanup_session.execute(
delete(Calendar).where(
(Calendar.date >= date(2024, 1, 1)) & (Calendar.date <= date(2024, 12, 31))
)
)
await cleanup_session.commit()

await engine.dispose()

Expand Down
2 changes: 1 addition & 1 deletion app/features/featuresets/tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_config_is_frozen(self):
"""Config should be immutable (frozen)."""
config = FeatureSetConfig(name="test")
with pytest.raises(ValidationError):
config.name = "modified"
config.name = "modified" # type: ignore[misc]

def test_rejects_empty_name(self):
"""Empty name should be rejected."""
Expand Down
Loading