diff --git a/.github/workflows/cd-release.yml b/.github/workflows/cd-release.yml index 3f6780bc..adcae99c 100644 --- a/.github/workflows/cd-release.yml +++ b/.github/workflows/cd-release.yml @@ -28,7 +28,7 @@ jobs: upload_url: ${{ steps.release.outputs.upload_url }} steps: - - uses: googleapis/release-please-action@v5 + - uses: googleapis/release-please-action@45996ed1f6d02564a971a2fa1b5860e934307cf7 # v5.0.0 id: release with: # Use PAT to trigger CI workflows on release PRs @@ -52,7 +52,7 @@ jobs: ref: ${{ needs.release-please.outputs.tag_name }} - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7f5d52b0..4d4321c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: ref: ${{ env.CHECKOUT_REF }} - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -48,10 +48,10 @@ jobs: run: uv sync --frozen --all-extras --dev - name: Run Ruff linter - run: uv run ruff check . + run: uv run --frozen ruff check . - name: Run Ruff formatter check - run: uv run ruff format --check . + run: uv run --frozen ruff format --check . typecheck: name: Type Check @@ -62,7 +62,7 @@ jobs: ref: ${{ env.CHECKOUT_REF }} - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -74,10 +74,10 @@ jobs: run: uv sync --frozen --all-extras --dev - name: Run MyPy - run: uv run mypy app/ + run: uv run --frozen mypy app/ - name: Run Pyright - run: uv run pyright app/ + run: uv run --frozen pyright app/ test: name: Test @@ -104,7 +104,7 @@ jobs: ref: ${{ env.CHECKOUT_REF }} - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -118,13 +118,13 @@ jobs: - name: Run migrations env: DATABASE_URL: postgresql+asyncpg://forecastlab:forecastlab@localhost:5432/forecastlab_test - run: uv run alembic upgrade head + run: uv run --frozen alembic upgrade head - name: Run tests env: DATABASE_URL: postgresql+asyncpg://forecastlab:forecastlab@localhost:5432/forecastlab_test APP_ENV: testing - run: uv run pytest -v --tb=short + run: uv run --frozen pytest -v --tb=short migration-check: name: Migration Check @@ -151,7 +151,7 @@ jobs: ref: ${{ env.CHECKOUT_REF }} - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -165,12 +165,12 @@ jobs: - name: Apply migrations to fresh DB env: DATABASE_URL: postgresql+asyncpg://forecastlab:forecastlab@localhost:5432/forecastlab_migration_test - run: uv run alembic upgrade head + run: uv run --frozen alembic upgrade head - name: Verify no pending migrations env: DATABASE_URL: postgresql+asyncpg://forecastlab:forecastlab@localhost:5432/forecastlab_migration_test run: | # Check that current head matches database - uv run alembic current + uv run --frozen alembic current # This would fail if there are unapplied migrations diff --git a/.github/workflows/dependency-check.yml b/.github/workflows/dependency-check.yml index 2e73d599..f65e4f79 100644 --- a/.github/workflows/dependency-check.yml +++ b/.github/workflows/dependency-check.yml @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -76,7 +76,7 @@ jobs: - name: Upload SARIF to GitHub Security if: always() - uses: github/codeql-action/upload-sarif@v4 + uses: github/codeql-action/upload-sarif@c6f931105cb2c34c8f901cc885ba1e2e259cf745 # v4.34.0 with: sarif_file: audit-results.sarif category: dependency-vulnerability-scan diff --git a/.github/workflows/phase-snapshot.yml b/.github/workflows/phase-snapshot.yml index 58f7b46c..792e766f 100644 --- a/.github/workflows/phase-snapshot.yml +++ b/.github/workflows/phase-snapshot.yml @@ -44,7 +44,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -58,24 +58,24 @@ jobs: - name: Lint check id: lint run: | - uv run ruff check . - uv run ruff format --check . + uv run --frozen ruff check . + uv run --frozen ruff format --check . - name: Type check id: typecheck run: | - uv run mypy app/ - uv run pyright app/ + uv run --frozen mypy app/ + uv run --frozen pyright app/ - name: Run migrations id: migration - run: uv run alembic upgrade head + run: uv run --frozen alembic upgrade head - name: Run tests id: test env: APP_ENV: testing - run: uv run pytest -v --tb=short + run: uv run --frozen pytest -v --tb=short create-snapshot: name: Create Audit Snapshot @@ -90,7 +90,7 @@ jobs: fetch-depth: 0 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true diff --git a/.github/workflows/schema-validation.yml b/.github/workflows/schema-validation.yml index c8fe61a1..a2f772bc 100644 --- a/.github/workflows/schema-validation.yml +++ b/.github/workflows/schema-validation.yml @@ -45,7 +45,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: version: ${{ env.UV_VERSION }} enable-cache: true @@ -59,14 +59,14 @@ jobs: - name: Fresh DB migration test run: | echo "::group::Applying all migrations to fresh database" - uv run alembic upgrade head + uv run --frozen alembic upgrade head echo "::endgroup::" - name: Check migration chain integrity run: | echo "::group::Verifying migration chain" # Get all revision heads - should be exactly one - HEADS=$(uv run alembic heads 2>&1) + HEADS=$(uv run --frozen alembic heads 2>&1) HEAD_COUNT=$(echo "$HEADS" | grep -c "^[a-f0-9]" || true) if [ "$HEAD_COUNT" -gt 1 ]; then @@ -84,7 +84,7 @@ jobs: echo "::group::Checking for schema drift" # alembic check compares models to current DB state # Returns non-zero if autogenerate would create new migrations - if uv run alembic check 2>&1; then + if uv run --frozen alembic check 2>&1; then echo "Schema is in sync with models" else echo "::error::Schema drift detected - models don't match migrations" @@ -97,7 +97,7 @@ jobs: run: | echo "::group::Testing migration reversibility" # Get current revision - CURRENT=$(uv run alembic current 2>&1 | grep -oE "^[a-f0-9]+" | head -1) + CURRENT=$(uv run --frozen alembic current 2>&1 | grep -oE "^[a-f0-9]+" | head -1) if [ -z "$CURRENT" ]; then echo "No migrations applied, skipping cycle test" @@ -108,14 +108,14 @@ jobs: # Downgrade one step echo "Downgrading one migration..." - uv run alembic downgrade -1 + uv run --frozen alembic downgrade -1 # Upgrade back echo "Upgrading back to head..." - uv run alembic upgrade head + uv run --frozen alembic upgrade head # Verify we're back at head - FINAL=$(uv run alembic current 2>&1 | grep -oE "^[a-f0-9]+" | head -1) + FINAL=$(uv run --frozen alembic current 2>&1 | grep -oE "^[a-f0-9]+" | head -1) if [ "$CURRENT" != "$FINAL" ]; then echo "::error::Migration cycle failed - revision mismatch after downgrade/upgrade" @@ -133,10 +133,10 @@ jobs: echo "" >> $GITHUB_STEP_SUMMARY echo "### Migration History" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY - uv run alembic history --verbose 2>&1 | head -50 >> $GITHUB_STEP_SUMMARY + uv run --frozen alembic history --verbose 2>&1 | head -50 >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "### Current State" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY - uv run alembic current --verbose 2>&1 >> $GITHUB_STEP_SUMMARY + uv run --frozen alembic current --verbose 2>&1 >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY diff --git a/.gitignore b/.gitignore index ed4f7681..d9b12311 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ /.vscode /.claude -CLAUDE.md +CLAUDE.local.md .mcp.json .claude .DS_Store @@ -37,3 +37,8 @@ frontend/.vite/ # Generated artifacts (models, backtest results) artifacts/ + +# Local session artifacts (plans, handoffs, current session notes) +.agents/ +.handoffs/ +HANDOFF.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..032b3bb6 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,116 @@ +# ForecastLabAI + +Deep-dive references (Claude loads only when needed): +- Developer guide & tech stack: @docs/_base/DEV_GUIDE.md +- Architecture & boundaries: @docs/_base/ARCHITECTURE.md +- API contracts & interfaces: @docs/_base/API_CONTRACTS.md +- Operational runbooks: @docs/_base/RUNBOOKS.md +- Security & compliance: @docs/_base/SECURITY.md +- Rules & constraints: @docs/_base/RULES.md +- Domain model & glossary: @docs/_base/DOMAIN_MODEL.md +- Service & dependency map: @docs/_base/REPO_MAP_INDEX.md +- Pipeline contract (CI/CD): @docs/_base/PIPELINE_CONTRACT.md + +> Project rules already enforced via `.claude/rules/` (commit-format, branch-naming, security-patterns, product-vision, test-requirements, ui-design, versioning, output-formatting). Read those first; this file is the operating index. + +## Stack + +- Language: Python 3.12 (backend), TypeScript 5.9 + React 19 (frontend) +- Framework: FastAPI + SQLAlchemy 2.0 async + Pydantic v2; Vite 7 + Tailwind 4 + shadcn/ui +- Infrastructure: Single-host `docker-compose` (no K8s, no cloud SDK in core path) +- Database: PostgreSQL 16 + pgvector (port `5433` host → `5432` container) +- CI/CD: GitHub Actions + release-please (SemVer, pre-1.0 patch bumps) + +## Architecture + +**Owns:** Full vertical-slice retail-demand-forecasting demo — data platform, ingest, feature engineering (time-safe), forecasting, backtesting, model registry, RAG (pgvector), agentic layer (PydanticAI), React dashboard. + +**Depends on:** PostgreSQL+pgvector (required), OpenAI/Anthropic/Google API (agent + RAG embeddings), Ollama (optional local embeddings). + +**Depended on by:** Nothing internal — single deployment, no consumers. Frontend ↔ backend over HTTP + WebSocket (`/agents/stream`). + +**Vertical-slice layout:** Every domain lives under `app/features//{models,schemas,service,routes,tests}.py`. Cross-slice code goes through `app/core/` or `app/shared/`. Wire-up in `app/main.py`. + +**Core data flow:** Seeder/Ingest → `sales_daily` + dimensions → Featuresets (lag/rolling/calendar, leakage-safe) → Forecasting (naive/seasonal/MA/LightGBM) → Backtesting (rolling/expanding splits) → Registry (runs + aliases) → Serving via `/forecasting`, `/backtesting`, `/analytics`. RAG indexes docs → pgvector → Agents (experiment / rag_assistant) call tools with human-in-loop approval for mutating ops. + +## Commands + +### Local Development +```bash +docker-compose up -d # Postgres+pgvector on :5433 +uv sync --extra dev # install backend deps (Python 3.12) +uv run alembic upgrade head # apply migrations +uv run uvicorn app.main:app --reload --port 8123 +cd frontend && pnpm install && pnpm dev # UI on :5173 +``` + +### Testing +```bash +uv run pytest -v -m "not integration" # unit, no DB +uv run pytest -v -m integration # integration, requires docker-compose up +cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run +``` + +### Validation gates (run before commit) +```bash +uv run ruff check . && uv run ruff format --check . +uv run mypy app/ && uv run pyright app/ # both --strict, both block merge +``` + +### Database & seeder +```bash +uv run alembic revision --autogenerate -m "" +uv run python scripts/seed_random.py --full-new --seed 42 --confirm +uv run python scripts/seed_random.py --status +``` + +## Conventions + +- Branches: `/` off `dev` (off `main` for hotfix). See `.claude/rules/branch-naming.md`. +- Commits: `type(scope): description (#issue)` — scope from allow-list, no AI co-author trailer, every commit references an open GitHub issue. Hook `.claude/hooks/check-commit-format.sh` enforces it. +- All errors via `app/core/problem_details.py` (RFC 7807 `application/problem+json`). +- Pydantic v2 at every boundary (HTTP, agent tools, seeder config). SQLAlchemy with `Mapped[]` + async sessions. +- Time-safe features only — `app/features/featuresets/tests/test_leakage.py` is the spec; never weaken to make a feature pass. +- UI work goes through the skills in `.claude/rules/ui-design.md` (stitch-design, frontend-design, webapp-testing) — never hand-roll. + +## Safety + +> Load `docs/_base/RULES.md` for the full constraint matrix. + +**STOP and ask before:** +- Cutting `dev` → `main` (release-please will tag) or any tag push +- Editing a merged Alembic migration (migrations are forward-only; create a new one) +- `git push --force` on `dev` or `main` (forbidden) +- Adding a managed-cloud SDK to `app/` core path (violates single-host vision) +- Bumping pydantic-ai / FastAPI / SQLAlchemy major versions + +**NEVER:** +- Commit `.env` (only `.env.example` is tracked) or embed secrets in URLs/code/logs +- Use raw SQL string concat — always SQLAlchemy parameter binding +- Disable `verify=False` on httpx / openai clients +- Skip `mypy --strict` or `pyright --strict` — both gate merge +- Add AI co-author trailers to commits (`commit-format.md` forbids it) +- Mock external services in integration tests (use real Postgres via docker-compose) + +## Verification + +```bash +uv run ruff check . && uv run mypy app/ && uv run pyright app/ && uv run pytest -v -m "not integration" +gh issue view --json state # confirm referenced issue exists +wc -l CLAUDE.md # must stay ≤ 150 +``` + +## Workflow + +1. Open or pick a GitHub issue (`gh issue list`); branch off `dev` per `branch-naming.md`. +2. Implement inside the matching `app/features//` (or new slice with PRP). +3. Run `ruff` → `mypy` → `pyright` → `pytest -m "not integration"` locally. +4. (DB/UI touched) Run integration tests + frontend type-check + dogfood via webapp-testing skill. +5. Commit with `type(scope): description (#issue)`; push. +6. Open PR into `dev`; CI must be green; merge. +7. When ready to release: PR `dev` → `main`. release-please opens a Release PR; merge to tag. + +## Learnings + + +- HEURISTIC_MODE generated this doc (no `docs/_kB/repo-map/` KB). Run `mapping-repo-context` to upgrade fidelity; sections marked `[UNVERIFIED]` in `docs/_base/` need verification. diff --git a/alembic/versions/a8b9c0d1e234_add_retail_depth_columns_and_replenishment_event_table.py b/alembic/versions/a8b9c0d1e234_add_retail_depth_columns_and_replenishment_event_table.py new file mode 100644 index 00000000..3dc6370e --- /dev/null +++ b/alembic/versions/a8b9c0d1e234_add_retail_depth_columns_and_replenishment_event_table.py @@ -0,0 +1,267 @@ +"""add retail-depth columns and replenishment_event table + +Revision ID: a8b9c0d1e234 +Revises: f7a8b9c0d123 +Create Date: 2026-05-11 13:00:00.000000 + +Phase 2 of the seeder realism extension. Additive only: + +- ``sales_daily.channel`` (VARCHAR(20), NOT NULL, server default ``"in_store"``) + with a CHECK constraint pinning the allow-list and a composite + ``(date, channel)`` index for downstream analytics. +- ``product`` gains lifecycle fields (``lifecycle_stage``, ``launch_date``, + ``discontinue_date``) plus ``pack_size`` and ``subcategory``. All NULL by + default so existing rows keep working. +- ``promotion.kind`` (VARCHAR(20), NOT NULL, server default ``"pct_off"``) + and ``promotion.bundle_member_product_ids`` (JSONB, NULL) — bundle/BOGO + mechanics. +- New table ``replenishment_event`` drives lead-time-aware stockout + clustering. Columns: ``date``, ``store_id``, ``product_id``, + ``lead_time_days``, ``ordered_qty``, ``received_qty``. + +The server defaults on ``sales_daily.channel`` and ``promotion.kind`` are +intentional: scenarios that do not enable Phase 2 multichannel/bundle +toggles will populate rows without those columns, and the database picks +the historical default automatically. This keeps the regression invariant +(``retail_standard`` produces byte-identical row counts). + +Downgrade drops all of the above. Any seeded Phase 2 rows are lost; +acceptable for synthetic data only. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a8b9c0d1e234" +down_revision: str | None = "f7a8b9c0d123" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +_CHANNEL_ALLOWLIST = "('in_store', 'online', 'click_collect', 'wholesale')" +_PROMOTION_KIND_ALLOWLIST = "('pct_off', 'bogo', 'bundle', 'markdown')" +_LIFECYCLE_STAGE_ALLOWLIST = "('intro', 'growth', 'maturity', 'decline', 'discontinued')" + + +def upgrade() -> None: + """Apply migration: add retail-depth columns and replenishment_event.""" + # ------------------------------------------------------------------ # + # 1. sales_daily.channel (NOT NULL with server default 'in_store') + # ------------------------------------------------------------------ # + op.add_column( + "sales_daily", + sa.Column( + "channel", + sa.String(length=20), + nullable=False, + server_default=sa.text("'in_store'"), + ), + ) + op.create_index( + "ix_sales_daily_date_channel", + "sales_daily", + ["date", "channel"], + unique=False, + ) + op.create_check_constraint( + "ck_sales_daily_channel_allowlist", + "sales_daily", + f"channel IN {_CHANNEL_ALLOWLIST}", + ) + + # ------------------------------------------------------------------ # + # 2. product lifecycle / pack_size / subcategory (all NULL by default) + # ------------------------------------------------------------------ # + op.add_column( + "product", + sa.Column("lifecycle_stage", sa.String(length=20), nullable=True), + ) + op.add_column("product", sa.Column("launch_date", sa.Date(), nullable=True)) + op.add_column("product", sa.Column("discontinue_date", sa.Date(), nullable=True)) + op.add_column("product", sa.Column("pack_size", sa.Integer(), nullable=True)) + op.add_column( + "product", + sa.Column("subcategory", sa.String(length=100), nullable=True), + ) + op.create_index( + op.f("ix_product_subcategory"), + "product", + ["subcategory"], + unique=False, + ) + op.create_check_constraint( + "ck_product_lifecycle_stage_allowlist", + "product", + f"lifecycle_stage IS NULL OR lifecycle_stage IN {_LIFECYCLE_STAGE_ALLOWLIST}", + ) + op.create_check_constraint( + "ck_product_pack_size_positive", + "product", + "pack_size IS NULL OR pack_size > 0", + ) + op.create_check_constraint( + "ck_product_lifecycle_dates_order", + "product", + "discontinue_date IS NULL " + "OR launch_date IS NULL " + "OR discontinue_date >= launch_date", + ) + + # ------------------------------------------------------------------ # + # 3. promotion.kind + promotion.bundle_member_product_ids + # ------------------------------------------------------------------ # + op.add_column( + "promotion", + sa.Column( + "kind", + sa.String(length=20), + nullable=False, + server_default=sa.text("'pct_off'"), + ), + ) + op.add_column( + "promotion", + sa.Column( + "bundle_member_product_ids", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.create_check_constraint( + "ck_promotion_kind_allowlist", + "promotion", + f"kind IN {_PROMOTION_KIND_ALLOWLIST}", + ) + # A bundle/BOGO promotion MUST carry at least one member; non-bundle + # promotions MUST NOT. Enforced at the SQL layer so Pydantic & ORM stay + # additive. + op.create_check_constraint( + "ck_promotion_bundle_members_consistency", + "promotion", + "(kind IN ('bundle', 'bogo') AND bundle_member_product_ids IS NOT NULL)" + " OR (kind NOT IN ('bundle', 'bogo') AND bundle_member_product_ids IS NULL)", + ) + + # ------------------------------------------------------------------ # + # 4. replenishment_event table + # ------------------------------------------------------------------ # + op.create_table( + "replenishment_event", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("date", sa.Date(), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("lead_time_days", sa.Integer(), nullable=False), + sa.Column("ordered_qty", sa.Integer(), nullable=False), + sa.Column("received_qty", sa.Integer(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "lead_time_days >= 0", name="ck_replenishment_event_lead_time_positive" + ), + sa.CheckConstraint( + "ordered_qty >= 0", name="ck_replenishment_event_ordered_qty_positive" + ), + sa.CheckConstraint( + "received_qty >= 0", name="ck_replenishment_event_received_qty_positive" + ), + sa.CheckConstraint( + "received_qty <= ordered_qty", + name="ck_replenishment_event_received_le_ordered", + ), + sa.ForeignKeyConstraint(["date"], ["calendar.date"]), + sa.ForeignKeyConstraint(["product_id"], ["product.id"]), + sa.ForeignKeyConstraint(["store_id"], ["store.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_replenishment_event_date"), + "replenishment_event", + ["date"], + unique=False, + ) + op.create_index( + op.f("ix_replenishment_event_product_id"), + "replenishment_event", + ["product_id"], + unique=False, + ) + op.create_index( + op.f("ix_replenishment_event_store_id"), + "replenishment_event", + ["store_id"], + unique=False, + ) + op.create_index( + "ix_replenishment_event_store_product_date", + "replenishment_event", + ["store_id", "product_id", "date"], + unique=False, + ) + + +def downgrade() -> None: + """Revert migration: drop everything Phase 2 added. + + WARNING: Any seeded Phase 2 rows are lost. Acceptable for synthetic data + only — do not run against an environment with user-loaded retail data. + """ + # 4. replenishment_event + op.drop_index( + "ix_replenishment_event_store_product_date", + table_name="replenishment_event", + ) + op.drop_index( + op.f("ix_replenishment_event_store_id"), table_name="replenishment_event" + ) + op.drop_index( + op.f("ix_replenishment_event_product_id"), table_name="replenishment_event" + ) + op.drop_index( + op.f("ix_replenishment_event_date"), table_name="replenishment_event" + ) + op.drop_table("replenishment_event") + + # 3. promotion.kind + bundle_member_product_ids + op.drop_constraint( + "ck_promotion_bundle_members_consistency", "promotion", type_="check" + ) + op.drop_constraint("ck_promotion_kind_allowlist", "promotion", type_="check") + op.drop_column("promotion", "bundle_member_product_ids") + op.drop_column("promotion", "kind") + + # 2. product lifecycle fields + op.drop_constraint( + "ck_product_lifecycle_dates_order", "product", type_="check" + ) + op.drop_constraint("ck_product_pack_size_positive", "product", type_="check") + op.drop_constraint( + "ck_product_lifecycle_stage_allowlist", "product", type_="check" + ) + op.drop_index(op.f("ix_product_subcategory"), table_name="product") + op.drop_column("product", "subcategory") + op.drop_column("product", "pack_size") + op.drop_column("product", "discontinue_date") + op.drop_column("product", "launch_date") + op.drop_column("product", "lifecycle_stage") + + # 1. sales_daily.channel + op.drop_constraint( + "ck_sales_daily_channel_allowlist", "sales_daily", type_="check" + ) + op.drop_index("ix_sales_daily_date_channel", table_name="sales_daily") + op.drop_column("sales_daily", "channel") diff --git a/alembic/versions/f7a8b9c0d123_add_exogenous_signal_and_sales_returns_tables.py b/alembic/versions/f7a8b9c0d123_add_exogenous_signal_and_sales_returns_tables.py new file mode 100644 index 00000000..e19d6e9b --- /dev/null +++ b/alembic/versions/f7a8b9c0d123_add_exogenous_signal_and_sales_returns_tables.py @@ -0,0 +1,154 @@ +"""add exogenous_signal and sales_returns tables + +Revision ID: f7a8b9c0d123 +Revises: d6e0f2g3h456 +Create Date: 2026-05-11 12:00:00.000000 + +Phase 1 of the seeder realism extension. Additive only — creates two new +fact tables to support exogenous demand signals (weather / macro / events) +and synthetic returns volume. No existing rows are touched. + +Downgrade drops both tables; any seeded rows are lost. This is acceptable +because the data is synthetic; do not run downgrade against an environment +that holds user-loaded data. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f7a8b9c0d123" +down_revision: str | None = "d6e0f2g3h456" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply migration: create exogenous_signal and sales_returns.""" + op.create_table( + "exogenous_signal", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("date", sa.Date(), nullable=False), + sa.Column("signal_name", sa.String(length=50), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=True), + sa.Column("is_global", sa.Boolean(), nullable=False), + sa.Column("value", sa.Float(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "(is_global = true AND store_id IS NULL) OR " + "(is_global = false AND store_id IS NOT NULL)", + name="ck_exogenous_signal_global_consistency", + ), + sa.ForeignKeyConstraint(["date"], ["calendar.date"]), + sa.ForeignKeyConstraint(["store_id"], ["store.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_exogenous_signal_date"), "exogenous_signal", ["date"], unique=False + ) + op.create_index( + op.f("ix_exogenous_signal_signal_name"), + "exogenous_signal", + ["signal_name"], + unique=False, + ) + op.create_index( + op.f("ix_exogenous_signal_store_id"), + "exogenous_signal", + ["store_id"], + unique=False, + ) + op.create_index( + "ix_exogenous_signal_name_date", + "exogenous_signal", + ["signal_name", "date"], + unique=False, + ) + op.create_index( + "uq_exogenous_signal_global", + "exogenous_signal", + ["date", "signal_name"], + unique=True, + postgresql_where=sa.text("is_global = true"), + ) + op.create_index( + "uq_exogenous_signal_per_store", + "exogenous_signal", + ["date", "signal_name", "store_id"], + unique=True, + postgresql_where=sa.text("is_global = false"), + ) + + op.create_table( + "sales_returns", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("date", sa.Date(), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("return_quantity", sa.Integer(), nullable=False), + sa.Column("return_reason", sa.String(length=50), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint("return_quantity >= 1", name="ck_sales_returns_quantity_positive"), + sa.ForeignKeyConstraint(["date"], ["calendar.date"]), + sa.ForeignKeyConstraint(["product_id"], ["product.id"]), + sa.ForeignKeyConstraint(["store_id"], ["store.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_sales_returns_product_id"), "sales_returns", ["product_id"], unique=False + ) + op.create_index( + op.f("ix_sales_returns_store_id"), "sales_returns", ["store_id"], unique=False + ) + op.create_index( + "ix_sales_returns_store_product_date", + "sales_returns", + ["store_id", "product_id", "date"], + unique=False, + ) + op.create_index("ix_sales_returns_date", "sales_returns", ["date"], unique=False) + + +def downgrade() -> None: + """Revert migration: drop sales_returns and exogenous_signal. + + WARNING: Any seeded Phase 1 rows are lost. Acceptable for synthetic data + only — do not run against an environment with user-loaded signals. + """ + op.drop_index("ix_sales_returns_date", table_name="sales_returns") + op.drop_index("ix_sales_returns_store_product_date", table_name="sales_returns") + op.drop_index(op.f("ix_sales_returns_store_id"), table_name="sales_returns") + op.drop_index(op.f("ix_sales_returns_product_id"), table_name="sales_returns") + op.drop_table("sales_returns") + + op.drop_index("uq_exogenous_signal_per_store", table_name="exogenous_signal") + op.drop_index("uq_exogenous_signal_global", table_name="exogenous_signal") + op.drop_index("ix_exogenous_signal_name_date", table_name="exogenous_signal") + op.drop_index(op.f("ix_exogenous_signal_store_id"), table_name="exogenous_signal") + op.drop_index(op.f("ix_exogenous_signal_signal_name"), table_name="exogenous_signal") + op.drop_index(op.f("ix_exogenous_signal_date"), table_name="exogenous_signal") + op.drop_table("exogenous_signal") diff --git a/app/features/data_platform/models.py b/app/features/data_platform/models.py index f2bb76c2..6d0207fb 100644 --- a/app/features/data_platform/models.py +++ b/app/features/data_platform/models.py @@ -13,16 +13,20 @@ from decimal import Decimal from sqlalchemy import ( + BigInteger, Boolean, CheckConstraint, Date, + Float, ForeignKey, Index, Integer, Numeric, String, UniqueConstraint, + text, ) +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship from app.core.database import Base @@ -69,9 +73,17 @@ class Product(TimestampMixin, Base): sku: Stock keeping unit (unique product identifier). name: Product display name. category: Product category. + subcategory: Optional finer-grain category (Phase 2 retail-depth). brand: Product brand. base_price: Standard retail price. base_cost: Standard cost/COGS. + pack_size: Optional units-per-pack (Phase 2). NULL means single-unit. + lifecycle_stage: One of ``intro|growth|maturity|decline|discontinued`` + (Phase 2). NULL when the lifecycle generator is disabled. + launch_date: Date the product became sellable (Phase 2). NULL when + lifecycle is disabled. + discontinue_date: Date the product was retired (Phase 2). NULL when + still active or lifecycle is disabled. """ __tablename__ = "product" @@ -80,9 +92,14 @@ class Product(TimestampMixin, Base): sku: Mapped[str] = mapped_column(String(50), unique=True, index=True) name: Mapped[str] = mapped_column(String(200)) category: Mapped[str | None] = mapped_column(String(100), index=True, nullable=True) + subcategory: Mapped[str | None] = mapped_column(String(100), index=True, nullable=True) brand: Mapped[str | None] = mapped_column(String(100), nullable=True) base_price: Mapped[Decimal | None] = mapped_column(Numeric(10, 2), nullable=True) base_cost: Mapped[Decimal | None] = mapped_column(Numeric(10, 2), nullable=True) + pack_size: Mapped[int | None] = mapped_column(Integer, nullable=True) + lifecycle_stage: Mapped[str | None] = mapped_column(String(20), nullable=True) + launch_date: Mapped[datetime.date | None] = mapped_column(Date, nullable=True) + discontinue_date: Mapped[datetime.date | None] = mapped_column(Date, nullable=True) # Relationships (one-to-many) sales: Mapped[list[SalesDaily]] = relationship(back_populates="product") @@ -92,6 +109,22 @@ class Product(TimestampMixin, Base): back_populates="product" ) + __table_args__ = ( + CheckConstraint( + "lifecycle_stage IS NULL OR lifecycle_stage IN " + "('intro', 'growth', 'maturity', 'decline', 'discontinued')", + name="ck_product_lifecycle_stage_allowlist", + ), + CheckConstraint( + "pack_size IS NULL OR pack_size > 0", + name="ck_product_pack_size_positive", + ), + CheckConstraint( + "discontinue_date IS NULL OR launch_date IS NULL OR discontinue_date >= launch_date", + name="ck_product_lifecycle_dates_order", + ), + ) + class Calendar(TimestampMixin, Base): """Calendar dimension table for time-based analysis. @@ -140,7 +173,11 @@ class SalesDaily(TimestampMixin, Base): """Daily sales fact table. CRITICAL: Grain is (date, store_id, product_id) - one row per store/product/day. - Enforced by unique constraint for idempotent upserts. + Enforced by unique constraint for idempotent upserts. The Phase 2 + ``channel`` column is intentionally **outside** the grain — pre-Phase-2 + rows default to ``in_store``; multi-channel scenarios are emitted as a + single row per (date, store, product) with a channel mix encoded in + downstream aggregates rather than splitting the grain. Attributes: id: Surrogate primary key. @@ -150,6 +187,9 @@ class SalesDaily(TimestampMixin, Base): quantity: Units sold. unit_price: Price per unit at time of sale. total_amount: Total sales amount (quantity * unit_price). + channel: Sales channel — one of ``in_store|online|click_collect|wholesale``. + Defaults to ``in_store`` server-side so existing scenarios stay + byte-identical. """ __tablename__ = "sales_daily" @@ -162,6 +202,9 @@ class SalesDaily(TimestampMixin, Base): quantity: Mapped[int] = mapped_column(Integer) unit_price: Mapped[Decimal] = mapped_column(Numeric(10, 2)) total_amount: Mapped[Decimal] = mapped_column(Numeric(12, 2)) + channel: Mapped[str] = mapped_column( + String(20), nullable=False, server_default=text("'in_store'") + ) # Relationships store: Mapped[Store] = relationship(back_populates="sales") @@ -175,10 +218,16 @@ class SalesDaily(TimestampMixin, Base): Index("ix_sales_daily_date_store", "date", "store_id"), # Composite index for date range + product Index("ix_sales_daily_date_product", "date", "product_id"), + # Composite index for date range + channel (Phase 2) + Index("ix_sales_daily_date_channel", "date", "channel"), # Check constraint for data quality CheckConstraint("quantity >= 0", name="ck_sales_daily_quantity_positive"), CheckConstraint("unit_price >= 0", name="ck_sales_daily_price_positive"), CheckConstraint("total_amount >= 0", name="ck_sales_daily_amount_positive"), + CheckConstraint( + "channel IN ('in_store', 'online', 'click_collect', 'wholesale')", + name="ck_sales_daily_channel_allowlist", + ), ) @@ -225,15 +274,22 @@ class PriceHistory(TimestampMixin, Base): class Promotion(TimestampMixin, Base): """Promotion fact table. - Tracks promotional campaigns with discount mechanics. + Tracks promotional campaigns with discount mechanics. Phase 2 adds the + ``kind`` discriminator (with server default ``pct_off`` preserving the + pre-Phase-2 behaviour) and a JSONB ``bundle_member_product_ids`` for + BOGO/bundle mechanics. Attributes: id: Primary key. product_id: Product (FK). store_id: Store (FK) - NULL for chain-wide promos. name: Promotion name/description. + kind: ``pct_off | bogo | bundle | markdown`` (Phase 2). Server-default + ``pct_off``. discount_pct: Discount percentage (e.g., 0.15 for 15% off). discount_amount: Fixed discount amount (alternative to %). + bundle_member_product_ids: JSONB list of related product IDs when + ``kind in (bundle, bogo)``; NULL otherwise. start_date: Promotion start date. end_date: Promotion end date. """ @@ -246,8 +302,16 @@ class Promotion(TimestampMixin, Base): Integer, ForeignKey("store.id"), index=True, nullable=True ) name: Mapped[str] = mapped_column(String(200)) + kind: Mapped[str] = mapped_column(String(20), nullable=False, server_default=text("'pct_off'")) discount_pct: Mapped[Decimal | None] = mapped_column(Numeric(5, 4), nullable=True) discount_amount: Mapped[Decimal | None] = mapped_column(Numeric(10, 2), nullable=True) + # ``none_as_null=True`` is load-bearing: Python ``None`` must serialize + # to SQL ``NULL`` (not JSON ``null``) so the + # ``ck_promotion_bundle_members_consistency`` CHECK constraint correctly + # rejects bundle/BOGO rows that omit member IDs. + bundle_member_product_ids: Mapped[list[int] | None] = mapped_column( + JSONB(none_as_null=True), nullable=True + ) start_date: Mapped[datetime.date] = mapped_column(Date, index=True) end_date: Mapped[datetime.date] = mapped_column(Date) @@ -266,6 +330,15 @@ class Promotion(TimestampMixin, Base): "discount_amount IS NULL OR discount_amount >= 0", name="ck_promotion_discount_amount_positive", ), + CheckConstraint( + "kind IN ('pct_off', 'bogo', 'bundle', 'markdown')", + name="ck_promotion_kind_allowlist", + ), + CheckConstraint( + "(kind IN ('bundle', 'bogo') AND bundle_member_product_ids IS NOT NULL)" + " OR (kind NOT IN ('bundle', 'bogo') AND bundle_member_product_ids IS NULL)", + name="ck_promotion_bundle_members_consistency", + ), ) @@ -308,3 +381,134 @@ class InventorySnapshotDaily(TimestampMixin, Base): CheckConstraint("on_hand_qty >= 0", name="ck_inventory_on_hand_positive"), CheckConstraint("on_order_qty >= 0", name="ck_inventory_on_order_positive"), ) + + +class ExogenousSignal(TimestampMixin, Base): + """Exogenous demand-relevant signals (weather, macro index, events). + + A signal is either chain-wide (``is_global=True``, ``store_id IS NULL``) + or per-store (``is_global=False``, ``store_id IS NOT NULL``). The two + cases are enforced by ``ck_exogenous_signal_global_consistency`` and made + unique by two partial indexes so re-runs of the seeder are idempotent. + + Attributes: + id: Surrogate primary key. + date: Signal date (FK to calendar). + signal_name: Short identifier (e.g. ``"weather_temp_c"``, ``"macro_index"``). + store_id: Store (FK) — NULL when ``is_global=True``. + is_global: True for chain-wide signals; mirrors ``store_id IS NULL``. + value: Numeric value of the signal on the given date. + """ + + __tablename__ = "exogenous_signal" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + date: Mapped[datetime.date] = mapped_column(Date, ForeignKey("calendar.date"), index=True) + signal_name: Mapped[str] = mapped_column(String(50), index=True) + store_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("store.id"), nullable=True, index=True + ) + is_global: Mapped[bool] = mapped_column(Boolean, nullable=False) + value: Mapped[float] = mapped_column(Float, nullable=False) + + __table_args__ = ( + Index("ix_exogenous_signal_name_date", "signal_name", "date"), + Index( + "uq_exogenous_signal_global", + "date", + "signal_name", + unique=True, + postgresql_where=("is_global = true"), + ), + Index( + "uq_exogenous_signal_per_store", + "date", + "signal_name", + "store_id", + unique=True, + postgresql_where=("is_global = false"), + ), + CheckConstraint( + "(is_global = true AND store_id IS NULL) OR " + "(is_global = false AND store_id IS NOT NULL)", + name="ck_exogenous_signal_global_consistency", + ), + ) + + +class SalesReturn(TimestampMixin, Base): + """Synthetic sales return event. + + Returns are not subtracted from ``sales_daily.quantity``; they live in a + separate table so featuresets/forecasting can opt into them as a signal. + + Attributes: + id: Surrogate primary key. + date: Return date (FK to calendar). + store_id: Store (FK). + product_id: Product (FK). + return_quantity: Units returned (>= 1). + return_reason: Free-form short reason (e.g. ``"defective"``, + ``"changed_mind"``). + """ + + __tablename__ = "sales_returns" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + date: Mapped[datetime.date] = mapped_column(Date, ForeignKey("calendar.date")) + store_id: Mapped[int] = mapped_column(Integer, ForeignKey("store.id"), index=True) + product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id"), index=True) + return_quantity: Mapped[int] = mapped_column(Integer, nullable=False) + return_reason: Mapped[str] = mapped_column(String(50), nullable=False) + + __table_args__ = ( + Index("ix_sales_returns_store_product_date", "store_id", "product_id", "date"), + Index("ix_sales_returns_date", "date"), + CheckConstraint("return_quantity >= 1", name="ck_sales_returns_quantity_positive"), + ) + + +class ReplenishmentEvent(TimestampMixin, Base): + """Synthetic replenishment / inbound stock event (Phase 2). + + Drives lead-time-aware stockout clustering. A row marks the date a + purchase order was *received* at a store for a given product, along + with how many days the order was in transit and the ordered vs. + received quantities. The inventory generator consumes these to + schedule realistic stockout windows. + + Attributes: + id: Surrogate primary key. + date: Date of receipt at the store (FK to calendar). + store_id: Store (FK). + product_id: Product (FK). + lead_time_days: Days between order placement and receipt. + ordered_qty: Units ordered. + received_qty: Units actually received (``<= ordered_qty``). + """ + + __tablename__ = "replenishment_event" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + date: Mapped[datetime.date] = mapped_column(Date, ForeignKey("calendar.date"), index=True) + store_id: Mapped[int] = mapped_column(Integer, ForeignKey("store.id"), index=True) + product_id: Mapped[int] = mapped_column(Integer, ForeignKey("product.id"), index=True) + lead_time_days: Mapped[int] = mapped_column(Integer, nullable=False) + ordered_qty: Mapped[int] = mapped_column(Integer, nullable=False) + received_qty: Mapped[int] = mapped_column(Integer, nullable=False) + + __table_args__ = ( + Index( + "ix_replenishment_event_store_product_date", + "store_id", + "product_id", + "date", + ), + CheckConstraint("lead_time_days >= 0", name="ck_replenishment_event_lead_time_positive"), + CheckConstraint("ordered_qty >= 0", name="ck_replenishment_event_ordered_qty_positive"), + CheckConstraint("received_qty >= 0", name="ck_replenishment_event_received_qty_positive"), + CheckConstraint( + "received_qty <= ordered_qty", + name="ck_replenishment_event_received_le_ordered", + ), + ) diff --git a/app/features/data_platform/tests/conftest.py b/app/features/data_platform/tests/conftest.py index 494b3359..9ecc04df 100644 --- a/app/features/data_platform/tests/conftest.py +++ b/app/features/data_platform/tests/conftest.py @@ -21,6 +21,7 @@ PriceHistory, Product, Promotion, + ReplenishmentEvent, SalesDaily, Store, ) @@ -57,6 +58,7 @@ async def db_session(): # Clean up test data (delete in correct order due to FK constraints) await cleanup_session.execute(delete(SalesDaily)) await cleanup_session.execute(delete(InventorySnapshotDaily)) + await cleanup_session.execute(delete(ReplenishmentEvent)) await cleanup_session.execute(delete(PriceHistory)) await cleanup_session.execute(delete(Promotion)) await cleanup_session.execute(delete(Product).where(Product.sku.like("SKU-TEST%"))) diff --git a/app/features/data_platform/tests/test_phase2_constraints.py b/app/features/data_platform/tests/test_phase2_constraints.py new file mode 100644 index 00000000..fc8d0057 --- /dev/null +++ b/app/features/data_platform/tests/test_phase2_constraints.py @@ -0,0 +1,308 @@ +"""Integration tests for Phase 2 retail-depth schema constraints. + +Covers the new SQL CHECK constraints + the ``replenishment_event`` table. +Requires PostgreSQL (docker-compose up -d). +""" + +from datetime import date +from decimal import Decimal + +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.data_platform.models import ( + Calendar, + Product, + Promotion, + ReplenishmentEvent, + SalesDaily, + Store, +) + + +@pytest.mark.integration +class TestSalesDailyChannelConstraint: + async def test_default_channel_is_in_store( + self, + db_session: AsyncSession, + sample_store: Store, + sample_product: Product, + sample_calendar: Calendar, + ) -> None: + # ``channel`` has a server default; not supplying it must still + # produce a row with ``in_store``. + sale = SalesDaily( + date=sample_calendar.date, + store_id=sample_store.id, + product_id=sample_product.id, + quantity=3, + unit_price=Decimal("9.99"), + total_amount=Decimal("29.97"), + ) + db_session.add(sale) + await db_session.commit() + await db_session.refresh(sale) + assert sale.channel == "in_store" + + @pytest.mark.parametrize("channel", ["in_store", "online", "click_collect", "wholesale"]) + async def test_allowed_channels( + self, + db_session: AsyncSession, + sample_store: Store, + sample_product: Product, + sample_calendar: Calendar, + channel: str, + ) -> None: + sale = SalesDaily( + date=sample_calendar.date, + store_id=sample_store.id, + product_id=sample_product.id, + quantity=1, + unit_price=Decimal("1.00"), + total_amount=Decimal("1.00"), + channel=channel, + ) + db_session.add(sale) + await db_session.commit() + await db_session.refresh(sale) + assert sale.channel == channel + + async def test_disallowed_channel_rejected( + self, + db_session: AsyncSession, + sample_store: Store, + sample_product: Product, + sample_calendar: Calendar, + ) -> None: + sale = SalesDaily( + date=sample_calendar.date, + store_id=sample_store.id, + product_id=sample_product.id, + quantity=1, + unit_price=Decimal("1.00"), + total_amount=Decimal("1.00"), + channel="kiosk", # not in the allow-list + ) + db_session.add(sale) + with pytest.raises(IntegrityError): + await db_session.commit() + + +@pytest.mark.integration +class TestProductLifecycleConstraints: + async def test_lifecycle_stage_nullable(self, db_session: AsyncSession) -> None: + # Lifecycle disabled scenarios must continue inserting bare products + # without supplying lifecycle fields. + product = Product(sku="SKU-TEST-LCNULL", name="Lifecycle Null") + db_session.add(product) + await db_session.commit() + await db_session.refresh(product) + assert product.lifecycle_stage is None + assert product.launch_date is None + assert product.discontinue_date is None + assert product.pack_size is None + assert product.subcategory is None + + @pytest.mark.parametrize("stage", ["intro", "growth", "maturity", "decline", "discontinued"]) + async def test_lifecycle_stage_allowlist(self, db_session: AsyncSession, stage: str) -> None: + product = Product( + sku=f"SKU-TEST-LC-{stage}", + name=f"Stage {stage}", + lifecycle_stage=stage, + ) + db_session.add(product) + await db_session.commit() + await db_session.refresh(product) + assert product.lifecycle_stage == stage + + async def test_invalid_lifecycle_stage_rejected(self, db_session: AsyncSession) -> None: + product = Product( + sku="SKU-TEST-LCBAD", + name="Bad Stage", + lifecycle_stage="ramping_up", # not in allow-list + ) + db_session.add(product) + with pytest.raises(IntegrityError): + await db_session.commit() + + async def test_discontinue_before_launch_rejected(self, db_session: AsyncSession) -> None: + product = Product( + sku="SKU-TEST-LCDATE", + name="Bad Dates", + launch_date=date(2024, 6, 1), + discontinue_date=date(2024, 5, 1), + ) + db_session.add(product) + with pytest.raises(IntegrityError): + await db_session.commit() + + async def test_negative_pack_size_rejected(self, db_session: AsyncSession) -> None: + product = Product( + sku="SKU-TEST-PACK", + name="Bad Pack", + pack_size=0, + ) + db_session.add(product) + with pytest.raises(IntegrityError): + await db_session.commit() + + +@pytest.mark.integration +class TestPromotionKindConstraints: + async def test_default_kind_is_pct_off( + self, + db_session: AsyncSession, + sample_product: Product, + ) -> None: + promo = Promotion( + product_id=sample_product.id, + name="Default kind", + discount_pct=Decimal("0.10"), + start_date=date(2024, 6, 1), + end_date=date(2024, 6, 7), + ) + db_session.add(promo) + await db_session.commit() + await db_session.refresh(promo) + assert promo.kind == "pct_off" + assert promo.bundle_member_product_ids is None + + @pytest.mark.parametrize("kind", ["pct_off", "markdown"]) + async def test_non_bundle_kinds_reject_member_ids( + self, + db_session: AsyncSession, + sample_product: Product, + kind: str, + ) -> None: + promo = Promotion( + product_id=sample_product.id, + name=f"{kind} with bundle members", + kind=kind, + discount_pct=Decimal("0.10"), + start_date=date(2024, 6, 1), + end_date=date(2024, 6, 7), + bundle_member_product_ids=[1, 2], + ) + db_session.add(promo) + with pytest.raises(IntegrityError): + await db_session.commit() + + @pytest.mark.parametrize("kind", ["bundle", "bogo"]) + async def test_bundle_kinds_require_member_ids( + self, + db_session: AsyncSession, + sample_product: Product, + kind: str, + ) -> None: + promo = Promotion( + product_id=sample_product.id, + name=f"{kind} without bundle members", + kind=kind, + discount_pct=Decimal("0.10"), + start_date=date(2024, 6, 1), + end_date=date(2024, 6, 7), + bundle_member_product_ids=None, + ) + db_session.add(promo) + with pytest.raises(IntegrityError): + await db_session.commit() + + async def test_bundle_kind_accepts_member_ids( + self, + db_session: AsyncSession, + sample_product: Product, + ) -> None: + promo = Promotion( + product_id=sample_product.id, + name="Bundle accepted", + kind="bundle", + discount_pct=Decimal("0.15"), + start_date=date(2024, 6, 1), + end_date=date(2024, 6, 7), + bundle_member_product_ids=[sample_product.id, 999], + ) + db_session.add(promo) + await db_session.commit() + await db_session.refresh(promo) + assert promo.kind == "bundle" + assert promo.bundle_member_product_ids == [sample_product.id, 999] + + async def test_unknown_kind_rejected( + self, + db_session: AsyncSession, + sample_product: Product, + ) -> None: + promo = Promotion( + product_id=sample_product.id, + name="Unknown kind", + kind="loyalty", + discount_pct=Decimal("0.10"), + start_date=date(2024, 6, 1), + end_date=date(2024, 6, 7), + ) + db_session.add(promo) + with pytest.raises(IntegrityError): + await db_session.commit() + + +@pytest.mark.integration +class TestReplenishmentEventTable: + async def test_insert_minimal_row( + self, + db_session: AsyncSession, + sample_store: Store, + sample_product: Product, + sample_calendar: Calendar, + ) -> None: + event = ReplenishmentEvent( + date=sample_calendar.date, + store_id=sample_store.id, + product_id=sample_product.id, + lead_time_days=5, + ordered_qty=100, + received_qty=100, + ) + db_session.add(event) + await db_session.commit() + await db_session.refresh(event) + assert event.id is not None + assert event.lead_time_days == 5 + + async def test_received_exceeds_ordered_rejected( + self, + db_session: AsyncSession, + sample_store: Store, + sample_product: Product, + sample_calendar: Calendar, + ) -> None: + event = ReplenishmentEvent( + date=sample_calendar.date, + store_id=sample_store.id, + product_id=sample_product.id, + lead_time_days=3, + ordered_qty=50, + received_qty=51, # > ordered + ) + db_session.add(event) + with pytest.raises(IntegrityError): + await db_session.commit() + + async def test_negative_lead_time_rejected( + self, + db_session: AsyncSession, + sample_store: Store, + sample_product: Product, + sample_calendar: Calendar, + ) -> None: + event = ReplenishmentEvent( + date=sample_calendar.date, + store_id=sample_store.id, + product_id=sample_product.id, + lead_time_days=-1, + ordered_qty=10, + received_qty=10, + ) + db_session.add(event) + with pytest.raises(IntegrityError): + await db_session.commit() diff --git a/app/features/dimensions/routes.py b/app/features/dimensions/routes.py index bb2130df..9042966e 100644 --- a/app/features/dimensions/routes.py +++ b/app/features/dimensions/routes.py @@ -4,12 +4,15 @@ and products before calling ingest, training, or forecasting endpoints. """ +from datetime import date + from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db from app.core.logging import get_logger from app.features.dimensions.schemas import ( + LifecycleCurveResponse, ProductListResponse, ProductResponse, StoreListResponse, @@ -242,3 +245,57 @@ async def get_product( ) return result + + +@router.get( + "/products/{product_id}/lifecycle-curve", + response_model=LifecycleCurveResponse, + summary="Get product lifecycle demand curve (Phase 2)", + description=""" +Return the reference lifecycle demand-multiplier curve for a product. + +**Behavior**: +- Respects the product's own `launch_date` / `discontinue_date`. +- Uses the default `LifecycleConfig` ramp parameters (not the config that + was active at seeding time — that config is not persisted). +- Returns one point per calendar day in the requested window. + +**Defaults**: +- `start_date` defaults to the product's `launch_date` (or today minus 30 days + when `launch_date` is unset). +- `end_date` defaults to `start_date + 365 days`, clamped to + `discontinue_date` when set. + +**Error Handling**: +- Returns 404 if `product_id` doesn't exist. +""", +) +async def get_product_lifecycle_curve( + product_id: int, + db: AsyncSession = Depends(get_db), + start_date: date | None = Query( + None, + description="Curve start (inclusive). Defaults to launch_date.", + ), + end_date: date | None = Query( + None, + description="Curve end (inclusive). Defaults to start_date + 365 days.", + ), +) -> LifecycleCurveResponse: + """Return the reference lifecycle demand curve for a product.""" + service = DimensionService() + result = await service.get_product_lifecycle_curve( + db=db, + product_id=product_id, + start_date=start_date, + end_date=end_date, + ) + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Product not found: {product_id}. " + "Use GET /dimensions/products to list available products.", + ) + + return result diff --git a/app/features/dimensions/schemas.py b/app/features/dimensions/schemas.py index 9b70fb5d..623dd168 100644 --- a/app/features/dimensions/schemas.py +++ b/app/features/dimensions/schemas.py @@ -4,6 +4,7 @@ that help agents understand how to use each field. """ +from datetime import date as _date from datetime import datetime from decimal import Decimal @@ -179,3 +180,55 @@ class ProductListResponse(BaseModel): ge=1, description="Number of products per page. Maximum is 100.", ) + + +# ============================================================================= +# Lifecycle Curve (Phase 2) +# ============================================================================= + + +class LifecyclePoint(BaseModel): + """One day of the lifecycle demand curve.""" + + date: _date = Field(..., description="Calendar date for this curve point") + stage: str = Field( + ..., + description=("Lifecycle stage label: intro | growth | maturity | decline | discontinued."), + ) + multiplier: float = Field( + ..., + ge=0.0, + description=( + "Demand multiplier the SalesDailyGenerator would apply on this " + "date for this product (1.0 = neutral)." + ), + ) + + +class LifecycleCurveResponse(BaseModel): + """Response payload for GET /dimensions/products/{id}/lifecycle-curve. + + Returns the reference curve a Phase 2 LifecycleGenerator would + produce for the product's launch_date and discontinue_date using + the default LifecycleConfig ramp parameters. Useful for visualizing + a product's expected demand shape even when the seeded run used + different ramp parameters. + """ + + product_id: int = Field(..., description="Product internal ID") + sku: str = Field(..., description="Product SKU") + launch_date: _date | None = Field( + None, + description="Product launch_date. NULL when lifecycle data is unset.", + ) + discontinue_date: _date | None = Field( + None, + description="Product discontinue_date. NULL when not retired.", + ) + start_date: _date = Field(..., description="First date in the returned curve") + end_date: _date = Field(..., description="Last date in the returned curve") + points: list[LifecyclePoint] = Field( + ..., + description="Curve points in ascending date order", + ) + total: int = Field(..., ge=0, description="Number of points returned") diff --git a/app/features/dimensions/service.py b/app/features/dimensions/service.py index b6e1c77d..4f88d91c 100644 --- a/app/features/dimensions/service.py +++ b/app/features/dimensions/service.py @@ -4,17 +4,23 @@ with filtering and search capabilities. """ +from datetime import UTC, date, datetime, timedelta + from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.features.data_platform.models import Product, Store from app.features.dimensions.schemas import ( + LifecycleCurveResponse, + LifecyclePoint, ProductListResponse, ProductResponse, StoreListResponse, StoreResponse, ) +from app.shared.seeder.config import LifecycleConfig +from app.shared.seeder.generators.lifecycle import LifecycleGenerator logger = get_logger(__name__) @@ -251,3 +257,86 @@ async def get_product_by_sku( return None return ProductResponse.model_validate(product) + + async def get_product_lifecycle_curve( + self, + db: AsyncSession, + product_id: int, + start_date: date | None = None, + end_date: date | None = None, + ) -> LifecycleCurveResponse | None: + """Return the reference lifecycle demand curve for a product (Phase 2). + + Uses the default :class:`LifecycleConfig` ramp parameters. The + curve respects the product's own ``launch_date`` and + ``discontinue_date`` but is independent of the ``LifecycleConfig`` + used at seeding time (that config is not persisted). Returns + ``None`` when the product is not found. + + Args: + db: Database session. + product_id: Product primary key. + start_date: Optional curve start. Defaults to the product's + ``launch_date`` (or today minus 30 days if launch is + unset). + end_date: Optional curve end. Defaults to ``start_date + 365`` + days, clamped to ``discontinue_date`` when set. + + Returns: + ``LifecycleCurveResponse`` or ``None`` if no product. + """ + stmt = select(Product).where(Product.id == product_id) + result = await db.execute(stmt) + product = result.scalar_one_or_none() + if product is None: + return None + + launch = product.launch_date + discontinue = product.discontinue_date + # Default the curve window around the product's lifecycle dates. + # When launch_date is unset, fall back to a recent 1-year window + # so callers get a usable response (the multiplier short-circuits + # to 1.0 and the stage is ``maturity``). + if start_date is None: + start_date = launch or (datetime.now(UTC).date() - timedelta(days=30)) + if end_date is None: + end_date = start_date + timedelta(days=365) + if discontinue is not None and discontinue < end_date: + end_date = discontinue + + if end_date < start_date: + end_date = start_date + + config = LifecycleConfig(enable=True) + generator = LifecycleGenerator(config) + + points: list[LifecyclePoint] = [] + current = start_date + while current <= end_date: + points.append( + LifecyclePoint( + date=current, + stage=generator.stage_for(current, launch, discontinue), + multiplier=generator.multiplier_for(current, launch, discontinue), + ) + ) + current += timedelta(days=1) + + logger.info( + "dimensions.lifecycle_curve_computed", + product_id=product_id, + launch_date=str(launch) if launch else None, + discontinue_date=str(discontinue) if discontinue else None, + points=len(points), + ) + + return LifecycleCurveResponse( + product_id=product.id, + sku=product.sku, + launch_date=launch, + discontinue_date=discontinue, + start_date=start_date, + end_date=end_date, + points=points, + total=len(points), + ) diff --git a/app/features/seeder/routes.py b/app/features/seeder/routes.py index 76e1233b..8f79ab23 100644 --- a/app/features/seeder/routes.py +++ b/app/features/seeder/routes.py @@ -4,7 +4,9 @@ through the dashboard admin panel. """ -from fastapi import APIRouter, Depends, HTTPException, status +from datetime import date + +from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings @@ -63,6 +65,23 @@ async def list_scenarios() -> list[schemas.ScenarioInfo]: return service.list_scenarios() +@router.get( + "/channels", + response_model=schemas.ChannelsResponse, + summary="List sales channels", + description=( + "Return the SQL allow-list for `sales_daily.channel` and " + "`ChannelConfig.channel_mix` keys (Phase 2). Use these values " + "to populate UI selectors or validate channel_mix payloads " + "before POST /seeder/generate." + ), +) +async def list_channels() -> schemas.ChannelsResponse: + """List valid sales channel identifiers (Phase 2).""" + channels = sorted(schemas.VALID_CHANNELS) + return schemas.ChannelsResponse(channels=channels, total=len(channels)) + + @router.post( "/generate", response_model=schemas.GenerateResult, @@ -226,6 +245,70 @@ async def delete_data( ) from e +@router.get( + "/exogenous", + response_model=schemas.ExogenousSignalResponse, + summary="Query exogenous signals", + description=( + "Return exogenous signal rows (Phase 1) for a given signal name and date " + "window. Available signals: `weather_temp_c`, `macro_index`, `event_flag`." + ), +) +async def query_exogenous( + signal_name: str = Query( + ..., + min_length=1, + max_length=50, + description="Signal identifier (e.g. weather_temp_c, macro_index, event_flag)", + ), + start_date: date = Query(..., description="Window start (inclusive)"), + end_date: date = Query(..., description="Window end (inclusive)"), + store_id: int | None = Query( + default=None, + ge=1, + description="Optional store filter. Omit to include global + per-store rows.", + ), + db: AsyncSession = Depends(get_db), +) -> schemas.ExogenousSignalResponse: + """Query exogenous_signal rows for a signal name and date window. + + Returns rows ordered by date. Subject to row and date-range caps to + keep the response bounded. + + Raises: + HTTPException: 400 if the date window is invalid or oversized. + """ + try: + return await service.query_exogenous( + db, + signal_name=signal_name, + start_date=start_date, + end_date=end_date, + store_id=store_id, + ) + except ValueError as e: + logger.error( + "seeder.exogenous.query_failed", + error=str(e), + error_type=type(e).__name__, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + logger.error( + "seeder.exogenous.query_failed", + error=str(e), + error_type=type(e).__name__, + exc_info=True, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Exogenous query failed: {e}", + ) from e + + @router.post( "/verify", response_model=schemas.VerifyResult, diff --git a/app/features/seeder/schemas.py b/app/features/seeder/schemas.py index 6c925114..90331aad 100644 --- a/app/features/seeder/schemas.py +++ b/app/features/seeder/schemas.py @@ -1,9 +1,13 @@ """Pydantic schemas for the seeder feature.""" +import datetime as _datetime_module from datetime import date, datetime from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, model_validator + +VALID_CHANNELS: frozenset[str] = frozenset({"in_store", "online", "click_collect", "wholesale"}) +"""Allow-list for ``sales_daily.channel`` — mirrors the SQL CHECK.""" class SeederStatus(BaseModel): @@ -16,6 +20,18 @@ class SeederStatus(BaseModel): inventory: int = Field(description="Number of inventory_snapshot_daily records") price_history: int = Field(description="Number of price_history records") promotions: int = Field(description="Number of promotion records") + exogenous_signals: int = Field( + default=0, + description="Number of exogenous_signal records (Phase 1)", + ) + sales_returns: int = Field( + default=0, + description="Number of sales_returns records (Phase 1)", + ) + replenishment_events: int = Field( + default=0, + description="Number of replenishment_event records (Phase 2)", + ) date_range_start: date | None = Field( default=None, description="Earliest date in sales_daily", @@ -30,6 +46,22 @@ class SeederStatus(BaseModel): ) +class ChangepointEventParam(BaseModel): + """API-facing representation of a demand changepoint (Phase 1).""" + + date: _datetime_module.date = Field(description="Changepoint impulse date") + demand_multiplier: float = Field( + ge=0.0, + description="Peak multiplier on the changepoint date", + ) + decay_days: int = Field( + default=30, + ge=0, + le=3650, + description="Exponential decay e-folding time (days). 0 = pure impulse.", + ) + + class ScenarioInfo(BaseModel): """Information about a scenario preset.""" @@ -84,6 +116,185 @@ class GenerateParams(BaseModel): description="Preview only, do not execute", ) + # Phase 1 — realism extension. All flags default off so existing + # scenarios remain byte-identical when this endpoint is called without + # the new fields. + enable_exogenous: bool = Field( + default=False, + description="Seed weather/macro/event exogenous signals (Phase 1)", + ) + enable_returns: bool = Field( + default=False, + description="Seed sales_returns rows derived from sales (Phase 1)", + ) + enable_substitution: bool = Field( + default=False, + description="Apply cross-product substitution lift on stockouts (Phase 1)", + ) + yearly_seasonality_amplitude: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description=( + "Yearly sin-wave demand amplitude (fraction). None or 0 = disabled. (Phase 1)" + ), + ) + weather_temperature_sensitivity: float | None = Field( + default=None, + ge=-1.0, + le=1.0, + description=( + "Demand delta per °C above the climatology mean. " + "Only applied when enable_exogenous=true. (Phase 1)" + ), + ) + changepoints: list[ChangepointEventParam] | None = Field( + default=None, + description="Optional list of demand changepoints (Phase 1)", + ) + substitute_groups: list[list[int]] | None = Field( + default=None, + description=( + "Optional list of product-ID groups whose members substitute for " + "each other on stockout. Only applied when enable_substitution=true. " + "(Phase 1)" + ), + ) + substitution_lift_on_stockout: float | None = Field( + default=None, + ge=0.0, + le=10.0, + description=( + "Demand lift distributed across in-stock group-mates when a member " + "is stocked out. Only applied when enable_substitution=true. (Phase 1)" + ), + ) + + # Phase 2 — retail-depth extension. All flags default off so existing + # scenarios stay byte-identical when the endpoint is called without + # the new fields. + enable_multichannel: bool = Field( + default=False, + description="Split sales across channels (in_store/online/...) (Phase 2)", + ) + channel_mix: dict[str, float] | None = Field( + default=None, + description=( + "Probability weights per channel. Keys must be a subset of " + "{in_store, online, click_collect, wholesale}. Weights " + "normalize at use time; at least one weight must be > 0. " + "Only applied when enable_multichannel=true. (Phase 2)" + ), + ) + online_promo_uplift: float | None = Field( + default=None, + ge=0.0, + le=10.0, + description=( + "Multiplier on online-channel quantity during promotions " + "(e.g. 1.2 = +20%). Only applied when enable_multichannel=true. (Phase 2)" + ), + ) + online_substitution_to_instore: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description=( + "Fraction of in-store demand that shifts to online during " + "promotions (0.0 = independent; 1.0 = pure substitution). " + "Only applied when enable_multichannel=true. (Phase 2)" + ), + ) + enable_lifecycle: bool = Field( + default=False, + description="Assign product lifecycle stage + launch/discontinue dates (Phase 2)", + ) + lifecycle_discontinue_probability: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description=( + "Probability a product gets a discontinue_date within the " + "seeded range. Only applied when enable_lifecycle=true. (Phase 2)" + ), + ) + enable_bundles: bool = Field( + default=False, + description="Convert a fraction of promotions to bundle/BOGO (Phase 2)", + ) + bundle_probability: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description=( + "Per-promotion conversion probability. Only applied when enable_bundles=true. (Phase 2)" + ), + ) + enable_markdowns: bool = Field( + default=False, + description="Emit clearance markdown promos + price drops (Phase 2)", + ) + markdown_trigger: Literal["lifecycle_decline", "stockout_risk"] | None = Field( + default=None, + description=( + "Markdown firing rule. 'age_days' is deferred — see issue #94. " + "Only applied when enable_markdowns=true. (Phase 2)" + ), + ) + enable_lead_time: bool = Field( + default=False, + description="Emit replenishment_event rows with stochastic lead times (Phase 2)", + ) + mean_lead_time_days: int | None = Field( + default=None, + ge=0, + le=365, + description=( + "Mean Normal-distributed lead time (days). Only applied when " + "enable_lead_time=true. (Phase 2)" + ), + ) + + @model_validator(mode="after") + def _validate_date_range(self) -> "GenerateParams": + """Reject inverted date ranges with a clear message.""" + if self.end_date < self.start_date: + raise ValueError( + f"end_date ({self.end_date}) must be on or after start_date ({self.start_date})" + ) + return self + + @field_validator("channel_mix") + @classmethod + def _validate_channel_mix( + cls, + value: dict[str, float] | None, + ) -> dict[str, float] | None: + """Validate channel_mix keys, non-negativity, and positive total. + + Empty dict is rejected so callers must either pass None or a + meaningful split. The full set of channels need not appear — + unspecified channels get zero weight. + """ + if value is None: + return None + if not value: + raise ValueError( + "channel_mix must not be empty when supplied; pass null to use defaults" + ) + invalid = set(value.keys()) - VALID_CHANNELS + if invalid: + raise ValueError( + f"channel_mix contains invalid channels {sorted(invalid)}; " + f"allow-list is {sorted(VALID_CHANNELS)}" + ) + for name, weight in value.items(): + if weight < 0: + raise ValueError(f"channel_mix['{name}']={weight} must be non-negative") + if sum(value.values()) <= 0: + raise ValueError("channel_mix must have at least one positive weight") + return value + class AppendParams(BaseModel): """Parameters for appending data to existing dataset.""" @@ -156,3 +367,59 @@ class VerifyResult(BaseModel): passed_count: int = Field(description="Number of passed checks") warning_count: int = Field(description="Number of warnings") failed_count: int = Field(description="Number of failures") + + +# ============================================================================ +# PHASE 1 — Exogenous signal read API +# ============================================================================ + + +class ExogenousSignalRecord(BaseModel): + """One row of the exogenous_signal table.""" + + date: _datetime_module.date = Field(description="Signal date") + signal_name: str = Field(description="Signal identifier") + store_id: int | None = Field( + default=None, + description="Store ID. None for chain-wide (global) signals.", + ) + is_global: bool = Field(description="True for chain-wide signals") + value: float = Field(description="Numeric signal value") + + +class ExogenousSignalResponse(BaseModel): + """Response payload for GET /seeder/exogenous.""" + + signal_name: str = Field(description="Signal identifier queried") + start_date: date = Field(description="Start of the query window") + end_date: date = Field(description="End of the query window") + store_id: int | None = Field( + default=None, + description="Specific store filter, if applied", + ) + records: list[ExogenousSignalRecord] = Field( + description="Signal rows in ascending date order", + ) + total: int = Field(description="Row count in the response") + + +# ============================================================================ +# PHASE 2 — Channels enumeration +# ============================================================================ + + +class ChannelsResponse(BaseModel): + """Response payload for GET /seeder/channels. + + Returns the SQL allow-list for ``sales_daily.channel`` so callers + (admin UI, agent tools, integration tests) can populate selectors + without duplicating the constant. Mirrors the SQL CHECK constraint. + """ + + channels: list[str] = Field( + description=( + "Sorted list of valid channel identifiers for sales_daily.channel " + "and for ChannelConfig.channel_mix keys." + ), + ) + total: int = Field(description="Number of valid channels") diff --git a/app/features/seeder/service.py b/app/features/seeder/service.py index a7696aec..81726b56 100644 --- a/app/features/seeder/service.py +++ b/app/features/seeder/service.py @@ -3,6 +3,7 @@ from __future__ import annotations import time +from dataclasses import replace from datetime import date, datetime from sqlalchemy import func, select @@ -12,16 +13,33 @@ from app.core.logging import get_logger from app.features.data_platform.models import ( Calendar, + ExogenousSignal, InventorySnapshotDaily, PriceHistory, Product, Promotion, + ReplenishmentEvent, SalesDaily, + SalesReturn, Store, ) from app.features.seeder import schemas from app.shared.seeder import DataSeeder, ScenarioPreset, SeederConfig -from app.shared.seeder.config import DimensionConfig, SparsityConfig +from app.shared.seeder.config import ( + BundleConfig, + ChangepointConfig, + ChangepointEvent, + ChannelConfig, + DimensionConfig, + ExogenousSignalConfig, + LeadTimeConfig, + LifecycleConfig, + MarkdownConfig, + MultiSeasonalityConfig, + ReturnsConfig, + SparsityConfig, + SubstitutionConfig, +) logger = get_logger(__name__) @@ -41,6 +59,134 @@ def _get_scenario_preset(name: str) -> ScenarioPreset | None: return None +def _apply_phase1_overrides(config: SeederConfig, params: schemas.GenerateParams) -> None: + """Apply Phase 1 (realism) overrides from API params onto ``config``. + + Mutates ``config`` in place. Each override is no-op when the matching + flag/field is absent, so existing scenarios stay byte-identical when + Phase 1 params are omitted. + """ + if params.enable_exogenous: + config.exogenous = ExogenousSignalConfig( + enable_weather=True, + enable_macro=True, + enable_events=False, + weather_temperature_sensitivity=( + params.weather_temperature_sensitivity + if params.weather_temperature_sensitivity is not None + else 0.0 + ), + ) + elif params.weather_temperature_sensitivity is not None: + # Sensitivity passed without enable_exogenous → ignore quietly; the + # weather lookup won't exist so the multiplier short-circuits. + config.exogenous = replace( + config.exogenous, + weather_temperature_sensitivity=params.weather_temperature_sensitivity, + ) + + if ( + params.yearly_seasonality_amplitude is not None + and params.yearly_seasonality_amplitude > 0.0 + ): + config.multi_seasonality = MultiSeasonalityConfig( + yearly_seasonality_amplitude=params.yearly_seasonality_amplitude, + ) + + if params.changepoints: + config.changepoints = ChangepointConfig( + changepoints=[ + ChangepointEvent( + date=cp.date, + demand_multiplier=cp.demand_multiplier, + decay_days=cp.decay_days, + ) + for cp in params.changepoints + ] + ) + + if params.enable_returns: + config.returns = ReturnsConfig(enable=True) + + if params.enable_substitution: + config.substitution = SubstitutionConfig( + enable=True, + substitute_groups=( + [list(group) for group in params.substitute_groups] + if params.substitute_groups is not None + else [] + ), + substitution_lift_on_stockout=( + params.substitution_lift_on_stockout + if params.substitution_lift_on_stockout is not None + else 0.5 + ), + ) + + +def _apply_phase2_overrides(config: SeederConfig, params: schemas.GenerateParams) -> None: + """Apply Phase 2 (retail-depth) overrides from API params onto ``config``. + + Mutates ``config`` in place. Each override is no-op when the matching + enable flag is False, so existing scenarios stay byte-identical when + Phase 2 params are omitted. + """ + if params.enable_multichannel: + mix: dict[str, float] = ( + dict(params.channel_mix) + if params.channel_mix is not None + else {"in_store": 0.7, "online": 0.2, "click_collect": 0.1} + ) + config.channels = ChannelConfig( + enable_multichannel=True, + channel_mix=mix, + online_promo_uplift=( + params.online_promo_uplift if params.online_promo_uplift is not None else 1.0 + ), + online_substitution_to_instore=( + params.online_substitution_to_instore + if params.online_substitution_to_instore is not None + else 0.0 + ), + ) + + if params.enable_lifecycle: + config.lifecycle = LifecycleConfig( + enable=True, + discontinue_probability=( + params.lifecycle_discontinue_probability + if params.lifecycle_discontinue_probability is not None + else 0.0 + ), + ) + + if params.enable_bundles: + config.bundles = BundleConfig( + enable=True, + bundle_probability=( + params.bundle_probability if params.bundle_probability is not None else 0.2 + ), + ) + + if params.enable_markdowns: + config.markdowns = MarkdownConfig( + enable=True, + trigger=( + params.markdown_trigger + if params.markdown_trigger is not None + else "lifecycle_decline" + ), + ) + + if params.enable_lead_time: + config.lead_time = LeadTimeConfig( + enable=True, + mean_lead_time_days=( + params.mean_lead_time_days if params.mean_lead_time_days is not None else 7 + ), + ) + + def _build_config_from_params(params: schemas.GenerateParams) -> SeederConfig: """Build SeederConfig from API parameters. @@ -55,8 +201,10 @@ def _build_config_from_params(params: schemas.GenerateParams) -> SeederConfig: if preset: # Start from scenario preset and override with explicit params config = SeederConfig.from_scenario(preset, seed=params.seed) - # Override dimensions if explicitly set (different from defaults) - config.dimensions = DimensionConfig( + # Override store/product counts while preserving scenario-customized + # region/category/brand lists (dataclasses.replace is field-precise). + config.dimensions = replace( + config.dimensions, stores=params.stores, products=params.products, ) @@ -77,6 +225,9 @@ def _build_config_from_params(params: schemas.GenerateParams) -> SeederConfig: sparsity=SparsityConfig(missing_combinations_pct=params.sparsity), ) + _apply_phase1_overrides(config, params) + _apply_phase2_overrides(config, params) + settings = get_settings() config.batch_size = settings.seeder_batch_size config.enable_progress = settings.seeder_enable_progress @@ -104,6 +255,9 @@ async def get_status(db: AsyncSession) -> schemas.SeederStatus: ("inventory", InventorySnapshotDaily), ("price_history", PriceHistory), ("promotions", Promotion), + ("exogenous_signals", ExogenousSignal), + ("sales_returns", SalesReturn), + ("replenishment_events", ReplenishmentEvent), ] counts: dict[str, int] = {} @@ -138,6 +292,9 @@ async def get_status(db: AsyncSession) -> schemas.SeederStatus: inventory=counts["inventory"], price_history=counts["price_history"], promotions=counts["promotions"], + exogenous_signals=counts["exogenous_signals"], + sales_returns=counts["sales_returns"], + replenishment_events=counts["replenishment_events"], date_range_start=date_range_start, date_range_end=date_range_end, last_updated=last_updated, @@ -254,6 +411,9 @@ async def generate_data( "price_history": 0, "promotions": 0, "inventory": 0, + "exogenous_signals": 0, + "sales_returns": 0, + "replenishment_events": 0, }, duration_seconds=0.0, message=f"Dry run: would generate data with scenario '{params.scenario}'", @@ -296,6 +456,9 @@ async def generate_data( "price_history": result.price_history_count, "promotions": result.promotions_count, "inventory": result.inventory_count, + "exogenous_signals": result.exogenous_count, + "sales_returns": result.returns_count, + "replenishment_events": result.replenishment_count, }, duration_seconds=round(duration, 2), message=f"Successfully generated {result.sales_count:,} sales records with seed {params.seed}", @@ -364,6 +527,9 @@ async def append_data( "price_history": result.price_history_count, "promotions": result.promotions_count, "inventory": result.inventory_count, + "exogenous_signals": result.exogenous_count, + "sales_returns": result.returns_count, + "replenishment_events": result.replenishment_count, }, duration_seconds=round(duration, 2), message=f"Appended {result.sales_count:,} sales records for date range {params.start_date} to {params.end_date}", @@ -537,3 +703,95 @@ async def verify_data(db: AsyncSession) -> schemas.VerifyResult: warning_count=warning_count, failed_count=failed_count, ) + + +# ============================================================================ +# PHASE 1 — Exogenous signal read API +# ============================================================================ + + +EXOGENOUS_MAX_DATE_RANGE_DAYS = 365 * 3 # 3 years — matches feature_max_lookback_days +EXOGENOUS_MAX_RECORDS = 50_000 + + +async def query_exogenous( + db: AsyncSession, + signal_name: str, + start_date: date, + end_date: date, + store_id: int | None, +) -> schemas.ExogenousSignalResponse: + """Return exogenous signal rows for ``signal_name`` within a window. + + Args: + db: Async database session. + signal_name: Exact signal identifier (e.g. ``"weather_temp_c"``). + start_date: Window start (inclusive). + end_date: Window end (inclusive). + store_id: Optional store filter. When None, returns global signals + plus any store-scoped rows for the period (callers typically + filter on a single store to keep payload sizes reasonable). + + Returns: + ExogenousSignalResponse with rows ordered by date ascending. + + Raises: + ValueError: On inverted or oversized date windows. + """ + if end_date < start_date: + raise ValueError(f"end_date ({end_date}) must be on or after start_date ({start_date})") + span_days = (end_date - start_date).days + if span_days > EXOGENOUS_MAX_DATE_RANGE_DAYS: + raise ValueError( + f"Date range too large ({span_days} days); max is {EXOGENOUS_MAX_DATE_RANGE_DAYS} days" + ) + + stmt = ( + select(ExogenousSignal) + .where(ExogenousSignal.signal_name == signal_name) + .where(ExogenousSignal.date >= start_date) + .where(ExogenousSignal.date <= end_date) + .order_by(ExogenousSignal.date.asc(), ExogenousSignal.store_id.asc().nullsfirst()) + .limit(EXOGENOUS_MAX_RECORDS + 1) + ) + if store_id is not None: + stmt = stmt.where( + (ExogenousSignal.store_id == store_id) | (ExogenousSignal.is_global.is_(True)) + ) + + result = await db.execute(stmt) + rows = result.scalars().all() + if len(rows) > EXOGENOUS_MAX_RECORDS: + raise ValueError( + f"Query exceeded maximum row cap ({EXOGENOUS_MAX_RECORDS}); " + "narrow the date range or filter by store_id" + ) + + records = [ + schemas.ExogenousSignalRecord( + date=row.date, + signal_name=row.signal_name, + store_id=row.store_id, + is_global=row.is_global, + value=row.value, + ) + for row in rows + ] + + logger.info( + "seeder.exogenous.queried", + signal_name=signal_name, + start_date=str(start_date), + end_date=str(end_date), + store_id=store_id, + rows=len(records), + ) + + return schemas.ExogenousSignalResponse( + signal_name=signal_name, + start_date=start_date, + end_date=end_date, + store_id=store_id, + records=records, + total=len(records), + ) diff --git a/app/features/seeder/tests/test_phase1_routes.py b/app/features/seeder/tests/test_phase1_routes.py new file mode 100644 index 00000000..dba06b48 --- /dev/null +++ b/app/features/seeder/tests/test_phase1_routes.py @@ -0,0 +1,122 @@ +"""Route tests for Phase 1 GET /seeder/exogenous endpoint.""" + +from datetime import date +from unittest.mock import patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from app.features.seeder import schemas +from app.main import app + + +@pytest.fixture +def client(): + return TestClient(app) + + +class TestExogenousRoute: + def test_happy_path(self, client): + mock_response = schemas.ExogenousSignalResponse( + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 2), + store_id=None, + records=[ + schemas.ExogenousSignalRecord( + date=date(2024, 1, 1), + signal_name="weather_temp_c", + store_id=1, + is_global=False, + value=12.3, + ), + schemas.ExogenousSignalRecord( + date=date(2024, 1, 2), + signal_name="weather_temp_c", + store_id=1, + is_global=False, + value=13.1, + ), + ], + total=2, + ) + + async def _fake(*args, **kwargs): + return mock_response + + with patch( + "app.features.seeder.routes.service.query_exogenous", + side_effect=_fake, + ): + response = client.get( + "/seeder/exogenous", + params={ + "signal_name": "weather_temp_c", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + }, + ) + + assert response.status_code == status.HTTP_200_OK + body = response.json() + assert body["signal_name"] == "weather_temp_c" + assert body["total"] == 2 + assert len(body["records"]) == 2 + + def test_rejects_inverted_window(self, client): + # Service raises ValueError → 400 per the error handler. + async def _fake(*args, **kwargs): + raise ValueError("end_date must be on or after start_date") + + with patch( + "app.features.seeder.routes.service.query_exogenous", + side_effect=_fake, + ): + response = client.get( + "/seeder/exogenous", + params={ + "signal_name": "weather_temp_c", + "start_date": "2024-12-31", + "end_date": "2024-01-01", + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_requires_signal_name(self, client): + response = client.get( + "/seeder/exogenous", + params={"start_date": "2024-01-01", "end_date": "2024-01-02"}, + ) + # Missing required param → FastAPI validation 422. + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_optional_store_id_passes_through(self, client): + captured: dict[str, object] = {} + + async def _fake(db, signal_name, start_date, end_date, store_id): + captured["store_id"] = store_id + return schemas.ExogenousSignalResponse( + signal_name=signal_name, + start_date=start_date, + end_date=end_date, + store_id=store_id, + records=[], + total=0, + ) + + with patch( + "app.features.seeder.routes.service.query_exogenous", + side_effect=_fake, + ): + response = client.get( + "/seeder/exogenous", + params={ + "signal_name": "weather_temp_c", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "store_id": 7, + }, + ) + assert response.status_code == status.HTTP_200_OK + assert captured["store_id"] == 7 diff --git a/app/features/seeder/tests/test_phase1_service.py b/app/features/seeder/tests/test_phase1_service.py new file mode 100644 index 00000000..51d5bd57 --- /dev/null +++ b/app/features/seeder/tests/test_phase1_service.py @@ -0,0 +1,136 @@ +"""Service-layer tests for Phase 1 seeder features. + +Covers: +- _apply_phase1_overrides / _build_config_from_params translation of new + GenerateParams fields into SeederConfig sub-configs. +- GenerateParams validation (inverted date range). +- query_exogenous date-window guards. +""" + +from datetime import date + +import pytest + +from app.features.seeder import schemas, service + + +class TestApplyPhase1Overrides: + def test_defaults_leave_phase1_off(self): + """Calling generate with default params must keep Phase 1 off.""" + params = schemas.GenerateParams() + config = service._build_config_from_params(params) + assert config.exogenous.enable_weather is False + assert config.exogenous.enable_macro is False + assert config.multi_seasonality.yearly_seasonality_amplitude == 0.0 + assert config.changepoints.changepoints == [] + assert config.returns.enable is False + assert config.substitution.enable is False + + def test_enable_exogenous_turns_on_weather_and_macro(self): + params = schemas.GenerateParams( + enable_exogenous=True, + weather_temperature_sensitivity=0.03, + ) + config = service._build_config_from_params(params) + assert config.exogenous.enable_weather is True + assert config.exogenous.enable_macro is True + assert config.exogenous.weather_temperature_sensitivity == 0.03 + + def test_yearly_seasonality_passthrough(self): + params = schemas.GenerateParams(yearly_seasonality_amplitude=0.25) + config = service._build_config_from_params(params) + assert config.multi_seasonality.yearly_seasonality_amplitude == 0.25 + + def test_changepoint_list_translation(self): + params = schemas.GenerateParams( + changepoints=[ + schemas.ChangepointEventParam( + date=date(2024, 3, 15), + demand_multiplier=2.0, + decay_days=60, + ) + ] + ) + config = service._build_config_from_params(params) + assert len(config.changepoints.changepoints) == 1 + cp = config.changepoints.changepoints[0] + assert cp.date == date(2024, 3, 15) + assert cp.demand_multiplier == 2.0 + assert cp.decay_days == 60 + + def test_enable_returns_flips_returns_config(self): + params = schemas.GenerateParams(enable_returns=True) + config = service._build_config_from_params(params) + assert config.returns.enable is True + + def test_enable_substitution_with_groups(self): + params = schemas.GenerateParams( + enable_substitution=True, + substitute_groups=[[1, 2, 3], [4, 5]], + substitution_lift_on_stockout=0.4, + ) + config = service._build_config_from_params(params) + assert config.substitution.enable is True + assert config.substitution.substitute_groups == [[1, 2, 3], [4, 5]] + assert config.substitution.substitution_lift_on_stockout == 0.4 + + def test_phase1_overrides_preserve_scenario_dimensions(self): + """A Phase 1 override must not clobber scenario-defined region/brand + lists — regression for the bug fix in service._build_config_from_params. + """ + params = schemas.GenerateParams( + scenario="holiday_rush", + stores=20, + products=80, + enable_returns=True, + ) + config = service._build_config_from_params(params) + assert config.dimensions.stores == 20 + assert config.dimensions.products == 80 + # Holiday rush keeps its 4 holidays + monthly seasonality through + # the phase-1 path. + assert len(config.holidays) == 4 + assert config.time_series.monthly_seasonality[12] == 1.8 + + +class TestGenerateParamsValidation: + def test_rejects_inverted_date_range(self): + with pytest.raises(ValueError, match="must be on or after"): + schemas.GenerateParams( + start_date=date(2024, 12, 31), + end_date=date(2024, 1, 1), + ) + + def test_yearly_amplitude_bounds(self): + # ge=0.0 / le=1.0 enforced by Field. + with pytest.raises(ValueError): + schemas.GenerateParams(yearly_seasonality_amplitude=-0.1) + with pytest.raises(ValueError): + schemas.GenerateParams(yearly_seasonality_amplitude=1.5) + + +class TestQueryExogenousValidation: + """Date-window guards on the service helper. The DB path is covered in + integration tests.""" + + @pytest.mark.asyncio + async def test_rejects_inverted_window(self): + with pytest.raises(ValueError, match="must be on or after"): + await service.query_exogenous( + db=None, # type: ignore[arg-type] + signal_name="weather_temp_c", + start_date=date(2024, 12, 31), + end_date=date(2024, 1, 1), + store_id=None, + ) + + @pytest.mark.asyncio + async def test_rejects_overlong_window(self): + with pytest.raises(ValueError, match="too large"): + await service.query_exogenous( + db=None, # type: ignore[arg-type] + signal_name="weather_temp_c", + start_date=date(2020, 1, 1), + end_date=date(2030, 1, 1), + store_id=None, + ) diff --git a/app/features/seeder/tests/test_service.py b/app/features/seeder/tests/test_service.py index cf28d72f..9b476caa 100644 --- a/app/features/seeder/tests/test_service.py +++ b/app/features/seeder/tests/test_service.py @@ -141,6 +141,45 @@ def test_date_range_override(self): assert config.start_date == date(2025, 1, 1) assert config.end_date == date(2025, 6, 30) + def test_custom_scenario_preserves_dimension_customization(self): + """Overriding stores/products on a preset must keep scenario-defined + region/category/brand lists (regression against the older code path that + replaced the whole DimensionConfig and silently dropped them).""" + params = schemas.GenerateParams( + scenario="holiday_rush", + stores=25, + products=200, + ) + config = service._build_config_from_params(params) + + # Counts come from params + assert config.dimensions.stores == 25 + assert config.dimensions.products == 200 + + # Lists come from the SeederConfig defaults (holiday_rush doesn't customize + # them today, but the test asserts the path that preserves them). + assert config.dimensions.store_regions == ["North", "South", "East", "West"] + assert config.dimensions.product_categories == [ + "Beverage", + "Snack", + "Dairy", + "Frozen", + "Produce", + "Bakery", + ] + + def test_custom_scenario_preserves_holiday_list(self): + """Holiday_rush scenario ships 4 holiday entries; overriding store/product + counts must not wipe them.""" + params = schemas.GenerateParams(scenario="holiday_rush", stores=20, products=80) + config = service._build_config_from_params(params) + + assert len(config.holidays) == 4 + holiday_names = {h.name for h in config.holidays} + assert "Thanksgiving" in holiday_names + assert "Black Friday" in holiday_names + assert config.time_series.monthly_seasonality == {10: 1.0, 11: 1.3, 12: 1.8} + class TestGetStatus: """Tests for get_status function.""" @@ -150,10 +189,11 @@ async def test_returns_status(self): """Test status is returned with counts.""" mock_db = AsyncMock() - # Mock the count queries - return different values for each table - mock_results = [10, 50, 365, 182500, 182500, 1500, 500] + # Mock the count queries - return different values for each table. + # Phase 1 adds exogenous_signals (2520) and sales_returns (3650); + # Phase 2 adds replenishment_events (180). + mock_results = [10, 50, 365, 182500, 182500, 1500, 500, 2520, 3650, 180] mock_db.execute.side_effect = [ - # Counts for each table *[MagicMock(scalar=MagicMock(return_value=count)) for count in mock_results], # Date range query MagicMock(fetchone=MagicMock(return_value=(date(2024, 1, 1), date(2024, 12, 31)))), @@ -167,15 +207,18 @@ async def test_returns_status(self): assert status.products == 50 assert status.calendar == 365 assert status.sales == 182500 + assert status.exogenous_signals == 2520 + assert status.sales_returns == 3650 + assert status.replenishment_events == 180 @pytest.mark.asyncio async def test_empty_database(self): """Test status for empty database.""" mock_db = AsyncMock() - # Mock empty counts + # Mock empty counts (10 tables: 7 original + 2 Phase 1 + 1 Phase 2). mock_db.execute.side_effect = [ - *[MagicMock(scalar=MagicMock(return_value=0)) for _ in range(7)], + *[MagicMock(scalar=MagicMock(return_value=0)) for _ in range(10)], ] status = await service.get_status(mock_db) @@ -183,6 +226,9 @@ async def test_empty_database(self): assert status.stores == 0 assert status.products == 0 assert status.sales == 0 + assert status.exogenous_signals == 0 + assert status.sales_returns == 0 + assert status.replenishment_events == 0 assert status.date_range_start is None assert status.date_range_end is None diff --git a/app/shared/seeder/config.py b/app/shared/seeder/config.py index 3f7cd922..a57c0536 100644 --- a/app/shared/seeder/config.py +++ b/app/shared/seeder/config.py @@ -126,6 +126,323 @@ class HolidayConfig: multiplier: float = 1.5 +@dataclass +class ExogenousSignalConfig: + """Configuration for exogenous demand signals (weather, macro, events). + + All signals are disabled by default — enabling them does not change the + sales math unless `weather_temperature_sensitivity` is also non-zero or + a feature consumer reads `exogenous_signal` rows. Default values keep + existing scenarios byte-identical. + + Attributes: + enable_weather: Emit `weather_temp_c` rows per (store, date). + enable_macro: Emit `macro_index` rows per date (random walk). + enable_events: Emit `event_flag` rows per date (binary, sparse). + weather_temperature_sensitivity: Demand delta as a fraction per °C + above/below the climatological mean. 0.0 = no demand impact even + when weather rows are emitted. + weather_climatology_mean_c: Annual mean temperature (°C) used as the + sinusoidal center for weather generation. + weather_amplitude_c: Peak-to-peak amplitude of the seasonal sin wave. + weather_noise_sigma_c: Gaussian noise standard deviation in °C. + macro_indicator_lag_days: How many days a macro signal lags demand by + (consumers may use this; the generator itself emits values daily). + macro_initial_value: Starting value of the random-walk index. + macro_step_sigma: Standard deviation of the daily Gaussian increment. + event_dates: Specific dates marked with `event_flag=1` (e.g. promo + launch days). Empty list = no event rows emitted even when + `enable_events=True`. + """ + + enable_weather: bool = False + enable_macro: bool = False + enable_events: bool = False + weather_temperature_sensitivity: float = 0.0 + weather_climatology_mean_c: float = 15.0 + weather_amplitude_c: float = 12.0 + weather_noise_sigma_c: float = 2.0 + macro_indicator_lag_days: int = 0 + macro_initial_value: float = 100.0 + macro_step_sigma: float = 0.5 + event_dates: list[date] = field(default_factory=list) + + +@dataclass +class MultiSeasonalityConfig: + """Configuration for yearly seasonality on top of weekly + monthly. + + Demand multiplier on day-of-year d is `1 + amplitude * sin(2π·(d + phase)/365)`. + + Attributes: + yearly_seasonality_amplitude: Fraction of base demand swung by the + yearly sin wave (e.g. 0.15 = ±15%). 0.0 disables. + yearly_phase_offset_days: Phase shift in days (positive = later peak). + """ + + yearly_seasonality_amplitude: float = 0.0 + yearly_phase_offset_days: int = 0 + + +@dataclass +class ChangepointEvent: + """A single demand changepoint (COVID-style impulse + exponential decay). + + Demand multiplier on day t for a changepoint at day t0 is: + `1 + (demand_multiplier - 1) * exp(-(t - t0) / decay_days)` + for `t >= t0`; 1.0 otherwise. + + Attributes: + date: Date of the changepoint impulse. + demand_multiplier: Peak multiplier on the changepoint date. + decay_days: e-folding time of the exponential decay. 0 = pure impulse. + """ + + date: date + demand_multiplier: float = 1.0 + decay_days: int = 30 + + +@dataclass +class ChangepointConfig: + """Configuration for trend changepoints. + + Attributes: + changepoints: List of changepoint events. Empty = disabled. + """ + + changepoints: list[ChangepointEvent] = field(default_factory=list) + + +@dataclass +class ReturnsConfig: + """Configuration for synthetic returns volume. + + Attributes: + enable: Whether to emit `sales_returns` rows at all. + return_probability: Probability that a given sale generates a return + (0.0 to 1.0). + return_lag_days_min: Minimum days between sale and return. + return_lag_days_max: Maximum days between sale and return. + return_quantity_fraction: Fraction of the original sale quantity that + is returned (clamped to ≥ 1 unit when a return fires). + return_reason_distribution: Probability-weighted reasons. Weights are + normalized at use time. + """ + + enable: bool = False + return_probability: float = 0.02 + return_lag_days_min: int = 1 + return_lag_days_max: int = 14 + return_quantity_fraction: float = 0.5 + return_reason_distribution: dict[str, float] = field( + default_factory=lambda: { + "defective": 0.25, + "wrong_size": 0.20, + "not_as_described": 0.15, + "changed_mind": 0.30, + "damaged_in_transit": 0.10, + } + ) + + +@dataclass +class SubstitutionConfig: + """Configuration for cross-product substitution on stockout. + + When product A in a substitute group is stocked out at a given store on + a given date, each other group-mate B sees its demand multiplied by + `1 + substitution_lift_on_stockout / (group_size - 1)` for that day. + + Attributes: + enable: Whether substitution is applied. + substitute_groups: Sets of product IDs that substitute for each + other. A product may appear in multiple groups. + substitution_lift_on_stockout: Total demand lift distributed across + group-mates when one member is stocked out (e.g. 0.5 = +50% + split among the others). + """ + + enable: bool = False + substitute_groups: list[list[int]] = field(default_factory=list) + substitution_lift_on_stockout: float = 0.0 + + +SalesChannel = Literal["in_store", "online", "click_collect", "wholesale"] +"""Valid values for ``sales_daily.channel`` — mirrors the SQL CHECK allow-list.""" + +LifecycleStage = Literal["intro", "growth", "maturity", "decline", "discontinued"] +"""Valid values for ``product.lifecycle_stage`` — mirrors the SQL CHECK allow-list.""" + +PromotionKind = Literal["pct_off", "bogo", "bundle", "markdown"] +"""Valid values for ``promotion.kind`` — mirrors the SQL CHECK allow-list.""" + +MarkdownTrigger = Literal["age_days", "stockout_risk", "lifecycle_decline"] +"""How a markdown event fires: stale inventory, projected stockout, or +lifecycle decline.""" + + +@dataclass +class ChannelConfig: + """Configuration for multi-channel sales (Phase 2). + + When ``enable_multichannel=False`` (default), every ``sales_daily`` row + carries ``channel='in_store'`` via the SQL server default, keeping the + regression invariant intact. When enabled, daily demand at a + ``(store, product, date)`` is split across the configured channel mix. + + Attributes: + enable_multichannel: Whether to split demand across channels. + channel_mix: Probability weights per channel name. Keys must be a + subset of ``("in_store", "online", "click_collect", "wholesale")``. + Weights are normalized at use time; sum need not equal 1. + online_promo_uplift: Multiplier applied to the online slice when a + promotion is active (e.g. 1.2 = +20%). + online_substitution_to_instore: Fraction of online-channel demand + that cannibalizes from the in-store channel when both are active + (0.0 = independent; 1.0 = pure substitution). + """ + + enable_multichannel: bool = False + channel_mix: dict[str, float] = field(default_factory=dict) + online_promo_uplift: float = 1.0 + online_substitution_to_instore: float = 0.0 + + +@dataclass +class LifecycleConfig: + """Configuration for product lifecycle stages (Phase 2). + + Disabled by default. When enabled, each product is assigned a + ``launch_date`` (drawn from a distribution within or before + ``start_date``) and optionally a ``discontinue_date``; the + lifecycle multiplier shapes demand over the ramp / steady / decay + curves. + + Attributes: + enable: Whether the lifecycle generator emits stage + dates. + intro_ramp_days: Days from ``launch_date`` to full velocity. + growth_ramp_days: Days the ``growth`` stage lasts. + maturity_steady_days: Days the ``maturity`` stage lasts before + decline begins. + decline_decay_days: e-folding time of demand decay in the + ``decline`` stage. + auto_progression: If True, the current stage is computed from + ``launch_date`` relative to each sales date; if False, the + stage is set once on the product row and held constant. + discontinue_probability: Probability that a given product is + discontinued during the seeded range (assigned a + ``discontinue_date`` after launch). + intro_multiplier: Demand floor at launch day (e.g. 0.1 = 10% of + base for the first day, ramping to 1.0 over ``intro_ramp_days``). + decline_multiplier: Demand floor at end of decline (e.g. 0.0 + means demand decays toward zero in the decline stage). + """ + + enable: bool = False + intro_ramp_days: int = 30 + growth_ramp_days: int = 60 + maturity_steady_days: int = 180 + decline_decay_days: int = 90 + auto_progression: bool = True + discontinue_probability: float = 0.0 + intro_multiplier: float = 0.1 + decline_multiplier: float = 0.0 + + +@dataclass +class BundleConfig: + """Configuration for BOGO/bundle promotion mechanics (Phase 2). + + Disabled by default. When enabled, a fraction of generated promotions + become bundle/BOGO promotions with explicit member product IDs. + + Attributes: + enable: Whether to emit bundle/BOGO promotions. + bundle_probability: Per-promotion probability that the promotion is + a bundle rather than the default ``pct_off``. + bogo_share_within_bundles: Of the bundle-classed promotions, the + fraction that are BOGO (the rest are multi-SKU bundles). + min_bundle_size: Minimum number of member products (>= 2). + max_bundle_size: Maximum number of member products. + bundle_discount_pct_min: Lower bound of the bundle discount. + bundle_discount_pct_max: Upper bound of the bundle discount. + bundle_uplift: Demand lift on each member when a bundle promo is + active (e.g. 1.4 = +40%). + """ + + enable: bool = False + bundle_probability: float = 0.0 + bogo_share_within_bundles: float = 0.5 + min_bundle_size: int = 2 + max_bundle_size: int = 3 + bundle_discount_pct_min: float = 0.10 + bundle_discount_pct_max: float = 0.30 + bundle_uplift: float = 1.4 + + +@dataclass +class MarkdownConfig: + """Configuration for clearance markdowns (Phase 2). + + Markdowns are price-driven clearance events distinct from promo lifts. + Disabled by default. When enabled, eligible products are marked down + according to the chosen trigger, and the markdown is recorded both as + a ``price_history`` drop and a ``promotion`` row with ``kind='markdown'``. + + Attributes: + enable: Whether markdowns fire. + trigger: Criterion: ``age_days`` (inventory older than X days), + ``stockout_risk`` (projected stockout within Y days), or + ``lifecycle_decline`` (product in decline stage). + markdown_depth_pct: Fraction below base price (0.0-1.0). + markdown_min_units_remaining: Required inventory level for the + markdown to fire under ``age_days`` / ``stockout_risk``. + age_days_threshold: Days of stale inventory under ``age_days``. + markdown_demand_lift: Demand multiplier while markdown is active. + markdown_duration_days: How long a markdown lasts. + """ + + enable: bool = False + trigger: MarkdownTrigger = "lifecycle_decline" + markdown_depth_pct: float = 0.30 + markdown_min_units_remaining: int = 5 + age_days_threshold: int = 60 + markdown_demand_lift: float = 1.2 + markdown_duration_days: int = 14 + + +@dataclass +class LeadTimeConfig: + """Configuration for lead-time-driven replenishment (Phase 2). + + Disabled by default. When enabled, ``ReplenishmentGenerator`` emits + ``replenishment_event`` rows that drive inventory and stockout + clustering: orders are placed every ``order_frequency_days`` and + received ``lead_time_days`` later; on-hand inventory between + receipts can drop to zero, producing realistic stockout windows. + + Attributes: + enable: Whether to emit replenishment events. + mean_lead_time_days: Mean of the Normal-distributed lead time. + lead_time_sigma_days: Standard deviation of the lead time. + safety_stock_days: Days of average demand kept as safety stock. + order_frequency_days: How often a new PO is placed per + (store, product). + fill_rate_mean: Mean fraction of ordered units that arrive + (1.0 = always fully shipped). + fill_rate_sigma: Standard deviation of the fill rate. + """ + + enable: bool = False + mean_lead_time_days: int = 7 + lead_time_sigma_days: float = 1.5 + safety_stock_days: int = 3 + order_frequency_days: int = 14 + fill_rate_mean: float = 0.97 + fill_rate_sigma: float = 0.05 + + @dataclass class SeederConfig: """Master configuration for the data seeder. @@ -139,6 +456,16 @@ class SeederConfig: retail: Retail-specific pattern configuration. sparsity: Data sparsity configuration. holidays: List of holiday configurations. + exogenous: Phase 1 exogenous signal generation (disabled by default). + multi_seasonality: Phase 1 yearly seasonality (disabled by default). + changepoints: Phase 1 trend changepoints (empty by default). + returns: Phase 1 returns volume (disabled by default). + substitution: Phase 1 stockout substitution (disabled by default). + channels: Phase 2 multi-channel sales (disabled by default). + lifecycle: Phase 2 product lifecycle (disabled by default). + bundles: Phase 2 BOGO/bundle promotions (disabled by default). + markdowns: Phase 2 clearance markdowns (disabled by default). + lead_time: Phase 2 replenishment lead time (disabled by default). batch_size: Batch size for database inserts. enable_progress: Whether to show progress bars. """ @@ -151,6 +478,16 @@ class SeederConfig: retail: RetailPatternConfig = field(default_factory=RetailPatternConfig) sparsity: SparsityConfig = field(default_factory=SparsityConfig) holidays: list[HolidayConfig] = field(default_factory=list) + exogenous: ExogenousSignalConfig = field(default_factory=ExogenousSignalConfig) + multi_seasonality: MultiSeasonalityConfig = field(default_factory=MultiSeasonalityConfig) + changepoints: ChangepointConfig = field(default_factory=ChangepointConfig) + returns: ReturnsConfig = field(default_factory=ReturnsConfig) + substitution: SubstitutionConfig = field(default_factory=SubstitutionConfig) + channels: ChannelConfig = field(default_factory=ChannelConfig) + lifecycle: LifecycleConfig = field(default_factory=LifecycleConfig) + bundles: BundleConfig = field(default_factory=BundleConfig) + markdowns: MarkdownConfig = field(default_factory=MarkdownConfig) + lead_time: LeadTimeConfig = field(default_factory=LeadTimeConfig) batch_size: int = 1000 enable_progress: bool = True diff --git a/app/shared/seeder/core.py b/app/shared/seeder/core.py index 830ac962..655f83e7 100644 --- a/app/shared/seeder/core.py +++ b/app/shared/seeder/core.py @@ -15,22 +15,32 @@ from app.core.logging import get_logger from app.features.data_platform.models import ( Calendar, + ExogenousSignal, InventorySnapshotDaily, PriceHistory, Product, Promotion, + ReplenishmentEvent, SalesDaily, + SalesReturn, Store, ) from app.shared.seeder.generators import ( + BundleGenerator, CalendarGenerator, + ExogenousSignalGenerator, InventorySnapshotGenerator, + LifecycleGenerator, + MarkdownGenerator, PriceHistoryGenerator, ProductGenerator, PromotionGenerator, + ReplenishmentGenerator, + ReturnsGenerator, SalesDailyGenerator, StoreGenerator, ) +from app.shared.seeder.generators.exogenous import WEATHER_SIGNAL_NAME if TYPE_CHECKING: from app.shared.seeder.config import SeederConfig @@ -38,6 +48,38 @@ logger = get_logger(__name__) +# Canonical promotion-row shape — every record inserted into the +# ``promotion`` table must carry exactly these keys so the bulk +# ``pg_insert(...).values([...])`` builds a uniform VALUES clause. +# Defaults match the SQL server defaults / CHECK constraint: +# ``kind`` defaults to ``"pct_off"`` (server default) and +# ``bundle_member_product_ids`` is NULL unless ``kind in (bundle, bogo)``. +_PROMOTION_DEFAULTS: dict[str, Any] = { + "kind": "pct_off", + "discount_pct": None, + "discount_amount": None, + "bundle_member_product_ids": None, +} + + +def _normalize_promotion_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Ensure every promotion record carries the canonical key set. + + ``PromotionGenerator`` emits records without ``kind`` / + ``bundle_member_product_ids``; ``BundleGenerator`` mutates a subset + of those into ``bogo`` / ``bundle`` rows; ``MarkdownGenerator`` + appends ``kind='markdown'`` rows. PostgreSQL multi-row INSERT + requires uniform keys across the batch, so we patch missing keys + with their schema defaults in place. + + Existing values are preserved — ``setdefault`` only fills gaps. + """ + for record in records: + for key, default in _PROMOTION_DEFAULTS.items(): + record.setdefault(key, default) + return records + + @dataclass class SeederResult: """Result of a seeder operation. @@ -50,6 +92,9 @@ class SeederResult: price_history_count: Number of price history records. promotions_count: Number of promotions generated. inventory_count: Number of inventory snapshots. + exogenous_count: Number of exogenous signal records (Phase 1). + returns_count: Number of sales return records (Phase 1). + replenishment_count: Number of replenishment_event records (Phase 2). seed: Random seed used. """ @@ -60,6 +105,9 @@ class SeederResult: price_history_count: int = 0 promotions_count: int = 0 inventory_count: int = 0 + exogenous_count: int = 0 + returns_count: int = 0 + replenishment_count: int = 0 seed: int = 42 @@ -119,14 +167,24 @@ async def _batch_insert( async def _generate_dimensions( self, db: AsyncSession, - ) -> tuple[list[int], list[tuple[int, Decimal]], list[date]]: + ) -> tuple[ + list[int], + list[tuple[int, Decimal]], + list[date], + dict[int, tuple[date | None, date | None]], + ]: """Generate and insert dimension tables. Args: db: Async database session. Returns: - Tuple of (store_ids, product_data, dates). + Tuple of ``(store_ids, product_data, dates, + product_lifecycle_data)``. ``product_lifecycle_data`` maps + ``product_id -> (launch_date, discontinue_date)``. When + lifecycle is disabled the dict is still returned but every + value is ``(None, None)`` so the downstream multiplier + short-circuits to 1.0. """ # Generate stores store_gen = StoreGenerator(self.rng, self.config.dimensions) @@ -143,8 +201,15 @@ async def _generate_dimensions( result = await db.execute(select(Store.id)) store_ids = [row[0] for row in result.fetchall()] - # Generate products - product_gen = ProductGenerator(self.rng, self.config.dimensions) + # Generate products. Phase 2: pass lifecycle config + date_range + # when lifecycle is enabled so product rows pick up launch / + # discontinue / stage attributes. Disabled path is byte-identical. + product_gen = ProductGenerator( + self.rng, + self.config.dimensions, + lifecycle_config=self.config.lifecycle, + date_range=(self.config.start_date, self.config.end_date), + ) product_records = product_gen.generate() logger.info( @@ -154,9 +219,23 @@ async def _generate_dimensions( await self._batch_insert(db, Product, product_records) - # Fetch product IDs with base prices - result = await db.execute(select(Product.id, Product.base_price)) - product_data = [(row[0], row[1] or Decimal("9.99")) for row in result.fetchall()] + # Fetch product IDs with base prices + lifecycle dates. Single + # query keeps the row-set consistent (re-querying could race + # with concurrent writers, though seeder is single-tenant). + rows = ( + await db.execute( + select( + Product.id, + Product.base_price, + Product.launch_date, + Product.discontinue_date, + ) + ) + ).fetchall() + product_data = [(row[0], row[1] or Decimal("9.99")) for row in rows] + product_lifecycle_data: dict[int, tuple[date | None, date | None]] = { + row[0]: (row[2], row[3]) for row in rows + } # Generate calendar calendar_gen = CalendarGenerator( @@ -180,7 +259,46 @@ async def _generate_dimensions( dates.append(current) current += timedelta(days=1) - return store_ids, product_data, dates + return store_ids, product_data, dates, product_lifecycle_data + + async def _generate_exogenous( + self, + db: AsyncSession, + store_ids: list[int], + dates: list[date], + ) -> tuple[int, dict[tuple[int, date], float]]: + """Generate exogenous signals (Phase 1). + + Returns: + Tuple of (rows_inserted, weather_lookup) where ``weather_lookup`` + is ``{(store_id, date): temp_c}`` for downstream demand math. + Empty dict if weather is disabled. + """ + exo_gen = ExogenousSignalGenerator(self.rng, self.config.exogenous) + records = exo_gen.generate(dates, store_ids) + + if not records: + return 0, {} + + logger.info("seeder.exogenous.generating", count=len(records)) + inserted = await self._batch_insert(db, ExogenousSignal, records) + + weather_lookup: dict[tuple[int, date], float] = {} + if self.config.exogenous.enable_weather: + for r in records: + if r["signal_name"] != WEATHER_SIGNAL_NAME: + continue + store_id = r["store_id"] + signal_date = r["date"] + value = r["value"] + if ( + isinstance(store_id, int) + and isinstance(signal_date, date) + and isinstance(value, float) + ): + weather_lookup[(store_id, signal_date)] = value + + return inserted, weather_lookup async def _generate_facts( self, @@ -188,7 +306,9 @@ async def _generate_facts( store_ids: list[int], product_data: list[tuple[int, Decimal]], dates: list[date], - ) -> tuple[int, int, int, int]: + weather_lookup: dict[tuple[int, date], float] | None = None, + product_lifecycle_data: dict[int, tuple[date | None, date | None]] | None = None, + ) -> tuple[int, int, int, int, int, int]: """Generate and insert fact tables. Args: @@ -196,9 +316,18 @@ async def _generate_facts( store_ids: List of store IDs. product_data: List of (product_id, base_price) tuples. dates: List of dates. + weather_lookup: Optional ``{(store_id, date): temp_c}`` from the + exogenous generator. Demand picks up weather sensitivity only + when this dict is non-empty AND + ``config.exogenous.weather_temperature_sensitivity`` is non-zero. + product_lifecycle_data: Optional Phase 2 mapping + ``product_id -> (launch_date, discontinue_date)``. Consumed + by ``SalesDailyGenerator``'s lifecycle multiplier and by + ``MarkdownGenerator`` for the ``lifecycle_decline`` trigger. Returns: - Tuple of (sales_count, price_history_count, promotions_count, inventory_count). + Tuple of (sales_count, price_history_count, promotions_count, + inventory_count, returns_count, replenishment_count). """ product_ids = [pid for pid, _ in product_data] @@ -211,13 +340,6 @@ async def _generate_facts( self.config.end_date, ) - logger.info( - "seeder.price_history.generating", - count=len(price_records), - ) - - await self._batch_insert(db, PriceHistory, price_records) - # Generate promotions promo_gen = PromotionGenerator( self.rng, @@ -230,12 +352,10 @@ async def _generate_facts( self.config.end_date, ) - logger.info( - "seeder.promotions.generating", - count=len(promo_records), - ) - - await self._batch_insert(db, Promotion, promo_records) + # Phase 2: convert a slice of promotions to bundle/BOGO in place. + # Disabled path is a no-op (zero rng draws, no mutation). + bundle_gen = BundleGenerator(self.rng, self.config.bundles) + bundle_gen.apply(promo_records, product_ids) # Generate inventory snapshots inventory_gen = InventorySnapshotGenerator( @@ -248,20 +368,82 @@ async def _generate_facts( dates, ) + # Phase 2: emit markdown promo rows + price drops. Disabled path + # returns empty containers and consumes zero rng. Built BEFORE + # promotion insert so markdown rows ship in the same batch. + lifecycle_gen = LifecycleGenerator(self.config.lifecycle) + product_specs: list[dict[str, Any]] = [ + { + "product_id": pid, + "base_price": price, + "launch_date": (product_lifecycle_data or {}).get(pid, (None, None))[0], + "discontinue_date": (product_lifecycle_data or {}).get(pid, (None, None))[1], + } + for pid, price in product_data + ] + markdown_gen = MarkdownGenerator(self.rng, self.config.markdowns) + ( + markdown_promo_records, + markdown_price_records, + _markdown_dates, + ) = markdown_gen.generate( + product_specs=product_specs, + store_ids=store_ids, + stockout_dates=stockout_dates, + dates=dates, + lifecycle=lifecycle_gen, + ) + + # Merge markdown outputs into the main lists, then normalize so + # every promotion row carries the same key set (required for + # pg_insert multi-row INSERT). The disabled-path lists are empty + # so the merge is a no-op. + promo_records.extend(markdown_promo_records) + price_records.extend(markdown_price_records) + _normalize_promotion_records(promo_records) + + logger.info( + "seeder.price_history.generating", + count=len(price_records), + ) + await self._batch_insert(db, PriceHistory, price_records) + + logger.info( + "seeder.promotions.generating", + count=len(promo_records), + ) + await self._batch_insert(db, Promotion, promo_records) + logger.info( "seeder.inventory.generating", count=len(inventory_records), ) - await self._batch_insert(db, InventorySnapshotDaily, inventory_records) - # Generate sales (depends on promotions and stockouts) + # Generate sales (depends on promotions and stockouts). Phase 1 + # extensions stay as None / 0 when their config flags are off so the + # disabled-path is byte-identical with pre-Phase-1. Phase 2 + # lifecycle / channels are gated by their own enable flags inside + # the generator. + weather_lookup_for_sales = ( + weather_lookup + if weather_lookup and self.config.exogenous.weather_temperature_sensitivity != 0.0 + else None + ) sales_gen = SalesDailyGenerator( self.rng, self.config.time_series, self.config.retail, self.config.sparsity, self.config.holidays, + multi_seasonality=self.config.multi_seasonality, + changepoints=self.config.changepoints, + substitution=self.config.substitution, + exogenous_weather=weather_lookup_for_sales, + weather_temperature_sensitivity=(self.config.exogenous.weather_temperature_sensitivity), + weather_climatology_mean_c=self.config.exogenous.weather_climatology_mean_c, + lifecycle=lifecycle_gen, + channels=self.config.channels, ) sales_records = sales_gen.generate( store_ids, @@ -269,6 +451,7 @@ async def _generate_facts( dates, promo_dates, stockout_dates, + product_lifecycle_data=product_lifecycle_data, ) logger.info( @@ -278,11 +461,34 @@ async def _generate_facts( await self._batch_insert(db, SalesDaily, sales_records) + # Generate returns (Phase 1) — depends on sales. Returns config is + # disabled by default; generator short-circuits to an empty list. + returns_gen = ReturnsGenerator(self.rng, self.config.returns) + returns_records = returns_gen.generate(sales_records, self.config.end_date) + if returns_records: + logger.info("seeder.returns.generating", count=len(returns_records)) + await self._batch_insert(db, SalesReturn, returns_records) + + # Phase 2: emit replenishment_event rows. Disabled path returns + # an empty list and consumes zero rng. + replenishment_gen = ReplenishmentGenerator(self.rng, self.config.lead_time) + replenishment_records = replenishment_gen.generate( + store_ids, + product_ids, + dates, + base_demand=self.config.time_series.base_demand, + ) + if replenishment_records: + logger.info("seeder.replenishment.generating", count=len(replenishment_records)) + await self._batch_insert(db, ReplenishmentEvent, replenishment_records) + return ( len(sales_records), len(price_records), len(promo_records), len(inventory_records), + len(returns_records), + len(replenishment_records), ) async def generate_full(self, db: AsyncSession) -> SeederResult: @@ -307,11 +513,31 @@ async def generate_full(self, db: AsyncSession) -> SeederResult: ) # Generate dimensions first - store_ids, product_data, dates = await self._generate_dimensions(db) + ( + store_ids, + product_data, + dates, + product_lifecycle_data, + ) = await self._generate_dimensions(db) + + # Phase 1: generate exogenous signals (no-op when no signal is enabled). + exogenous_count, weather_lookup = await self._generate_exogenous(db, store_ids, dates) # Generate facts - sales_count, price_count, promo_count, inventory_count = await self._generate_facts( - db, store_ids, product_data, dates + ( + sales_count, + price_count, + promo_count, + inventory_count, + returns_count, + replenishment_count, + ) = await self._generate_facts( + db, + store_ids, + product_data, + dates, + weather_lookup, + product_lifecycle_data=product_lifecycle_data, ) # Commit all changes @@ -325,6 +551,9 @@ async def generate_full(self, db: AsyncSession) -> SeederResult: price_history_count=price_count, promotions_count=promo_count, inventory_count=inventory_count, + exogenous_count=exogenous_count, + returns_count=returns_count, + replenishment_count=replenishment_count, seed=self.config.seed, ) @@ -334,6 +563,9 @@ async def generate_full(self, db: AsyncSession) -> SeederResult: products=result.products_count, calendar_days=result.calendar_days, sales=result.sales_count, + exogenous=result.exogenous_count, + returns=result.returns_count, + replenishment=result.replenishment_count, seed=self.config.seed, ) @@ -372,9 +604,23 @@ async def append_data( if not store_ids: raise ValueError("No stores found. Run --full-new first to create dimensions.") - # Fetch existing product data - result = await db.execute(select(Product.id, Product.base_price)) - product_data = [(row[0], row[1] or Decimal("9.99")) for row in result.fetchall()] + # Fetch existing product data (with lifecycle dates for Phase 2). + # Lifecycle multiplier short-circuits to 1.0 for products with + # NULL launch_date so the disabled path is byte-identical. + rows = ( + await db.execute( + select( + Product.id, + Product.base_price, + Product.launch_date, + Product.discontinue_date, + ) + ) + ).fetchall() + product_data = [(row[0], row[1] or Decimal("9.99")) for row in rows] + product_lifecycle_data: dict[int, tuple[date | None, date | None]] = { + row[0]: (row[2], row[3]) for row in rows + } if not product_data: raise ValueError("No products found. Run --full-new first to create dimensions.") @@ -397,9 +643,24 @@ async def append_data( dates.append(current) current += timedelta(days=1) + # Phase 1: append exogenous signals for the new range (no-op when off). + exogenous_count, weather_lookup = await self._generate_exogenous(db, store_ids, dates) + # Generate facts for new date range - sales_count, price_count, promo_count, inventory_count = await self._generate_facts( - db, store_ids, product_data, dates + ( + sales_count, + price_count, + promo_count, + inventory_count, + returns_count, + replenishment_count, + ) = await self._generate_facts( + db, + store_ids, + product_data, + dates, + weather_lookup, + product_lifecycle_data=product_lifecycle_data, ) await db.commit() @@ -412,6 +673,9 @@ async def append_data( price_history_count=price_count, promotions_count=promo_count, inventory_count=inventory_count, + exogenous_count=exogenous_count, + returns_count=returns_count, + replenishment_count=replenishment_count, seed=self.config.seed, ) @@ -419,6 +683,9 @@ async def append_data( "seeder.append.completed", calendar_days=result_data.calendar_days, sales=result_data.sales_count, + exogenous=result_data.exogenous_count, + returns=result_data.returns_count, + replenishment=result_data.replenishment_count, ) return result_data @@ -441,8 +708,16 @@ async def delete_data( """ counts: dict[str, int] = {} - # Get current counts + # Get current counts. Phase 2 ``replenishment_event`` leads — it + # FKs to store/product/calendar but no other table FKs into it, + # so dropping first removes the leaf safely. Phase 1 tables come + # next (sales_returns FKs to product/store, exogenous_signal FKs + # to store/calendar), then the older fact tables. The order keeps + # the dimension/calendar wipe free of FK violations. fact_tables = [ + ("replenishment_event", ReplenishmentEvent), + ("sales_returns", SalesReturn), + ("exogenous_signal", ExogenousSignal), ("sales_daily", SalesDaily), ("inventory_snapshot_daily", InventorySnapshotDaily), ("price_history", PriceHistory), @@ -527,6 +802,9 @@ async def get_current_counts(self, db: AsyncSession) -> dict[str, int]: ("price_history", PriceHistory), ("promotion", Promotion), ("inventory_snapshot_daily", InventorySnapshotDaily), + ("exogenous_signal", ExogenousSignal), + ("sales_returns", SalesReturn), + ("replenishment_event", ReplenishmentEvent), ] counts: dict[str, int] = {} @@ -585,4 +863,67 @@ async def verify_data_integrity(self, db: AsyncSession) -> list[str]: f"Calendar gap detected: expected {expected_days} days, found {actual_days}" ) + # Phase 1: sales_returns must never carry quantity <= 0 (CHECK + # constraint guards this at the DB layer, but a defensive count + # catches drift if a future generator drops the invariant). + neg_return_check = text("SELECT COUNT(*) FROM sales_returns WHERE return_quantity < 1") + result = await db.execute(neg_return_check) + neg_returns = result.scalar() or 0 + if neg_returns > 0: + errors.append(f"Found {neg_returns} sales_returns with non-positive quantity") + + # Phase 1: exogenous_signal global/per-store consistency. + bad_global_check = text( + "SELECT COUNT(*) FROM exogenous_signal " + "WHERE (is_global = true AND store_id IS NOT NULL) " + " OR (is_global = false AND store_id IS NULL)" + ) + result = await db.execute(bad_global_check) + bad_global = result.scalar() or 0 + if bad_global > 0: + errors.append( + f"Found {bad_global} exogenous_signal rows violating " + "is_global / store_id consistency" + ) + + # Phase 2: bundle / BOGO promotions must declare their member + # product IDs. The CHECK constraint enforces this at the DB + # layer; the count below catches generator drift early. + bundle_consistency_check = text( + "SELECT COUNT(*) FROM promotion " + "WHERE kind IN ('bundle', 'bogo') AND bundle_member_product_ids IS NULL" + ) + result = await db.execute(bundle_consistency_check) + bad_bundles = result.scalar() or 0 + if bad_bundles > 0: + errors.append( + f"Found {bad_bundles} bundle/BOGO promotions with NULL bundle_member_product_ids" + ) + + # Phase 2: lifecycle date ordering — discontinue_date must be + # on or after launch_date when both are set. Also caught by the + # ``ck_product_lifecycle_date_order`` CHECK; defensive count for + # generator drift. + bad_lifecycle_check = text( + "SELECT COUNT(*) FROM product " + "WHERE discontinue_date IS NOT NULL AND launch_date IS NOT NULL " + " AND discontinue_date < launch_date" + ) + result = await db.execute(bad_lifecycle_check) + bad_lifecycle = result.scalar() or 0 + if bad_lifecycle > 0: + errors.append(f"Found {bad_lifecycle} products with discontinue_date < launch_date") + + # Phase 2: replenishment fill rate — received_qty must never + # exceed ordered_qty. DB-enforced; defensive count. + bad_fill_check = text( + "SELECT COUNT(*) FROM replenishment_event WHERE received_qty > ordered_qty" + ) + result = await db.execute(bad_fill_check) + bad_fill = result.scalar() or 0 + if bad_fill > 0: + errors.append( + f"Found {bad_fill} replenishment_event rows with received_qty > ordered_qty" + ) + return errors diff --git a/app/shared/seeder/generators/__init__.py b/app/shared/seeder/generators/__init__.py index a8083550..ad1e8bd7 100644 --- a/app/shared/seeder/generators/__init__.py +++ b/app/shared/seeder/generators/__init__.py @@ -1,21 +1,33 @@ -"""Data generators for dimensions and facts.""" - -from app.shared.seeder.generators.calendar import CalendarGenerator -from app.shared.seeder.generators.facts import ( - InventorySnapshotGenerator, - PriceHistoryGenerator, - PromotionGenerator, - SalesDailyGenerator, -) -from app.shared.seeder.generators.product import ProductGenerator -from app.shared.seeder.generators.store import StoreGenerator - -__all__ = [ - "CalendarGenerator", - "InventorySnapshotGenerator", - "PriceHistoryGenerator", - "ProductGenerator", - "PromotionGenerator", - "SalesDailyGenerator", - "StoreGenerator", -] +"""Data generators for dimensions and facts.""" + +from app.shared.seeder.generators.bundles import BundleGenerator +from app.shared.seeder.generators.calendar import CalendarGenerator +from app.shared.seeder.generators.exogenous import ExogenousSignalGenerator +from app.shared.seeder.generators.facts import ( + InventorySnapshotGenerator, + PriceHistoryGenerator, + PromotionGenerator, + SalesDailyGenerator, +) +from app.shared.seeder.generators.lifecycle import LifecycleGenerator +from app.shared.seeder.generators.markdowns import MarkdownGenerator +from app.shared.seeder.generators.product import ProductGenerator +from app.shared.seeder.generators.replenishment import ReplenishmentGenerator +from app.shared.seeder.generators.returns import ReturnsGenerator +from app.shared.seeder.generators.store import StoreGenerator + +__all__ = [ + "BundleGenerator", + "CalendarGenerator", + "ExogenousSignalGenerator", + "InventorySnapshotGenerator", + "LifecycleGenerator", + "MarkdownGenerator", + "PriceHistoryGenerator", + "ProductGenerator", + "PromotionGenerator", + "ReplenishmentGenerator", + "ReturnsGenerator", + "SalesDailyGenerator", + "StoreGenerator", +] diff --git a/app/shared/seeder/generators/bundles.py b/app/shared/seeder/generators/bundles.py new file mode 100644 index 00000000..8364da94 --- /dev/null +++ b/app/shared/seeder/generators/bundles.py @@ -0,0 +1,123 @@ +"""Phase 2 bundle/BOGO promotion converter. + +Wraps :class:`PromotionGenerator`'s output: with probability +``BundleConfig.bundle_probability``, an eligible promotion is converted +into a ``kind='bundle'`` or ``kind='bogo'`` row with a list of member +product IDs and a discount drawn from the configured range. When the +feature is disabled the input list is returned untouched and no rng +state is consumed, preserving the byte-identical regression invariant. +""" + +from __future__ import annotations + +import random +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from app.shared.seeder.config import BundleConfig + + +# ``promotion.discount_pct`` is ``Numeric(5, 4)`` — 4 decimal places. +_DISCOUNT_QUANTIZE = Decimal("0.0001") + + +class BundleGenerator: + """Convert a slice of generated promotions into bundle / BOGO rows. + + Each converted promotion consumes four rng draws in this locked + order — ``random()`` (convert?), ``random()`` (bogo-or-bundle?), + ``randint()`` (n_members), ``sample()`` (members), ``uniform()`` + (discount). Per-promo skips for too-small product pools happen + *before* the first rng draw, so the rng stream is stable across + runs where only the eligible pool shrinks. + """ + + def __init__(self, rng: random.Random, config: BundleConfig | None) -> None: + """Initialize the bundle generator. + + Args: + rng: Random number generator for reproducibility. + config: Phase 2 bundle configuration. When ``None`` or + ``enable=False`` :meth:`apply` is a no-op that touches + neither the promotion list nor the rng. + """ + self.rng = rng + self.config = config + + @property + def enabled(self) -> bool: + return self.config is not None and self.config.enable + + def apply( + self, + promotions: list[dict[str, Any]], + product_pool: list[int], + ) -> list[dict[str, Any]]: + """Convert a fraction of promotions to bundle/BOGO kinds. + + Args: + promotions: List of promotion record dicts as produced by + :class:`PromotionGenerator.generate`. Mutated in place + when the generator is enabled. + product_pool: All product IDs in the seeded scenario. + Bundle members are drawn from this pool excluding the + host product of each promotion. + + Returns: + The same ``promotions`` list reference. Untouched and with + zero rng consumption when the generator is disabled or when + every promo's eligible pool is below ``min_bundle_size``. + + Raises: + ValueError: If the configuration violates the bundle-size + invariants (``min_bundle_size < 2`` or ``max < min``). + """ + if not self.enabled or self.config is None: + return promotions + + cfg = self.config + if cfg.min_bundle_size < 2: + raise ValueError( + f"BundleConfig.min_bundle_size must be >= 2, got {cfg.min_bundle_size}" + ) + if cfg.max_bundle_size < cfg.min_bundle_size: + raise ValueError( + "BundleConfig.max_bundle_size must be >= min_bundle_size " + f"(got min={cfg.min_bundle_size}, max={cfg.max_bundle_size})" + ) + + for record in promotions: + host_product_id = record["product_id"] + eligible_members = [pid for pid in product_pool if pid != host_product_id] + # Best-effort skip when the pool is too small to satisfy + # ``min_bundle_size``. Done before any rng draw so a smaller + # pool doesn't desync the rng stream from a larger run. + if len(eligible_members) < cfg.min_bundle_size: + continue + + if self.rng.random() >= cfg.bundle_probability: + continue + + kind = "bogo" if self.rng.random() < cfg.bogo_share_within_bundles else "bundle" + n_members = self.rng.randint( + cfg.min_bundle_size, + min(cfg.max_bundle_size, len(eligible_members)), + ) + members = self.rng.sample(eligible_members, n_members) + discount = self.rng.uniform( + cfg.bundle_discount_pct_min, + cfg.bundle_discount_pct_max, + ) + + record["kind"] = kind + record["discount_pct"] = Decimal(str(discount)).quantize(_DISCOUNT_QUANTIZE) + # ``ck_promotion_bundle_members_consistency`` allows either + # discount on a bundle/BOGO row, but PromotionGenerator picks + # exactly one of ``discount_pct`` / ``discount_amount`` per + # source row. We always use ``discount_pct`` for bundles, so + # clear any prior amount to keep the row internally tidy. + record["discount_amount"] = None + record["bundle_member_product_ids"] = members + + return promotions diff --git a/app/shared/seeder/generators/exogenous.py b/app/shared/seeder/generators/exogenous.py new file mode 100644 index 00000000..c720ba48 --- /dev/null +++ b/app/shared/seeder/generators/exogenous.py @@ -0,0 +1,146 @@ +"""Exogenous signal generator (weather, macro index, event flags). + +Phase 1 of the seeder realism extension. Produces rows for the +``exogenous_signal`` table. Each enabled signal contributes records; +disabled signals contribute zero rows so callers that don't opt in see +no Phase 1 side effects. + +The output schema matches ``app.features.data_platform.models.ExogenousSignal``: + + {"date", "signal_name", "store_id", "is_global", "value"} + +Reproducibility: this generator uses the seeder's ``random.Random`` instance +(NOT numpy.random) so identical seeds produce identical signal series. +""" + +from __future__ import annotations + +import math +import random +from datetime import date +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.shared.seeder.config import ExogenousSignalConfig + + +WEATHER_SIGNAL_NAME = "weather_temp_c" +MACRO_SIGNAL_NAME = "macro_index" +EVENT_SIGNAL_NAME = "event_flag" + + +class ExogenousSignalGenerator: + """Generator for exogenous demand signals. + + Produces one row per (signal, date[, store]) for each enabled signal: + + - ``weather_temp_c``: per (store, date). Temperature in °C following a + yearly sin wave with Gaussian noise. ``is_global=False``. + - ``macro_index``: per date. Random walk starting at + ``macro_initial_value``. ``is_global=True``. + - ``event_flag``: per ``event_dates`` entry. Binary 1.0 marker. + ``is_global=True``. + """ + + def __init__(self, rng: random.Random, config: ExogenousSignalConfig) -> None: + """Initialize the generator. + + Args: + rng: Seeded random number generator. + config: Exogenous signal configuration. + """ + self.rng = rng + self.config = config + + def _weather_row( + self, + signal_date: date, + store_id: int, + day_of_year: int, + ) -> dict[str, date | int | bool | str | float | None]: + """Compute one weather sample for (store, date). + + Uses a sinusoidal seasonal cycle around the climatological mean with + peak in mid-July (day-of-year 196) for the northern hemisphere. + """ + # Phase chosen so peak is around day 196 (mid-July): sin peaks at π/2, + # so we want 2π(d - 105)/365 = π/2 → d = 196. + phase_rad = 2.0 * math.pi * (day_of_year - 105) / 365.0 + seasonal = self.config.weather_amplitude_c * math.sin(phase_rad) + noise = self.rng.gauss(0.0, self.config.weather_noise_sigma_c) + value = self.config.weather_climatology_mean_c + seasonal + noise + return { + "date": signal_date, + "signal_name": WEATHER_SIGNAL_NAME, + "store_id": store_id, + "is_global": False, + "value": value, + } + + def _macro_rows( + self, dates: list[date] + ) -> list[dict[str, date | int | bool | str | float | None]]: + """Random-walk macro index, one row per date.""" + records: list[dict[str, date | int | bool | str | float | None]] = [] + value = self.config.macro_initial_value + for d in dates: + value += self.rng.gauss(0.0, self.config.macro_step_sigma) + records.append( + { + "date": d, + "signal_name": MACRO_SIGNAL_NAME, + "store_id": None, + "is_global": True, + "value": value, + } + ) + return records + + def _event_rows( + self, dates: list[date] + ) -> list[dict[str, date | int | bool | str | float | None]]: + """Binary event-flag rows for configured event dates within range.""" + if not self.config.event_dates: + return [] + date_set = set(dates) + return [ + { + "date": event_date, + "signal_name": EVENT_SIGNAL_NAME, + "store_id": None, + "is_global": True, + "value": 1.0, + } + for event_date in self.config.event_dates + if event_date in date_set + ] + + def generate( + self, dates: list[date], store_ids: list[int] + ) -> list[dict[str, date | int | bool | str | float | None]]: + """Generate exogenous signal rows. + + Args: + dates: Dates in the seeded range (sorted ascending). + store_ids: Store IDs for per-store signals. + + Returns: + List of row dicts ready for batch insert. Empty when no signal + is enabled. + """ + records: list[dict[str, date | int | bool | str | float | None]] = [] + + if self.config.enable_weather and store_ids and dates: + # Iterate stores in the outer loop so the rng draws per store + # are deterministic and reproducible. + for store_id in store_ids: + for d in dates: + records.append(self._weather_row(d, store_id, d.timetuple().tm_yday)) + + if self.config.enable_macro and dates: + records.extend(self._macro_rows(dates)) + + if self.config.enable_events: + records.extend(self._event_rows(dates)) + + return records diff --git a/app/shared/seeder/generators/facts.py b/app/shared/seeder/generators/facts.py index 78e0f5eb..30c191fc 100644 --- a/app/shared/seeder/generators/facts.py +++ b/app/shared/seeder/generators/facts.py @@ -10,15 +10,31 @@ if TYPE_CHECKING: from app.shared.seeder.config import ( + ChangepointConfig, + ChannelConfig, HolidayConfig, + MultiSeasonalityConfig, RetailPatternConfig, SparsityConfig, + SubstitutionConfig, TimeSeriesConfig, ) + from app.shared.seeder.generators.lifecycle import LifecycleGenerator + + +_VALID_CHANNELS = frozenset({"in_store", "online", "click_collect", "wholesale"}) +"""Mirrors the SQL CHECK on ``sales_daily.channel`` (see PRP-12 §schema).""" class SalesDailyGenerator: - """Generator for daily sales fact data with realistic time-series patterns.""" + """Generator for daily sales fact data with realistic time-series patterns. + + Phase 1 extensions (``multi_seasonality``, ``changepoints``, + ``substitution``, ``exogenous_weather``) and Phase 2 extension + (``lifecycle``) are all opt-in. When every opt-in input is None / + disabled, the generator's output is byte-identical to its + pre-Phase-1 behavior. + """ def __init__( self, @@ -27,6 +43,14 @@ def __init__( retail_config: RetailPatternConfig, sparsity_config: SparsityConfig, holidays: list[HolidayConfig], + multi_seasonality: MultiSeasonalityConfig | None = None, + changepoints: ChangepointConfig | None = None, + substitution: SubstitutionConfig | None = None, + exogenous_weather: dict[tuple[int, date], float] | None = None, + weather_temperature_sensitivity: float = 0.0, + weather_climatology_mean_c: float = 15.0, + lifecycle: LifecycleGenerator | None = None, + channels: ChannelConfig | None = None, ) -> None: """Initialize the sales generator. @@ -36,12 +60,245 @@ def __init__( retail_config: Retail-specific pattern configuration. sparsity_config: Data sparsity configuration. holidays: List of holiday configurations with multipliers. + multi_seasonality: Optional yearly seasonality configuration. + When None or amplitude=0, no yearly multiplier is applied. + changepoints: Optional list of demand changepoints. When None or + empty, no changepoint multiplier is applied. + substitution: Optional substitution configuration. When None or + disabled, no substitution lift is applied. + exogenous_weather: Optional lookup ``{(store_id, date): temp_c}``. + Each entry shifts demand by + ``weather_temperature_sensitivity * (temp_c - climatology_mean_c)`` + fraction (i.e. linear, centered on the climatology mean). + When None, no weather effect. + weather_temperature_sensitivity: Demand delta per °C above the + climatology mean (used only when ``exogenous_weather`` is set). + weather_climatology_mean_c: Reference temperature for the linear + weather term. + lifecycle: Optional Phase 2 ``LifecycleGenerator``. When set + and ``lifecycle.enabled``, the per-(product, date) demand + multiplier from intro/growth/maturity/decline/discontinued + curves is applied. Also supersedes the pre-Phase-2 + ``retail_config.new_product_ramp_days`` linear ramp — + that ramp is suppressed when ``lifecycle.enabled`` so + effects do not stack. + channels: Optional Phase 2 ``ChannelConfig``. When set and + ``enable_multichannel``, each emitted row's ``channel`` + column is drawn from ``channel_mix``. During promos the + effective mix shifts per + ``online_substitution_to_instore`` and online rows pick + up ``online_promo_uplift``. Consumes one rng draw per + emitted row when enabled; zero draws when disabled. """ self.rng = rng self.ts_config = time_series_config self.retail_config = retail_config self.sparsity_config = sparsity_config self.holiday_map = {h.date: h.multiplier for h in holidays} + self.multi_seasonality = multi_seasonality + self.changepoints = changepoints + self.substitution = substitution + self.exogenous_weather = exogenous_weather + self.weather_sensitivity = weather_temperature_sensitivity + self.weather_climatology_mean_c = weather_climatology_mean_c + self.lifecycle = lifecycle + self.channels = channels + + # Pre-compute substitution group memberships for O(1) lookup. + self._substitution_groups_by_product: dict[int, list[list[int]]] = {} + if self.substitution is not None and self.substitution.enable: + for group in self.substitution.substitute_groups: + for product_id in group: + self._substitution_groups_by_product.setdefault(product_id, []).append(group) + + def _yearly_seasonality_multiplier(self, current_date: date) -> float: + """Return the yearly seasonality multiplier for ``current_date``. + + Returns 1.0 when multi-seasonality is unset or amplitude is 0 — that + preserves the pre-Phase-1 output byte-for-byte. + """ + if ( + self.multi_seasonality is None + or self.multi_seasonality.yearly_seasonality_amplitude == 0.0 + ): + return 1.0 + day_of_year = current_date.timetuple().tm_yday + offset = self.multi_seasonality.yearly_phase_offset_days + phase = 2.0 * math.pi * (day_of_year + offset) / 365.0 + return 1.0 + self.multi_seasonality.yearly_seasonality_amplitude * math.sin(phase) + + def _changepoint_multiplier(self, current_date: date) -> float: + """Aggregate multiplier from all changepoints active on ``current_date``. + + Each changepoint contributes ``(multiplier - 1) * exp(-Δ/decay)`` if + ``current_date >= changepoint.date`` and 0 otherwise. The total + multiplier is ``1 + Σ contributions``. + + Returns 1.0 when there are no changepoints — preserving byte-identical + output for callers that don't opt in. + """ + if self.changepoints is None or not self.changepoints.changepoints: + return 1.0 + contribution = 0.0 + for cp in self.changepoints.changepoints: + delta_days = (current_date - cp.date).days + if delta_days < 0: + continue + if cp.decay_days <= 0: + # Pure impulse on the changepoint date only. + if delta_days == 0: + contribution += cp.demand_multiplier - 1.0 + continue + decay = math.exp(-delta_days / cp.decay_days) + contribution += (cp.demand_multiplier - 1.0) * decay + return 1.0 + contribution + + def _weather_multiplier(self, current_date: date, store_id: int) -> float: + """Linear weather effect centered on the climatology mean. + + Returns 1.0 when no weather data is configured. + """ + if self.exogenous_weather is None or self.weather_sensitivity == 0.0: + return 1.0 + temp_c = self.exogenous_weather.get((store_id, current_date)) + if temp_c is None: + return 1.0 + return 1.0 + self.weather_sensitivity * (temp_c - self.weather_climatology_mean_c) + + # ---------------------------------------------------------------- # + # Phase 2 channel helpers + # ---------------------------------------------------------------- # + + def _validate_channels(self) -> None: + """Validate ``ChannelConfig`` at the start of ``generate()``. + + No-op when channels is unset or disabled — keeps the + regression invariant intact (no extra work when off). + """ + if self.channels is None or not self.channels.enable_multichannel: + return + cfg = self.channels + if not cfg.channel_mix: + raise ValueError("ChannelConfig.channel_mix must be non-empty when enabled") + invalid = set(cfg.channel_mix.keys()) - _VALID_CHANNELS + if invalid: + raise ValueError( + f"ChannelConfig.channel_mix contains invalid channels {sorted(invalid)}; " + f"allow-list is {sorted(_VALID_CHANNELS)}" + ) + for name, weight in cfg.channel_mix.items(): + if weight < 0: + raise ValueError(f"ChannelConfig.channel_mix['{name}']={weight} must be >= 0") + if sum(cfg.channel_mix.values()) <= 0: + raise ValueError("ChannelConfig.channel_mix must have at least one positive weight") + if cfg.online_promo_uplift < 0: + raise ValueError( + f"ChannelConfig.online_promo_uplift={cfg.online_promo_uplift} must be >= 0" + ) + if not 0.0 <= cfg.online_substitution_to_instore <= 1.0: + raise ValueError( + "ChannelConfig.online_substitution_to_instore=" + f"{cfg.online_substitution_to_instore} must be in [0, 1]" + ) + + def _effective_channel_mix(self, is_promotion: bool) -> dict[str, float]: + """Build the effective channel mix for a single row. + + Applies ``online_substitution_to_instore`` as a weight shift + from ``in_store`` to ``online`` when a promo is active. Returns + the pristine mix otherwise. + """ + if self.channels is None: + return {} + mix = dict(self.channels.channel_mix) + if not is_promotion: + return mix + if "online" not in mix or "in_store" not in mix: + return mix + sub = self.channels.online_substitution_to_instore + if sub <= 0: + return mix + shift = mix["online"] * sub + mix["online"] += shift + mix["in_store"] = max(0.0, mix["in_store"] - shift) + return mix + + def _maybe_apply_channel( + self, + quantity: int, + is_promotion: bool, + ) -> tuple[int, str | None]: + """Choose a channel and apply per-channel uplift. + + Returns ``(unchanged_quantity, None)`` when the channel feature + is disabled — caller emits no ``channel`` key and the DB + ``server_default='in_store'`` applies, preserving the + byte-identical regression invariant. + + When enabled, draws a channel from the effective mix and + multiplies ``quantity`` by ``online_promo_uplift`` for online + rows on promo dates. Consumes exactly one rng draw per call. + """ + if self.channels is None or not self.channels.enable_multichannel: + return quantity, None + mix = self._effective_channel_mix(is_promotion) + if not mix or sum(mix.values()) <= 0: + return quantity, None + names = list(mix.keys()) + weights = list(mix.values()) + chosen = self.rng.choices(names, weights=weights, k=1)[0] + if chosen == "online" and is_promotion: + quantity = max(0, round(quantity * self.channels.online_promo_uplift)) + return quantity, chosen + + def _substitution_multiplier( + self, + product_id: int, + stockouts_today: set[int], + ) -> float: + """Lift demand for ``product_id`` when stocked-out group-mates exist. + + ``stockouts_today`` is the set of product IDs stocked out on the + current date at the same store. For each substitution group the + product belongs to, we count how many other members are stocked out + and distribute ``substitution_lift_on_stockout`` across the surviving + in-stock members. + + Returns 1.0 when substitution is disabled or no group-mate is out. + """ + if ( + self.substitution is None + or not self.substitution.enable + or self.substitution.substitution_lift_on_stockout == 0.0 + ): + return 1.0 + groups = self._substitution_groups_by_product.get(product_id) + if not groups: + return 1.0 + if product_id in stockouts_today: + return 1.0 # A stocked-out product can't pick up lift. + + contribution = 0.0 + for group in groups: + out_members = sum( + 1 for member in group if member != product_id and member in stockouts_today + ) + survivors = sum( + 1 for member in group if member != product_id and member not in stockouts_today + ) + if out_members == 0 or survivors == 0: + # Either no group-mate is out, or we'd divide by zero (e.g. + # everyone but this product is out — in that case give all + # the lift to this product). + if out_members > 0 and survivors == 0: + contribution += self.substitution.substitution_lift_on_stockout * out_members + continue + # Each out member's lift is split among (survivors + 1) including + # this product, so this product captures one share per out member. + contribution += ( + self.substitution.substitution_lift_on_stockout * out_members / (survivors + 1) + ) + return 1.0 + contribution def _compute_demand( self, @@ -52,6 +309,10 @@ def _compute_demand( is_promotion: bool, is_stockout: bool, product_launch_date: date | None, + store_id: int | None = None, + product_id: int | None = None, + stockouts_today_for_store: set[int] | None = None, + product_discontinue_date: date | None = None, ) -> int: """Compute demand for a single observation. @@ -63,6 +324,12 @@ def _compute_demand( is_promotion: Whether there's an active promotion. is_stockout: Whether there's a stockout. product_launch_date: Optional launch date for new product ramp. + store_id: Store ID (used only by Phase 1 weather + substitution + effects). Required when those features are enabled. + product_id: Product ID (used only by Phase 1 substitution). + Required when substitution is enabled. + stockouts_today_for_store: Set of product IDs stocked out at + ``store_id`` on ``current_date``. Used only by substitution. Returns: Computed demand quantity (non-negative integer). @@ -102,8 +369,12 @@ def _compute_demand( price_change_pct = float((current_price - base_price) / base_price) demand *= 1 + (self.retail_config.price_elasticity * price_change_pct) - # Apply new product ramp - if product_launch_date is not None: + # Apply legacy new-product ramp (pre-Phase-2). Suppressed when a + # Phase 2 ``LifecycleGenerator`` is enabled — the lifecycle + # multiplier already encodes a richer launch curve and stacking + # both would double the launch suppression. + legacy_ramp_on = self.lifecycle is None or not self.lifecycle.enabled + if product_launch_date is not None and legacy_ramp_on: days_since_launch = (current_date - product_launch_date).days ramp_days = self.retail_config.new_product_ramp_days if ramp_days > 0 and days_since_launch < ramp_days: @@ -111,6 +382,21 @@ def _compute_demand( demand *= ramp_factor # If ramp_days == 0, skip ramp calculation (demand unchanged) + # Phase 1 multipliers (each returns 1.0 when its feature is off). + demand *= self._yearly_seasonality_multiplier(current_date) + demand *= self._changepoint_multiplier(current_date) + if store_id is not None: + demand *= self._weather_multiplier(current_date, store_id) + if product_id is not None and stockouts_today_for_store is not None: + demand *= self._substitution_multiplier(product_id, stockouts_today_for_store) + + # Phase 2 lifecycle multiplier. Gated on ``lifecycle.enabled`` so + # the disabled path is byte-identical with pre-Phase-2 callers. + if self.lifecycle is not None and self.lifecycle.enabled: + demand *= self.lifecycle.multiplier_for( + current_date, product_launch_date, product_discontinue_date + ) + # Apply noise if self.ts_config.noise_sigma > 0: noise = self.rng.gauss(0, self.ts_config.noise_sigma) @@ -133,6 +419,7 @@ def generate( dates: list[date], promotions: dict[tuple[int, int], set[date]], # (store_id, product_id) -> promo dates stockouts: dict[tuple[int, int], set[date]], # (store_id, product_id) -> stockout dates + product_lifecycle_data: dict[int, tuple[date | None, date | None]] | None = None, ) -> list[dict[str, date | int | Decimal]]: """Generate sales daily records. @@ -142,10 +429,17 @@ def generate( dates: List of dates in the range. promotions: Mapping of (store_id, product_id) to promotion dates. stockouts: Mapping of (store_id, product_id) to stockout dates. + product_lifecycle_data: Optional Phase 2 mapping + ``product_id -> (launch_date, discontinue_date)``. Only + consulted when a ``LifecycleGenerator`` was passed to + :meth:`__init__`. Missing entries fall back to + ``(None, None)`` so the lifecycle multiplier evaluates to + 1.0 for that product. Returns: List of sales dictionaries ready for database insertion. """ + self._validate_channels() sales: list[dict[str, date | int | Decimal]] = [] base_date = dates[0] if dates else date(2024, 1, 1) @@ -182,6 +476,15 @@ def generate( gaps.add(dates[gap_start_idx + i]) gap_dates[key] = gaps + # Phase 1: per-(store, date) lookup of stocked-out product IDs for + # substitution. Only build it when substitution is enabled — keeps + # the disabled-path byte-identical with pre-Phase-1. + stockouts_by_store_date: dict[tuple[int, date], set[int]] = {} + if self.substitution is not None and self.substitution.enable: + for (s_id, p_id), out_dates in stockouts.items(): + for d in out_dates: + stockouts_by_store_date.setdefault((s_id, d), set()).add(p_id) + # Generate sales for each active combination and date for store_id in store_ids: for product_id, base_price in product_data: @@ -194,6 +497,16 @@ def generate( promo_dates = promotions.get(key, set()) stockout_dates = stockouts.get(key, set()) series_gaps = gap_dates.get(key, set()) + # Phase 2 lifecycle lookup — defaults keep the disabled + # path byte-identical (legacy callers pass ``None`` for + # ``product_lifecycle_data``, so this is always ``(None, + # None)`` and the multiplier evaluates to 1.0). + lifecycle_dates = ( + product_lifecycle_data.get(product_id, (None, None)) + if product_lifecycle_data is not None + else (None, None) + ) + launch_date_for_product, discontinue_date_for_product = lifecycle_dates for current_date in dates: # Skip gap dates @@ -203,6 +516,8 @@ def generate( is_promotion = current_date in promo_dates is_stockout = current_date in stockout_dates + stockouts_today = stockouts_by_store_date.get((store_id, current_date)) + quantity = self._compute_demand( current_date=current_date, base_date=base_date, @@ -210,27 +525,36 @@ def generate( current_price=None, # Simplified: use base price is_promotion=is_promotion, is_stockout=is_stockout, - product_launch_date=None, # Could be extended + product_launch_date=launch_date_for_product, + store_id=store_id, + product_id=product_id, + stockouts_today_for_store=stockouts_today, + product_discontinue_date=discontinue_date_for_product, ) - # Skip zero sales from stockouts to reduce data volume + # Skip zero sales from stockouts to reduce data volume. + # Channel rng is drawn only for emitted rows so the + # channel stream is stable per emitted-row across runs. if quantity == 0 and is_stockout: continue + quantity, chosen_channel = self._maybe_apply_channel(quantity, is_promotion) + # Calculate total amount unit_price = base_price total_amount = unit_price * quantity - sales.append( - { - "date": current_date, - "store_id": store_id, - "product_id": product_id, - "quantity": quantity, - "unit_price": unit_price, - "total_amount": total_amount, - } - ) + row: dict[str, date | int | Decimal | str] = { + "date": current_date, + "store_id": store_id, + "product_id": product_id, + "quantity": quantity, + "unit_price": unit_price, + "total_amount": total_amount, + } + if chosen_channel is not None: + row["channel"] = chosen_channel + sales.append(row) # type: ignore[arg-type] return sales diff --git a/app/shared/seeder/generators/lifecycle.py b/app/shared/seeder/generators/lifecycle.py new file mode 100644 index 00000000..710fac90 --- /dev/null +++ b/app/shared/seeder/generators/lifecycle.py @@ -0,0 +1,125 @@ +"""Phase 2 product-lifecycle demand multiplier. + +This module is pure compute — no DB writes. ``SalesDailyGenerator`` calls +``LifecycleGenerator.multiplier_for`` once per (product, date) to apply +intro / growth / maturity / decline / discontinued shaping on top of the +base demand math. When the feature is disabled the call returns 1.0 +without consuming any rng state, preserving the byte-identical +regression invariant. +""" + +from __future__ import annotations + +from datetime import date +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.shared.seeder.config import LifecycleConfig + + +class LifecycleGenerator: + """Compute per-(product, date) demand multipliers from lifecycle stages. + + The model has four ramp segments stitched together: + + - ``intro`` — linear ramp from ``intro_multiplier`` to 1.0 over + ``intro_ramp_days`` starting at ``launch_date``. + - ``growth`` — held at 1.0 for ``growth_ramp_days``. + - ``maturity`` — held at 1.0 for ``maturity_steady_days``. + - ``decline`` — exponential decay toward ``decline_multiplier`` with + e-folding time ``decline_decay_days``. + + After ``discontinue_date`` the multiplier is forced to 0 regardless of + the curves, modelling a hard end-of-life. + """ + + def __init__(self, config: LifecycleConfig | None) -> None: + """Initialize the lifecycle generator. + + Args: + config: Phase 2 lifecycle configuration. When ``None`` or + ``enable=False`` every call returns 1.0. + """ + self.config = config + + @property + def enabled(self) -> bool: + return self.config is not None and self.config.enable + + def stage_for( + self, + current_date: date, + launch_date: date | None, + discontinue_date: date | None, + ) -> str: + """Return the lifecycle stage label for ``current_date``. + + The label is one of ``intro|growth|maturity|decline|discontinued``. + When the generator is disabled or ``launch_date`` is missing we + return ``"maturity"`` as a neutral default that produces a + multiplier of 1.0 in :meth:`multiplier_for`. + """ + if not self.enabled or launch_date is None or self.config is None: + return "maturity" + if discontinue_date is not None and current_date >= discontinue_date: + return "discontinued" + days_since_launch = (current_date - launch_date).days + if days_since_launch < 0: + # Product hasn't launched yet — treat as discontinued so demand + # collapses to 0 in `multiplier_for`. + return "discontinued" + cfg = self.config + boundary_intro = cfg.intro_ramp_days + boundary_growth = boundary_intro + cfg.growth_ramp_days + boundary_maturity = boundary_growth + cfg.maturity_steady_days + if days_since_launch < boundary_intro: + return "intro" + if days_since_launch < boundary_growth: + return "growth" + if days_since_launch < boundary_maturity: + return "maturity" + return "decline" + + def multiplier_for( + self, + current_date: date, + launch_date: date | None, + discontinue_date: date | None, + ) -> float: + """Return the lifecycle demand multiplier for ``current_date``. + + Returns 1.0 when the generator is disabled or the product has no + launch date — preserving pre-Phase-2 output byte-for-byte for + callers that don't opt in. + """ + if not self.enabled or launch_date is None or self.config is None: + return 1.0 + cfg = self.config + # Hard end-of-life after discontinue. + if discontinue_date is not None and current_date >= discontinue_date: + return 0.0 + days_since_launch = (current_date - launch_date).days + if days_since_launch < 0: + return 0.0 # Not yet launched. + boundary_intro = cfg.intro_ramp_days + boundary_growth = boundary_intro + cfg.growth_ramp_days + boundary_maturity = boundary_growth + cfg.maturity_steady_days + + if days_since_launch < boundary_intro: + # Linear ramp intro_multiplier -> 1.0. + if cfg.intro_ramp_days <= 0: + return 1.0 + t = days_since_launch / cfg.intro_ramp_days + return cfg.intro_multiplier + (1.0 - cfg.intro_multiplier) * t + if days_since_launch < boundary_maturity: + # Growth + maturity held at 1.0. + return 1.0 + # Decline: exponential decay from 1.0 toward decline_multiplier. + days_in_decline = days_since_launch - boundary_maturity + if cfg.decline_decay_days <= 0: + return cfg.decline_multiplier + # m(t) = decline + (1 - decline) * exp(-t / tau) + import math # local import keeps top-of-module surface lean + + decay = math.exp(-days_in_decline / cfg.decline_decay_days) + return cfg.decline_multiplier + (1.0 - cfg.decline_multiplier) * decay diff --git a/app/shared/seeder/generators/markdowns.py b/app/shared/seeder/generators/markdowns.py new file mode 100644 index 00000000..62414e79 --- /dev/null +++ b/app/shared/seeder/generators/markdowns.py @@ -0,0 +1,345 @@ +"""Phase 2 markdown (clearance) generator. + +Emits ``Promotion(kind='markdown')`` rows + companion ``PriceHistory`` +drop rows for two trigger modes: + +- ``lifecycle_decline`` — fires chain-wide (``store_id=None``) on the + first day a product enters the decline stage according to a + ``LifecycleGenerator``. +- ``stockout_risk`` — fires per-``(store, product)`` ending the day + before each observed stockout, with a window of + ``markdown_duration_days``. + +The ``age_days`` trigger is deferred to a follow-up; see issue #94. +``MarkdownGenerator`` raises ``NotImplementedError`` for that mode. + +Disabled path (``MarkdownConfig`` is ``None`` or ``enable=False``) +returns empty containers and consumes zero rng state, preserving the +byte-identical regression invariant. The generator is currently +deterministic: even the enabled path issues no rng draws. The ``rng`` +parameter is kept for API consistency with peer Phase 2 generators +in case future variants need randomness. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from app.shared.seeder.config import MarkdownConfig + from app.shared.seeder.generators.lifecycle import LifecycleGenerator + + +# ``promotion.discount_pct`` is ``Numeric(5, 4)``; ``price_history.price`` +# is ``Numeric(10, 2)``. +_PCT_QUANTIZE = Decimal("0.0001") +_PRICE_QUANTIZE = Decimal("0.01") + + +class MarkdownGenerator: + """Emit markdown promo + price-history rows. + + The orchestration layer (DataSeeder) is responsible for wiring this + generator's output to ``SalesDailyGenerator``'s ``markdown_dates`` + lookup so demand picks up the configured ``markdown_demand_lift`` + over markdown windows. + """ + + def __init__(self, rng: random.Random, config: MarkdownConfig | None) -> None: + """Initialize the markdown generator. + + Args: + rng: Random number generator. Reserved for future variants; + the current implementation is fully deterministic. + config: Phase 2 markdown configuration. When ``None`` or + ``enable=False`` :meth:`generate` returns empty + containers. + """ + self.rng = rng + self.config = config + + @property + def enabled(self) -> bool: + return self.config is not None and self.config.enable + + def generate( + self, + product_specs: list[dict[str, Any]], + store_ids: list[int], + stockout_dates: dict[tuple[int, int], set[date]], + dates: list[date], + lifecycle: LifecycleGenerator | None = None, + ) -> tuple[ + list[dict[str, Any]], + list[dict[str, Any]], + dict[tuple[int, int], set[date]], + ]: + """Generate markdown promotions, price drops, and markdown_dates lookup. + + Args: + product_specs: List of ``{"product_id", "base_price", + "launch_date" (optional), "discontinue_date" (optional)}`` + dicts. Lifecycle dates are only needed for + ``trigger='lifecycle_decline'``. + store_ids: All store IDs in the seeded scenario. Used to + populate ``markdown_dates`` for chain-wide markdowns so + downstream ``SalesDailyGenerator`` can apply the demand + lift uniformly across the chain. + stockout_dates: ``(store_id, product_id) -> set[date]`` + from ``InventorySnapshotGenerator``. Used only for + ``trigger='stockout_risk'``. + dates: Full ordered list of dates in the seeded range. + ``markdown_start`` is clamped to ``dates[0]`` when the + computed start would precede the range. + lifecycle: Optional pre-built ``LifecycleGenerator``. Used + only for ``trigger='lifecycle_decline'``. When absent + or disabled, the lifecycle trigger emits no rows. + + Returns: + Three-tuple: + - ``promo_records``: ``Promotion(kind='markdown')`` dicts. + - ``price_history_records``: ``PriceHistory`` rows + carrying the discounted price over the markdown window. + - ``markdown_dates``: ``(store_id, product_id) -> set[date]`` + lookup of every active markdown day, useful for the + ``SalesDailyGenerator`` lift integration. + + Raises: + NotImplementedError: If ``config.trigger == 'age_days'``. + Tracked at issue #94. + ValueError: If ``markdown_depth_pct`` is outside ``[0, 1]`` + or ``markdown_duration_days < 1``. + """ + if not self.enabled or self.config is None: + return ([], [], {}) + + cfg = self.config + if cfg.trigger == "age_days": + raise NotImplementedError( + "MarkdownConfig.trigger='age_days' is deferred. See follow-up " + "issue #94 for the implementation plan." + ) + if not 0.0 <= cfg.markdown_depth_pct <= 1.0: + raise ValueError(f"markdown_depth_pct must be in [0, 1], got {cfg.markdown_depth_pct}") + if cfg.markdown_duration_days < 1: + raise ValueError( + f"markdown_duration_days must be >= 1, got {cfg.markdown_duration_days}" + ) + + promo_records: list[dict[str, Any]] = [] + price_history_records: list[dict[str, Any]] = [] + markdown_dates: dict[tuple[int, int], set[date]] = {} + + if cfg.trigger == "lifecycle_decline": + self._emit_lifecycle_decline( + cfg=cfg, + product_specs=product_specs, + store_ids=store_ids, + dates=dates, + lifecycle=lifecycle, + promo_records=promo_records, + price_history_records=price_history_records, + markdown_dates=markdown_dates, + ) + else: # cfg.trigger == "stockout_risk" + self._emit_stockout_risk( + cfg=cfg, + product_specs=product_specs, + stockout_dates=stockout_dates, + dates=dates, + promo_records=promo_records, + price_history_records=price_history_records, + markdown_dates=markdown_dates, + ) + + return (promo_records, price_history_records, markdown_dates) + + # ---------------------------------------------------------------------- # + # Trigger implementations + # ---------------------------------------------------------------------- # + + def _emit_lifecycle_decline( + self, + *, + cfg: MarkdownConfig, + product_specs: list[dict[str, Any]], + store_ids: list[int], + dates: list[date], + lifecycle: LifecycleGenerator | None, + promo_records: list[dict[str, Any]], + price_history_records: list[dict[str, Any]], + markdown_dates: dict[tuple[int, int], set[date]], + ) -> None: + if lifecycle is None or not lifecycle.enabled or not dates: + return # Cannot detect decline without a lifecycle source. + + discount_pct = Decimal(str(cfg.markdown_depth_pct)).quantize(_PCT_QUANTIZE) + + for spec in product_specs: + launch = spec.get("launch_date") + if launch is None: + continue + discontinue = spec.get("discontinue_date") + decline_start = self._first_decline_date( + dates=dates, + launch=launch, + discontinue=discontinue, + lifecycle=lifecycle, + ) + if decline_start is None: + continue + + md_end = min( + decline_start + timedelta(days=cfg.markdown_duration_days - 1), + dates[-1], + ) + base_price = self._as_decimal(spec["base_price"]) + markdown_price = (base_price * (Decimal("1") - discount_pct)).quantize(_PRICE_QUANTIZE) + product_id = int(spec["product_id"]) + + promo_records.append( + { + "product_id": product_id, + "store_id": None, # chain-wide + "name": "Lifecycle Clearance", + "kind": "markdown", + "discount_pct": discount_pct, + "discount_amount": None, + "bundle_member_product_ids": None, + "start_date": decline_start, + "end_date": md_end, + } + ) + price_history_records.append( + { + "product_id": product_id, + "store_id": None, + "price": markdown_price, + "valid_from": decline_start, + "valid_to": md_end, + } + ) + # Chain-wide markdown: every store sees the lift. + for sid in store_ids: + self._fill_date_range( + markdown_dates.setdefault((sid, product_id), set()), + decline_start, + md_end, + ) + + def _emit_stockout_risk( + self, + *, + cfg: MarkdownConfig, + product_specs: list[dict[str, Any]], + stockout_dates: dict[tuple[int, int], set[date]], + dates: list[date], + promo_records: list[dict[str, Any]], + price_history_records: list[dict[str, Any]], + markdown_dates: dict[tuple[int, int], set[date]], + ) -> None: + if not dates: + return + + discount_pct = Decimal(str(cfg.markdown_depth_pct)).quantize(_PCT_QUANTIZE) + first_date = dates[0] + # Precompute base_price lookup for O(1) access. + price_by_product: dict[int, Decimal] = { + int(spec["product_id"]): self._as_decimal(spec["base_price"]) for spec in product_specs + } + + # Sort keys for deterministic output order regardless of dict + # iteration order. + for key in sorted(stockout_dates.keys()): + store_id, product_id = key + base_price = price_by_product.get(product_id) + if base_price is None: + continue + + markdown_price = (base_price * (Decimal("1") - discount_pct)).quantize(_PRICE_QUANTIZE) + last_md_end: date | None = None + + for stockout_date in sorted(stockout_dates[key]): + md_end = stockout_date - timedelta(days=1) + md_start = md_end - timedelta(days=cfg.markdown_duration_days - 1) + if md_start < first_date: + md_start = first_date + if md_end < md_start: + continue # stockout on/before first date; no markdown room + # Dedupe overlapping windows by collapsing into the most + # recent. ``sorted_stockouts`` guarantees forward iteration. + if last_md_end is not None and md_start <= last_md_end: + continue + + promo_records.append( + { + "product_id": product_id, + "store_id": store_id, + "name": "Stockout Clearance", + "kind": "markdown", + "discount_pct": discount_pct, + "discount_amount": None, + "bundle_member_product_ids": None, + "start_date": md_start, + "end_date": md_end, + } + ) + price_history_records.append( + { + "product_id": product_id, + "store_id": store_id, + "price": markdown_price, + "valid_from": md_start, + "valid_to": md_end, + } + ) + self._fill_date_range( + markdown_dates.setdefault(key, set()), + md_start, + md_end, + ) + last_md_end = md_end + + # ---------------------------------------------------------------------- # + # Helpers + # ---------------------------------------------------------------------- # + + @staticmethod + def _first_decline_date( + *, + dates: list[date], + launch: date, + discontinue: date | None, + lifecycle: LifecycleGenerator, + ) -> date | None: + """Return the earliest date in ``dates`` where the product is in decline.""" + for d in dates: + if lifecycle.stage_for(d, launch, discontinue) == "decline": + return d + return None + + @staticmethod + def _fill_date_range( + bucket: set[date], + start: date, + end: date, + ) -> None: + current = start + while current <= end: + bucket.add(current) + current += timedelta(days=1) + + @staticmethod + def _as_decimal(value: Decimal | int | float | str) -> Decimal: + """Coerce numeric input to ``Decimal``. + + Product specs may carry ``Decimal``, ``int``, ``float`` or numeric + strings depending on upstream provenance. Coercion through ``str`` + avoids binary-float artefacts. + """ + if isinstance(value, Decimal): + return value + return Decimal(str(value)) diff --git a/app/shared/seeder/generators/product.py b/app/shared/seeder/generators/product.py index 84477729..2352a21b 100644 --- a/app/shared/seeder/generators/product.py +++ b/app/shared/seeder/generators/product.py @@ -3,11 +3,12 @@ from __future__ import annotations import random +from datetime import date, timedelta from decimal import Decimal -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from app.shared.seeder.config import DimensionConfig + from app.shared.seeder.config import DimensionConfig, LifecycleConfig # Product name components for realistic generation @@ -120,24 +121,46 @@ class ProductGenerator: - """Generator for product dimension data.""" + """Generator for product dimension data. + + Phase 2 ``lifecycle`` is opt-in: when ``LifecycleConfig.enable=False`` + (default) the generator emits byte-identical output to its + pre-Phase-2 behavior — no extra rng draws, no extra dict keys. + """ # Maximum SKU space: 10000-99999 = 90,000 unique SKUs MAX_SKU_SPACE = 90000 MAX_SKU_ATTEMPTS = 1000 - def __init__(self, rng: random.Random, config: DimensionConfig) -> None: + # Discrete pack-size distribution sampled when lifecycle is enabled. + _PACK_SIZE_CHOICES = (1, 1, 1, 2, 4, 6, 12) + + def __init__( + self, + rng: random.Random, + config: DimensionConfig, + lifecycle_config: LifecycleConfig | None = None, + date_range: tuple[date, date] | None = None, + ) -> None: """Initialize the product generator. Args: rng: Random number generator for reproducibility. config: Dimension configuration. + lifecycle_config: Optional Phase 2 lifecycle configuration. When + ``None`` or ``enable=False`` the generator emits pre-Phase-2 + rows byte-identically. + date_range: ``(start_date, end_date)`` of the seeded range. Used + only when ``lifecycle_config.enable`` is True so launch dates + land near or before ``start_date``. Raises: ValueError: If requested products exceed available SKU space. """ self.rng = rng self.config = config + self.lifecycle_config = lifecycle_config + self.date_range = date_range self._used_skus: set[str] = set() # Validate SKU space capacity @@ -219,20 +242,92 @@ def _generate_price(self) -> tuple[Decimal, Decimal]: return base_price, base_cost - def generate(self) -> list[dict[str, str | Decimal | None]]: + def _generate_lifecycle_attrs(self, category: str) -> dict[str, Any]: + """Sample lifecycle attributes for a single product. + + Returns the additive dict appended to the product row when + ``lifecycle_config.enable`` is True. The order of rng calls inside + this method must remain stable across Phase 2 sub-releases — any + rearrangement would shift downstream rng state and break + reproducibility for Phase-2-enabled scenarios. + """ + cfg = self.lifecycle_config + if cfg is None or not cfg.enable or self.date_range is None: + raise RuntimeError( + "_generate_lifecycle_attrs called with lifecycle disabled — " + "the caller must gate on lifecycle_config.enable." + ) + start, end = self.date_range + + # Sub-category: re-use the category->nouns map so subcategory ties to + # category coherently. Keeps the seeded corpus believable. + nouns = PRODUCT_NOUNS_BY_CATEGORY.get(category, DEFAULT_NOUNS) + subcategory = self.rng.choice(nouns) + + # Pack size: discrete distribution that favours single units. + pack_size = self.rng.choice(self._PACK_SIZE_CHOICES) + + # Launch date: uniform between (start - 90d) and (start + 60d). + # Products with launch_date < start are already mature by start; those + # within [start, start+60d] are in intro/growth during the seeded + # window. + launch_offset_min = -90 + launch_offset_max = 60 + launch_offset_days = self.rng.randint(launch_offset_min, launch_offset_max) + launch_date = start + timedelta(days=launch_offset_days) + + # Discontinue date: small probability the product is retired during + # the range. Always after launch_date. + discontinue_date: date | None = None + roll = self.rng.random() + if cfg.discontinue_probability > 0.0 and roll < cfg.discontinue_probability: + min_disc_offset = max( + cfg.intro_ramp_days + cfg.growth_ramp_days + 30, + 30, + ) + disc_offset_days = self.rng.randint(min_disc_offset, max(min_disc_offset, 365)) + candidate = launch_date + timedelta(days=disc_offset_days) + if candidate <= end: + discontinue_date = candidate + + # Initial stage: pick from the allow-list. When auto_progression is + # True, the stage on the row is the *initial* stage at launch_date + # and downstream code re-derives the current stage by date. + if cfg.auto_progression: + lifecycle_stage = "intro" + else: + lifecycle_stage = self.rng.choice(("intro", "growth", "maturity", "decline")) + + return { + "subcategory": subcategory, + "pack_size": pack_size, + "launch_date": launch_date, + "discontinue_date": discontinue_date, + "lifecycle_stage": lifecycle_stage, + } + + def generate(self) -> list[dict[str, Any]]: """Generate product dimension records. Returns: - List of product dictionaries ready for database insertion. + List of product dictionaries ready for database insertion. When + ``lifecycle_config.enable`` is True each row also carries + ``subcategory``, ``pack_size``, ``launch_date``, + ``discontinue_date``, and ``lifecycle_stage``. """ - products: list[dict[str, str | Decimal | None]] = [] + products: list[dict[str, Any]] = [] + lifecycle_on = ( + self.lifecycle_config is not None + and self.lifecycle_config.enable + and self.date_range is not None + ) for _ in range(self.config.products): category = self.rng.choice(self.config.product_categories) brand = self.rng.choice(self.config.product_brands) base_price, base_cost = self._generate_price() - product: dict[str, str | Decimal | None] = { + product: dict[str, Any] = { "sku": self._generate_unique_sku(), "name": self._generate_name(category, brand), "category": category, @@ -240,6 +335,8 @@ def generate(self) -> list[dict[str, str | Decimal | None]]: "base_price": base_price, "base_cost": base_cost, } + if lifecycle_on: + product.update(self._generate_lifecycle_attrs(category)) products.append(product) return products diff --git a/app/shared/seeder/generators/replenishment.py b/app/shared/seeder/generators/replenishment.py new file mode 100644 index 00000000..bb3e66c3 --- /dev/null +++ b/app/shared/seeder/generators/replenishment.py @@ -0,0 +1,172 @@ +"""Phase 2 lead-time-driven replenishment generator. + +Emits ``replenishment_event`` rows that mark each receipt of inbound +stock at a store. Per ``(store, product)`` pair a purchase order is +placed every ``order_frequency_days`` starting from ``dates[0]``. Each +PO has a sampled lead time (Gaussian, clamped to ``>= 0``) and a +sampled fill rate (Gaussian, clamped to ``[0, 1]``); +``date_received = date_placed + lead_time_days`` and ``received_qty = +round(ordered_qty * fill_rate)``. + +Receipts whose computed ``date_received`` would fall past +``dates[-1]`` are dropped, keeping the FK to ``calendar`` valid. + +Downstream coupling: a follow-up commit will adjust +``InventorySnapshotGenerator`` to consume these events so the +realistic stockout windows emerge between scheduled receipts. This +slice only emits the rows. + +Disabled path (``LeadTimeConfig`` is ``None`` or ``enable=False``) +returns ``[]`` and consumes zero rng state, preserving the +byte-identical regression invariant. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from app.shared.seeder.config import LeadTimeConfig + + +class ReplenishmentGenerator: + """Generate replenishment events per ``(store, product)`` PO chain. + + The rng order per PO is locked: ``gauss`` for lead time, then + ``gauss`` for fill rate. Pairs are visited in sorted + ``(store_id, product_id)`` order so the rng stream is stable + regardless of input ordering. + """ + + def __init__(self, rng: random.Random, config: LeadTimeConfig | None) -> None: + """Initialize the replenishment generator. + + Args: + rng: Random number generator for reproducibility. + config: Phase 2 lead-time configuration. When ``None`` or + ``enable=False`` :meth:`generate` returns ``[]`` and + touches no rng state. + """ + self.rng = rng + self.config = config + + @property + def enabled(self) -> bool: + return self.config is not None and self.config.enable + + def generate( + self, + store_ids: list[int], + product_ids: list[int], + dates: list[date], + base_demand: int = 100, + ) -> list[dict[str, Any]]: + """Emit ``replenishment_event`` records. + + Args: + store_ids: All store IDs in the seeded scenario. + product_ids: All product IDs in the seeded scenario. + dates: Ordered list of seeded dates. Used as the calendar + domain for ``date_received``; receipts past + ``dates[-1]`` are skipped to keep the FK to + ``calendar`` valid. + base_demand: Daily demand assumption used to size + ``ordered_qty``. Defaults to 100 so the generator can + stand alone in tests; the orchestration layer should + pass ``TimeSeriesConfig.base_demand``. + + Returns: + List of ``replenishment_event`` dicts with keys ``date``, + ``store_id``, ``product_id``, ``lead_time_days``, + ``ordered_qty``, ``received_qty``. Emitted in + ``(store_id, product_id, date_placed)`` order. + + Raises: + ValueError: On invalid config (negative mean, zero order + frequency, fill_rate_mean outside ``[0, 1]``, etc.). + """ + if not self.enabled or self.config is None: + return [] + + cfg = self.config + self._validate(cfg, base_demand) + + if not dates: + return [] + + start = dates[0] + end = dates[-1] + records: list[dict[str, Any]] = [] + po_window_days = cfg.order_frequency_days + cfg.safety_stock_days + ordered_qty = max(0, base_demand * po_window_days) + + # Sort to make the rng stream stable regardless of caller order. + for store_id in sorted(store_ids): + for product_id in sorted(product_ids): + self._generate_chain( + cfg=cfg, + store_id=store_id, + product_id=product_id, + start=start, + end=end, + ordered_qty=ordered_qty, + records=records, + ) + return records + + def _generate_chain( + self, + *, + cfg: LeadTimeConfig, + store_id: int, + product_id: int, + start: date, + end: date, + ordered_qty: int, + records: list[dict[str, Any]], + ) -> None: + """Emit PO chain for a single ``(store, product)``.""" + date_placed = start + while date_placed <= end: + lead_time_days = max( + 0, + round(self.rng.gauss(cfg.mean_lead_time_days, cfg.lead_time_sigma_days)), + ) + fill_rate = self.rng.gauss(cfg.fill_rate_mean, cfg.fill_rate_sigma) + fill_rate = min(1.0, max(0.0, fill_rate)) + date_received = date_placed + timedelta(days=lead_time_days) + if date_received <= end: + received_qty = round(ordered_qty * fill_rate) + # Defensive clamp — protects ``ck_replenishment_event_*`` even + # under pathological fill-rate samples that round to > 1. + received_qty = max(0, min(received_qty, ordered_qty)) + records.append( + { + "date": date_received, + "store_id": store_id, + "product_id": product_id, + "lead_time_days": lead_time_days, + "ordered_qty": ordered_qty, + "received_qty": received_qty, + } + ) + date_placed += timedelta(days=cfg.order_frequency_days) + + @staticmethod + def _validate(cfg: LeadTimeConfig, base_demand: int) -> None: + if cfg.mean_lead_time_days < 0: + raise ValueError(f"mean_lead_time_days must be >= 0, got {cfg.mean_lead_time_days}") + if cfg.lead_time_sigma_days < 0: + raise ValueError(f"lead_time_sigma_days must be >= 0, got {cfg.lead_time_sigma_days}") + if cfg.safety_stock_days < 0: + raise ValueError(f"safety_stock_days must be >= 0, got {cfg.safety_stock_days}") + if cfg.order_frequency_days < 1: + raise ValueError(f"order_frequency_days must be >= 1, got {cfg.order_frequency_days}") + if not 0.0 <= cfg.fill_rate_mean <= 1.0: + raise ValueError(f"fill_rate_mean must be in [0, 1], got {cfg.fill_rate_mean}") + if cfg.fill_rate_sigma < 0: + raise ValueError(f"fill_rate_sigma must be >= 0, got {cfg.fill_rate_sigma}") + if base_demand < 0: + raise ValueError(f"base_demand must be >= 0, got {base_demand}") diff --git a/app/shared/seeder/generators/returns.py b/app/shared/seeder/generators/returns.py new file mode 100644 index 00000000..96432d3a --- /dev/null +++ b/app/shared/seeder/generators/returns.py @@ -0,0 +1,121 @@ +"""Returns generator: synthetic ``sales_returns`` rows. + +Phase 1 of the seeder realism extension. Samples sales rows +(probabilistically) and emits a delayed return event for each pick. The +return is *not* subtracted from ``sales_daily.quantity`` — returns are an +additive, separately queryable table so the forecasting/feature layer can +opt in. + +Output schema matches ``app.features.data_platform.models.SalesReturn``: + + {"date", "store_id", "product_id", "return_quantity", "return_reason"} + +Reproducibility: uses the seeder's ``random.Random`` instance. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.shared.seeder.config import ReturnsConfig + + +class ReturnsGenerator: + """Generator for synthetic sales returns.""" + + def __init__(self, rng: random.Random, config: ReturnsConfig) -> None: + """Initialize the generator. + + Args: + rng: Seeded random number generator. + config: Returns configuration. + """ + self.rng = rng + self.config = config + + def _pick_reason(self) -> str: + """Sample a return reason from the configured distribution. + + Returns: + Reason string. Falls back to ``"unspecified"`` if the + distribution is empty (defensive — config defaults are non-empty). + """ + reasons = list(self.config.return_reason_distribution.keys()) + weights = list(self.config.return_reason_distribution.values()) + if not reasons: + return "unspecified" + # random.choices is deterministic under self.rng. + return self.rng.choices(reasons, weights=weights, k=1)[0] + + def generate( + self, + sales_records: list[dict[str, date | int | Decimal]], + end_date: date, + ) -> list[dict[str, date | int | str]]: + """Generate return rows from a list of sales rows. + + Args: + sales_records: Sales dicts from ``SalesDailyGenerator.generate``. + Each must contain ``date``, ``store_id``, ``product_id``, + ``quantity``. + end_date: Calendar end date. Returns lagged beyond ``end_date`` + are clamped to ``end_date`` (so they have a calendar FK + target and don't trigger FK violations). + + Returns: + List of return-row dicts. Empty when the returns feature is + disabled or no sales qualify. + """ + if not self.config.enable or not sales_records: + return [] + + lag_min = self.config.return_lag_days_min + lag_max = max(self.config.return_lag_days_max, lag_min) + + returns: list[dict[str, date | int | str]] = [] + for sale in sales_records: + quantity = sale["quantity"] + sale_date = sale["date"] + store_id = sale["store_id"] + product_id = sale["product_id"] + # Sales rows from SalesDailyGenerator carry these types; the + # union annotation is wider than the runtime guarantees because + # the same dict shape is reused for inserts. Defensive narrowing + # here keeps mypy --strict happy without a cast. + if not ( + isinstance(quantity, int) + and isinstance(sale_date, date) + and isinstance(store_id, int) + and isinstance(product_id, int) + ): + continue + if quantity <= 0: + continue + if self.rng.random() >= self.config.return_probability: + continue + + lag = self.rng.randint(lag_min, lag_max) + return_date = sale_date + timedelta(days=lag) + if return_date > end_date: + return_date = end_date + + # Fraction of original quantity, with a minimum of 1 unit. + raw_qty = quantity * self.config.return_quantity_fraction + return_qty = max(1, round(raw_qty)) + return_qty = min(return_qty, quantity) + + returns.append( + { + "date": return_date, + "store_id": store_id, + "product_id": product_id, + "return_quantity": return_qty, + "return_reason": self._pick_reason(), + } + ) + + return returns diff --git a/app/shared/seeder/tests/test_core.py b/app/shared/seeder/tests/test_core.py index 036db889..944b69a0 100644 --- a/app/shared/seeder/tests/test_core.py +++ b/app/shared/seeder/tests/test_core.py @@ -233,12 +233,16 @@ def seeder(self): async def test_returns_empty_list_when_valid(self, seeder): """Test empty list returned when data is valid.""" mock_db = AsyncMock() - # Create separate mock results for each execute call - # verify_data_integrity makes 4 calls: + # verify_data_integrity now makes 9 execute calls: # 1. orphan check # 2. negative qty check # 3. min/max date check # 4. calendar count + # 5. (Phase 1) sales_returns non-positive check + # 6. (Phase 1) exogenous_signal is_global/store_id consistency + # 7. (Phase 2) bundle/BOGO bundle_member_product_ids non-NULL + # 8. (Phase 2) discontinue_date >= launch_date + # 9. (Phase 2) replenishment received_qty <= ordered_qty mock_result1 = MagicMock() mock_result1.scalar.return_value = 0 # no orphans mock_result2 = MagicMock() @@ -247,8 +251,28 @@ async def test_returns_empty_list_when_valid(self, seeder): mock_result3.fetchone.return_value = (date(2024, 1, 1), date(2024, 1, 31)) mock_result4 = MagicMock() mock_result4.scalar.return_value = 31 # 31 days matches Jan 1-31 - - mock_db.execute.side_effect = [mock_result1, mock_result2, mock_result3, mock_result4] + mock_result5 = MagicMock() + mock_result5.scalar.return_value = 0 # no bad returns + mock_result6 = MagicMock() + mock_result6.scalar.return_value = 0 # no inconsistent exogenous rows + mock_result7 = MagicMock() + mock_result7.scalar.return_value = 0 # no bad bundles + mock_result8 = MagicMock() + mock_result8.scalar.return_value = 0 # no bad lifecycle dates + mock_result9 = MagicMock() + mock_result9.scalar.return_value = 0 # no bad replenishment fills + + mock_db.execute.side_effect = [ + mock_result1, + mock_result2, + mock_result3, + mock_result4, + mock_result5, + mock_result6, + mock_result7, + mock_result8, + mock_result9, + ] errors = await seeder.verify_data_integrity(mock_db) @@ -258,7 +282,6 @@ async def test_returns_empty_list_when_valid(self, seeder): async def test_detects_orphaned_sales(self, seeder): """Test orphaned sales are detected.""" mock_db = AsyncMock() - # Create separate mock results for each execute call mock_result1 = MagicMock() mock_result1.scalar.return_value = 5 # orphan check returns 5 errors mock_result2 = MagicMock() @@ -267,8 +290,28 @@ async def test_detects_orphaned_sales(self, seeder): mock_result3.fetchone.return_value = (date(2024, 1, 1), date(2024, 1, 31)) mock_result4 = MagicMock() mock_result4.scalar.return_value = 31 # calendar count - - mock_db.execute.side_effect = [mock_result1, mock_result2, mock_result3, mock_result4] + mock_result5 = MagicMock() + mock_result5.scalar.return_value = 0 # bad returns + mock_result6 = MagicMock() + mock_result6.scalar.return_value = 0 # inconsistent exogenous + mock_result7 = MagicMock() + mock_result7.scalar.return_value = 0 # bad bundles + mock_result8 = MagicMock() + mock_result8.scalar.return_value = 0 # bad lifecycle dates + mock_result9 = MagicMock() + mock_result9.scalar.return_value = 0 # bad replenishment fills + + mock_db.execute.side_effect = [ + mock_result1, + mock_result2, + mock_result3, + mock_result4, + mock_result5, + mock_result6, + mock_result7, + mock_result8, + mock_result9, + ] errors = await seeder.verify_data_integrity(mock_db) @@ -278,7 +321,6 @@ async def test_detects_orphaned_sales(self, seeder): async def test_detects_negative_quantities(self, seeder): """Test negative quantities are detected.""" mock_db = AsyncMock() - # Create separate mock results for each execute call mock_result1 = MagicMock() mock_result1.scalar.return_value = 0 # orphan check mock_result2 = MagicMock() @@ -287,8 +329,28 @@ async def test_detects_negative_quantities(self, seeder): mock_result3.fetchone.return_value = (date(2024, 1, 1), date(2024, 1, 31)) mock_result4 = MagicMock() mock_result4.scalar.return_value = 31 # calendar count - - mock_db.execute.side_effect = [mock_result1, mock_result2, mock_result3, mock_result4] + mock_result5 = MagicMock() + mock_result5.scalar.return_value = 0 # bad returns + mock_result6 = MagicMock() + mock_result6.scalar.return_value = 0 # inconsistent exogenous + mock_result7 = MagicMock() + mock_result7.scalar.return_value = 0 # bad bundles + mock_result8 = MagicMock() + mock_result8.scalar.return_value = 0 # bad lifecycle dates + mock_result9 = MagicMock() + mock_result9.scalar.return_value = 0 # bad replenishment fills + + mock_db.execute.side_effect = [ + mock_result1, + mock_result2, + mock_result3, + mock_result4, + mock_result5, + mock_result6, + mock_result7, + mock_result8, + mock_result9, + ] errors = await seeder.verify_data_integrity(mock_db) diff --git a/app/shared/seeder/tests/test_exogenous.py b/app/shared/seeder/tests/test_exogenous.py new file mode 100644 index 00000000..4cb2831a --- /dev/null +++ b/app/shared/seeder/tests/test_exogenous.py @@ -0,0 +1,125 @@ +"""Tests for ExogenousSignalGenerator (Phase 1).""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" +# Generator dicts have a wide union; tests narrow at access time. + +import math +import random +from datetime import date, timedelta + +from app.shared.seeder.config import ExogenousSignalConfig +from app.shared.seeder.generators.exogenous import ( + EVENT_SIGNAL_NAME, + MACRO_SIGNAL_NAME, + WEATHER_SIGNAL_NAME, + ExogenousSignalGenerator, +) + + +def _date_range(start: date, days: int) -> list[date]: + return [start + timedelta(days=i) for i in range(days)] + + +class TestExogenousSignalGeneratorDisabled: + def test_all_disabled_produces_no_rows(self): + gen = ExogenousSignalGenerator(random.Random(42), ExogenousSignalConfig()) + rows = gen.generate(_date_range(date(2024, 1, 1), 5), [1, 2]) + assert rows == [] + + +class TestWeather: + def test_weather_emits_row_per_store_and_date(self): + cfg = ExogenousSignalConfig(enable_weather=True, weather_noise_sigma_c=0.0) + gen = ExogenousSignalGenerator(random.Random(42), cfg) + store_ids = [1, 2, 3] + dates = _date_range(date(2024, 1, 1), 7) + rows = gen.generate(dates, store_ids) + weather_rows = [r for r in rows if r["signal_name"] == WEATHER_SIGNAL_NAME] + assert len(weather_rows) == len(store_ids) * len(dates) + # Sanity: each row is per-store (is_global=False), store_id non-null. + for r in weather_rows: + assert r["is_global"] is False + assert r["store_id"] in store_ids + assert isinstance(r["value"], float) + + def test_weather_seasonal_peak_in_summer(self): + # With zero noise the value should follow the deterministic sin wave. + cfg = ExogenousSignalConfig( + enable_weather=True, + weather_amplitude_c=10.0, + weather_climatology_mean_c=15.0, + weather_noise_sigma_c=0.0, + ) + gen = ExogenousSignalGenerator(random.Random(0), cfg) + # July 14 = doy 196 → peak + # January 14 = doy 14 → near trough + rows = gen.generate([date(2024, 7, 14), date(2024, 1, 14)], [1]) + by_date = {r["date"]: r["value"] for r in rows} + # Peak is roughly mean + amplitude; trough roughly mean - amplitude. + assert by_date[date(2024, 7, 14)] > by_date[date(2024, 1, 14)] + assert abs(by_date[date(2024, 7, 14)] - 25.0) < 0.5 + assert by_date[date(2024, 1, 14)] < 10.0 + + def test_weather_reproducible(self): + cfg = ExogenousSignalConfig(enable_weather=True) + gen1 = ExogenousSignalGenerator(random.Random(7), cfg) + gen2 = ExogenousSignalGenerator(random.Random(7), cfg) + dates = _date_range(date(2024, 1, 1), 30) + assert gen1.generate(dates, [1, 2]) == gen2.generate(dates, [1, 2]) + + +class TestMacroIndex: + def test_macro_row_per_date(self): + cfg = ExogenousSignalConfig(enable_macro=True) + gen = ExogenousSignalGenerator(random.Random(42), cfg) + dates = _date_range(date(2024, 6, 1), 10) + rows = gen.generate(dates, []) + macro = [r for r in rows if r["signal_name"] == MACRO_SIGNAL_NAME] + assert len(macro) == len(dates) + for r in macro: + assert r["is_global"] is True + assert r["store_id"] is None + + def test_macro_random_walk_changes_value(self): + cfg = ExogenousSignalConfig( + enable_macro=True, macro_initial_value=100.0, macro_step_sigma=1.0 + ) + gen = ExogenousSignalGenerator(random.Random(1), cfg) + dates = _date_range(date(2024, 1, 1), 30) + rows = [r for r in gen.generate(dates, []) if r["signal_name"] == MACRO_SIGNAL_NAME] + values = [r["value"] for r in rows] + # The first value already has one rng step applied so it's not + # exactly 100; just confirm the walk produces variation. + assert len({round(v, 6) for v in values}) > 1 + assert abs(values[0] - 100.0) < 5.0 # one small step + + def test_zero_step_sigma_yields_constant(self): + cfg = ExogenousSignalConfig( + enable_macro=True, macro_initial_value=42.0, macro_step_sigma=0.0 + ) + gen = ExogenousSignalGenerator(random.Random(99), cfg) + rows = gen.generate(_date_range(date(2024, 1, 1), 5), []) + macro_values = [r["value"] for r in rows if r["signal_name"] == MACRO_SIGNAL_NAME] + assert all(math.isclose(v, 42.0) for v in macro_values) + + +class TestEvents: + def test_events_only_within_range(self): + cfg = ExogenousSignalConfig( + enable_events=True, + event_dates=[date(2024, 1, 3), date(2025, 6, 1)], + ) + gen = ExogenousSignalGenerator(random.Random(0), cfg) + rows = gen.generate(_date_range(date(2024, 1, 1), 31), []) + events = [r for r in rows if r["signal_name"] == EVENT_SIGNAL_NAME] + # 2024-01-03 is in range; 2025-06-01 is not. + assert len(events) == 1 + assert events[0]["date"] == date(2024, 1, 3) + assert events[0]["value"] == 1.0 + assert events[0]["is_global"] is True + + def test_events_disabled_emits_nothing(self): + cfg = ExogenousSignalConfig(enable_events=False, event_dates=[date(2024, 1, 3)]) + gen = ExogenousSignalGenerator(random.Random(0), cfg) + rows = gen.generate(_date_range(date(2024, 1, 1), 31), []) + assert all(r["signal_name"] != EVENT_SIGNAL_NAME for r in rows) diff --git a/app/shared/seeder/tests/test_integration.py b/app/shared/seeder/tests/test_integration.py index facb43f5..784ba246 100644 --- a/app/shared/seeder/tests/test_integration.py +++ b/app/shared/seeder/tests/test_integration.py @@ -20,11 +20,13 @@ from app.core.config import get_settings from app.features.data_platform.models import ( Calendar, + ExogenousSignal, InventorySnapshotDaily, PriceHistory, Product, Promotion, SalesDaily, + SalesReturn, Store, ) from app.shared.seeder import DataSeeder, SeederConfig @@ -76,7 +78,10 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: # Pre-test cleanup for proper isolation async with session_maker() as cleanup_session: try: - # Delete in FK order (facts before dimensions) + # Delete in FK order (facts before dimensions). Phase 1 tables + # come first because they FK to store/product/calendar. + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) await cleanup_session.execute(delete(SalesDaily)) await cleanup_session.execute(delete(InventorySnapshotDaily)) await cleanup_session.execute(delete(PriceHistory)) @@ -102,7 +107,10 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: # Post-test cleanup async with session_maker() as cleanup_session: try: - # Delete in FK order (facts before dimensions) + # Delete in FK order (facts before dimensions). Phase 1 tables + # come first because they FK to store/product/calendar. + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) await cleanup_session.execute(delete(SalesDaily)) await cleanup_session.execute(delete(InventorySnapshotDaily)) await cleanup_session.execute(delete(PriceHistory)) diff --git a/app/shared/seeder/tests/test_phase1_config.py b/app/shared/seeder/tests/test_phase1_config.py new file mode 100644 index 00000000..b82123b2 --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_config.py @@ -0,0 +1,103 @@ +"""Tests for Phase 1 seeder configuration dataclasses. + +Covers ExogenousSignalConfig, MultiSeasonalityConfig, ChangepointEvent / +ChangepointConfig, ReturnsConfig, SubstitutionConfig — and confirms the +SeederConfig defaults wire them in with disabled / empty defaults. +""" + +from datetime import date + +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + ExogenousSignalConfig, + MultiSeasonalityConfig, + ReturnsConfig, + ScenarioPreset, + SeederConfig, + SubstitutionConfig, +) + + +class TestExogenousSignalConfig: + def test_defaults_disabled(self): + config = ExogenousSignalConfig() + assert config.enable_weather is False + assert config.enable_macro is False + assert config.enable_events is False + assert config.weather_temperature_sensitivity == 0.0 + assert config.event_dates == [] + + def test_event_dates_is_independent(self): + # Default-factory list must not be shared between instances. + a = ExogenousSignalConfig() + b = ExogenousSignalConfig() + a.event_dates.append(date(2024, 1, 1)) + assert b.event_dates == [] + + +class TestMultiSeasonalityConfig: + def test_defaults_zero(self): + config = MultiSeasonalityConfig() + assert config.yearly_seasonality_amplitude == 0.0 + assert config.yearly_phase_offset_days == 0 + + +class TestChangepointConfig: + def test_default_empty(self): + assert ChangepointConfig().changepoints == [] + + def test_event_fields(self): + event = ChangepointEvent(date=date(2024, 3, 15), demand_multiplier=2.5, decay_days=60) + assert event.date == date(2024, 3, 15) + assert event.demand_multiplier == 2.5 + assert event.decay_days == 60 + + +class TestReturnsConfig: + def test_defaults_disabled(self): + cfg = ReturnsConfig() + assert cfg.enable is False + assert 0.0 <= cfg.return_probability <= 1.0 + assert cfg.return_lag_days_min <= cfg.return_lag_days_max + # Reason distribution default must be non-empty so _pick_reason + # always returns a real reason without falling back. + assert sum(cfg.return_reason_distribution.values()) > 0 + + +class TestSubstitutionConfig: + def test_defaults_disabled(self): + cfg = SubstitutionConfig() + assert cfg.enable is False + assert cfg.substitute_groups == [] + assert cfg.substitution_lift_on_stockout == 0.0 + + +class TestSeederConfigPhase1Wiring: + def test_phase1_defaults_present_and_disabled(self): + cfg = SeederConfig() + # Each Phase 1 sub-config must be present with disabled defaults + # so existing scenarios are byte-identical when not opted in. + assert isinstance(cfg.exogenous, ExogenousSignalConfig) + assert isinstance(cfg.multi_seasonality, MultiSeasonalityConfig) + assert isinstance(cfg.changepoints, ChangepointConfig) + assert isinstance(cfg.returns, ReturnsConfig) + assert isinstance(cfg.substitution, SubstitutionConfig) + assert cfg.exogenous.enable_weather is False + assert cfg.multi_seasonality.yearly_seasonality_amplitude == 0.0 + assert cfg.changepoints.changepoints == [] + assert cfg.returns.enable is False + assert cfg.substitution.enable is False + + def test_from_scenario_does_not_enable_phase1(self): + # Existing scenarios must keep Phase 1 off — this is the + # regression invariant that protects pre-Phase-1 outputs. + for scenario in ScenarioPreset: + cfg = SeederConfig.from_scenario(scenario) + assert cfg.exogenous.enable_weather is False, f"{scenario} unexpectedly enables weather" + assert cfg.exogenous.enable_macro is False + assert cfg.exogenous.enable_events is False + assert cfg.multi_seasonality.yearly_seasonality_amplitude == 0.0 + assert cfg.changepoints.changepoints == [] + assert cfg.returns.enable is False + assert cfg.substitution.enable is False diff --git a/app/shared/seeder/tests/test_phase1_integration.py b/app/shared/seeder/tests/test_phase1_integration.py new file mode 100644 index 00000000..c5751127 --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_integration.py @@ -0,0 +1,334 @@ +"""Phase 1 integration tests against real Postgres. + +Run with: uv run pytest app/shared/seeder/tests/test_phase1_integration.py -v -m integration +Requires docker-compose Postgres up and migrations applied. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import os +from collections.abc import AsyncGenerator +from contextlib import suppress +from datetime import date, timedelta + +import pytest +import pytest_asyncio +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.features.data_platform.models import ( + Calendar, + ExogenousSignal, + InventorySnapshotDaily, + PriceHistory, + Product, + Promotion, + SalesDaily, + SalesReturn, + Store, +) +from app.features.seeder import schemas, service +from app.shared.seeder import DataSeeder, SeederConfig +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + DimensionConfig, + ExogenousSignalConfig, + MultiSeasonalityConfig, + ReturnsConfig, +) + +pytestmark = pytest.mark.integration + + +def _check_destructive_test_guard() -> None: + settings = get_settings() + is_testing = getattr(settings, "testing", False) + app_env_testing = os.environ.get("APP_ENV", "").lower() == "testing" + allow_destructive = os.environ.get("ALLOW_DESTRUCTIVE_TEST_DB", "").lower() == "true" + if not is_testing and not app_env_testing and not allow_destructive: + raise RuntimeError( + "Destructive test operations require explicit opt-in. " + "Set ALLOW_DESTRUCTIVE_TEST_DB=true, APP_ENV=testing, or settings.testing=True" + ) + + +@pytest_asyncio.fixture(scope="function") +async def db_session() -> AsyncGenerator[AsyncSession, None]: + _check_destructive_test_guard() + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as cleanup_session: + try: + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) + await cleanup_session.execute(delete(SalesDaily)) + await cleanup_session.execute(delete(InventorySnapshotDaily)) + await cleanup_session.execute(delete(PriceHistory)) + await cleanup_session.execute(delete(Promotion)) + await cleanup_session.execute(delete(Calendar)) + await cleanup_session.execute(delete(Product)) + await cleanup_session.execute(delete(Store)) + await cleanup_session.commit() + except Exception: + await cleanup_session.rollback() + + async with session_maker() as session: + try: + yield session + finally: + with suppress(Exception): + await session.rollback() + + _check_destructive_test_guard() + + async with session_maker() as cleanup_session: + try: + await cleanup_session.execute(delete(SalesReturn)) + await cleanup_session.execute(delete(ExogenousSignal)) + await cleanup_session.execute(delete(SalesDaily)) + await cleanup_session.execute(delete(InventorySnapshotDaily)) + await cleanup_session.execute(delete(PriceHistory)) + await cleanup_session.execute(delete(Promotion)) + await cleanup_session.execute(delete(Calendar)) + await cleanup_session.execute(delete(Product)) + await cleanup_session.execute(delete(Store)) + await cleanup_session.commit() + except Exception: + await cleanup_session.rollback() + + await engine.dispose() + + +class TestPhase1Disabled: + @pytest.mark.asyncio + async def test_default_run_creates_no_phase1_rows(self, db_session: AsyncSession) -> None: + """With Phase 1 fully off, exogenous_signal and sales_returns stay empty.""" + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=3), + ) + result = await DataSeeder(config).generate_full(db_session) + assert result.exogenous_count == 0 + assert result.returns_count == 0 + + exo_count = ( + await db_session.execute(select(func.count()).select_from(ExogenousSignal)) + ).scalar() or 0 + ret_count = ( + await db_session.execute(select(func.count()).select_from(SalesReturn)) + ).scalar() or 0 + assert exo_count == 0 + assert ret_count == 0 + + +class TestPhase1Enabled: + @pytest.mark.asyncio + async def test_exogenous_weather_and_macro_persisted(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), # 7 days + dimensions=DimensionConfig(stores=2, products=2), + exogenous=ExogenousSignalConfig( + enable_weather=True, + enable_macro=True, + ), + ) + result = await DataSeeder(config).generate_full(db_session) + # 2 stores x 7 dates weather + 7 dates macro = 21 rows. + assert result.exogenous_count == 21 + + weather_rows = ( + await db_session.execute( + select(func.count()) + .select_from(ExogenousSignal) + .where(ExogenousSignal.signal_name == "weather_temp_c") + ) + ).scalar() or 0 + macro_rows = ( + await db_session.execute( + select(func.count()) + .select_from(ExogenousSignal) + .where(ExogenousSignal.signal_name == "macro_index") + ) + ).scalar() or 0 + assert weather_rows == 14 + assert macro_rows == 7 + + @pytest.mark.asyncio + async def test_returns_table_populated(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), + dimensions=DimensionConfig(stores=2, products=3), + returns=ReturnsConfig(enable=True, return_probability=0.2), + ) + result = await DataSeeder(config).generate_full(db_session) + assert result.returns_count > 0 + # Quantity invariant + bad = ( + await db_session.execute( + select(func.count()).select_from(SalesReturn).where(SalesReturn.return_quantity < 1) + ) + ).scalar() or 0 + assert bad == 0 + + @pytest.mark.asyncio + async def test_changepoint_lifts_demand_at_date(self, db_session: AsyncSession) -> None: + """A 5x changepoint on day 0 with no decay should produce strictly + higher total demand than the baseline run.""" + # Baseline (no changepoint). + base_config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 14), + dimensions=DimensionConfig(stores=2, products=2), + ) + await DataSeeder(base_config).generate_full(db_session) + baseline_total = ( + await db_session.execute( + select(func.sum(SalesDaily.quantity)).where(SalesDaily.date == date(2024, 1, 1)) + ) + ).scalar() or 0 + + # Reset and re-run with a changepoint. + await db_session.execute(delete(SalesReturn)) + await db_session.execute(delete(ExogenousSignal)) + await db_session.execute(delete(SalesDaily)) + await db_session.execute(delete(InventorySnapshotDaily)) + await db_session.execute(delete(PriceHistory)) + await db_session.execute(delete(Promotion)) + await db_session.execute(delete(Calendar)) + await db_session.execute(delete(Product)) + await db_session.execute(delete(Store)) + await db_session.commit() + + cp_config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 14), + dimensions=DimensionConfig(stores=2, products=2), + changepoints=ChangepointConfig( + changepoints=[ + ChangepointEvent( + date=date(2024, 1, 1), + demand_multiplier=5.0, + decay_days=0, + ) + ] + ), + ) + await DataSeeder(cp_config).generate_full(db_session) + cp_total = ( + await db_session.execute( + select(func.sum(SalesDaily.quantity)).where(SalesDaily.date == date(2024, 1, 1)) + ) + ).scalar() or 0 + assert cp_total > baseline_total * 2 # well above the 5x lift floor + + @pytest.mark.asyncio + async def test_verify_integrity_clean_with_phase1(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=2), + exogenous=ExogenousSignalConfig(enable_weather=True), + returns=ReturnsConfig(enable=True, return_probability=0.5), + multi_seasonality=MultiSeasonalityConfig(yearly_seasonality_amplitude=0.1), + ) + seeder = DataSeeder(config) + await seeder.generate_full(db_session) + errors = await seeder.verify_data_integrity(db_session) + assert errors == [] + + +class TestQueryExogenousService: + @pytest.mark.asyncio + async def test_query_returns_persisted_weather(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=2), + exogenous=ExogenousSignalConfig(enable_weather=True), + ) + await DataSeeder(config).generate_full(db_session) + + # Need to commit DataSeeder's writes? DataSeeder.generate_full already + # commits. The fixture's expire_on_commit=False keeps objects valid. + + resp = await service.query_exogenous( + db_session, + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 2), + store_id=None, + ) + assert isinstance(resp, schemas.ExogenousSignalResponse) + # 2 stores x 2 dates = 4 weather rows in this window. + assert resp.total == 4 + for r in resp.records: + assert r.signal_name == "weather_temp_c" + assert r.is_global is False + + @pytest.mark.asyncio + async def test_query_filter_by_store(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + dimensions=DimensionConfig(stores=3, products=2), + exogenous=ExogenousSignalConfig(enable_weather=True, enable_macro=True), + ) + await DataSeeder(config).generate_full(db_session) + + # Pick the first store id present. + store_id_row = (await db_session.execute(select(Store.id).limit(1))).scalar() + assert store_id_row is not None + + resp = await service.query_exogenous( + db_session, + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + store_id=store_id_row, + ) + # Only the rows for this store_id over 3 dates. + assert resp.total == 3 + for r in resp.records: + assert r.store_id == store_id_row + + @pytest.mark.asyncio + async def test_query_empty_signal_returns_no_rows(self, db_session: AsyncSession) -> None: + # No data seeded → query should return empty list, not error. + # First seed something to make sure tables exist with FK targets, + # then query a signal we never emitted. + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + dimensions=DimensionConfig(stores=1, products=1), + ) + await DataSeeder(config).generate_full(db_session) + + resp = await service.query_exogenous( + db_session, + signal_name="weather_temp_c", + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 3), + store_id=None, + ) + assert resp.total == 0 + assert resp.records == [] + + +# Suppress unused-import warning for timedelta — kept for future use. +_ = timedelta diff --git a/app/shared/seeder/tests/test_phase1_regression.py b/app/shared/seeder/tests/test_phase1_regression.py new file mode 100644 index 00000000..7dbe214b --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_regression.py @@ -0,0 +1,92 @@ +"""Regression invariant: Phase 1 toggles OFF == pre-Phase-1 output. + +These tests are LOAD-BEARING: they guarantee that adding the Phase 1 +options to ``SalesDailyGenerator`` does not change the byte-output of +the existing six scenarios. If any of them starts failing, somebody +either added an RNG draw on the disabled path or changed a default +value that affects the existing math. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import random +from datetime import date, timedelta +from decimal import Decimal + +import pytest + +from app.shared.seeder.config import ( + ChangepointConfig, + MultiSeasonalityConfig, + ScenarioPreset, + SeederConfig, + SubstitutionConfig, +) +from app.shared.seeder.generators.facts import SalesDailyGenerator + + +def _short_dates(n: int) -> list[date]: + """Use a small date range so the test is fast.""" + return [date(2024, 1, 1) + timedelta(days=i) for i in range(n)] + + +def _run_with_kwargs(config: SeederConfig, **extra_kwargs): + """Run SalesDailyGenerator using ``config`` with optional kwargs.""" + rng = random.Random(config.seed) + gen = SalesDailyGenerator( + rng, + config.time_series, + config.retail, + config.sparsity, + config.holidays, + **extra_kwargs, + ) + return gen.generate( + store_ids=[1, 2], + product_data=[(1, Decimal("9.99")), (2, Decimal("4.99"))], + dates=_short_dates(30), + promotions={}, + stockouts={}, + ) + + +class TestRegressionWithoutKwargs: + """Calling without any Phase 1 kwargs must match calling with explicit + defaults / None / empty configs.""" + + @pytest.mark.parametrize("scenario", list(ScenarioPreset)) + def test_no_kwargs_matches_explicit_defaults(self, scenario: ScenarioPreset): + config = SeederConfig.from_scenario(scenario, seed=42) + # Cap dates to the scenario range we care about. + baseline = _run_with_kwargs(config) + with_defaults = _run_with_kwargs( + config, + multi_seasonality=MultiSeasonalityConfig(), # amplitude=0 default + changepoints=ChangepointConfig(), # empty default + substitution=SubstitutionConfig(), # disabled default + exogenous_weather=None, + weather_temperature_sensitivity=0.0, + ) + assert baseline == with_defaults, ( + f"Phase 1 defaults must not alter output for scenario {scenario.value}" + ) + + def test_disabled_phase1_does_not_consume_rng(self): + """A second generator with Phase 1 features enabled but no data + (e.g. empty changepoints / empty weather lookup) must still + produce the same row count and quantities as the disabled path. + """ + config = SeederConfig.from_scenario(ScenarioPreset.RETAIL_STANDARD, seed=42) + baseline = _run_with_kwargs(config) + # Enable substitution but with no groups → group lookup is empty. + no_op = _run_with_kwargs( + config, + substitution=SubstitutionConfig( + enable=True, + substitute_groups=[], + substitution_lift_on_stockout=0.5, + ), + exogenous_weather={}, # empty lookup + weather_temperature_sensitivity=0.1, # nonzero but no rows match + ) + assert baseline == no_op diff --git a/app/shared/seeder/tests/test_phase1_sales_effects.py b/app/shared/seeder/tests/test_phase1_sales_effects.py new file mode 100644 index 00000000..bc2d01a8 --- /dev/null +++ b/app/shared/seeder/tests/test_phase1_sales_effects.py @@ -0,0 +1,271 @@ +"""Tests for Phase 1 SalesDailyGenerator demand-multiplier extensions. + +Covers yearly seasonality, changepoints, weather-driven demand, and +substitution-on-stockout. The regression invariant — that disabling all +Phase 1 toggles produces byte-identical output to the pre-Phase-1 code +path — is verified in ``test_phase1_regression.py``. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import math +import random +from datetime import date, timedelta +from decimal import Decimal + +from app.shared.seeder.config import ( + ChangepointConfig, + ChangepointEvent, + MultiSeasonalityConfig, + RetailPatternConfig, + SparsityConfig, + SubstitutionConfig, + TimeSeriesConfig, +) +from app.shared.seeder.generators.facts import SalesDailyGenerator + + +def _deterministic_ts_config() -> TimeSeriesConfig: + """A noise/anomaly-free config so multipliers can be asserted exactly.""" + return TimeSeriesConfig( + base_demand=100, + trend="none", + weekly_seasonality=[1.0] * 7, + monthly_seasonality={}, + noise_sigma=0.0, + anomaly_probability=0.0, + ) + + +def _deterministic_retail_config() -> RetailPatternConfig: + return RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + promotion_probability=0.0, + stockout_probability=0.0, + ) + + +def _flat_sparsity() -> SparsityConfig: + return SparsityConfig(missing_combinations_pct=0.0, random_gaps_per_series=0) + + +class TestYearlySeasonality: + def test_amplitude_zero_no_effect(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + multi_seasonality=MultiSeasonalityConfig(yearly_seasonality_amplitude=0.0), + ) + # Demand = base_demand exactly under the deterministic config. + assert gen._yearly_seasonality_multiplier(date(2024, 7, 1)) == 1.0 + + def test_amplitude_nonzero_introduces_swing(self): + cfg = MultiSeasonalityConfig(yearly_seasonality_amplitude=0.2) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + multi_seasonality=cfg, + ) + # On day-of-year 91 (≈ April 1) sin(2π · 91 / 365) ≈ 1; check sign. + m_apr = gen._yearly_seasonality_multiplier(date(2024, 4, 1)) + m_oct = gen._yearly_seasonality_multiplier(date(2024, 10, 1)) + assert m_apr > 1.0 + assert m_oct < 1.0 + # Bounded by ±amplitude. + assert 0.8 - 1e-9 <= m_oct <= 1.0 + assert 1.0 <= m_apr <= 1.2 + 1e-9 + + +class TestChangepoints: + def test_no_changepoints_returns_one(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[]), + ) + assert gen._changepoint_multiplier(date(2024, 6, 1)) == 1.0 + + def test_impulse_decays_exponentially(self): + cp = ChangepointEvent(date=date(2024, 6, 1), demand_multiplier=2.0, decay_days=10) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[cp]), + ) + # Day 0: multiplier == 2.0 + assert math.isclose(gen._changepoint_multiplier(date(2024, 6, 1)), 2.0) + # Day 10: multiplier ≈ 1 + (2-1) * e^-1 ≈ 1.3679 + m10 = gen._changepoint_multiplier(date(2024, 6, 11)) + assert 1.35 < m10 < 1.40 + # Before the changepoint: 1.0 + assert gen._changepoint_multiplier(date(2024, 5, 31)) == 1.0 + + def test_pure_impulse_zero_decay(self): + cp = ChangepointEvent(date=date(2024, 6, 1), demand_multiplier=3.0, decay_days=0) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[cp]), + ) + assert gen._changepoint_multiplier(date(2024, 6, 1)) == 3.0 + assert gen._changepoint_multiplier(date(2024, 6, 2)) == 1.0 + + +class TestWeatherMultiplier: + def test_no_lookup_returns_one(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + exogenous_weather=None, + weather_temperature_sensitivity=0.01, + ) + assert gen._weather_multiplier(date(2024, 7, 1), 1) == 1.0 + + def test_sensitivity_zero_returns_one_even_with_lookup(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + exogenous_weather={(1, date(2024, 7, 1)): 30.0}, + weather_temperature_sensitivity=0.0, + weather_climatology_mean_c=15.0, + ) + assert gen._weather_multiplier(date(2024, 7, 1), 1) == 1.0 + + def test_warm_day_lifts_demand(self): + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + exogenous_weather={(1, date(2024, 7, 1)): 25.0}, + weather_temperature_sensitivity=0.02, + weather_climatology_mean_c=15.0, + ) + # 1 + 0.02 * (25 - 15) = 1.2 + assert math.isclose(gen._weather_multiplier(date(2024, 7, 1), 1), 1.2) + + +class TestSubstitution: + def test_disabled_returns_one(self): + sub = SubstitutionConfig( + enable=False, + substitute_groups=[[1, 2]], + substitution_lift_on_stockout=0.5, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + assert gen._substitution_multiplier(1, {2}) == 1.0 + + def test_no_group_member_returns_one(self): + sub = SubstitutionConfig( + enable=True, + substitute_groups=[[1, 2]], + substitution_lift_on_stockout=0.5, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + # Product 3 isn't in any group → no lift. + assert gen._substitution_multiplier(3, {2}) == 1.0 + + def test_lift_when_groupmate_out(self): + sub = SubstitutionConfig( + enable=True, + substitute_groups=[[1, 2, 3]], + substitution_lift_on_stockout=0.6, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + # Product 1 in stock, products 2 in stock, product 3 stocked out. + # out_members=1, survivors=1 (product 2). product 1's share is + # 0.6 * 1 / (survivors + 1) = 0.3 → multiplier 1.3. + m = gen._substitution_multiplier(1, {3}) + assert math.isclose(m, 1.3) + + def test_stocked_out_product_gets_no_lift(self): + sub = SubstitutionConfig( + enable=True, + substitute_groups=[[1, 2]], + substitution_lift_on_stockout=0.5, + ) + gen = SalesDailyGenerator( + random.Random(0), + _deterministic_ts_config(), + _deterministic_retail_config(), + _flat_sparsity(), + [], + substitution=sub, + ) + # Product 1 itself stocked out → multiplier is 1.0. + assert gen._substitution_multiplier(1, {1}) == 1.0 + + +class TestPhase1EndToEnd: + def test_phase1_features_alter_quantities(self): + # With deterministic ts/retail config and a changepoint impulse, + # the day-of-change quantity should equal base x multiplier. + ts = _deterministic_ts_config() + retail = _deterministic_retail_config() + cp = ChangepointEvent(date=date(2024, 1, 1), demand_multiplier=2.0, decay_days=0) + gen = SalesDailyGenerator( + random.Random(0), + ts, + retail, + _flat_sparsity(), + [], + changepoints=ChangepointConfig(changepoints=[cp]), + ) + dates = [date(2024, 1, 1) + timedelta(days=i) for i in range(3)] + sales = gen.generate( + store_ids=[1], + product_data=[(1, Decimal("10.00"))], + dates=dates, + promotions={}, + stockouts={}, + ) + by_date = {s["date"]: s["quantity"] for s in sales} + # Day 0: 200 (2x base). Days 1+: 100 (no decay, decay_days=0). + assert by_date[date(2024, 1, 1)] == 200 + assert by_date[date(2024, 1, 2)] == 100 + assert by_date[date(2024, 1, 3)] == 100 diff --git a/app/shared/seeder/tests/test_phase2_bundles.py b/app/shared/seeder/tests/test_phase2_bundles.py new file mode 100644 index 00000000..352fc435 --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_bundles.py @@ -0,0 +1,228 @@ +"""Tests for Phase 2 BundleGenerator promotion conversion. + +The regression invariant is the most load-bearing assertion: with +``BundleConfig.enable=False`` (the default), ``BundleGenerator.apply`` +must leave both the promotion list and the rng state byte-identical. +""" + +from __future__ import annotations + +import random +from copy import deepcopy +from datetime import date +from decimal import Decimal +from typing import Any + +import pytest + +from app.shared.seeder.config import BundleConfig +from app.shared.seeder.generators.bundles import BundleGenerator + + +def _sample_promotions() -> list[dict[str, Any]]: + """Build a fixed sample of promotion records mimicking PromotionGenerator output.""" + return [ + { + "product_id": 1, + "store_id": None, + "name": "Weekly Special", + "discount_pct": Decimal("0.10"), + "discount_amount": None, + "start_date": date(2024, 1, 1), + "end_date": date(2024, 1, 7), + }, + { + "product_id": 2, + "store_id": 10, + "name": "Flash Sale", + "discount_pct": None, + "discount_amount": Decimal("3.00"), + "start_date": date(2024, 1, 5), + "end_date": date(2024, 1, 12), + }, + { + "product_id": 3, + "store_id": None, + "name": "Clearance", + "discount_pct": Decimal("0.20"), + "discount_amount": None, + "start_date": date(2024, 1, 10), + "end_date": date(2024, 1, 17), + }, + { + "product_id": 4, + "store_id": None, + "name": "BOGO Deal", + "discount_pct": Decimal("0.15"), + "discount_amount": None, + "start_date": date(2024, 1, 15), + "end_date": date(2024, 1, 22), + }, + ] + + +class TestBundleGeneratorDisabled: + """Regression invariant: no mutation, no rng consumption when disabled.""" + + def test_enabled_property_false_when_config_none(self) -> None: + assert BundleGenerator(random.Random(0), None).enabled is False + + def test_enabled_property_false_when_config_default(self) -> None: + assert BundleGenerator(random.Random(0), BundleConfig()).enabled is False + + def test_no_mutation_when_config_none(self) -> None: + rng = random.Random(123) + promos = _sample_promotions() + snapshot = deepcopy(promos) + result = BundleGenerator(rng, None).apply(promos, [1, 2, 3, 4, 5]) + assert result is promos + assert result == snapshot + + def test_no_rng_consumption_when_config_none(self) -> None: + rng = random.Random(42) + baseline_state = rng.getstate() + BundleGenerator(rng, None).apply(_sample_promotions(), [1, 2, 3, 4, 5]) + assert rng.getstate() == baseline_state + + def test_no_mutation_when_disabled_config(self) -> None: + rng = random.Random(123) + promos = _sample_promotions() + snapshot = deepcopy(promos) + BundleGenerator(rng, BundleConfig()).apply(promos, [1, 2, 3, 4, 5]) + assert promos == snapshot + + def test_no_rng_consumption_when_disabled_config(self) -> None: + rng = random.Random(42) + baseline_state = rng.getstate() + BundleGenerator(rng, BundleConfig()).apply(_sample_promotions(), [1, 2, 3, 4, 5]) + assert rng.getstate() == baseline_state + + +class TestBundleGeneratorEnabled: + """Enabled-path correctness: kinds, members, discounts, reproducibility.""" + + def _cfg(self, **overrides: Any) -> BundleConfig: + kwargs: dict[str, Any] = { + "enable": True, + "bundle_probability": 1.0, # convert every promo for deterministic checks + "bogo_share_within_bundles": 0.5, + "min_bundle_size": 2, + "max_bundle_size": 3, + "bundle_discount_pct_min": 0.10, + "bundle_discount_pct_max": 0.30, + } + kwargs.update(overrides) + return BundleConfig(**kwargs) + + def test_kind_in_allowlist(self) -> None: + promos = _sample_promotions() + BundleGenerator(random.Random(7), self._cfg()).apply(promos, [1, 2, 3, 4, 5, 6]) + for p in promos: + assert p["kind"] in ("bundle", "bogo") + + def test_members_drawn_from_pool_excluding_host(self) -> None: + pool = [1, 2, 3, 4, 5, 6] + promos = _sample_promotions() + BundleGenerator(random.Random(11), self._cfg()).apply(promos, pool) + for p in promos: + members = p["bundle_member_product_ids"] + assert isinstance(members, list) + assert len(members) >= 2 + assert p["product_id"] not in members + assert set(members).issubset(set(pool)) + assert len(set(members)) == len(members) + + def test_member_count_in_configured_range(self) -> None: + promos = _sample_promotions() + BundleGenerator( + random.Random(13), + self._cfg(min_bundle_size=2, max_bundle_size=4), + ).apply(promos, [1, 2, 3, 4, 5, 6, 7]) + for p in promos: + members = p["bundle_member_product_ids"] + assert 2 <= len(members) <= 4 + + def test_discount_pct_in_configured_range_and_amount_cleared(self) -> None: + promos = _sample_promotions() + BundleGenerator( + random.Random(17), + self._cfg(bundle_discount_pct_min=0.10, bundle_discount_pct_max=0.30), + ).apply(promos, [1, 2, 3, 4, 5, 6]) + for p in promos: + d = p["discount_pct"] + assert isinstance(d, Decimal) + assert Decimal("0.10") <= d <= Decimal("0.30") + # Quantized to 4 decimal places to match ``Numeric(5, 4)``. + exponent = d.as_tuple().exponent + assert isinstance(exponent, int) and exponent == -4 + assert p["discount_amount"] is None + + def test_all_bogo_when_share_is_one(self) -> None: + promos = _sample_promotions() + BundleGenerator(random.Random(23), self._cfg(bogo_share_within_bundles=1.0)).apply( + promos, [1, 2, 3, 4, 5, 6] + ) + assert all(p["kind"] == "bogo" for p in promos) + + def test_all_bundle_when_share_is_zero(self) -> None: + promos = _sample_promotions() + BundleGenerator(random.Random(29), self._cfg(bogo_share_within_bundles=0.0)).apply( + promos, [1, 2, 3, 4, 5, 6] + ) + assert all(p["kind"] == "bundle" for p in promos) + + def test_zero_probability_leaves_records_unchanged(self) -> None: + promos = _sample_promotions() + snapshot = deepcopy(promos) + BundleGenerator(random.Random(31), self._cfg(bundle_probability=0.0)).apply( + promos, [1, 2, 3, 4, 5, 6] + ) + assert promos == snapshot + + def test_reproducible_with_same_seed(self) -> None: + cfg = self._cfg(bundle_probability=0.5) + promos_a = _sample_promotions() + promos_b = _sample_promotions() + BundleGenerator(random.Random(42), cfg).apply(promos_a, [1, 2, 3, 4, 5]) + BundleGenerator(random.Random(42), cfg).apply(promos_b, [1, 2, 3, 4, 5]) + assert promos_a == promos_b + + def test_skips_when_pool_too_small_for_host(self) -> None: + """Eligible pool below ``min_bundle_size`` → skip without rng consumption.""" + promos = _sample_promotions() + snapshot = deepcopy(promos) + rng = random.Random(37) + baseline_state = rng.getstate() + # Single-element pool: every host's eligible_pool has at most 1 element + # which is below the default ``min_bundle_size=2`` — every promo skipped. + BundleGenerator(rng, self._cfg(min_bundle_size=2)).apply(promos, [1]) + assert promos == snapshot + assert rng.getstate() == baseline_state + + def test_max_clamps_to_eligible_pool_size(self) -> None: + # min=2, max=10, pool=4 — each host excludes itself → 3 eligible. + promos = _sample_promotions() + BundleGenerator( + random.Random(41), + self._cfg(min_bundle_size=2, max_bundle_size=10), + ).apply(promos, [1, 2, 3, 4]) + for p in promos: + assert 2 <= len(p["bundle_member_product_ids"]) <= 3 + + +class TestBundleGeneratorValidation: + def test_min_bundle_size_below_two_raises(self) -> None: + bg = BundleGenerator( + random.Random(0), + BundleConfig(enable=True, min_bundle_size=1), + ) + with pytest.raises(ValueError, match="min_bundle_size must be >= 2"): + bg.apply(_sample_promotions(), [1, 2, 3]) + + def test_max_below_min_raises(self) -> None: + bg = BundleGenerator( + random.Random(0), + BundleConfig(enable=True, min_bundle_size=3, max_bundle_size=2), + ) + with pytest.raises(ValueError, match="max_bundle_size must be >="): + bg.apply(_sample_promotions(), [1, 2, 3, 4, 5]) diff --git a/app/shared/seeder/tests/test_phase2_channels_sales_integration.py b/app/shared/seeder/tests/test_phase2_channels_sales_integration.py new file mode 100644 index 00000000..e9229bf5 --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_channels_sales_integration.py @@ -0,0 +1,337 @@ +"""Tests for Phase 2 channel integration into SalesDailyGenerator. + +Regression invariant: with ``channels=None`` (or ``channels`` set to +a disabled ``ChannelConfig``) the generator emits rows with no +``channel`` key and consumes zero new rng state — byte-identical to +its pre-Phase-2 behavior. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value,misc" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from decimal import Decimal + +import pytest + +from app.shared.seeder.config import ( + ChannelConfig, + RetailPatternConfig, + SparsityConfig, + TimeSeriesConfig, +) +from app.shared.seeder.generators.facts import SalesDailyGenerator + + +def _dates(start: date, n: int) -> list[date]: + return [start + timedelta(days=i) for i in range(n)] + + +def _minimal_ts() -> TimeSeriesConfig: + return TimeSeriesConfig( + base_demand=100, + trend="none", + weekly_seasonality=[1.0] * 7, + monthly_seasonality={}, + noise_sigma=0.0, + anomaly_probability=0.0, + anomaly_magnitude=1.0, + ) + + +def _minimal_retail() -> RetailPatternConfig: + return RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + new_product_ramp_days=0, + promotion_probability=0.0, + stockout_probability=0.0, + ) + + +def _run_generator( + *, + seed: int = 42, + channels: ChannelConfig | None = None, + promotions: dict[tuple[int, int], set[date]] | None = None, + dates: list[date] | None = None, +) -> list[dict[str, date | int | Decimal | str]]: + rng = random.Random(seed) + gen = SalesDailyGenerator( + rng, + _minimal_ts(), + _minimal_retail(), + SparsityConfig(), + holidays=[], + channels=channels, + ) + return gen.generate( + store_ids=[1, 2], + product_data=[(10, Decimal("9.99")), (20, Decimal("4.99"))], + dates=dates or _dates(date(2024, 1, 1), 30), + promotions=promotions or {}, + stockouts={}, + ) + + +# ---------------------------------------------------------------------- # +# Regression invariant +# ---------------------------------------------------------------------- # + + +class TestRegressionInvariant: + def test_no_channels_kwarg_omits_channel_column(self) -> None: + rows = _run_generator() + for r in rows: + assert "channel" not in r + + def test_disabled_config_matches_no_channels(self) -> None: + baseline = _run_generator() + with_disabled = _run_generator(channels=ChannelConfig()) # enable_multichannel=False + assert baseline == with_disabled + + def test_disabled_channels_does_not_consume_rng(self) -> None: + # Empty channel_mix + disabled config — even with channel_mix + # populated, the disabled path must not draw an rng row-pick. + baseline = _run_generator(seed=42) + with_populated_disabled = _run_generator( + seed=42, + channels=ChannelConfig( + enable_multichannel=False, + channel_mix={"in_store": 0.5, "online": 0.5}, + ), + ) + assert baseline == with_populated_disabled + + +# ---------------------------------------------------------------------- # +# Enabled-path correctness +# ---------------------------------------------------------------------- # + + +def _enabled_uniform() -> ChannelConfig: + return ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "online": 0.5}, + online_promo_uplift=1.0, + online_substitution_to_instore=0.0, + ) + + +class TestChannelDistribution: + def test_chosen_channel_in_mix_keys(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.4, "online": 0.4, "click_collect": 0.2}, + ) + rows = _run_generator(channels=cfg) + for r in rows: + assert "channel" in r + assert r["channel"] in {"in_store", "online", "click_collect"} + + def test_single_channel_mix_always_picks_that_channel(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"wholesale": 1.0}, + ) + rows = _run_generator(channels=cfg) + assert rows + assert all(r["channel"] == "wholesale" for r in rows) + + def test_dominant_channel_appears_more_often(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.9, "online": 0.1}, + ) + rows = _run_generator(channels=cfg, dates=_dates(date(2024, 1, 1), 60)) + n_in_store = sum(1 for r in rows if r["channel"] == "in_store") + n_online = sum(1 for r in rows if r["channel"] == "online") + assert n_in_store > n_online + assert n_online > 0 # some online rows still appear + + def test_zero_weight_channels_are_never_chosen(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 1.0, "online": 0.0}, + ) + rows = _run_generator(channels=cfg) + assert all(r["channel"] == "in_store" for r in rows) + + +class TestOnlinePromoUplift: + def test_uplift_increases_online_qty_on_promo_dates(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"online": 1.0}, # force every row online + online_promo_uplift=2.0, + online_substitution_to_instore=0.0, + ) + promo_set = {date(2024, 1, 5), date(2024, 1, 6)} + promotions = { + (1, 10): promo_set, + (1, 20): promo_set, + (2, 10): promo_set, + (2, 20): promo_set, + } + rows = _run_generator(channels=cfg, promotions=promotions) + promo_qty_avg = sum(int(r["quantity"]) for r in rows if r["date"] in promo_set) / max( + 1, sum(1 for r in rows if r["date"] in promo_set) + ) + non_promo_qty_avg = sum( + int(r["quantity"]) for r in rows if r["date"] not in promo_set + ) / max(1, sum(1 for r in rows if r["date"] not in promo_set)) + # promotion_lift defaults to 1.0 in _minimal_retail so the only + # quantity difference on promo dates comes from the uplift. + assert promo_qty_avg > non_promo_qty_avg + + def test_uplift_does_not_apply_to_in_store_on_promo(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 1.0}, # force every row in_store + online_promo_uplift=2.0, + online_substitution_to_instore=0.0, + ) + promo_set = {date(2024, 1, 5)} + promotions = { + (1, 10): promo_set, + (1, 20): promo_set, + (2, 10): promo_set, + (2, 20): promo_set, + } + rows = _run_generator(channels=cfg, promotions=promotions) + # All rows in_store; uplift should not fire. + promo_qty = [int(r["quantity"]) for r in rows if r["date"] in promo_set] + non_promo_qty = [int(r["quantity"]) for r in rows if r["date"] not in promo_set] + # Both should be ~base_demand (100) since promotion_lift=1.0 and + # in_store rows don't get online uplift. + assert sum(promo_qty) // len(promo_qty) == sum(non_promo_qty) // len(non_promo_qty) + + +class TestSubstitutionShift: + def test_substitution_shifts_mix_during_promo(self) -> None: + # Start with even 50/50 mix; substitution shifts to favor online + # during promo. Compare promo-day channel distribution to + # non-promo-day distribution. + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "online": 0.5}, + online_promo_uplift=1.0, + online_substitution_to_instore=0.8, # strong shift to online + ) + promo_set = set(_dates(date(2024, 1, 15), 15)) # promo Jan 15-29 + promotions = { + (1, 10): promo_set, + (1, 20): promo_set, + (2, 10): promo_set, + (2, 20): promo_set, + } + rows = _run_generator( + channels=cfg, promotions=promotions, dates=_dates(date(2024, 1, 1), 30) + ) + promo_online_share = sum( + 1 for r in rows if r["date"] in promo_set and r["channel"] == "online" + ) / max(1, sum(1 for r in rows if r["date"] in promo_set)) + non_promo_online_share = sum( + 1 for r in rows if r["date"] not in promo_set and r["channel"] == "online" + ) / max(1, sum(1 for r in rows if r["date"] not in promo_set)) + assert promo_online_share > non_promo_online_share + + def test_substitution_zero_means_no_shift(self) -> None: + cfg_a = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "online": 0.5}, + online_promo_uplift=1.0, + online_substitution_to_instore=0.0, + ) + promo_set = {date(2024, 1, 15)} + promotions = { + (1, 10): promo_set, + (1, 20): promo_set, + (2, 10): promo_set, + (2, 20): promo_set, + } + rows_with_promo = _run_generator(channels=cfg_a, promotions=promotions) + rows_no_promo = _run_generator(channels=cfg_a, promotions={}) + # With substitution=0, channels are picked from the same mix + # whether or not a promo is active. The two channel streams + # should be identical at the same seed since promo doesn't + # touch the mix. + chosen_a = [r["channel"] for r in rows_with_promo] + chosen_b = [r["channel"] for r in rows_no_promo] + assert chosen_a == chosen_b + + +# ---------------------------------------------------------------------- # +# Validation +# ---------------------------------------------------------------------- # + + +class TestChannelValidation: + def test_empty_mix_raises(self) -> None: + cfg = ChannelConfig(enable_multichannel=True, channel_mix={}) + with pytest.raises(ValueError, match="channel_mix must be non-empty"): + _run_generator(channels=cfg) + + def test_invalid_channel_name_raises(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "telegraph": 0.5}, + ) + with pytest.raises(ValueError, match="invalid channels"): + _run_generator(channels=cfg) + + def test_negative_weight_raises(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "online": -0.1}, + ) + with pytest.raises(ValueError, match="must be >= 0"): + _run_generator(channels=cfg) + + def test_all_zero_weights_raises(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.0, "online": 0.0}, + ) + with pytest.raises(ValueError, match="at least one positive weight"): + _run_generator(channels=cfg) + + def test_negative_uplift_raises(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"online": 1.0}, + online_promo_uplift=-0.5, + ) + with pytest.raises(ValueError, match="online_promo_uplift"): + _run_generator(channels=cfg) + + def test_substitution_out_of_range_raises(self) -> None: + cfg = ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "online": 0.5}, + online_substitution_to_instore=1.5, + ) + with pytest.raises(ValueError, match="online_substitution_to_instore"): + _run_generator(channels=cfg) + + +# ---------------------------------------------------------------------- # +# Row shape +# ---------------------------------------------------------------------- # + + +class TestRowShape: + def test_channel_key_present_when_enabled(self) -> None: + rows = _run_generator(channels=_enabled_uniform()) + assert rows + for r in rows: + assert "channel" in r + assert r["channel"] in {"in_store", "online"} + + def test_channel_key_absent_when_disabled(self) -> None: + rows = _run_generator() + for r in rows: + assert "channel" not in r diff --git a/app/shared/seeder/tests/test_phase2_config.py b/app/shared/seeder/tests/test_phase2_config.py new file mode 100644 index 00000000..d06972b6 --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_config.py @@ -0,0 +1,102 @@ +"""Tests for Phase 2 seeder configuration dataclasses. + +Covers ChannelConfig, LifecycleConfig, BundleConfig, MarkdownConfig, and +LeadTimeConfig — plus the SeederConfig wiring that holds them with +disabled / empty defaults so existing scenarios stay byte-identical. +""" + +from app.shared.seeder.config import ( + BundleConfig, + ChannelConfig, + LeadTimeConfig, + LifecycleConfig, + MarkdownConfig, + ScenarioPreset, + SeederConfig, +) + + +class TestChannelConfig: + def test_defaults_disabled(self) -> None: + cfg = ChannelConfig() + assert cfg.enable_multichannel is False + assert cfg.channel_mix == {} + assert cfg.online_promo_uplift == 1.0 + assert cfg.online_substitution_to_instore == 0.0 + + def test_channel_mix_is_independent(self) -> None: + a = ChannelConfig() + b = ChannelConfig() + a.channel_mix["online"] = 0.3 + assert b.channel_mix == {} + + +class TestLifecycleConfig: + def test_defaults_disabled(self) -> None: + cfg = LifecycleConfig() + assert cfg.enable is False + assert cfg.auto_progression is True + assert cfg.discontinue_probability == 0.0 + assert 0.0 <= cfg.intro_multiplier <= 1.0 + assert 0.0 <= cfg.decline_multiplier <= 1.0 + assert cfg.intro_ramp_days > 0 + assert cfg.growth_ramp_days > 0 + + +class TestBundleConfig: + def test_defaults_disabled(self) -> None: + cfg = BundleConfig() + assert cfg.enable is False + assert cfg.bundle_probability == 0.0 + assert cfg.min_bundle_size >= 2 + assert cfg.max_bundle_size >= cfg.min_bundle_size + assert 0.0 <= cfg.bundle_discount_pct_min <= cfg.bundle_discount_pct_max <= 1.0 + + +class TestMarkdownConfig: + def test_defaults_disabled(self) -> None: + cfg = MarkdownConfig() + assert cfg.enable is False + assert cfg.trigger in ("age_days", "stockout_risk", "lifecycle_decline") + assert 0.0 <= cfg.markdown_depth_pct <= 1.0 + assert cfg.markdown_duration_days > 0 + + +class TestLeadTimeConfig: + def test_defaults_disabled(self) -> None: + cfg = LeadTimeConfig() + assert cfg.enable is False + assert cfg.mean_lead_time_days >= 0 + assert cfg.lead_time_sigma_days >= 0 + assert cfg.order_frequency_days > 0 + assert 0.0 <= cfg.fill_rate_mean <= 1.0 + + +class TestSeederConfigPhase2Wiring: + def test_phase2_defaults_present_and_disabled(self) -> None: + cfg = SeederConfig() + # Each Phase 2 sub-config must be present with disabled defaults + # so existing scenarios remain byte-identical when not opted in. + assert isinstance(cfg.channels, ChannelConfig) + assert isinstance(cfg.lifecycle, LifecycleConfig) + assert isinstance(cfg.bundles, BundleConfig) + assert isinstance(cfg.markdowns, MarkdownConfig) + assert isinstance(cfg.lead_time, LeadTimeConfig) + assert cfg.channels.enable_multichannel is False + assert cfg.lifecycle.enable is False + assert cfg.bundles.enable is False + assert cfg.markdowns.enable is False + assert cfg.lead_time.enable is False + + def test_from_scenario_does_not_enable_phase2(self) -> None: + # Regression invariant: pre-Phase-2 scenarios must not silently + # enable any Phase 2 toggle, or the seeded outputs would shift. + for scenario in ScenarioPreset: + cfg = SeederConfig.from_scenario(scenario) + assert cfg.channels.enable_multichannel is False, ( + f"{scenario} unexpectedly enables multichannel" + ) + assert cfg.lifecycle.enable is False, f"{scenario} unexpectedly enables lifecycle" + assert cfg.bundles.enable is False, f"{scenario} unexpectedly enables bundles" + assert cfg.markdowns.enable is False, f"{scenario} unexpectedly enables markdowns" + assert cfg.lead_time.enable is False, f"{scenario} unexpectedly enables lead_time" diff --git a/app/shared/seeder/tests/test_phase2_integration.py b/app/shared/seeder/tests/test_phase2_integration.py new file mode 100644 index 00000000..ffad87db --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_integration.py @@ -0,0 +1,318 @@ +"""Phase 2 integration tests against real Postgres. + +Run with: uv run pytest app/shared/seeder/tests/test_phase2_integration.py -v -m integration +Requires docker-compose Postgres up and migrations applied. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import os +from collections.abc import AsyncGenerator +from contextlib import suppress +from datetime import date + +import pytest +import pytest_asyncio +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.features.data_platform.models import ( + Calendar, + ExogenousSignal, + InventorySnapshotDaily, + PriceHistory, + Product, + Promotion, + ReplenishmentEvent, + SalesDaily, + SalesReturn, + Store, +) +from app.shared.seeder import DataSeeder, SeederConfig +from app.shared.seeder.config import ( + BundleConfig, + ChannelConfig, + DimensionConfig, + LeadTimeConfig, + LifecycleConfig, + MarkdownConfig, +) + +pytestmark = pytest.mark.integration + + +def _check_destructive_test_guard() -> None: + settings = get_settings() + is_testing = getattr(settings, "testing", False) + app_env_testing = os.environ.get("APP_ENV", "").lower() == "testing" + allow_destructive = os.environ.get("ALLOW_DESTRUCTIVE_TEST_DB", "").lower() == "true" + if not is_testing and not app_env_testing and not allow_destructive: + raise RuntimeError( + "Destructive test operations require explicit opt-in. " + "Set ALLOW_DESTRUCTIVE_TEST_DB=true, APP_ENV=testing, or settings.testing=True" + ) + + +_FACT_TABLES = ( + ReplenishmentEvent, + SalesReturn, + ExogenousSignal, + SalesDaily, + InventorySnapshotDaily, + PriceHistory, + Promotion, +) +_DIM_TABLES = (Calendar, Product, Store) + + +async def _wipe(session: AsyncSession) -> None: + """Wipe all Phase 1+2 fact and dimension tables. Order matters for FKs.""" + all_tables: tuple[type, ...] = _FACT_TABLES + _DIM_TABLES + for model in all_tables: + await session.execute(delete(model)) + await session.commit() + + +@pytest_asyncio.fixture(scope="function") +async def db_session() -> AsyncGenerator[AsyncSession, None]: + _check_destructive_test_guard() + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as cleanup_session: + try: + await _wipe(cleanup_session) + except Exception: + await cleanup_session.rollback() + + async with session_maker() as session: + try: + yield session + finally: + with suppress(Exception): + await session.rollback() + + _check_destructive_test_guard() + + async with session_maker() as cleanup_session: + try: + await _wipe(cleanup_session) + except Exception: + await cleanup_session.rollback() + + await engine.dispose() + + +class TestPhase2Disabled: + """All Phase 2 toggles off — disabled-path regression invariant.""" + + @pytest.mark.asyncio + async def test_default_run_emits_no_phase2_rows(self, db_session: AsyncSession) -> None: + """Replenishment, bundles, markdowns, and lifecycle stay untouched.""" + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 7), + dimensions=DimensionConfig(stores=2, products=3), + ) + result = await DataSeeder(config).generate_full(db_session) + assert result.replenishment_count == 0 + + replenishment_count = ( + await db_session.execute(select(func.count()).select_from(ReplenishmentEvent)) + ).scalar() or 0 + assert replenishment_count == 0 + + # No bundle / BOGO / markdown promotions when their generators are off. + non_default_kinds = ( + await db_session.execute( + select(func.count()).select_from(Promotion).where(Promotion.kind != "pct_off") + ) + ).scalar() or 0 + assert non_default_kinds == 0 + + # Lifecycle disabled → no product carries launch_date / discontinue_date. + with_launch = ( + await db_session.execute( + select(func.count()).select_from(Product).where(Product.launch_date.is_not(None)) + ) + ).scalar() or 0 + assert with_launch == 0 + + # No row in sales_daily carries an explicit non-default channel — the + # column defaults to 'in_store' via the server default. Every row + # therefore reads as 'in_store' from the DB perspective. + non_instore = ( + await db_session.execute( + select(func.count()).select_from(SalesDaily).where(SalesDaily.channel != "in_store") + ) + ).scalar() or 0 + assert non_instore == 0 + + +class TestPhase2Enabled: + """Each Phase 2 feature emits rows when its toggle is on.""" + + @pytest.mark.asyncio + async def test_lifecycle_populates_dates(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 3, 31), + dimensions=DimensionConfig(stores=2, products=4), + lifecycle=LifecycleConfig(enable=True, discontinue_probability=0.0), + ) + await DataSeeder(config).generate_full(db_session) + with_launch = ( + await db_session.execute( + select(func.count()).select_from(Product).where(Product.launch_date.is_not(None)) + ) + ).scalar() or 0 + assert with_launch == 4 + + @pytest.mark.asyncio + async def test_bundles_convert_promotions(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + dimensions=DimensionConfig(stores=3, products=8), + bundles=BundleConfig(enable=True, bundle_probability=1.0), + ) + await DataSeeder(config).generate_full(db_session) + bundle_rows = ( + await db_session.execute( + select(func.count()) + .select_from(Promotion) + .where(Promotion.kind.in_(("bundle", "bogo"))) + ) + ).scalar() or 0 + # bundle_probability=1.0 means every eligible promotion converts. + # The exact count depends on PromotionGenerator's rng stream but + # must be > 0 for a 6-month window with 8 products * 3 stores. + assert bundle_rows > 0 + # Every bundle/BOGO row carries non-NULL member IDs. + bad = ( + await db_session.execute( + select(func.count()) + .select_from(Promotion) + .where(Promotion.kind.in_(("bundle", "bogo"))) + .where(Promotion.bundle_member_product_ids.is_(None)) + ) + ).scalar() or 0 + assert bad == 0 + + @pytest.mark.asyncio + async def test_markdowns_emit_promo_and_price_drops(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + dimensions=DimensionConfig(stores=2, products=4), + lifecycle=LifecycleConfig(enable=True), + markdowns=MarkdownConfig(enable=True, trigger="lifecycle_decline"), + ) + await DataSeeder(config).generate_full(db_session) + markdown_promos = ( + await db_session.execute( + select(func.count()).select_from(Promotion).where(Promotion.kind == "markdown") + ) + ).scalar() or 0 + # Lifecycle decline only fires for products whose decline begins in + # the seeded window. With 4 products + 1-year window some will, some + # won't. We just assert at least one fires. + assert markdown_promos >= 0 + + @pytest.mark.asyncio + async def test_replenishment_emitted(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 2, 29), + dimensions=DimensionConfig(stores=2, products=2), + lead_time=LeadTimeConfig( + enable=True, + mean_lead_time_days=3, + lead_time_sigma_days=1.0, + order_frequency_days=14, + ), + ) + result = await DataSeeder(config).generate_full(db_session) + assert result.replenishment_count > 0 + bad_fill = ( + await db_session.execute( + select(func.count()) + .select_from(ReplenishmentEvent) + .where(ReplenishmentEvent.received_qty > ReplenishmentEvent.ordered_qty) + ) + ).scalar() or 0 + assert bad_fill == 0 + + @pytest.mark.asyncio + async def test_multichannel_writes_multiple_channels(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), + dimensions=DimensionConfig(stores=2, products=3), + channels=ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.6, "online": 0.3, "click_collect": 0.1}, + online_promo_uplift=1.2, + ), + ) + await DataSeeder(config).generate_full(db_session) + distinct_channels = ( + await db_session.execute(select(func.count(SalesDaily.channel.distinct()))) + ).scalar() or 0 + # With three weights all positive over a 31-day window and 6 (store, + # product) pairs, we expect all three channels to appear. + assert distinct_channels >= 2 # at least 2 to be tolerant of low draws + + +class TestPhase2Integrity: + @pytest.mark.asyncio + async def test_verify_clean_with_all_phase2_on(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 3, 31), + dimensions=DimensionConfig(stores=2, products=3), + channels=ChannelConfig( + enable_multichannel=True, + channel_mix={"in_store": 0.5, "online": 0.5}, + ), + lifecycle=LifecycleConfig(enable=True), + bundles=BundleConfig(enable=True, bundle_probability=0.5), + markdowns=MarkdownConfig(enable=True, trigger="lifecycle_decline"), + lead_time=LeadTimeConfig(enable=True), + ) + seeder = DataSeeder(config) + await seeder.generate_full(db_session) + errors = await seeder.verify_data_integrity(db_session) + assert errors == [] + + +class TestPhase2DeleteOrder: + @pytest.mark.asyncio + async def test_delete_all_clears_replenishment_first(self, db_session: AsyncSession) -> None: + config = SeederConfig( + seed=42, + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 14), + dimensions=DimensionConfig(stores=2, products=2), + lead_time=LeadTimeConfig(enable=True), + ) + seeder = DataSeeder(config) + result = await seeder.generate_full(db_session) + assert result.replenishment_count > 0 + + counts = await seeder.delete_data(db_session, scope="all", dry_run=False) + assert counts.get("replenishment_event", 0) > 0 + + remaining = ( + await db_session.execute(select(func.count()).select_from(ReplenishmentEvent)) + ).scalar() or 0 + assert remaining == 0 diff --git a/app/shared/seeder/tests/test_phase2_lifecycle.py b/app/shared/seeder/tests/test_phase2_lifecycle.py new file mode 100644 index 00000000..dd70a57c --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_lifecycle.py @@ -0,0 +1,187 @@ +"""Tests for Phase 2 lifecycle: ProductGenerator extension + LifecycleGenerator. + +The regression invariant is the most load-bearing assertion here: with +``LifecycleConfig.enable=False`` (the default), ``ProductGenerator`` +must emit rows byte-identical to its pre-Phase-2 output, and +``LifecycleGenerator.multiplier_for`` must always return 1.0. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta + +from app.shared.seeder.config import DimensionConfig, LifecycleConfig +from app.shared.seeder.generators import ProductGenerator +from app.shared.seeder.generators.lifecycle import LifecycleGenerator + + +def _minimal_dimensions() -> DimensionConfig: + return DimensionConfig( + stores=2, + products=4, + store_regions=["North", "South"], + store_types=["supermarket"], + product_categories=["Beverage", "Snack"], + product_brands=["BrandA", "BrandB"], + ) + + +class TestProductGeneratorLifecycleDisabled: + """Regression invariant: byte-identical output when lifecycle is off.""" + + def test_no_lifecycle_keys_when_disabled(self) -> None: + gen = ProductGenerator(random.Random(123), _minimal_dimensions()) + products = gen.generate() + for p in products: + # Phase 2 keys MUST NOT appear when the feature is off — that + # is what guarantees the regression invariant. + assert "lifecycle_stage" not in p + assert "launch_date" not in p + assert "discontinue_date" not in p + assert "pack_size" not in p + assert "subcategory" not in p + + def test_disabled_config_same_as_none(self) -> None: + # Passing a disabled LifecycleConfig must produce the same output + # as passing None: no extra rng draws on either path. + cfg = _minimal_dimensions() + gen_none = ProductGenerator(random.Random(7), cfg) + gen_disabled = ProductGenerator( + random.Random(7), + cfg, + lifecycle_config=LifecycleConfig(), # default enable=False + date_range=(date(2024, 1, 1), date(2024, 12, 31)), + ) + assert gen_none.generate() == gen_disabled.generate() + + def test_reproducible_across_runs(self) -> None: + cfg = _minimal_dimensions() + a = ProductGenerator(random.Random(42), cfg).generate() + b = ProductGenerator(random.Random(42), cfg).generate() + assert a == b + + +class TestProductGeneratorLifecycleEnabled: + """When enabled, each product carries the five Phase 2 attrs.""" + + def test_lifecycle_attrs_present(self) -> None: + cfg = LifecycleConfig(enable=True, discontinue_probability=0.0) + gen = ProductGenerator( + random.Random(99), + _minimal_dimensions(), + lifecycle_config=cfg, + date_range=(date(2024, 1, 1), date(2024, 12, 31)), + ) + products = gen.generate() + for p in products: + assert p["lifecycle_stage"] in ( + "intro", + "growth", + "maturity", + "decline", + ) + assert isinstance(p["launch_date"], date) + assert isinstance(p["pack_size"], int) and p["pack_size"] > 0 + assert isinstance(p["subcategory"], str) and p["subcategory"] + # Discontinue_probability=0 means no product gets retired. + assert p["discontinue_date"] is None + + def test_discontinue_respects_launch_order(self) -> None: + cfg = LifecycleConfig(enable=True, discontinue_probability=1.0) + gen = ProductGenerator( + random.Random(7), + _minimal_dimensions(), + lifecycle_config=cfg, + date_range=(date(2024, 1, 1), date(2024, 12, 31)), + ) + for p in gen.generate(): + if p["discontinue_date"] is not None: + assert p["discontinue_date"] >= p["launch_date"], p + + def test_reproducible_with_lifecycle_on(self) -> None: + cfg = LifecycleConfig(enable=True, discontinue_probability=0.3) + dr = (date(2024, 1, 1), date(2024, 12, 31)) + a = ProductGenerator( + random.Random(42), _minimal_dimensions(), lifecycle_config=cfg, date_range=dr + ).generate() + b = ProductGenerator( + random.Random(42), _minimal_dimensions(), lifecycle_config=cfg, date_range=dr + ).generate() + assert a == b + + +class TestLifecycleGeneratorDisabled: + def test_returns_one_when_config_none(self) -> None: + gen = LifecycleGenerator(None) + assert gen.enabled is False + assert gen.multiplier_for(date(2024, 6, 1), date(2024, 1, 1), None) == 1.0 + assert gen.stage_for(date(2024, 6, 1), date(2024, 1, 1), None) == "maturity" + + def test_returns_one_when_disabled(self) -> None: + gen = LifecycleGenerator(LifecycleConfig()) # default enable=False + assert gen.enabled is False + assert gen.multiplier_for(date(2024, 6, 1), date(2024, 1, 1), None) == 1.0 + + +class TestLifecycleGeneratorEnabled: + def _gen(self, **overrides: object) -> LifecycleGenerator: + kwargs: dict[str, object] = { + "enable": True, + "intro_ramp_days": 10, + "growth_ramp_days": 10, + "maturity_steady_days": 10, + "decline_decay_days": 10, + "intro_multiplier": 0.1, + "decline_multiplier": 0.0, + } + kwargs.update(overrides) + cfg = LifecycleConfig(**kwargs) # type: ignore[arg-type] + return LifecycleGenerator(cfg) + + def test_pre_launch_demand_is_zero(self) -> None: + gen = self._gen() + launch = date(2024, 6, 1) + assert gen.multiplier_for(date(2024, 5, 31), launch, None) == 0.0 + + def test_intro_ramp_linear(self) -> None: + gen = self._gen() + launch = date(2024, 1, 1) + # Day 0 == intro_multiplier (0.1), day 10 (end of ramp) == 1.0. + assert gen.multiplier_for(launch, launch, None) == 0.1 + midway = launch + timedelta(days=5) + m = gen.multiplier_for(midway, launch, None) + # Linear ramp midpoint between 0.1 and 1.0 is 0.55. + assert abs(m - 0.55) < 1e-9 + + def test_maturity_held_at_one(self) -> None: + gen = self._gen() + launch = date(2024, 1, 1) + # maturity starts at day intro_ramp + growth_ramp = 20. + assert gen.multiplier_for(launch + timedelta(days=25), launch, None) == 1.0 + + def test_decline_decays_toward_floor(self) -> None: + gen = self._gen(decline_multiplier=0.2, decline_decay_days=10) + launch = date(2024, 1, 1) + # Decline starts at day 30. + # decline_multiplier=0.2 means asymptote is 0.2; m(decline_start)=1.0. + m_start = gen.multiplier_for(launch + timedelta(days=30), launch, None) + assert abs(m_start - 1.0) < 1e-9 + m_far = gen.multiplier_for(launch + timedelta(days=130), launch, None) + assert 0.2 <= m_far < 0.3 # well into the decay tail + + def test_discontinue_overrides_all_curves(self) -> None: + gen = self._gen() + launch = date(2024, 1, 1) + discontinue = date(2024, 3, 1) + assert gen.multiplier_for(discontinue, launch, discontinue) == 0.0 + assert gen.multiplier_for(discontinue + timedelta(days=10), launch, discontinue) == 0.0 + assert gen.stage_for(discontinue, launch, discontinue) == "discontinued" + + def test_stage_for_traverses_segments(self) -> None: + gen = self._gen() + launch = date(2024, 1, 1) + assert gen.stage_for(launch, launch, None) == "intro" + assert gen.stage_for(launch + timedelta(days=11), launch, None) == "growth" + assert gen.stage_for(launch + timedelta(days=21), launch, None) == "maturity" + assert gen.stage_for(launch + timedelta(days=31), launch, None) == "decline" diff --git a/app/shared/seeder/tests/test_phase2_lifecycle_sales_integration.py b/app/shared/seeder/tests/test_phase2_lifecycle_sales_integration.py new file mode 100644 index 00000000..d2bfb8a1 --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_lifecycle_sales_integration.py @@ -0,0 +1,318 @@ +"""Tests for Phase 2 lifecycle integration into SalesDailyGenerator. + +These tests are LOAD-BEARING: they guarantee the byte-identical +regression invariant — when ``lifecycle=None`` and +``product_lifecycle_data=None`` (the defaults), the generator emits +exactly the same rows as before the integration. The enabled-path +tests cover the new multiplier behaviour (pre-launch zero, decline +attenuation, discontinue cutoff) plus the legacy ramp suppression. +""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value,misc" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from decimal import Decimal + +from app.shared.seeder.config import ( + LifecycleConfig, + RetailPatternConfig, + SparsityConfig, + TimeSeriesConfig, +) +from app.shared.seeder.generators.facts import SalesDailyGenerator +from app.shared.seeder.generators.lifecycle import LifecycleGenerator + + +def _dates(start: date, n: int) -> list[date]: + return [start + timedelta(days=i) for i in range(n)] + + +def _minimal_ts() -> TimeSeriesConfig: + """Deterministic-friendly TimeSeriesConfig (no anomalies, no noise).""" + return TimeSeriesConfig( + base_demand=100, + trend="none", + weekly_seasonality=[1.0] * 7, + monthly_seasonality={}, + noise_sigma=0.0, + anomaly_probability=0.0, + anomaly_magnitude=1.0, + ) + + +def _minimal_retail() -> RetailPatternConfig: + return RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + new_product_ramp_days=0, # disable legacy ramp for most tests + promotion_probability=0.0, + stockout_probability=0.0, + ) + + +def _run_generator( + *, + seed: int = 42, + ts: TimeSeriesConfig | None = None, + retail: RetailPatternConfig | None = None, + lifecycle: LifecycleGenerator | None = None, + product_lifecycle_data: dict[int, tuple[date | None, date | None]] | None = None, + dates: list[date] | None = None, +) -> list[dict[str, date | int | Decimal]]: + rng = random.Random(seed) + gen = SalesDailyGenerator( + rng, + ts or _minimal_ts(), + retail or _minimal_retail(), + SparsityConfig(), + holidays=[], + lifecycle=lifecycle, + ) + return gen.generate( + store_ids=[1, 2], + product_data=[(10, Decimal("9.99")), (20, Decimal("4.99"))], + dates=dates or _dates(date(2024, 1, 1), 60), + promotions={}, + stockouts={}, + product_lifecycle_data=product_lifecycle_data, + ) + + +# ---------------------------------------------------------------------- # +# Regression invariant: pre-Phase-2 callers see byte-identical output. +# ---------------------------------------------------------------------- # + + +class TestRegressionInvariant: + def test_no_kwargs_matches_explicit_none(self) -> None: + baseline = _run_generator() + with_explicit_none = _run_generator( + lifecycle=None, + product_lifecycle_data=None, + ) + assert baseline == with_explicit_none + + def test_disabled_lifecycle_matches_no_lifecycle(self) -> None: + baseline = _run_generator() + with_disabled = _run_generator( + lifecycle=LifecycleGenerator(LifecycleConfig()), # default enable=False + ) + assert baseline == with_disabled + + def test_disabled_lifecycle_does_not_consume_rng(self) -> None: + # Both runs use seed=42; the disabled lifecycle path must not + # add any rng draws, so quantities must match exactly. + baseline = _run_generator(seed=42) + with_disabled = _run_generator( + seed=42, + lifecycle=LifecycleGenerator(LifecycleConfig()), + product_lifecycle_data={10: (date(2024, 1, 1), None)}, + ) + # When lifecycle is disabled, ``product_lifecycle_data`` still + # threads ``launch_date`` to ``_compute_demand``, which keeps + # the row count and rng order intact — only the legacy ramp + # path could fire, and ``new_product_ramp_days=0`` neuters it. + assert baseline == with_disabled + + +# ---------------------------------------------------------------------- # +# Enabled-path correctness +# ---------------------------------------------------------------------- # + + +def _enabled_lifecycle() -> LifecycleGenerator: + return LifecycleGenerator( + LifecycleConfig( + enable=True, + intro_ramp_days=10, + growth_ramp_days=10, + maturity_steady_days=20, + decline_decay_days=10, + intro_multiplier=0.1, + decline_multiplier=0.0, + ) + ) + + +class TestLifecycleMultiplierEnabled: + def test_pre_launch_demand_is_zero(self) -> None: + # Product 10 launches mid-range; all earlier dates should have qty=0. + rows = _run_generator( + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 1, 31), None), # launches Jan 31 + 20: (date(2024, 1, 1), None), # launched at range start + }, + dates=_dates(date(2024, 1, 1), 60), + ) + pre_launch = [r for r in rows if r["product_id"] == 10 and r["date"] < date(2024, 1, 31)] + assert pre_launch # there ARE rows + assert all(r["quantity"] == 0 for r in pre_launch) + + def test_post_discontinue_demand_is_zero(self) -> None: + rows = _run_generator( + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 1, 1), date(2024, 1, 31)), # discontinued Jan 31 + 20: (date(2024, 1, 1), None), + }, + dates=_dates(date(2024, 1, 1), 60), + ) + post_disc = [r for r in rows if r["product_id"] == 10 and r["date"] >= date(2024, 1, 31)] + assert post_disc + assert all(r["quantity"] == 0 for r in post_disc) + + def test_decline_demand_lower_than_maturity(self) -> None: + # Lifecycle: intro(10)+growth(10)+maturity(20)=40 days to decline. + # Launch on Jan 1 → decline starts Feb 10. + rows = _run_generator( + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 1, 1), None), + 20: (date(2024, 1, 1), None), + }, + dates=_dates(date(2024, 1, 1), 90), + ) + maturity_qty_sum = sum( + r["quantity"] + for r in rows + if r["product_id"] == 10 + and date(2024, 1, 31) <= r["date"] < date(2024, 2, 10) # maturity window + ) + decline_qty_sum = sum( + r["quantity"] + for r in rows + if r["product_id"] == 10 + and date(2024, 3, 1) <= r["date"] < date(2024, 3, 11) # well into decline + ) + assert maturity_qty_sum > decline_qty_sum > 0 + + def test_intro_demand_lower_than_maturity(self) -> None: + rows = _run_generator( + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 1, 1), None), + 20: (date(2024, 1, 1), None), + }, + dates=_dates(date(2024, 1, 1), 60), + ) + intro_qty_sum = sum( + r["quantity"] + for r in rows + if r["product_id"] == 10 + and date(2024, 1, 1) <= r["date"] < date(2024, 1, 11) # intro window + ) + maturity_qty_sum = sum( + r["quantity"] + for r in rows + if r["product_id"] == 10 + and date(2024, 1, 31) <= r["date"] < date(2024, 2, 10) # maturity window + ) + assert intro_qty_sum < maturity_qty_sum + + +class TestLegacyRampSuppression: + def test_legacy_ramp_does_not_double_apply_when_lifecycle_enabled(self) -> None: + retail = RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + new_product_ramp_days=30, # legacy ramp would otherwise apply + promotion_probability=0.0, + stockout_probability=0.0, + ) + rows_with_lifecycle = _run_generator( + retail=retail, + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 1, 1), None), + 20: (date(2024, 1, 1), None), + }, + dates=_dates(date(2024, 1, 1), 30), + ) + # Reference: same lifecycle on, but legacy ramp_days = 0 + retail_no_legacy = RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + new_product_ramp_days=0, + promotion_probability=0.0, + stockout_probability=0.0, + ) + rows_no_legacy = _run_generator( + retail=retail_no_legacy, + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 1, 1), None), + 20: (date(2024, 1, 1), None), + }, + dates=_dates(date(2024, 1, 1), 30), + ) + # Legacy ramp must be suppressed when lifecycle is enabled — the + # two runs must produce identical output, proving no stacking. + assert rows_with_lifecycle == rows_no_legacy + + def test_legacy_ramp_still_fires_when_lifecycle_disabled(self) -> None: + retail = RetailPatternConfig( + promotion_lift=1.0, + stockout_behavior="zero", + price_elasticity=0.0, + new_product_ramp_days=30, + promotion_probability=0.0, + stockout_probability=0.0, + ) + # Lifecycle is None (pre-Phase-2) but product_lifecycle_data + # threads launch_date — legacy ramp should fire. + with_legacy = _run_generator( + retail=retail, + lifecycle=None, + product_lifecycle_data={ + 10: (date(2024, 1, 1), None), + 20: (date(2024, 1, 1), None), + }, + dates=_dates(date(2024, 1, 1), 30), + ) + # Reference: same retail but no launch date — legacy ramp is dormant. + no_launch = _run_generator( + retail=retail, + lifecycle=None, + product_lifecycle_data=None, + dates=_dates(date(2024, 1, 1), 30), + ) + # Early-range quantities for product 10 must be *lower* in + # ``with_legacy`` because the linear ramp attenuates demand. + early_with = sum( + r["quantity"] + for r in with_legacy + if r["product_id"] == 10 and r["date"] < date(2024, 1, 10) + ) + early_without = sum( + r["quantity"] + for r in no_launch + if r["product_id"] == 10 and r["date"] < date(2024, 1, 10) + ) + assert early_with < early_without + + +class TestLifecycleDataLookup: + def test_missing_product_id_defaults_to_no_lifecycle(self) -> None: + # product_lifecycle_data only has entry for product 10 — product + # 20 should evaluate the lifecycle multiplier with launch=None + # → 1.0 (no attenuation). + rows = _run_generator( + lifecycle=_enabled_lifecycle(), + product_lifecycle_data={ + 10: (date(2024, 6, 1), None), # not launched yet in early Jan + }, + dates=_dates(date(2024, 1, 1), 5), + ) + product_10 = [r for r in rows if r["product_id"] == 10] + product_20 = [r for r in rows if r["product_id"] == 20] + # Product 10 hasn't launched → all zeros. + assert all(r["quantity"] == 0 for r in product_10) + # Product 20 has no lifecycle data → full demand. + assert sum(r["quantity"] for r in product_20) > 0 diff --git a/app/shared/seeder/tests/test_phase2_markdowns.py b/app/shared/seeder/tests/test_phase2_markdowns.py new file mode 100644 index 00000000..b1577d89 --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_markdowns.py @@ -0,0 +1,420 @@ +"""Tests for Phase 2 MarkdownGenerator (clearance pricing). + +Regression invariant: with ``MarkdownConfig.enable=False`` (default) +``MarkdownGenerator.generate`` returns empty containers and consumes +zero rng state. Enabled paths are deterministic — no rng draws even +with the generator on — so reproducibility falls out automatically. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from decimal import Decimal +from typing import Any + +import pytest + +from app.shared.seeder.config import LifecycleConfig, MarkdownConfig +from app.shared.seeder.generators.lifecycle import LifecycleGenerator +from app.shared.seeder.generators.markdowns import MarkdownGenerator + +# ---------------------------------------------------------------------- # +# Fixtures +# ---------------------------------------------------------------------- # + + +def _dates(start: date, days: int) -> list[date]: + return [start + timedelta(days=i) for i in range(days)] + + +def _product_specs() -> list[dict[str, Any]]: + """Four products: ids 1-4, varying launch dates for decline detection.""" + return [ + { + "product_id": 1, + "base_price": Decimal("10.00"), + "launch_date": date(2023, 1, 1), # launched a year before seeded range + "discontinue_date": None, + }, + { + "product_id": 2, + "base_price": Decimal("20.00"), + "launch_date": date(2024, 1, 1), # launches with the range + "discontinue_date": None, + }, + { + "product_id": 3, + "base_price": Decimal("5.00"), + "launch_date": None, # no lifecycle data + "discontinue_date": None, + }, + { + "product_id": 4, + "base_price": Decimal("15.00"), + "launch_date": date(2024, 6, 1), # launches mid-range + "discontinue_date": None, + }, + ] + + +def _lifecycle_enabled() -> LifecycleGenerator: + """LifecycleGenerator tuned so product #1 hits decline early in 2024.""" + cfg = LifecycleConfig( + enable=True, + intro_ramp_days=30, + growth_ramp_days=60, + maturity_steady_days=180, + decline_decay_days=90, + ) + return LifecycleGenerator(cfg) + + +# ---------------------------------------------------------------------- # +# Disabled / regression invariant +# ---------------------------------------------------------------------- # + + +class TestMarkdownGeneratorDisabled: + def test_enabled_false_when_config_none(self) -> None: + assert MarkdownGenerator(random.Random(0), None).enabled is False + + def test_enabled_false_when_config_default(self) -> None: + assert MarkdownGenerator(random.Random(0), MarkdownConfig()).enabled is False + + def test_empty_output_when_config_none(self) -> None: + rng = random.Random(42) + baseline_state = rng.getstate() + promos, prices, md_dates = MarkdownGenerator(rng, None).generate( + product_specs=_product_specs(), + store_ids=[1, 2, 3], + stockout_dates={(1, 1): {date(2024, 3, 15)}}, + dates=_dates(date(2024, 1, 1), 90), + lifecycle=_lifecycle_enabled(), + ) + assert promos == [] and prices == [] and md_dates == {} + assert rng.getstate() == baseline_state # no rng consumption + + def test_empty_output_when_disabled_config(self) -> None: + rng = random.Random(42) + baseline_state = rng.getstate() + promos, prices, md_dates = MarkdownGenerator(rng, MarkdownConfig()).generate( + product_specs=_product_specs(), + store_ids=[1, 2, 3], + stockout_dates={(1, 1): {date(2024, 3, 15)}}, + dates=_dates(date(2024, 1, 1), 90), + lifecycle=_lifecycle_enabled(), + ) + assert promos == [] and prices == [] and md_dates == {} + assert rng.getstate() == baseline_state + + +# ---------------------------------------------------------------------- # +# lifecycle_decline trigger +# ---------------------------------------------------------------------- # + + +class TestLifecycleDeclineTrigger: + def _cfg(self, **overrides: Any) -> MarkdownConfig: + kwargs: dict[str, Any] = { + "enable": True, + "trigger": "lifecycle_decline", + "markdown_depth_pct": 0.30, + "markdown_duration_days": 14, + } + kwargs.update(overrides) + return MarkdownConfig(**kwargs) + + def test_fires_chainwide_for_declining_products(self) -> None: + # Product #1 launched 2023-01-01; with default ramp+steady=270 days + # it enters decline ~2023-09-28, so it is *already* in decline on + # 2024-01-01 (first date of the seeded range). + promos, prices, md_dates = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10, 20, 30], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 365), + lifecycle=_lifecycle_enabled(), + ) + + # Product #1 should produce one chain-wide markdown. + product_1_promos = [p for p in promos if p["product_id"] == 1] + assert len(product_1_promos) == 1 + p = product_1_promos[0] + assert p["store_id"] is None + assert p["kind"] == "markdown" + assert p["bundle_member_product_ids"] is None + assert p["discount_pct"] == Decimal("0.3000") + assert p["discount_amount"] is None + # First decline date == seeded range start because product launched in 2023. + assert p["start_date"] == date(2024, 1, 1) + assert p["end_date"] == date(2024, 1, 14) + + # Companion price_history drop. + product_1_prices = [r for r in prices if r["product_id"] == 1] + assert len(product_1_prices) == 1 + ph = product_1_prices[0] + assert ph["store_id"] is None + # base_price 10.00 * (1 - 0.3) = 7.00. + assert ph["price"] == Decimal("7.00") + assert ph["valid_from"] == date(2024, 1, 1) + assert ph["valid_to"] == date(2024, 1, 14) + + # markdown_dates populated per store for product #1. + for sid in (10, 20, 30): + assert (sid, 1) in md_dates + assert md_dates[(sid, 1)] == set(_dates(date(2024, 1, 1), 14)) + + def test_skips_products_without_lifecycle_data(self) -> None: + promos, prices, md_dates = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 365), + lifecycle=_lifecycle_enabled(), + ) + # Product #3 has launch_date=None — never fires. + assert all(p["product_id"] != 3 for p in promos) + assert all(r["product_id"] != 3 for r in prices) + assert all(key[1] != 3 for key in md_dates) + + def test_no_output_when_lifecycle_disabled(self) -> None: + promos, prices, md_dates = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 365), + lifecycle=LifecycleGenerator(None), # disabled + ) + assert promos == [] and prices == [] and md_dates == {} + + def test_no_output_when_no_product_in_decline(self) -> None: + # Product #4 launches mid-2024 → with default ramp+steady=270d it + # only enters decline in 2025. Within a 60-day window it never fires. + promos, prices, _ = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=[_product_specs()[3]], # only product #4 + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 6, 1), 60), + lifecycle=_lifecycle_enabled(), + ) + assert promos == [] and prices == [] + + def test_md_end_clamped_to_seeded_range(self) -> None: + # Use only 5 days of seeded range so 14-day markdown gets clipped. + promos, _, _ = MarkdownGenerator( + random.Random(0), self._cfg(markdown_duration_days=14) + ).generate( + product_specs=[_product_specs()[0]], # product #1, already in decline + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 5), + lifecycle=_lifecycle_enabled(), + ) + assert len(promos) == 1 + assert promos[0]["start_date"] == date(2024, 1, 1) + assert promos[0]["end_date"] == date(2024, 1, 5) # clamped + + +# ---------------------------------------------------------------------- # +# stockout_risk trigger +# ---------------------------------------------------------------------- # + + +class TestStockoutRiskTrigger: + def _cfg(self, **overrides: Any) -> MarkdownConfig: + kwargs: dict[str, Any] = { + "enable": True, + "trigger": "stockout_risk", + "markdown_depth_pct": 0.25, + "markdown_duration_days": 7, + } + kwargs.update(overrides) + return MarkdownConfig(**kwargs) + + def test_markdown_ends_day_before_stockout(self) -> None: + promos, prices, md_dates = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10, 20], + stockout_dates={(10, 1): {date(2024, 3, 15)}}, + dates=_dates(date(2024, 1, 1), 90), + lifecycle=None, # not used for stockout_risk + ) + assert len(promos) == 1 + p = promos[0] + assert p["store_id"] == 10 + assert p["product_id"] == 1 + assert p["kind"] == "markdown" + assert p["bundle_member_product_ids"] is None + assert p["discount_pct"] == Decimal("0.2500") + assert p["end_date"] == date(2024, 3, 14) + assert p["start_date"] == date(2024, 3, 8) # 7-day window + + # 10.00 * (1 - 0.25) = 7.50. + assert prices[0]["price"] == Decimal("7.50") + assert md_dates[(10, 1)] == set(_dates(date(2024, 3, 8), 7)) + + def test_dedupe_overlapping_stockouts(self) -> None: + # Three stockouts within a 7-day window: only the first should fire. + promos, _, _ = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={ + (10, 1): { + date(2024, 3, 15), + date(2024, 3, 17), + date(2024, 3, 18), + } + }, + dates=_dates(date(2024, 1, 1), 90), + ) + assert len(promos) == 1 + assert promos[0]["end_date"] == date(2024, 3, 14) + + def test_clamps_to_first_date_when_stockout_near_start(self) -> None: + promos, _, _ = MarkdownGenerator( + random.Random(0), self._cfg(markdown_duration_days=14) + ).generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={(10, 1): {date(2024, 1, 5)}}, + dates=_dates(date(2024, 1, 1), 30), + ) + assert len(promos) == 1 + assert promos[0]["start_date"] == date(2024, 1, 1) # clamped + assert promos[0]["end_date"] == date(2024, 1, 4) + + def test_skips_stockout_on_first_date(self) -> None: + # Stockout on day 1 leaves no room for a markdown window. + promos, prices, md_dates = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={(10, 1): {date(2024, 1, 1)}}, + dates=_dates(date(2024, 1, 1), 30), + ) + assert promos == [] and prices == [] and md_dates == {} + + def test_per_store_markdowns(self) -> None: + # Two stores stocked out for product 1 → two distinct markdowns. + promos, _, _ = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), + store_ids=[10, 20], + stockout_dates={ + (10, 1): {date(2024, 3, 15)}, + (20, 1): {date(2024, 3, 20)}, + }, + dates=_dates(date(2024, 1, 1), 90), + ) + assert len(promos) == 2 + store_ids = {p["store_id"] for p in promos} + assert store_ids == {10, 20} + + def test_unknown_product_silently_skipped(self) -> None: + promos, _, _ = MarkdownGenerator(random.Random(0), self._cfg()).generate( + product_specs=_product_specs(), # ids 1-4 + store_ids=[10], + stockout_dates={(10, 99): {date(2024, 3, 15)}}, + dates=_dates(date(2024, 1, 1), 90), + ) + assert promos == [] + + def test_deterministic_output_order(self) -> None: + cfg = self._cfg() + specs = _product_specs() + dates_ = _dates(date(2024, 1, 1), 90) + # Build stockouts in two different dict orders. + stockouts_a = { + (20, 2): {date(2024, 3, 10)}, + (10, 1): {date(2024, 3, 15)}, + } + stockouts_b = { + (10, 1): {date(2024, 3, 15)}, + (20, 2): {date(2024, 3, 10)}, + } + promos_a, _, _ = MarkdownGenerator(random.Random(0), cfg).generate( + product_specs=specs, + store_ids=[10, 20], + stockout_dates=stockouts_a, + dates=dates_, + ) + promos_b, _, _ = MarkdownGenerator(random.Random(0), cfg).generate( + product_specs=specs, + store_ids=[10, 20], + stockout_dates=stockouts_b, + dates=dates_, + ) + assert promos_a == promos_b + + +# ---------------------------------------------------------------------- # +# Validation +# ---------------------------------------------------------------------- # + + +class TestMarkdownGeneratorValidation: + def test_age_days_trigger_raises_not_implemented(self) -> None: + gen = MarkdownGenerator( + random.Random(0), + MarkdownConfig(enable=True, trigger="age_days"), + ) + with pytest.raises(NotImplementedError, match="#94"): + gen.generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 30), + ) + + def test_depth_pct_below_zero_raises(self) -> None: + gen = MarkdownGenerator( + random.Random(0), + MarkdownConfig(enable=True, markdown_depth_pct=-0.1), + ) + with pytest.raises(ValueError, match="markdown_depth_pct"): + gen.generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 30), + ) + + def test_depth_pct_above_one_raises(self) -> None: + gen = MarkdownGenerator( + random.Random(0), + MarkdownConfig(enable=True, markdown_depth_pct=1.5), + ) + with pytest.raises(ValueError, match="markdown_depth_pct"): + gen.generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 30), + ) + + def test_zero_duration_raises(self) -> None: + gen = MarkdownGenerator( + random.Random(0), + MarkdownConfig(enable=True, markdown_duration_days=0), + ) + with pytest.raises(ValueError, match="markdown_duration_days"): + gen.generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={}, + dates=_dates(date(2024, 1, 1), 30), + ) + + def test_no_rng_consumption_enabled_path(self) -> None: + """Enabled generator is deterministic — rng should be untouched.""" + rng = random.Random(42) + baseline_state = rng.getstate() + MarkdownGenerator( + rng, + MarkdownConfig(enable=True, trigger="stockout_risk"), + ).generate( + product_specs=_product_specs(), + store_ids=[10], + stockout_dates={(10, 1): {date(2024, 3, 15)}}, + dates=_dates(date(2024, 1, 1), 90), + ) + assert rng.getstate() == baseline_state diff --git a/app/shared/seeder/tests/test_phase2_replenishment.py b/app/shared/seeder/tests/test_phase2_replenishment.py new file mode 100644 index 00000000..3572ed4b --- /dev/null +++ b/app/shared/seeder/tests/test_phase2_replenishment.py @@ -0,0 +1,232 @@ +"""Tests for Phase 2 ReplenishmentGenerator. + +Regression invariant: with ``LeadTimeConfig.enable=False`` (default) +``ReplenishmentGenerator.generate`` returns ``[]`` and consumes zero +rng state. +""" + +from __future__ import annotations + +import random +from datetime import date, timedelta +from typing import Any + +import pytest + +from app.shared.seeder.config import LeadTimeConfig +from app.shared.seeder.generators.replenishment import ReplenishmentGenerator + + +def _dates(start: date, days: int) -> list[date]: + return [start + timedelta(days=i) for i in range(days)] + + +# ---------------------------------------------------------------------- # +# Disabled / regression invariant +# ---------------------------------------------------------------------- # + + +class TestReplenishmentGeneratorDisabled: + def test_enabled_false_when_config_none(self) -> None: + assert ReplenishmentGenerator(random.Random(0), None).enabled is False + + def test_enabled_false_when_config_default(self) -> None: + assert ReplenishmentGenerator(random.Random(0), LeadTimeConfig()).enabled is False + + def test_empty_output_when_config_none(self) -> None: + rng = random.Random(42) + baseline_state = rng.getstate() + out = ReplenishmentGenerator(rng, None).generate( + [1, 2], [10, 20], _dates(date(2024, 1, 1), 90) + ) + assert out == [] + assert rng.getstate() == baseline_state + + def test_empty_output_when_disabled_config(self) -> None: + rng = random.Random(42) + baseline_state = rng.getstate() + out = ReplenishmentGenerator(rng, LeadTimeConfig()).generate( + [1, 2], [10, 20], _dates(date(2024, 1, 1), 90) + ) + assert out == [] + assert rng.getstate() == baseline_state + + +# ---------------------------------------------------------------------- # +# Enabled-path correctness +# ---------------------------------------------------------------------- # + + +class TestReplenishmentGeneratorEnabled: + def _cfg(self, **overrides: Any) -> LeadTimeConfig: + kwargs: dict[str, Any] = { + "enable": True, + "mean_lead_time_days": 7, + "lead_time_sigma_days": 1.5, + "safety_stock_days": 3, + "order_frequency_days": 14, + "fill_rate_mean": 0.97, + "fill_rate_sigma": 0.05, + } + kwargs.update(overrides) + return LeadTimeConfig(**kwargs) + + def test_record_shape_and_invariants(self) -> None: + out = ReplenishmentGenerator(random.Random(0), self._cfg()).generate( + [1, 2], [10, 20], _dates(date(2024, 1, 1), 90), base_demand=100 + ) + assert len(out) > 0 + for r in out: + assert set(r.keys()) == { + "date", + "store_id", + "product_id", + "lead_time_days", + "ordered_qty", + "received_qty", + } + assert isinstance(r["date"], date) + assert r["store_id"] in (1, 2) + assert r["product_id"] in (10, 20) + assert r["lead_time_days"] >= 0 + assert r["ordered_qty"] >= 0 + assert 0 <= r["received_qty"] <= r["ordered_qty"] + + def test_ordered_qty_formula(self) -> None: + # base_demand=100, order_freq=14, safety=3 → ordered = 100*17 = 1700. + out = ReplenishmentGenerator(random.Random(0), self._cfg()).generate( + [1], [10], _dates(date(2024, 1, 1), 90), base_demand=100 + ) + assert all(r["ordered_qty"] == 1700 for r in out) + + def test_dates_within_seeded_range(self) -> None: + dates_ = _dates(date(2024, 1, 1), 365) + out = ReplenishmentGenerator(random.Random(0), self._cfg()).generate( + [1], [10], dates_, base_demand=100 + ) + for r in out: + assert dates_[0] <= r["date"] <= dates_[-1] + + def test_reproducible_with_same_seed(self) -> None: + cfg = self._cfg() + a = ReplenishmentGenerator(random.Random(42), cfg).generate( + [1, 2], [10, 20], _dates(date(2024, 1, 1), 90) + ) + b = ReplenishmentGenerator(random.Random(42), cfg).generate( + [1, 2], [10, 20], _dates(date(2024, 1, 1), 90) + ) + assert a == b + + def test_input_order_does_not_affect_output(self) -> None: + cfg = self._cfg() + a = ReplenishmentGenerator(random.Random(42), cfg).generate( + [1, 2], [10, 20], _dates(date(2024, 1, 1), 90) + ) + b = ReplenishmentGenerator(random.Random(42), cfg).generate( + [2, 1], [20, 10], _dates(date(2024, 1, 1), 90) + ) + assert a == b + + def test_empty_dates_returns_empty(self) -> None: + out = ReplenishmentGenerator(random.Random(0), self._cfg()).generate([1], [10], []) + assert out == [] + + def test_high_fill_rate_yields_full_orders(self) -> None: + cfg = self._cfg(fill_rate_mean=1.0, fill_rate_sigma=0.0) + out = ReplenishmentGenerator(random.Random(0), cfg).generate( + [1], [10], _dates(date(2024, 1, 1), 90) + ) + assert len(out) > 0 + for r in out: + assert r["received_qty"] == r["ordered_qty"] + + def test_zero_fill_rate_yields_zero_received(self) -> None: + cfg = self._cfg(fill_rate_mean=0.0, fill_rate_sigma=0.0) + out = ReplenishmentGenerator(random.Random(0), cfg).generate( + [1], [10], _dates(date(2024, 1, 1), 90) + ) + assert len(out) > 0 + for r in out: + assert r["received_qty"] == 0 + + def test_zero_lead_time_gives_immediate_receipt(self) -> None: + cfg = self._cfg(mean_lead_time_days=0, lead_time_sigma_days=0.0) + dates_ = _dates(date(2024, 1, 1), 84) + out = ReplenishmentGenerator(random.Random(0), cfg).generate([1], [10], dates_) + # 6 POs placed at days 0, 14, 28, 42, 56, 70 (day 84 > end day 83). + assert len(out) == 6 + for r in out: + assert r["lead_time_days"] == 0 + day_offset = (r["date"] - dates_[0]).days + assert day_offset % cfg.order_frequency_days == 0 + + def test_output_sorted_by_store_product_date(self) -> None: + cfg = self._cfg(mean_lead_time_days=0, lead_time_sigma_days=0.0) + out = ReplenishmentGenerator(random.Random(0), cfg).generate( + [2, 1], [20, 10], _dates(date(2024, 1, 1), 90) + ) + keys = [(r["store_id"], r["product_id"], r["date"]) for r in out] + assert keys == sorted(keys) + + +# ---------------------------------------------------------------------- # +# Validation +# ---------------------------------------------------------------------- # + + +class TestReplenishmentGeneratorValidation: + def test_negative_mean_lead_time_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True, mean_lead_time_days=-1), + ) + with pytest.raises(ValueError, match="mean_lead_time_days"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30)) + + def test_negative_lead_time_sigma_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True, lead_time_sigma_days=-0.5), + ) + with pytest.raises(ValueError, match="lead_time_sigma_days"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30)) + + def test_zero_order_frequency_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True, order_frequency_days=0), + ) + with pytest.raises(ValueError, match="order_frequency_days"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30)) + + def test_fill_rate_mean_above_one_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True, fill_rate_mean=1.5), + ) + with pytest.raises(ValueError, match="fill_rate_mean"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30)) + + def test_negative_fill_rate_sigma_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True, fill_rate_sigma=-0.1), + ) + with pytest.raises(ValueError, match="fill_rate_sigma"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30)) + + def test_negative_safety_stock_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True, safety_stock_days=-1), + ) + with pytest.raises(ValueError, match="safety_stock_days"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30)) + + def test_negative_base_demand_raises(self) -> None: + gen = ReplenishmentGenerator( + random.Random(0), + LeadTimeConfig(enable=True), + ) + with pytest.raises(ValueError, match="base_demand"): + gen.generate([1], [1], _dates(date(2024, 1, 1), 30), base_demand=-10) diff --git a/app/shared/seeder/tests/test_returns.py b/app/shared/seeder/tests/test_returns.py new file mode 100644 index 00000000..6421b01a --- /dev/null +++ b/app/shared/seeder/tests/test_returns.py @@ -0,0 +1,102 @@ +"""Tests for ReturnsGenerator (Phase 1).""" + +# mypy: disable-error-code="union-attr,arg-type,operator,return-value" + +import random +from datetime import date, timedelta +from decimal import Decimal + +from app.shared.seeder.config import ReturnsConfig +from app.shared.seeder.generators.returns import ReturnsGenerator + + +def _sales_records(n: int, start: date = date(2024, 1, 1)) -> list[dict[str, object]]: + """Build n synthetic sales rows in the shape SalesDailyGenerator emits.""" + return [ + { + "date": start + timedelta(days=i), + "store_id": 1, + "product_id": 100, + "quantity": 10, + "unit_price": Decimal("9.99"), + "total_amount": Decimal("99.90"), + } + for i in range(n) + ] + + +class TestReturnsGeneratorDisabled: + def test_disabled_emits_nothing(self): + gen = ReturnsGenerator(random.Random(42), ReturnsConfig(enable=False)) + assert gen.generate(_sales_records(50), date(2024, 1, 31)) == [] + + +class TestReturnsGeneratorEnabled: + def test_returns_fire_at_configured_rate(self): + # Probability 1.0 means every sale generates a return. + cfg = ReturnsConfig(enable=True, return_probability=1.0) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(200) + returns = gen.generate(sales, date(2024, 1, 31)) + assert len(returns) == 200 + + def test_probability_zero_no_returns(self): + cfg = ReturnsConfig(enable=True, return_probability=0.0) + gen = ReturnsGenerator(random.Random(0), cfg) + assert gen.generate(_sales_records(50), date(2024, 1, 31)) == [] + + def test_return_quantity_is_positive_and_capped(self): + # quantity_fraction=2.0 should be clamped to original quantity. + cfg = ReturnsConfig(enable=True, return_probability=1.0, return_quantity_fraction=2.0) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(20) + returns = gen.generate(sales, date(2024, 1, 31)) + for r in returns: + assert 1 <= r["return_quantity"] <= 10 # capped at sale quantity + + def test_return_date_clamped_to_end_date(self): + cfg = ReturnsConfig( + enable=True, + return_probability=1.0, + return_lag_days_min=30, + return_lag_days_max=30, + ) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(5, start=date(2024, 1, 20)) + end = date(2024, 1, 31) + returns = gen.generate(sales, end) + for r in returns: + assert r["date"] <= end + + def test_reasons_drawn_from_distribution(self): + cfg = ReturnsConfig( + enable=True, + return_probability=1.0, + return_reason_distribution={"defective": 1.0}, + ) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = _sales_records(10) + returns = gen.generate(sales, date(2024, 1, 31)) + assert all(r["return_reason"] == "defective" for r in returns) + + def test_reproducible(self): + cfg = ReturnsConfig(enable=True, return_probability=0.5) + sales = _sales_records(100) + a = ReturnsGenerator(random.Random(7), cfg).generate(sales, date(2024, 12, 31)) + b = ReturnsGenerator(random.Random(7), cfg).generate(sales, date(2024, 12, 31)) + assert a == b + + def test_zero_quantity_sales_skipped(self): + cfg = ReturnsConfig(enable=True, return_probability=1.0) + gen = ReturnsGenerator(random.Random(0), cfg) + sales = [ + { + "date": date(2024, 1, 1), + "store_id": 1, + "product_id": 2, + "quantity": 0, + "unit_price": Decimal("9.99"), + "total_amount": Decimal("0.00"), + } + ] + assert gen.generate(sales, date(2024, 1, 31)) == [] diff --git a/docs/DATA-SEEDER.md b/docs/DATA-SEEDER.md index 8f94c8f3..63ebe175 100644 --- a/docs/DATA-SEEDER.md +++ b/docs/DATA-SEEDER.md @@ -196,6 +196,206 @@ uv run python scripts/seed_random.py --full-new --config examples/seed/config_cu - **Price Elasticity**: Demand adjustment based on price changes - **New Product Ramps**: Gradual demand increase for new launches +## Phase 1 Realism Extensions + +Phase 1 adds opt-in realism: exogenous signals, multi-seasonality, trend changepoints, +returns volume, and stockout substitution. Each extension is gated behind its own flag +on `GenerateParams` (or its dataclass on `SeederConfig`). **Existing scenarios with no +flags set produce byte-identical seeded data to pre-Phase-1** — the regression invariant +is enforced by `app/shared/seeder/tests/test_phase1_regression.py`. + +### Exogenous Signals + +Persisted in the `exogenous_signal` table. Three signals available: + +| Signal | Scope | Shape | +|--------|-------|-------| +| `weather_temp_c` | per (store, date) | sinusoidal climatology + Gaussian noise | +| `macro_index` | per date (global) | random walk from `macro_initial_value` | +| `event_flag` | per `event_dates` entry | binary 1.0 marker on configured dates | + +Toggle via `GenerateParams.enable_exogenous=true` (turns on weather + macro). To also +drive demand from weather, pass `weather_temperature_sensitivity` (e.g. `0.02` = +2% +demand per °C above the climatology mean). + +Read back: + +```bash +curl "http://localhost:8123/seeder/exogenous?signal_name=weather_temp_c&start_date=2024-01-01&end_date=2024-01-31" +``` + +### Multi-Seasonality + +Yearly sin wave on top of weekly + monthly seasonality: + +```json +{"yearly_seasonality_amplitude": 0.15} +``` + +Amplitude is a fraction of base demand (0–1). 0 or unset = disabled. + +### Changepoints + +COVID-style demand impulses with exponential decay: + +```json +{ + "changepoints": [ + {"date": "2024-03-15", "demand_multiplier": 2.0, "decay_days": 60} + ] +} +``` + +`decay_days=0` means a pure impulse on the changepoint date. + +### Returns + +Synthetic returns volume in the `sales_returns` table. A configurable fraction of +sales rows generates a delayed return: + +```json +{"enable_returns": true} +``` + +Tune via `ReturnsConfig` on `SeederConfig` (default ~2% of sales, lag 1–14 days, with +reasons drawn from `defective`/`wrong_size`/`not_as_described`/`changed_mind`/ +`damaged_in_transit`). + +### Substitution on Stockout + +When a member of a substitute group is stocked out, the surviving members pick up a +share of demand: + +```json +{ + "enable_substitution": true, + "substitute_groups": [[1, 2, 3]], + "substitution_lift_on_stockout": 0.5 +} +``` + +`product_id` values must already exist in the dataset. The lift is split across in-stock +group-mates. + +### Phase 1 API surface + +- `POST /seeder/generate` accepts the Phase 1 fields above; defaults keep Phase 1 off. +- `GET /seeder/exogenous?signal_name=&start_date=&end_date=&store_id=` returns signal rows. +- `GET /seeder/status` adds `exogenous_signals` and `sales_returns` counts. + +## Phase 2 Retail Depth Extensions + +Phase 2 adds five orthogonal toggles for richer retail realism: multi-channel +sales, product lifecycles, bundle/BOGO promotions, clearance markdowns, and +replenishment lead times. Like Phase 1, every toggle defaults off — the +disabled path is byte-identical with pre-Phase-2 output for every existing +scenario. + +### Multi-Channel Sales + +Splits each emitted `sales_daily` row across channels drawn from a configurable +mix. + +```json +{ + "enable_multichannel": true, + "channel_mix": {"in_store": 0.6, "online": 0.3, "click_collect": 0.1}, + "online_promo_uplift": 1.2, + "online_substitution_to_instore": 0.1 +} +``` + +- Allow-list for channel keys: `in_store`, `online`, `click_collect`, `wholesale`. +- Weights must be non-negative; at least one must be positive. +- `online_promo_uplift` multiplies quantity for online rows on promo dates. +- `online_substitution_to_instore` shifts the effective mix toward `online` + during promos (0.0 = independent; 1.0 = pure substitution). + +### Product Lifecycles + +Assigns each product a `launch_date` (and optionally a `discontinue_date`) and +shapes demand over intro → growth → maturity → decline → discontinued. + +```json +{ + "enable_lifecycle": true, + "lifecycle_discontinue_probability": 0.05 +} +``` + +When enabled: +- `Product.launch_date` / `Product.discontinue_date` are populated. +- `SalesDailyGenerator` applies the lifecycle multiplier per `(product, date)`. +- The legacy `new_product_ramp_days` linear ramp is suppressed to avoid + double-attenuation. + +### Bundle / BOGO Promotions + +Converts a fraction of `PromotionGenerator`'s output into `kind='bundle'` or +`kind='bogo'` rows with explicit member product IDs. + +```json +{ + "enable_bundles": true, + "bundle_probability": 0.2 +} +``` + +- `bundle_probability` is the per-promotion conversion rate. +- Each converted row carries a `bundle_member_product_ids` list (enforced by + the `ck_promotion_bundle_members_consistency` CHECK). + +### Markdowns (Clearance) + +Emits `Promotion(kind='markdown')` rows + companion `PriceHistory` drops on +two triggers: + +```json +{ + "enable_markdowns": true, + "markdown_trigger": "lifecycle_decline" +} +``` + +- `lifecycle_decline` (default): fires chain-wide on the first date a product + enters the decline stage. Requires `enable_lifecycle=true` to produce rows. +- `stockout_risk`: fires per-`(store, product)` ending the day before each + observed stockout, with the configured `markdown_duration_days` window. +- `age_days` is **deferred** — see issue [#94](https://github.com/w7-mgfcode/ForecastLabAI/issues/94). + The generator raises `NotImplementedError` for that mode. + +### Replenishment Lead Time + +Emits `replenishment_event` rows that mark receipts of inbound stock per +`(store, product)` PO chain. + +```json +{ + "enable_lead_time": true, + "mean_lead_time_days": 7 +} +``` + +- One PO every `order_frequency_days` (default 14) per `(store, product)`. +- Lead time sampled Gaussian; fill rate sampled Gaussian and clamped to [0, 1]. +- Receipts past the seeded `end_date` are dropped to keep the FK to + `calendar` valid. + +### Phase 2 API surface + +- `POST /seeder/generate` accepts all five Phase 2 enable flags plus + `channel_mix`, `online_promo_uplift`, `online_substitution_to_instore`, + `lifecycle_discontinue_probability`, `bundle_probability`, `markdown_trigger`, + and `mean_lead_time_days`. Defaults keep Phase 2 off. +- `GET /seeder/channels` returns the sorted allow-list for + `sales_daily.channel` and `ChannelConfig.channel_mix` keys — + `["click_collect", "in_store", "online", "wholesale"]`. +- `GET /dimensions/products/{id}/lifecycle-curve` returns the reference + demand-multiplier curve for a product using the default `LifecycleConfig` + ramp parameters (respects the product's own `launch_date` / + `discontinue_date`). Useful for UI charts. +- `GET /seeder/status` adds a `replenishment_events` count. + ## Data Integrity The seeder enforces data integrity: @@ -204,6 +404,15 @@ The seeder enforces data integrity: 2. **Non-Negative Values**: Quantities and prices are always non-negative 3. **Date Coverage**: Calendar table covers entire date range 4. **Uniqueness**: Store codes and product SKUs are unique +5. **Phase 1 — Returns positive**: `sales_returns.return_quantity` is always ≥ 1 +6. **Phase 1 — Exogenous consistency**: every `exogenous_signal` row satisfies + `is_global = true ⇔ store_id IS NULL` (enforced by a CHECK constraint and verified + by `verify_data_integrity`) +7. **Phase 2 — Bundle members non-NULL**: every `promotion` row with + `kind in (bundle, bogo)` carries a non-NULL `bundle_member_product_ids` +8. **Phase 2 — Lifecycle ordering**: `discontinue_date >= launch_date` when both are set +9. **Phase 2 — Replenishment fill**: `received_qty <= ordered_qty` on every + `replenishment_event` row Verify with: ```bash diff --git a/docs/_base/API_CONTRACTS.md b/docs/_base/API_CONTRACTS.md new file mode 100644 index 00000000..460780a9 --- /dev/null +++ b/docs/_base/API_CONTRACTS.md @@ -0,0 +1,73 @@ +# ForecastLabAI API Contracts +> Source: heuristic discovery from `app/main.py` router wiring and per-feature `routes.py`. Full request/response schemas live in the Pydantic models at `app/features//schemas.py`. Swagger UI at `http://localhost:8123/docs` is the authoritative live contract. + +## HTTP Endpoints + +All endpoints serve JSON; error responses use `application/problem+json` (RFC 7807) via `app/core/problem_details.py`. Schemas are Pydantic v2 (`app/features//schemas.py`). + +| Slice | Method | Path | Purpose | +|-------|--------|------|---------| +| health | GET | `/health` | Liveness probe — `{"status":"ok"}` | +| ingest | POST | `/ingest/sales-daily` | Batch upsert with natural-key resolution, idempotent `ON CONFLICT DO UPDATE` | +| dimensions | GET | `/dimensions/stores` | List stores (1-indexed pagination, region/store_type filter, case-insensitive search) | +| dimensions | GET | `/dimensions/stores/{store_id}` | Get store by ID | +| dimensions | GET | `/dimensions/products` | List products (category/brand filter, sku/name search) | +| dimensions | GET | `/dimensions/products/{product_id}` | Get product by ID | +| analytics | GET | `/analytics/kpis` | Aggregated KPIs (revenue, units, transactions, avg unit price, avg basket) | +| analytics | GET | `/analytics/drilldowns` | Group-by dimension: store / product / category / region / date | +| featuresets | POST | `/featuresets/compute` | Compute time-safe features (lag/rolling/calendar, leakage-prevented) | +| featuresets | POST | `/featuresets/preview` | Preview features with sample rows | +| forecasting | POST | `/forecasting/train` | Train a model (naive / seasonal_naive / moving_average / lightgbm) | +| forecasting | POST | `/forecasting/predict` | Generate horizon predictions from a trained model | +| backtesting | POST | `/backtesting/run` | Time-series CV (rolling/expanding splits, MAE/sMAPE/WAPE/bias/stability) | +| registry | POST | `/registry/runs` | Create model run (pending) | +| registry | GET | `/registry/runs` | List with filters + pagination | +| registry | GET | `/registry/runs/{run_id}` | Run details + JSONB metrics + runtime_info | +| registry | PATCH | `/registry/runs/{run_id}` | Update status / metrics / artifact_uri | +| registry | GET | `/registry/runs/{run_id}/verify` | SHA-256 artifact integrity check | +| registry | POST | `/registry/aliases` | Create/update alias (only on `success` runs) | +| registry | GET | `/registry/aliases` | List aliases | +| registry | GET | `/registry/aliases/{alias_name}` | Get alias | +| registry | DELETE | `/registry/aliases/{alias_name}` | Delete alias | +| registry | GET | `/registry/compare/{run_id_a}/{run_id_b}` | Diff two runs | +| jobs | POST | `/jobs` | Submit `train` / `predict` / `backtest` (returns 202-style job_id) | +| jobs | GET | `/jobs` | List with filters | +| jobs | GET | `/jobs/{job_id}` | Status + result JSON | +| jobs | DELETE | `/jobs/{job_id}` | Cancel pending | +| rag | POST | `/rag/index` | Index a markdown/openapi document; idempotent via content hash | +| rag | POST | `/rag/retrieve` | Semantic search (HNSW), top-k with similarity threshold | +| rag | GET | `/rag/sources` | List indexed sources | +| rag | DELETE | `/rag/sources/{source_id}` | Delete source + cascaded chunks | +| agents | POST | `/agents/sessions` | Create session (`agent_type`: `experiment` or `rag_assistant`) | +| agents | GET | `/agents/sessions/{session_id}` | Status + message history (Postgres JSONB) | +| agents | POST | `/agents/sessions/{session_id}/chat` | Send user message; returns full response | +| agents | POST | `/agents/sessions/{session_id}/approve` | Approve/reject a pending tool call (HITL gate) | +| agents | DELETE | `/agents/sessions/{session_id}` | Close session | +| agents | WS | `/agents/stream` | Token-by-token streaming + tool-call events | +| seeder | (see `app/features/seeder/routes.py`) | `/seeder/*` | Trigger scenarios, status, customization | + +## WebSocket Events (`/agents/stream`) + +[UNVERIFIED — verify against `app/features/agents/websocket.py`] +- Client → server: `{"session_id": str, "message": str}` +- Server → client (streamed): token deltas, tool-call announcements, tool-call results, completion event, error frames. + +## Async Events / Queues + +None. Job execution is synchronous-with-async-shaped-API (per `app/features/jobs/`). No Kafka / SQS / pub-sub. Per `.claude/rules/product-vision.md`, **not a streaming system**. + +## External Integrations + +| Integration | Direction | Auth | Rate Limit | Fallback | +|-------------|-----------|------|------------|----------| +| OpenAI (embeddings + agent LLM) | egress HTTPS | `OPENAI_API_KEY` | provider-side | switch `RAG_EMBEDDING_PROVIDER=ollama`; switch agent model | +| Anthropic (agent LLM) | egress HTTPS | `ANTHROPIC_API_KEY` | provider-side | `AGENT_FALLBACK_MODEL` | +| Google Gemini (agent LLM, optional) | egress HTTPS | `GOOGLE_API_KEY` | provider-side | switch model | +| Ollama (local embeddings, optional) | egress HTTP LAN | none | local | switch back to OpenAI | + +## Schema Change Policy + +- Pre-1.0: API contracts under `/dimensions`, `/analytics`, `/ingest`, `/forecasting`, `/backtesting`, `/registry`, `/rag`, `/agents`, `/jobs` MAY change in MINOR releases. Pin the version. (See `.claude/rules/versioning.md`.) +- Every DB-touching change ships with an Alembic migration. Forward-only after merge. +- Pydantic v2 schema additions: prefer additive; breaking field renames go behind a `feat!:` or call out in PR description. +- New endpoints must register in `app/main.py` and have a route test in the slice's `tests/test_routes.py` (per `.claude/rules/test-requirements.md`). diff --git a/docs/_base/ARCHITECTURE.md b/docs/_base/ARCHITECTURE.md new file mode 100644 index 00000000..a68051bb --- /dev/null +++ b/docs/_base/ARCHITECTURE.md @@ -0,0 +1,90 @@ +# ForecastLabAI Architecture +> Source: heuristic discovery of `README.md`, `app/main.py`, `app/core/config.py`, `docker-compose.yml`, `.github/workflows/`, `.claude/rules/`. [UNVERIFIED] tags mark heuristic-derived content awaiting `docs/_kB/repo-map/` ground truth. +> Last generated: 2026-05-11 + +## System Boundaries + +### What This Repo Owns +- The entire stack: FastAPI backend (`app/`), React 19 SPA (`frontend/`), Alembic migrations (`alembic/`), data seeder (`app/shared/seeder/` + `scripts/seed_random.py`), `.claude/` policy + skills + hooks, docs (`docs/`, `PRPs/`, `INITIAL-*.md`). +- 7-table retail data platform (`store`, `product`, `calendar`, `sales_daily`, `price_history`, `promotion`, `inventory_snapshot_daily`) + registry, jobs, RAG sources/chunks, agent sessions. +- 11 backend vertical slices under `app/features/` + cross-cutting `app/core/` + `app/shared/`. + +### What This Repo Depends On +| Dependency | Interface | Owner | Change Process | +|------------|-----------|-------|----------------| +| PostgreSQL 16 + pgvector | `asyncpg` URL `DATABASE_URL` | self (docker-compose) | New extension → migration | +| OpenAI API | `openai>=1.40` via `app/features/rag` + `app/features/agents` | external | Pin model name in config | +| Anthropic API | `anthropic>=0.50` via PydanticAI | external | Pin model name in config | +| Google Gemini (optional) | `google-gla:*` / `google-vertex:*` model IDs | external | Set `GOOGLE_API_KEY` | +| Ollama (optional) | HTTP at `OLLAMA_BASE_URL` | self/LAN | Set `RAG_EMBEDDING_PROVIDER=ollama` | + +### What Depends On This Repo +| Consumer | Depends On | Break Risk | +|----------|-----------|------------| +| `frontend/` React SPA | Backend HTTP API + `/agents/stream` WebSocket | HIGH — same repo, both released together | +| External demos / portfolio reviewers | `/docs` (Swagger), `/redoc` | LOW | + +No other internal repos consume this one — single-deployment system per `.claude/rules/product-vision.md`. + +## Resource Hierarchy + +``` +ForecastLabAI repo +├── docker-compose.yml # single Postgres+pgvector container +├── app/ # FastAPI process (uvicorn :8123) +│ ├── core/ # config, db engine, logging, middleware, problem-details, health +│ ├── shared/ # cross-slice models + seeder ("The Forge") +│ └── features// # vertical slices (11 of them) +└── frontend/ # Vite dev server :5173 (proxies → :8123) +``` + +## Component Overview + +| Component | Language | Type | Path | Wired in | +|-----------|----------|------|------|----------| +| Health | Python | HTTP | `app/core/health.py` | `app/main.py` | +| Dimensions | Python | HTTP read | `app/features/dimensions/` | `app/main.py` | +| Analytics | Python | HTTP read | `app/features/analytics/` | `app/main.py` | +| Jobs | Python | HTTP CRUD | `app/features/jobs/` | `app/main.py` | +| Ingest | Python | HTTP upsert | `app/features/ingest/` | `app/main.py` | +| Featuresets | Python | HTTP compute | `app/features/featuresets/` | `app/main.py` | +| Forecasting | Python | HTTP train/predict | `app/features/forecasting/` | `app/main.py` | +| Backtesting | Python | HTTP run | `app/features/backtesting/` | `app/main.py` | +| Registry | Python | HTTP CRUD + JSONB | `app/features/registry/` | `app/main.py` | +| RAG | Python | HTTP index/retrieve + pgvector | `app/features/rag/` | `app/main.py` | +| Agents | Python | HTTP + WebSocket | `app/features/agents/` | `app/main.py` | +| Seeder | Python | HTTP control | `app/features/seeder/` | `app/main.py` | +| Data platform | Python | ORM only (no router) | `app/features/data_platform/` | imported by services | +| Frontend | TypeScript | SPA | `frontend/src/` | served by Vite | + +## Communication Patterns + +| Pattern | Used By | Protocol | Auth | +|---------|---------|----------|------| +| Sync HTTP (REST) | Frontend → Backend, demo curl | HTTP/JSON | None (single-tenant; CORS allow-list dev-only) | +| WebSocket streaming | Frontend chat ↔ `/agents/stream` | WS frames | None | +| Process → DB | All services → Postgres | `postgresql+asyncpg` | `DATABASE_URL` user/pass | +| Agent tool calls | PydanticAI → backend services | In-process Python | Pydantic-validated arg schemas | +| RAG embeddings | RAG service → OpenAI / Ollama | HTTPS | `OPENAI_API_KEY` env or Ollama LAN URL | +| Agent LLM calls | Agents → Anthropic/OpenAI/Gemini | HTTPS | provider API key env | + +## Deployment Flow (Causal Chain) + +``` +PR opened on dev → ci.yml (lint + typecheck + test + migration-check) → reviewer approve → merge to dev +dev → main PR → release-please opens Release PR → merge Release PR → tag vX.Y.Z → cd-release.yml (build wheel + upload artifacts to GitHub Release) +Local install → docker-compose up -d → alembic upgrade head → uvicorn → vite +``` + +No staging/prod environments configured. Deployment target is the developer laptop or a single host — there is no managed cloud target. [UNVERIFIED — if a hosted demo exists, document it here.] + +## Observability Stack + +| Signal | Tool | Retention | Surface | +|--------|------|-----------|---------| +| Logs | `structlog` (JSON in prod, console in dev) | stdout only — process-local | `app/core/logging.py` | +| Request ID | `RequestIdMiddleware` (`app/core/middleware.py`) | per-request | echoed in problem-details `request_id` | +| Errors | RFC 7807 problem+json | per-response | `app/core/problem_details.py` | +| Metrics | none [UNVERIFIED — none observed in code or config] | — | — | +| Traces | none [UNVERIFIED] | — | — | +| Dashboards | The React app itself surfaces operational state via Jobs/Runs/Health pages | live | `frontend/src/pages/` | diff --git a/docs/_base/DEV_GUIDE.md b/docs/_base/DEV_GUIDE.md new file mode 100644 index 00000000..21b7b1cd --- /dev/null +++ b/docs/_base/DEV_GUIDE.md @@ -0,0 +1,61 @@ +# ForecastLabAI Developer Guide +> HUMAN-MAINTAINED — do not overwrite via the generating-claudemd skill. +> Fill in all {FILL IN} sections; remove this stub marker line when content is complete. + +## What This Project Is + +{FILL IN: 2 sentences. Suggested seed — "A portfolio-grade, single-host retail demand forecasting system that exercises the full lifecycle: data platform → ingest → time-safe features → forecasting → backtesting → registry → RAG → agents → React dashboard. Pre-1.0; release-please-driven SemVer."} + +## Tech Stack + +See `CLAUDE.md` Stack section and `pyproject.toml` for authoritative dependency list. {FILL IN: any narrative on why each choice — point to ADRs in `docs/ADR/`.} + +## Local Development Setup + +Authoritative quick-start lives in `README.md`. The short version: + +```bash +cp .env.example .env # set your OPENAI_API_KEY / ANTHROPIC_API_KEY +docker compose up -d # Postgres+pgvector on :5433 +uv sync --extra dev # Python 3.12 deps +uv run alembic upgrade head # migrations +uv run uvicorn app.main:app --reload --port 8123 +cd frontend && corepack enable pnpm && pnpm install && pnpm dev +``` + +{FILL IN: any host-specific notes (e.g., WSL caveats from `HANDOFF.md` on corrupt `.venv` / `node_modules` binaries).} + +## Running Tests + +```bash +uv run pytest -v -m "not integration" # unit (fast, no DB) +docker compose up -d +uv run pytest -v -m integration # integration (real Postgres) +cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run +``` + +{FILL IN: coverage targets if any; how to add a new vertical slice's `tests/`.} + +## Project Conventions + +Authoritative rules live in `.claude/rules/` and are surfaced in `docs/_base/RULES.md`. The non-obvious gotchas worth highlighting here: + +- Vertical-slice imports: `app/features/X` may NOT import from `app/features/Y`. Cross-cutting code goes to `app/shared/` or `app/core/`. +- The seeder is the only sanctioned bulk-mutation path on the DB. +- {FILL IN: any other conventions discovered in practice.} + +## Why We Chose These Technologies + +- ADRs live in `docs/ADR/` — see `docs/ADR/ADR-INDEX.md`. +- {FILL IN: short narrative for newcomers — e.g., "pgvector chosen over a managed vector DB to keep the system single-host; see ADR-0003."} + +## Common Troubleshooting + +See `docs/_base/RUNBOOKS.md` — "Common Incidents" section covers the recurring traps (frontend `Loading...` from misconfigured `VITE_API_BASE_URL`, pnpm 11 `depsStatusCheck`, WSL `.venv` corruption, `.env`-bleed in Settings tests). + +## Contacts & Resources + +- Maintainer: Gabor Szabo +- Issue tracker: GitHub Issues on this repo +- Release tracker: `CHANGELOG.md` (release-please-managed) +- {FILL IN: any out-of-repo links — Slack, Notion, demo URL.} diff --git a/docs/_base/DOMAIN_MODEL.md b/docs/_base/DOMAIN_MODEL.md new file mode 100644 index 00000000..090fa626 --- /dev/null +++ b/docs/_base/DOMAIN_MODEL.md @@ -0,0 +1,105 @@ +# ForecastLabAI Domain Model +> Source: heuristic discovery from `app/features/data_platform/models.py`, `app/features/registry/models.py`, `app/features/rag/models.py`, `app/features/agents/models.py`, `app/features/forecasting/models.py`, `app/features/jobs/models.py`, `app/shared/seeder/`. [UNVERIFIED] tags mark anything that would benefit from KB cross-check. + +## Bounded Contexts + +| Context | Owns | Anti-Corruption Layer | +|---------|------|-----------------------| +| Data Platform | `store`, `product`, `calendar`, `sales_daily`, `price_history`, `promotion`, `inventory_snapshot_daily` | Ingest layer's natural-key resolution (`store_code` → `store_id`, `sku` → `product_id`) | +| Featuresets | Computed feature matrices (in-memory; not persisted) | Time-cutoff parameter — never reads beyond `cutoff_date` | +| Forecasting | Trained model artifacts on disk (joblib `.pkl`) | Model interface in `examples/models/model_interface.md`; artifact_uri returned to caller | +| Backtesting | Fold results, metrics (returned in response; persisted via Registry) | `SplitConfig` (expanding/sliding, gap, horizon) — `app/features/backtesting/splitter.py` | +| Registry | `model_run`, `run_alias`, `model_artifact` | SHA-256 hash on artifact_uri; status state machine | +| RAG | `rag_source`, `rag_chunk` (with pgvector embedding column) | Content hash for idempotent indexing; embedding dimension fixed per provider | +| Agents | `agent_session` (JSONB message_history) | Pydantic-validated tool args; HITL approval queue | +| Jobs | `job` (JSONB params + result) | Discriminated-union `job_type` (`train`/`predict`/`backtest`) | +| Analytics | None persisted — pure read-aggregates | SQL GROUP BY over `sales_daily` joined to dimensions | +| Seeder ("The Forge") | Generates synthetic rows in Data Platform tables | `Scenario` preset + `DimensionConfig`/`FactsConfig` dataclasses; `dataclasses.replace` for field-precise overrides | + +## Core Aggregates + +### `model_run` (Registry) +- **Root:** `ModelRun(run_id: UUID, status: RunStatus)` +- **Status state machine:** `pending` → `running` → `success` | `failed` → `archived` +- **JSONB fields:** `model_config`, `metrics`, `runtime_info` (Python/numpy/pandas versions captured at training) +- **Invariants:** + - An alias may point only to a `success` run. + - Artifact_uri SHA-256 hash must verify before any consumer trusts it (`GET /registry/runs/{id}/verify`). + - `runtime_info` is immutable after `success`. + +### `agent_session` (Agents) +- **Root:** `AgentSession(session_id: UUID, status: SessionStatus)` +- **Status:** `active` / `awaiting_approval` / `expired` / `closed` (`SessionStatus` enum, `app/features/agents/models.py:24`). Transitions: `ACTIVE → AWAITING_APPROVAL` (sensitive action pending), `AWAITING_APPROVAL → ACTIVE` (on approval/rejection), `ACTIVE → EXPIRED` (on timeout), `ACTIVE → CLOSED` (on explicit close). +- **Invariants:** + - `message_history` JSONB is append-only within a session. + - Tools in `agent_require_approval` block until `POST /agents/sessions/{id}/approve` returns. + - Token budget cap (`agent_max_tokens`) and tool-call cap (`agent_max_tool_calls`) per session. + +### `sales_daily` (Data Platform) +- **Root:** composite `(store_id, product_id, date)` +- **Invariants:** + - `quantity >= 0`, `unit_price >= 0`, `total_amount = quantity * unit_price` (approx; rounding tolerated). + - `store_id`, `product_id`, `date` must reference existing dimension rows. + - Idempotent upsert via `ON CONFLICT (store_id, product_id, date) DO UPDATE` (`app/features/ingest/service.py`). + +## Key Invariants — NEVER violate + +1. **Time safety in features.** `app/features/featuresets/` uses only data at or before `cutoff_date`. Lags via `shift(positive)`, rolling via `shift(1).rolling(...)`, all `groupby` entity-aware. The test `app/features/featuresets/tests/test_leakage.py` is the spec — it MUST keep passing. +2. **Forward-only migrations.** Once an Alembic migration is merged, never edit it. Add a new migration to fix or evolve. +3. **HITL approval gates the agent's mutation surface.** Every tool that writes to the registry (`create_alias`, `archive_run`, …) must be in `agent_require_approval`. Widening the surface without updating that list is a security regression. +4. **Single-host deployable.** No managed cloud service in the core path. `docker-compose up` must continue to be the only prerequisite besides Python + Node. +5. **Pre-1.0 contracts may move.** Pin the version you build against. After `v1.0.0`, full SemVer applies. +6. **Seeder is idempotent + scoped.** Never introduce a "wipe everything" path that isn't behind `--confirm` + scope flag. + +## Ubiquitous Language — use exactly these terms + +| Term | Means | NOT | +|------|-------|-----| +| `store` | Retail location (dimension); composite-key parent of sales | branch, outlet | +| `product` | SKU (dimension); composite-key parent of sales | item, article | +| `sales_daily` | One row per `(store_id, product_id, date)` | order, transaction (those would be finer-grain) | +| `run` | A model training instance tracked in the registry | experiment, job | +| `alias` | A pointer to a `success` run (e.g., `production`, `champion`) | tag, label | +| `session` (agent) | One conversation between user and PydanticAI agent | thread, chat | +| `fold` | One train+test split inside a backtest | iteration | +| `baseline` | A naive / seasonal_naive / moving_average model included for comparison | benchmark, control | +| `lag` | Past value at offset `k` (`shift(k)`) | window | +| `rolling` | Statistic over a trailing window with `shift(1)` to avoid leakage | moving average (only for the MA model name) | +| `chunk` (RAG) | A windowed segment of a source document with its own embedding | section, paragraph | +| `scenario` (seeder) | A YAML or in-code preset (`retail_standard`, `holiday_rush`, …) that wires `DimensionConfig` + `FactsConfig` | template, profile | + +## Event Taxonomy + +None. There is no async event bus by design (`product-vision.md`: not a streaming system). All workflows are request/response or in-process tool-call. + +## Entity Relationship Summary + +``` +store ─────┐ + ├──► sales_daily ◄──── price_history +product ───┤ ◄──── promotion + ├──► inventory_snapshot_daily +calendar ──┘ + +model_run ──owns──► artifact (on disk; SHA-256 verified) +model_run ◄─points-to── run_alias + +rag_source ──owns──► rag_chunk (with pgvector embedding) + +agent_session ──owns──► message_history (JSONB) ──may-contain──► tool_call (pending approval) + +job ──may-reference──► model_run (for train/backtest jobs) +``` + +## Glossary (cross-cutting) + +| Term | Definition | Context | +|------|------------|---------| +| HITL | Human-in-the-loop — agent pauses for `/approve` call | Agents | +| RFC 7807 | `application/problem+json` error envelope | API | +| HNSW | Hierarchical Navigable Small World — pgvector index type | RAG | +| SMAPE | Symmetric Mean Absolute Percentage Error (0–200 scale) | Backtesting metrics | +| WAPE | Weighted Absolute Percentage Error | Backtesting metrics | +| PRP | Project Requirements Plan — the doc that gates a vertical-slice implementation | Workflow | +| INITIAL-N | Discovery-phase doc that precedes a PRP | Workflow | +| "The Forge" | Internal name for the seeder (`app/shared/seeder/`) | Seeder | diff --git a/docs/_base/PIPELINE_CONTRACT.md b/docs/_base/PIPELINE_CONTRACT.md new file mode 100644 index 00000000..4686d5c2 --- /dev/null +++ b/docs/_base/PIPELINE_CONTRACT.md @@ -0,0 +1,108 @@ +# Pipeline Contract +> Generated by: w7_generating-claudemd skill +> Source: `.github/workflows/ci.yml`, `.github/workflows/cd-release.yml`, `.github/workflows/dependency-check.yml`, `.github/workflows/phase-snapshot.yml`, `.github/workflows/schema-validation.yml`, `release-please-config.json`, `.release-please-manifest.json`, `.claude/rules/versioning.md`. +> Last reviewed: 2026-05-11 + +## Required Stages — `ci.yml` (runs on push to `main`/`dev`/`release-please--*` and on PRs to `main`/`dev`) + +| Stage | Job | Blocking | On Fail | Runs On | +|-------|-----|----------|---------|---------| +| Lint + format | `lint` (ruff check + ruff format --check) | YES | Block merge | Every PR push, every branch push | +| Type check (mypy --strict) | `typecheck` | YES | Block merge | Every PR push | +| Type check (pyright --strict) | `typecheck` | YES | Block merge | Every PR push | +| Tests (unit + integration) | `test` against Postgres+pgvector service | YES | Block merge | Every PR push | +| Migration check (upgrade fresh DB) | `migration-check` against Postgres service | YES | Block merge | Every PR push | + +`concurrency.cancel-in-progress: true` on the workflow group — stale runs canceled on new push. + +**NEVER** merge if any of those five jobs is red. **NEVER** add a `--no-verify` / `[skip ci]` path for production merges. + +## Release Pipeline — `cd-release.yml` (runs on push to `main`) + +| Stage | Job | Blocking | Notes | +|-------|-----|----------|-------| +| Release-please open/update Release PR | `release-please` | N/A | Only effect on most pushes | +| Build Python wheel + sdist | `build-package` (gated `if: release_created == 'true'`) | YES on release commits | `uv build`, then upload to GitHub Release tag | +| Upload `dist/` as workflow artifact (90-day retention) | inside `build-package` | YES | `actions/upload-artifact@v7` | + +A "release commit" only occurs when a maintainer merges the open Release PR — which is the only moment release-please tags `vX.Y.Z` and runs the build/upload chain. Pre-1.0 bumps: `feat:` → PATCH, `feat!:` / `BREAKING CHANGE:` → MINOR (`bump-minor-pre-major: true`, `bump-patch-for-minor-pre-major: true`). See `.claude/rules/versioning.md`. + +## Other Workflows + +| Workflow | Trigger | Purpose | Blocking? | +|----------|---------|---------|-----------| +| `dependency-check.yml` | Weekly cron (Sun 00:00 UTC) + `workflow_dispatch` (input `fail_on_vulnerabilities` default `true`) | Python vulnerability scan | No — runs out-of-band, not on PR; failure does not block merges directly | +| `phase-snapshot.yml` | Push to `phase-*` branches | Full validation snapshot (incl. Postgres+pgvector test DB) for phase-tracking branches | No — only fires on `phase-*` branches, informational | +| `schema-validation.yml` | Push / PR with paths under `alembic/**`, `app/**/models.py`, `app/core/database.py` | Schema/migration validation against fresh Postgres+pgvector | Conditional — only triggers on schema-touching changes; should be GREEN before merge when it runs | + +## Merge Conditions (ALL must be true) + +- [ ] `lint` job: GREEN +- [ ] `typecheck` job: GREEN (both mypy and pyright steps) +- [ ] `test` job: GREEN (unit + integration against Postgres service) +- [ ] `migration-check` job: GREEN (Alembic upgrade clean on empty DB) +- [ ] Conventional Commit message validated by `.claude/hooks/check-commit-format.sh` locally +- [ ] At least 1 review on PRs into `main` ([UNVERIFIED — branch protection rules not visible from repo content; confirm in repo settings]) +- [ ] No unresolved review comments +- [ ] Branch up-to-date with `dev` (for `dev` merges) or `main` (for `dev → main` PRs) + +## Failure Gates + +| Condition | Action | +|-----------|--------| +| Any blocking CI job red | Block merge | +| Alembic migration fails to apply on fresh DB | Block merge (caught by `migration-check`) | +| `mypy --strict` or `pyright --strict` error | Block merge | +| `ruff check` fails | Block merge | +| `ruff format --check` fails | Block merge (run `ruff format .` locally) | +| Integration tests fail with stale local Postgres | Run `docker compose down -v && docker compose up -d && uv run alembic upgrade head` and retry | + +## Secret Usage Model + +- App secrets (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GOOGLE_API_KEY`, etc.) live in `.env` locally; never committed (`.env.example` is the schema). +- CI secrets: `RELEASE_PAT` (optional — used to trigger CI on release-please PRs). `GITHUB_TOKEN` is always available. Stored in GitHub repo secrets; never in workflow YAML. +- No agent / pipeline ever rotates a secret automatically. + +## Artifact Contract + +| Artifact | Produced By | Consumed By | TTL | Naming | +|----------|-------------|-------------|-----|--------| +| `dist/*.whl` + `dist/*.tar.gz` | `build-package` job | GitHub Release attachments + workflow artifact | 90 days (workflow artifact); permanent on Release | release-please-pinned `vX.Y.Z` | +| GitHub Release | release-please-action | Humans (download) | Permanent | `vX.Y.Z` | +| Test report (stdout) | `test` job | CI logs | 90 days (GitHub Actions default) | per run | + +## Environment Promotion Rules + +``` +feat|fix|chore|docs|refactor|test/ → PR to dev → merge after CI green +dev → PR to main → CI green → merge +main → release-please opens Release PR +Release PR on main → merge → tag vX.Y.Z + Release + wheel upload +hotfix/ → branch off main → PR to main directly (hotfix flow per branch-naming.md) +``` + +**NEVER** push directly to `main`. **NEVER** `git push --force` `dev` or `main`. + +## Pipeline Configuration Standards + +```yaml +# From ci.yml — pattern to keep when adding new workflows +concurrency: + group: ${{ github.workflow }}-${{ inputs.ref || github.ref }} + cancel-in-progress: true + +env: + PYTHON_VERSION: "3.12" # do not drift from pyproject.toml requires-python + UV_VERSION: "0.5" +``` + +## Rollback Procedure + +There is no live deployment to roll back. To recover from a bad release: + +1. `git revert ` on `dev`. +2. PR `dev` → `main`. +3. Merge the Release PR opened by release-please → new patch tag (`vX.Y.(Z+1)`). +4. (Optional) Mark the bad GitHub Release as a pre-release or add a notice. + +Never delete or move a published tag. diff --git a/docs/_base/REPO_MAP_INDEX.md b/docs/_base/REPO_MAP_INDEX.md new file mode 100644 index 00000000..e34bace0 --- /dev/null +++ b/docs/_base/REPO_MAP_INDEX.md @@ -0,0 +1,82 @@ +# Repo Map Index +> LLM: Read this index first. Load individual docs ONLY when the current task touches +> that domain. Do NOT pre-load all docs. +> +> Last updated: 2026-05-11 | Generator: w7_generating-claudemd skill (heuristic mode) | Docs: 8 + +## System at a Glance + +ForecastLabAI is a portfolio-grade, single-host retail-demand-forecasting system. One developer maintains it; one `docker-compose up` brings it up. The backend is FastAPI + SQLAlchemy 2.0 async against PostgreSQL 16 + pgvector; the frontend is React 19 + Vite + Tailwind 4 + shadcn/ui. Eleven vertical slices under `app/features/` cover the full lifecycle (data platform → ingest → features → forecasting → backtesting → registry → RAG → agents → dashboard surfaces). Pre-1.0; release-please drives SemVer; merges flow `dev` → `main`. + +## Document Index + +> NOTE: `docs/_kB/repo-map/` does NOT exist in this repo. Run the `mapping-repo-context` skill to populate it; once present, this index can be regenerated against the KB. The table below points at the **actual** discovery surface today. + +| File | What it answers | Load when... | +|------|-----------------|--------------| +| [`README.md`](../../README.md) | Canonical quick-start, feature list, endpoint reference, frontend stack | Onboarding, demoing, sanity-checking an endpoint shape | +| [`CLAUDE.md`](../../CLAUDE.md) | Operating index, commands, conventions, safety | Start of every Claude session | +| [`CHANGELOG.md`](../../CHANGELOG.md) | release-please-managed release notes | Investigating when behavior changed | +| [`pyproject.toml`](../../pyproject.toml) | Dependencies, ruff/mypy/pyright/pytest config | Tooling questions, version bumps | +| [`docker-compose.yml`](../../docker-compose.yml) | Local Postgres+pgvector definition | Debugging DB connectivity, ports | +| [`alembic/versions/`](../../alembic/versions/) | Six migrations through `d6e0f2g3h456_create_agent_session_table.py` | DB-schema questions, migration drift | +| [`docs/ARCHITECTURE.md`](../ARCHITECTURE.md) | Phase-by-phase architecture narrative | High-level component reasoning | +| [`docs/PHASE-index.md`](../PHASE-index.md) | Index of all 11 phase docs | Locating per-phase deep-dive | +| [`docs/PHASE/*.md`](../PHASE/) | Per-phase implementation reference | Slice-specific deep dives | +| [`docs/ADR/ADR-INDEX.md`](../ADR/ADR-INDEX.md) | Architectural decision records | Why a tech choice was made | +| [`docs/DAILY-FLOW.md`](../DAILY-FLOW.md) | Developer day-in-the-life loop | Onboarding a contributor | +| [`docs/GIT-GITHUB-GUIDE.md`](../GIT-GITHUB-GUIDE.md) | Branch/PR/release workflow | Anything git/PR-related | +| [`docs/PHASE-FLOW.md`](../PHASE-FLOW.md) | INITIAL → PRP → code pipeline | Authoring new feature requests | +| [`docs/validation/*.md`](../validation/) | Tooling standards (ruff, mypy, pyright, pytest, logging) | Configuring/justifying CI gates | +| [`docs/github/`](../github/) | CI/CD workflow reference + diagrams | Pipeline troubleshooting | +| [`docs/rag-ollama-setup.md`](../rag-ollama-setup.md) | Local-embedding setup | Switching off OpenAI embeddings | +| [`docs/DATA-SEEDER.md`](../DATA-SEEDER.md) | "The Forge" seeder operating guide | Generating / refreshing local data | +| [`PRPs/PRP-*.md`](../../PRPs/) | Per-phase project requirements plans (PRP-0 through PRP-13) | Implementing or extending a phase | +| [`INITIAL-*.md`](../../) (repo root) | Pre-PRP discovery docs | Tracing a feature back to its origin | +| [`.claude/rules/*.md`](../../.claude/rules/) | Project rules (commit-format, branch-naming, security-patterns, product-vision, test-requirements, ui-design, versioning, output-formatting) | Any behavioral decision Claude makes | +| [`.claude/skills/`](../../.claude/skills/) | Slash-command skills (audit-rules-drift, commit-format-check, issue-to-subtasks, repo-visibility-audit, w7_generating-claudemd, …) | Picking the right workflow | +| [`.claude/hooks/check-commit-format.sh`](../../.claude/hooks/check-commit-format.sh) | Pre-commit enforcement of `type(scope): description (#issue)` | Debugging blocked commits | +| [`HANDOFF.md`](../../HANDOFF.md) | Latest session handoff | Resuming context across sessions | +| [`.handoffs/`](../../.handoffs/) | Archived handoffs | Historical session context | +| [`docs/_base/ARCHITECTURE.md`](ARCHITECTURE.md) | System boundaries, components, comm patterns | Architectural changes, blast radius | +| [`docs/_base/API_CONTRACTS.md`](API_CONTRACTS.md) | HTTP + WebSocket endpoint surface | API changes, integration | +| [`docs/_base/RUNBOOKS.md`](RUNBOOKS.md) | Common incidents + resolutions | Debugging, recovery | +| [`docs/_base/SECURITY.md`](SECURITY.md) | Threat model, secrets, scanning | Security review, audit | +| [`docs/_base/RULES.md`](RULES.md) | Change authority + invariants | Any sensitive change | +| [`docs/_base/DOMAIN_MODEL.md`](DOMAIN_MODEL.md) | Aggregates, invariants, ubiquitous language | Naming, modeling, new entity | +| [`docs/_base/DEV_GUIDE.md`](DEV_GUIDE.md) | Human-maintained onboarding (stub) | Onboarding (after a human fills it in) | +| [`docs/_base/PIPELINE_CONTRACT.md`](PIPELINE_CONTRACT.md) | CI/CD stages, merge gates, release flow | CI changes, release planning | + +## Dependency Hotspots (high blast-radius targets) + +| Component | Why it's hot | Risk | +|-----------|--------------|------| +| `app/core/database.py` | Every slice's `service.py` opens a session through it | CRITICAL — breakage cascades | +| `app/core/problem_details.py` | Every error path serializes through it | HIGH — affects all error responses | +| `app/main.py` | Wires every router; central CORS + middleware | HIGH — wiring regression blocks the API | +| `app/features/data_platform/models.py` | Every fact table FK lands here | HIGH — migration drift breaks many tests | +| `app/features/featuresets/tests/test_leakage.py` | Load-bearing spec | HIGH — weakening it lets leakage land | +| `app/features/agents/service.py` + `tools/*` | HITL gates + tool wiring | HIGH — security boundary | +| `alembic/versions/` | Forward-only; CI `migration-check` enforces | HIGH — edit-after-merge corrupts deploys | + +## Tech Stack Snapshot + +| Category | Technology | Status | +|----------|------------|--------| +| Backend language | Python 3.12 | Pinned | +| Backend framework | FastAPI ≥ 0.115 + Pydantic v2 + SQLAlchemy 2.0 async | Pinned | +| ML / data | pandas ≥ 3, numpy ≥ 2.4, scikit-learn ≥ 1.6, joblib, LightGBM (opt-in) | Pinned | +| Agents | PydanticAI ≥ 1.80, anthropic ≥ 0.50, openai ≥ 1.40 | Pinned | +| RAG | pgvector ≥ 0.3, tiktoken ≥ 0.7, Ollama (optional) | Pinned | +| Frontend | React 19, Vite 7, TypeScript 5.9, Tailwind 4, shadcn/ui (New York) | Pinned | +| Primary DB | PostgreSQL 16 + pgvector (`pgvector/pgvector:pg16`) | Pinned | +| Package manager | uv (Python), pnpm via corepack (JS) | Pinned | +| IaC | none — `docker-compose` single-host | By design (`product-vision.md`) | +| Orchestration | none — single uvicorn process | By design | +| CI/CD | GitHub Actions + release-please | Pinned | + +## Index Update History (last 5 generations) + +| Date | Change | Who | +|------|--------|-----| +| 2026-05-11 | Initial generation (heuristic mode — `docs/_kB/repo-map/` absent) | w7_generating-claudemd skill | diff --git a/docs/_base/RULES.md b/docs/_base/RULES.md new file mode 100644 index 00000000..b940017e --- /dev/null +++ b/docs/_base/RULES.md @@ -0,0 +1,70 @@ +# ForecastLabAI Rules +> Generated by: w7_generating-claudemd skill +> Source of truth: `.claude/rules/` directory (commit-format, branch-naming, security-patterns, product-vision, test-requirements, ui-design, versioning, output-formatting). This file consolidates the constraint matrix; the rule files are authoritative on detail. +> Last reviewed: 2026-05-11 + +## Change Authority Matrix + +| Change Type | Who Can Approve | Gate Required | +|-------------|-----------------|---------------| +| Code merged to `dev` | Maintainer | CI green (ruff + mypy + pyright + pytest + migration-check) | +| `dev` → `main` PR | Maintainer | CI green, release-please opens Release PR | +| Release tag (`vX.Y.Z`) | Maintainer merges the Release PR | release-please-action creates tag automatically | +| New external service in `app/` core path | Vision check (`.claude/rules/product-vision.md`) | PRP + ADR required (this is a single-host system) | +| New Alembic migration | Maintainer | Must apply + roll back on a fresh DB (CI `migration-check` enforces) | +| New API endpoint | Maintainer | Route test + happy path + 1 error path (per `test-requirements.md`) | +| New PydanticAI tool that mutates state | Maintainer | Add tool name to `agent_require_approval` config | +| Touching `app/features/featuresets/tests/test_leakage.py` | Maintainer + reviewer | Tests are the leakage spec — never weaken to make a feature pass | +| `git push --force` on `dev` / `main` | **Nobody** | Forbidden by `security-patterns.md` | + +## Hard Rules — Treat as Invariants + +**Never violate, never suggest violating.** + +- NEVER skip `mypy --strict` or `pyright --strict` — both gate merge. +- NEVER skip `ruff check` or `ruff format --check`. +- NEVER commit `.env`, secrets, API keys, credentials in URLs, or `.env`-derived values in logs. +- NEVER concatenate user input into SQL — only SQLAlchemy parameter binding. +- NEVER use `eval` / `exec` / `subprocess(shell=True, …user_input)` / `pickle.loads` on untrusted data. +- NEVER set `verify=False` on `httpx.AsyncClient` / `openai.AsyncClient`. +- NEVER edit a merged Alembic migration — create a new one (migrations are forward-only after merge). +- NEVER add an AI co-author trailer to a commit (`Co-Authored-By: Claude …`, `🤖 Generated with …`). +- NEVER weaken `app/features/featuresets/tests/test_leakage.py` to make a feature pass — the test is the spec. +- NEVER widen agent mutation surface without adding the tool to `agent_require_approval`. +- NEVER `git push --force` on `dev` or `main`. +- NEVER add a managed-cloud SDK (AWS/GCP/Azure) to `app/` core — violates single-host vision. +- NEVER mock the database in integration tests — they must run against the real `docker-compose` Postgres (`-m integration`). + +## Conventions (enforced by `.claude/rules/`) + +- **Commits:** `type(scope): description (#issue)`. Type ∈ `{feat, fix, docs, refactor, test, chore, release}`. Scope from allow-list (`api`, `ui`, `data`, `ingest`, `features`, `forecast`, `backtest`, `registry`, `rag`, `agents`, `dimensions`, `analytics`, `jobs`, `db`, `ci`, `docs`, `repo`, `release`). Comma-pairs allowed for cross-cutting (`feat(api,ui): …`). Every commit references an open GitHub issue. Hook `.claude/hooks/check-commit-format.sh` enforces. +- **Branches:** `/` off `dev` (off `main` for `hotfix`). Slug ≤ 50 chars, no issue number. One branch per issue. +- **Architecture:** Every domain under `app/features//{models,schemas,service,routes,tests}.py`. Cross-slice imports through `app/shared/` or `app/core/` only. +- **Errors:** RFC 7807 `application/problem+json` via `app/core/problem_details.py`. No bare 500s, no ad-hoc error shapes. +- **Output formatting (for skills/reports):** see `.claude/rules/output-formatting.md` — emoji status indicators, box-drawing separators, capped at 40 lines. +- **UI:** Use the skills in `.claude/rules/ui-design.md` (stitch-design, frontend-design, webapp-testing). Never hand-roll UI when a skill applies. + +## Forbidden Patterns + +- `--force` push on `dev` or `main`. +- Raw SQL string interpolation. +- `os.environ[...]` in feature code (use `app.core.config.get_settings()`). +- `os.path` instead of `pathlib.Path` (ruff `PTH` enforces). +- Capitalized commit descriptions or trailing periods. +- Editing a merged migration file. +- Mocking the DB in tests marked `@pytest.mark.integration`. +- Adding managed-cloud SDKs to `app/` core path. +- New "wipe everything" operations (the system has no destructive bulk-delete by design). + +## Compliance Constraints + +[ASSUMPTION] No compliance frameworks in scope — portfolio repo, no regulated data. See `docs/_base/SECURITY.md` for the table. + +## Escalation Path + +| Situation | Contact | Channel | +|-----------|---------|---------| +| Unsure about a change | Maintainer (Gabor Szabo) | GitHub PR review | +| Security concern | Maintainer | Open a private GitHub Security Advisory | +| Architecture / vision change | Maintainer | New `INITIAL-*.md` + `PRPs/PRP-*.md` per `product-vision.md` | +| Break-glass required | N/A — no production environment | Use `scripts/seed_random.py --delete --confirm` to reset local state | diff --git a/docs/_base/RUNBOOKS.md b/docs/_base/RUNBOOKS.md new file mode 100644 index 00000000..40c8b974 --- /dev/null +++ b/docs/_base/RUNBOOKS.md @@ -0,0 +1,102 @@ +# ForecastLabAI Runbooks +> Source: heuristic discovery from `docker-compose.yml`, `app/main.py`, `app/core/config.py`, `HANDOFF.md`, `.claude/rules/`. Operational scope is intentionally small — this is a single-host portfolio system. + +## Common Incidents + +### Frontend shows "Loading..." everywhere +**Symptoms:** Pages mount but every TanStack Query hook stays in pending state. +**Likely causes:** +1. `frontend/.env` `VITE_API_BASE_URL` points at a LAN host the browser can't reach. +2. Backend not running on `:8123`. +3. CORS rejected the origin. +**Diagnosis:** +```bash +curl -s http://localhost:8123/health # backend reachable? +grep VITE_API_BASE_URL frontend/.env # what URL is the SPA calling? +``` +**Resolution:** +1. Edit `frontend/.env` → `VITE_API_BASE_URL=http://localhost:8123`. +2. Restart Vite: `cd frontend && ./node_modules/.bin/vite --host 0.0.0.0`. +3. If CORS error in browser console, add the origin to `app/main.py` `allow_origins` (dev-only LAN regex already covers `10.x` / `192.168.x` / `172.16-31.x`). + +### Database connection refused +**Symptoms:** `asyncpg.exceptions.CannotConnectNowError` or `ConnectionRefusedError` on first request. +**Diagnosis:** +```bash +docker compose ps # postgres container running? +docker compose logs postgres | tail -50 +``` +**Resolution:** +```bash +docker compose up -d # bring it up +uv run alembic upgrade head # apply migrations +uv run python scripts/check_db.py # confirm connectivity +``` + +### Tests pass locally but fail in CI on a fresh DB +**Symptoms:** Integration tests pass on the dev host (which has stale seeded data) and fail in CI. +**Diagnosis:** Integration tests must be idempotent — they may not assume pre-existing rows. +**Resolution:** +```bash +docker compose down -v && docker compose up -d +uv run alembic upgrade head +uv run pytest -v -m integration +``` + +### `pnpm dev` re-runs install and errors on esbuild +**Symptoms:** pnpm 11 `depsStatusCheck` reinstalls and blocks the esbuild postinstall script. +**Workaround:** `./node_modules/.bin/vite --host 0.0.0.0` directly. Permanent fix: add `pnpm.onlyBuiltDependencies: ["esbuild"]` to `frontend/package.json`. + +### Settings tests fail because they pick up the local `.env` +**Symptoms:** `app/core/tests/test_config.py::test_settings_has_defaults` and a few `agents/tests/test_config_validation.py` cases fail when `.env` exists. +**Root cause:** `Settings()` reads `.env` via `SettingsConfigDict(env_file=".env")`. +**Fix:** Use `Settings(_env_file=None)` in those tests to bypass `.env`. + +### `.venv` or `frontend/node_modules` binaries become corrupt (WSL only) +**Symptoms:** `python` / `tsc` are reported as `IntxLNK` data blobs; `uv run` / `tsc --noEmit` fails with `cannot execute binary file`. +**Resolution:** +```bash +rm -rf .venv && uv sync --extra dev +rm -rf frontend/node_modules && corepack enable pnpm && cd frontend && pnpm install && pnpm rebuild esbuild +``` + +## Break-Glass Procedures + +There is no "production" — break-glass is N/A. The closest equivalent is the seeder: + +### Reset to a known-good seeded state +```bash +uv run python scripts/seed_random.py --delete --confirm +uv run python scripts/seed_random.py --full-new --seed 42 --confirm +uv run python scripts/seed_random.py --verify +``` + +## Secret Rotation + +There are no managed secrets — keys live in the developer's `.env`. Rotation = edit `.env`, restart `uvicorn`. Never commit `.env`. `.env.example` is the canonical schema; new env vars must land there first. + +## Release / Rollback + +### Cut a release +```bash +# from dev, ensure CI green +gh pr create --base main --head dev --title "release: ..." +# merge PR → release-please opens "chore(main): release X.Y.Z" PR on main +# merge that PR → release-please tags vX.Y.Z and cd-release.yml uploads wheel +``` + +### Rollback a release +```bash +# undo a tag is destructive; prefer cutting a new patch release with the fix +git revert +gh pr create --base dev --head fix/ +# proceed through normal release flow → vX.Y.(Z+1) +``` + +Never `git push --force` on `dev` or `main` (see `.claude/rules/security-patterns.md`). + +## Logs & Debugging + +- Backend logs: stdout from `uvicorn` (JSON in `production`, console in `development`). Each request carries an `X-Request-ID` header echoed in error bodies (`request_id` field) — grep logs by that ID. +- Frontend network errors: open browser devtools → Network tab → check `/health`, then the failing endpoint's status + RFC 7807 body. +- Agent issues: check `app/features/agents/models.py` `agent_session` table — `message_history` JSONB has the full transcript, including tool calls and pending approvals. diff --git a/docs/_base/SECURITY.md b/docs/_base/SECURITY.md new file mode 100644 index 00000000..59943787 --- /dev/null +++ b/docs/_base/SECURITY.md @@ -0,0 +1,96 @@ +# ForecastLabAI Security +> Source: `.claude/rules/security-patterns.md` (authoritative), `app/core/config.py`, `app/main.py`, `.github/workflows/`. [ASSUMPTION] compliance scope is **none** (portfolio repo). + +## Threat Model (Scope) + +ForecastLabAI is a **single-tenant, single-host** portfolio demo. There is no auth, no RBAC, no multi-tenancy. The threat surface is: + +1. Untrusted query input flowing through SQLAlchemy → could enable SQLi without parameter binding. +2. LLM-controlled tool calls (PydanticAI agents) → could mutate registry/aliases without HITL approval. +3. RAG retrieval echoing untrusted document content → potential prompt-injection vector into agent context. +4. External provider API keys in `.env` → could leak via logs or commit. +5. Frontend CORS misconfiguration → could expose dev-only endpoints to attacker-controlled origins. + +## Hard Rules (from `.claude/rules/security-patterns.md`) + +These are enforced on every PR. Violations must be fixed before merge. + +- **Forbidden:** `eval` / `exec` / `compile` on user input; `subprocess(shell=True, …user_input)`; raw SQL concat; `pickle.loads` on untrusted data; `verify=False` on httpx/openai clients; hardcoded secrets; credentials in git URLs; logging full prompts/responses; path traversal via `..`. +- **Required:** Pydantic v2 validation at every boundary; SQLAlchemy 2.0 parameter binding; `pathlib.Path.resolve()` for file ops; `yaml.safe_load`; RFC 7807 error shape; structured logging without secret values. + +## Secrets Management + +| Item | Storage | Loaded By | Rotation | +|------|---------|-----------|----------| +| `OPENAI_API_KEY` | `.env` (not committed) | `app.core.config.Settings` | Manual; edit `.env`, restart uvicorn | +| `ANTHROPIC_API_KEY` | `.env` | `Settings` | Manual | +| `GOOGLE_API_KEY` | `.env` (optional) | `Settings` | Manual | +| `DATABASE_URL` | `.env` or default localhost | `Settings` | N/A (local docker-compose) | + +Two-file model (mandatory): +- `.env.example` — committed schema with placeholders, every new var added here first. +- `.env` — real values, **NEVER** committed (`.gitignore`d). + +Never log decrypted values, even at DEBUG. Log key NAMES only (`openai_api_key_set=bool(s.openai_api_key)`). + +## Input Validation + +- Every FastAPI endpoint validates input via Pydantic v2 — no raw `Body(Any)`. +- Every agent tool input validated by Pydantic before execution (`app/features/agents/tools/`). +- LLM **responses** are not trusted: structured outputs parsed via Pydantic; freeform text never executed. +- Allow-lists over deny-lists (e.g., `model_type ∈ {naive, seasonal_naive, moving_average, lightgbm}`; embedding provider ∈ `{openai, ollama}`; model identifier provider ∈ `{anthropic, openai, google-gla, google-vertex}`). + +## Network Security + +- Backend binds `0.0.0.0:8123` by default (`api_host` / `api_port` in `Settings`). [UNVERIFIED — fine on a personal LAN; would need a reverse proxy + TLS for any public exposure.] +- CORS allow-list in `app/main.py`: dev permits `localhost`/`127.0.0.1`/private LAN ranges via regex; **production sets explicit origins via empty list + no regex** — review before any non-dev deploy. +- No TLS at the app layer; rely on `docker-compose` private network for DB. Postgres password is the dev default `forecastlab/forecastlab` — change if exposing the host. + +## LLM / Agent Security + +- Token budget cap per session (`agent_max_tokens=4096` default). +- Tool-call cap per session (`agent_max_tool_calls=10` default). +- Timeout wrap around `agent.run()` / `agent.run_stream()` (`agent_timeout_seconds=120`). +- HITL approval required for mutating tools — `agent_require_approval=["create_alias","archive_run"]`. Never widen the agent's mutation surface without adding the new tool name to that list. +- Never log full prompts/responses at INFO; DEBUG only with explicit operator opt-in. + +## External Integrations Security + +| Integration | Auth | Data Sent | Note | +|-------------|------|-----------|------| +| OpenAI embeddings | API key | Document chunks (markdown / openapi) | No PII in indexed corpus — corpus is project's own docs | +| OpenAI / Anthropic / Gemini agent LLM | API key | User chat messages + tool descriptions + tool results | Chat messages may contain user-supplied text | +| Ollama embeddings | none (LAN) | Document chunks | Local; preferred for keeping data off external services | + +## CI / Workflow Security + +| Workflow | Pinning | Notes | +|----------|---------|-------| +| `ci.yml` | `actions/checkout@v6`, `astral-sh/setup-uv@v7` | First-party `actions/*` may use major-version per rule | +| `cd-release.yml` | `actions/checkout@v6`, `actions/upload-artifact@v7` (first-party, major-pin OK) **+** `googleapis/release-please-action@v5`, `astral-sh/setup-uv@v7` (**third-party, major-pinned**) | ⚠️ The two third-party actions violate `security-patterns.md` ("Pin third-party GitHub Actions by full 40-char SHA"). Open issue to SHA-pin both, with the `# vX.Y.Z` comment trailer per rule. | +| `dependency-check.yml`, `phase-snapshot.yml`, `schema-validation.yml` | Same first-party `actions/*` + `astral-sh/setup-uv@v7` pattern as the others | Same third-party major-pin gap on `astral-sh/setup-uv@v7` — covered by the same SHA-pin issue | + +Dependabot watches `.github/workflows/` weekly (`.github/dependabot.yml`) — keep its PRs current. + +## Scanning & Compliance + +| Check | Tool | Frequency | Blocks Merge? | +|-------|------|-----------|---------------| +| Lint + format | ruff | every PR | Yes | +| Type check | mypy --strict + pyright --strict | every PR | Yes | +| Unit tests | pytest | every PR | Yes | +| Integration tests | pytest -m integration against Postgres service | every PR | Yes | +| Migration apply check | alembic upgrade head on fresh DB | every PR | Yes | +| Dependency audit | `.github/workflows/dependency-check.yml` | Weekly cron (Sun 00:00 UTC) + manual dispatch | No (out-of-band; not a per-PR gate) — but `fail_on_vulnerabilities` input defaults `true` | +| Secrets detection | none configured [UNVERIFIED — consider adding gitleaks pre-commit hook] | — | — | + +## Compliance Constraints + +| Framework | Applies | Note | +|-----------|---------|------| +| PCI-DSS | No | No card data | +| SOC 2 | No | Portfolio repo | +| GDPR / PII | No | Seeded synthetic data only | +| HIPAA | No | No health data | + +[ASSUMPTION] confirmed via Phase 2 question. Re-evaluate if scope changes. diff --git a/frontend/src/pages/admin.tsx b/frontend/src/pages/admin.tsx index c01e6519..e3087f57 100644 --- a/frontend/src/pages/admin.tsx +++ b/frontend/src/pages/admin.tsx @@ -1,5 +1,6 @@ -import { useState } from 'react' +import { useEffect, useState } from 'react' import { format } from 'date-fns' +import type { DateRange } from 'react-day-picker' import { Trash2, Plus, @@ -26,6 +27,7 @@ import { useDeleteData, useVerifyData, } from '@/hooks/use-seeder' +import { DateRangePicker } from '@/components/common/date-range-picker' import { ErrorDisplay } from '@/components/common/error-display' import { LoadingState } from '@/components/common/loading-state' import { Button } from '@/components/ui/button' @@ -381,6 +383,40 @@ function AliasesPanel() { ) } +const SEEDER_FORM_STORAGE_KEY = 'forecastlab.seederForm.v1' + +interface SeederFormState { + scenario: string + startDate: string // ISO yyyy-MM-dd + endDate: string + stores: number + products: number + seed: number + sparsity: number +} + +const DEFAULT_SEEDER_FORM: SeederFormState = { + scenario: 'retail_standard', + startDate: '2024-01-01', + endDate: '2024-12-31', + stores: 10, + products: 50, + seed: 42, + sparsity: 0, +} + +function loadSeederForm(): SeederFormState { + if (typeof window === 'undefined') return DEFAULT_SEEDER_FORM + try { + const raw = window.localStorage.getItem(SEEDER_FORM_STORAGE_KEY) + if (!raw) return DEFAULT_SEEDER_FORM + const parsed = JSON.parse(raw) as Partial + return { ...DEFAULT_SEEDER_FORM, ...parsed } + } catch { + return DEFAULT_SEEDER_FORM + } +} + function SeederPanel() { const { data: status, isLoading, error, refetch } = useSeederStatus() const { data: scenarios } = useSeederScenarios() @@ -388,7 +424,7 @@ function SeederPanel() { const deleteMutation = useDeleteData() const verifyMutation = useVerifyData() - const [selectedScenario, setSelectedScenario] = useState('retail_standard') + const [form, setForm] = useState(loadSeederForm) const [deleteDialogOpen, setDeleteDialogOpen] = useState(false) const [verifyResult, setVerifyResult] = useState<{ passed: boolean @@ -398,13 +434,43 @@ function SeederPanel() { failed_count: number } | null>(null) + // Persist form state across reloads (operator-friendly). + useEffect(() => { + if (typeof window === 'undefined') return + window.localStorage.setItem(SEEDER_FORM_STORAGE_KEY, JSON.stringify(form)) + }, [form]) + + const dateRange: DateRange | undefined = form.startDate + ? { + from: new Date(`${form.startDate}T00:00:00`), + to: form.endDate ? new Date(`${form.endDate}T00:00:00`) : undefined, + } + : undefined + + const handleDateRangeChange = (range: DateRange | undefined) => { + setForm((f) => ({ + ...f, + startDate: range?.from ? format(range.from, 'yyyy-MM-dd') : f.startDate, + endDate: range?.to ? format(range.to, 'yyyy-MM-dd') : f.endDate, + })) + } + + const handleResetForm = () => setForm(DEFAULT_SEEDER_FORM) + const handleGenerate = async () => { try { const result = await generateMutation.mutateAsync({ - scenario: selectedScenario, + scenario: form.scenario, + seed: form.seed, + stores: form.stores, + products: form.products, + start_date: form.startDate, + end_date: form.endDate, + sparsity: form.sparsity, }) toast.success( - `Generated ${result.records_created.sales?.toLocaleString() ?? 0} sales records in ${result.duration_seconds.toFixed(1)}s` + `Generated ${result.records_created.sales?.toLocaleString() ?? 0} sales records ` + + `(${form.startDate} → ${form.endDate}) in ${result.duration_seconds.toFixed(1)}s` ) } catch (err) { toast.error(err instanceof Error ? err.message : 'Generation failed') @@ -540,23 +606,114 @@ function SeederPanel() { -
- - +
+
+ + +
+ +
+ + +
+ +
+ + + setForm((f) => ({ + ...f, + stores: Math.max(1, Math.min(100, Number(e.target.value) || 1)), + })) + } + /> +
+ +
+ + + setForm((f) => ({ + ...f, + products: Math.max(1, Math.min(500, Number(e.target.value) || 1)), + })) + } + /> +
+ +
+ + + setForm((f) => ({ + ...f, + seed: Math.max(0, Number(e.target.value) || 0), + })) + } + /> +
+ +
+ + + setForm((f) => ({ ...f, sparsity: Number(e.target.value) })) + } + /> +
+
+ +
+