From 4a2e5dc81b05ebda6330070da1f9b59f07e4b1e1 Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 14:02:26 +0200 Subject: [PATCH 1/2] fix(jobs,ui): reach model_exogenous from the what-if planner (#229) Add a regression branch to JobService._execute_train so a train job can train a regression model and complete with result.run_id set. Rewire the What-If Planner baseline picker to train jobs and read the artifact key from job.result.run_id, making the model_exogenous re-forecast path reachable from the browser. --- app/features/jobs/schemas.py | 2 +- app/features/jobs/service.py | 7 +++ app/features/jobs/tests/test_service.py | 65 +++++++++++++++++++++++- frontend/src/pages/visualize/planner.tsx | 16 ++++-- 4 files changed, 83 insertions(+), 7 deletions(-) diff --git a/app/features/jobs/schemas.py b/app/features/jobs/schemas.py index 0f411dfa..b7e9663d 100644 --- a/app/features/jobs/schemas.py +++ b/app/features/jobs/schemas.py @@ -25,7 +25,7 @@ class JobCreate(BaseModel): **Job Types and Required Params**: - **train**: Train a forecasting model - - `model_type`: Required - 'naive', 'seasonal_naive', 'linear_regression', etc. + - `model_type`: Required - 'naive', 'seasonal_naive', 'moving_average', 'regression'. - `store_id`: Required - Store ID from /dimensions/stores - `product_id`: Required - Product ID from /dimensions/products - `start_date`: Required - Training data start (YYYY-MM-DD) diff --git a/app/features/jobs/service.py b/app/features/jobs/service.py index 528bb20e..32a66308 100644 --- a/app/features/jobs/service.py +++ b/app/features/jobs/service.py @@ -426,6 +426,7 @@ async def _execute_train( from app.features.forecasting.schemas import ( MovingAverageModelConfig, NaiveModelConfig, + RegressionModelConfig, SeasonalNaiveModelConfig, ) from app.features.forecasting.service import ForecastingService @@ -457,6 +458,12 @@ async def _execute_train( elif model_type == "moving_average": window_size = params.get("window_size", 7) config = MovingAverageModelConfig(window_size=window_size) + elif model_type == "regression": + config = RegressionModelConfig( + max_iter=params.get("max_iter", 200), + learning_rate=params.get("learning_rate", 0.05), + max_depth=params.get("max_depth", 6), + ) else: msg = f"Unsupported model_type: {model_type}" raise ValueError(msg) diff --git a/app/features/jobs/tests/test_service.py b/app/features/jobs/tests/test_service.py index cd05540d..18d10c56 100644 --- a/app/features/jobs/tests/test_service.py +++ b/app/features/jobs/tests/test_service.py @@ -8,6 +8,11 @@ import math from datetime import date +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession from app.features.backtesting.schemas import ( BacktestResponse, @@ -16,7 +21,9 @@ SplitBoundary, SplitConfig, ) -from app.features.jobs.service import _finite, _shape_backtest_result +from app.features.forecasting.schemas import RegressionModelConfig, TrainResponse +from app.features.forecasting.service import ForecastingService +from app.features.jobs.service import JobService, _finite, _shape_backtest_result def _fold(idx: int, mae: float, smape: float, wape: float, bias: float) -> FoldResult: @@ -158,3 +165,59 @@ def test_finite_coerces_non_finite_values() -> None: assert _finite(math.nan) == 0.0 assert _finite(math.inf) == 0.0 assert _finite(-math.inf) == 0.0 + + +# ============================================================================= +# _execute_train regression-model support (#229) +# ============================================================================= + + +def _fake_train_response(model_type: str) -> TrainResponse: + """Build a TrainResponse stub for mocking ForecastingService.train_model.""" + return TrainResponse( + store_id=1, + product_id=1, + model_type=model_type, + model_path="/data/artifacts/model_abc123def456.joblib", + config_hash="cfg-hash", + n_observations=400, + train_start_date=date(2024, 1, 1), + train_end_date=date(2024, 12, 31), + duration_ms=12.0, + ) + + +_REGRESSION_PARAMS: dict[str, Any] = { + "model_type": "regression", + "store_id": 1, + "product_id": 1, + "start_date": "2024-01-01", + "end_date": "2024-12-31", +} + + +async def test_execute_train_builds_regression_config() -> None: + """A train job with model_type='regression' builds a RegressionModelConfig (#229).""" + fake = _fake_train_response("regression") + with patch.object( + ForecastingService, "train_model", new=AsyncMock(return_value=fake) + ) as mock_train: + result = await JobService()._execute_train( + db=cast(AsyncSession, AsyncMock()), + params=_REGRESSION_PARAMS, + ) + assert mock_train.call_args is not None + config = mock_train.call_args.kwargs["config"] + assert isinstance(config, RegressionModelConfig) + assert result["model_type"] == "regression" + # run_id is parsed from the model_abc123def456.joblib artifact path. + assert result["run_id"] == "abc123def456" + + +async def test_execute_train_rejects_unsupported_model_type() -> None: + """_execute_train still rejects a genuinely unsupported model_type (e.g. lightgbm).""" + with pytest.raises(ValueError, match="Unsupported model_type"): + await JobService()._execute_train( + db=cast(AsyncSession, AsyncMock()), + params={**_REGRESSION_PARAMS, "model_type": "lightgbm"}, + ) diff --git a/frontend/src/pages/visualize/planner.tsx b/frontend/src/pages/visualize/planner.tsx index 24aa3194..710a3fe8 100644 --- a/frontend/src/pages/visualize/planner.tsx +++ b/frontend/src/pages/visualize/planner.tsx @@ -73,8 +73,11 @@ export default function WhatIfPlannerPage() { const [selectedJobId, setSelectedJobId] = useState('') const [horizon, setHorizon] = useState(14) const { data: job } = useJob(selectedJobId, !!selectedJobId) - // A predict job's params.run_id is the baseline model artifact key. - const baselineRunId = typeof job?.params?.run_id === 'string' ? job.params.run_id : null + // A completed `train` job stores result.run_id — the model-artifact key + // POST /scenarios/simulate resolves. (This is NOT a registry run id.) + // A `regression` baseline routes the simulate call down the model_exogenous + // re-forecast branch; other model types fall back to the heuristic factor. + const baselineRunId = typeof job?.result?.run_id === 'string' ? job.result.run_id : null // -- Assumption form state --------------------------------------------- const [priceEnabled, setPriceEnabled] = useState(false) @@ -245,12 +248,15 @@ export default function WhatIfPlannerPage() { 1. Pick a baseline - Choose a completed prediction job — its model is the baseline this scenario adjusts. + Choose a completed training job — its model is the baseline this scenario + adjusts. A regression baseline is genuinely re-forecast through the model + (model-driven); naive, seasonal-naive and moving-average baselines use a + heuristic adjustment factor. {selectedJobId && !baselineRunId && (

- The selected job has no model artifact — pick a completed predict job. + The selected job has no model artifact — pick a completed train job.

)}
From 34104c9000b39fb701e6a642c91903a68cf3af3d Mon Sep 17 00:00:00 2001 From: Gabor Szabo Date: Tue, 19 May 2026 14:02:32 +0200 Subject: [PATCH 2/2] fix(ui): block planner runs with empty assumption dates (#228) Add a pure assumptionDateErrors helper (unit-tested) and use it in the What-If Planner to disable Run simulation and Save as plan while any enabled Price/Promotion assumption has a blank date, with inline Required hints. The empty-string-date 422 is now structurally unreachable from the form. --- frontend/src/lib/scenario-utils.test.ts | 59 ++++++++++++++++++++++++ frontend/src/lib/scenario-utils.ts | 38 +++++++++++++++ frontend/src/pages/visualize/planner.tsx | 42 +++++++++++++++-- 3 files changed, 135 insertions(+), 4 deletions(-) diff --git a/frontend/src/lib/scenario-utils.test.ts b/frontend/src/lib/scenario-utils.test.ts index b091470a..4331619b 100644 --- a/frontend/src/lib/scenario-utils.test.ts +++ b/frontend/src/lib/scenario-utils.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from 'vitest' import { + assumptionDateErrors, buildMultiSeries, coverageLabel, coverageVariant, @@ -143,3 +144,61 @@ describe('methodLabel', () => { expect(methodLabel('model_exogenous')).toBe('Model-driven') }) }) + +describe('assumptionDateErrors', () => { + const NONE = { + priceEnabled: false, + priceStart: '', + priceEnd: '', + promoEnabled: false, + promoStart: '', + promoEnd: '', + } + + it('reports no errors when nothing is enabled', () => { + expect(assumptionDateErrors(NONE).hasErrors).toBe(false) + }) + + it('flags both price dates when price is enabled and blank', () => { + const e = assumptionDateErrors({ ...NONE, priceEnabled: true }) + expect(e.priceStart).toBe(true) + expect(e.priceEnd).toBe(true) + expect(e.hasErrors).toBe(true) + }) + + it('clears price errors once both dates are filled', () => { + const e = assumptionDateErrors({ + ...NONE, + priceEnabled: true, + priceStart: '2026-07-01', + priceEnd: '2026-07-14', + }) + expect(e.hasErrors).toBe(false) + }) + + it('flags only the blank promotion date', () => { + const e = assumptionDateErrors({ + ...NONE, + promoEnabled: true, + promoStart: '2026-07-01', + promoEnd: '', + }) + expect(e.promoStart).toBe(false) + expect(e.promoEnd).toBe(true) + expect(e.hasErrors).toBe(true) + }) + + it('isolates errors per assumption (price ok, promo blank)', () => { + const e = assumptionDateErrors({ + priceEnabled: true, + priceStart: '2026-07-01', + priceEnd: '2026-07-14', + promoEnabled: true, + promoStart: '', + promoEnd: '', + }) + expect(e.priceStart).toBe(false) + expect(e.promoStart).toBe(true) + expect(e.hasErrors).toBe(true) + }) +}) diff --git a/frontend/src/lib/scenario-utils.ts b/frontend/src/lib/scenario-utils.ts index bbd7b7ed..cbe7795e 100644 --- a/frontend/src/lib/scenario-utils.ts +++ b/frontend/src/lib/scenario-utils.ts @@ -135,3 +135,41 @@ export function buildMultiSeries(comparison: MultiScenarioComparison): MultiSeri export function methodLabel(method: 'heuristic' | 'model_exogenous'): string { return method === 'model_exogenous' ? 'Model-driven' : 'Heuristic' } + +/** Form state for the date-bearing assumptions (price, promotion). */ +export interface AssumptionDateState { + priceEnabled: boolean + priceStart: string + priceEnd: string + promoEnabled: boolean + promoStart: string + promoEnd: string +} + +/** Which enabled assumption date inputs are still blank. */ +export interface AssumptionDateErrors { + priceStart: boolean + priceEnd: boolean + promoStart: boolean + promoEnd: boolean + hasErrors: boolean +} + +/** + * Flag every enabled Price/Promotion assumption whose From/To date is blank. + * The planner blocks Run/Save while `hasErrors` is true so the backend never + * receives an empty-string date (which fails Pydantic date validation → 422). + */ +export function assumptionDateErrors(state: AssumptionDateState): AssumptionDateErrors { + const priceStart = state.priceEnabled && !state.priceStart + const priceEnd = state.priceEnabled && !state.priceEnd + const promoStart = state.promoEnabled && !state.promoStart + const promoEnd = state.promoEnabled && !state.promoEnd + return { + priceStart, + priceEnd, + promoStart, + promoEnd, + hasErrors: priceStart || priceEnd || promoStart || promoEnd, + } +} diff --git a/frontend/src/pages/visualize/planner.tsx b/frontend/src/pages/visualize/planner.tsx index 710a3fe8..077ce6e2 100644 --- a/frontend/src/pages/visualize/planner.tsx +++ b/frontend/src/pages/visualize/planner.tsx @@ -35,6 +35,7 @@ import { import { downloadCsv, toCsv } from '@/lib/csv-export' import { formatCurrency, formatNumber, getErrorMessage } from '@/lib/api' import { + assumptionDateErrors, buildMultiSeries, coverageLabel, coverageVariant, @@ -100,6 +101,19 @@ export default function WhatIfPlannerPage() { const [lifecycleStage, setLifecycleStage] = useState<(typeof LIFECYCLE_STAGES)[number]>('maturity') + // -- Derived validation ------------------------------------------------ + // Enabling Price/Promotion without filling both dates would submit empty + // strings — Pydantic date validation rejects those with an RFC 7807 422. + // Gate Run/Save on this so the form can never produce that request (#228). + const dateErrors = assumptionDateErrors({ + priceEnabled, + priceStart, + priceEnd, + promoEnabled, + promoStart, + promoEnd, + }) + // -- Results / persistence state --------------------------------------- const [simulated, setSimulated] = useState(null) const [planName, setPlanName] = useState('') @@ -155,7 +169,7 @@ export default function WhatIfPlannerPage() { } async function handleRun() { - if (!baselineRunId) return + if (!baselineRunId || dateErrors.hasErrors) return setRunError(null) setReloadId('') try { @@ -172,7 +186,7 @@ export default function WhatIfPlannerPage() { } async function handleSave() { - if (!baselineRunId || !planName.trim()) return + if (!baselineRunId || !planName.trim() || dateErrors.hasErrors) return setRunError(null) try { await createScenario.mutateAsync({ @@ -323,6 +337,9 @@ export default function WhatIfPlannerPage() { value={priceStart} onChange={(event) => setPriceStart(event.target.value)} /> + {dateErrors.priceStart && ( +

Required

+ )}
To @@ -332,6 +349,9 @@ export default function WhatIfPlannerPage() { value={priceEnd} onChange={(event) => setPriceEnd(event.target.value)} /> + {dateErrors.priceEnd && ( +

Required

+ )}
)} @@ -374,6 +394,9 @@ export default function WhatIfPlannerPage() { value={promoStart} onChange={(event) => setPromoStart(event.target.value)} /> + {dateErrors.promoStart && ( +

Required

+ )}
To @@ -383,6 +406,9 @@ export default function WhatIfPlannerPage() { value={promoEnd} onChange={(event) => setPromoEnd(event.target.value)} /> + {dateErrors.promoEnd && ( +

Required

+ )}
)} @@ -468,7 +494,10 @@ export default function WhatIfPlannerPage() {
-