Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/features/jobs/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class JobCreate(BaseModel):
**Job Types and Required Params**:

- **train**: Train a forecasting model
- `model_type`: Required - 'naive', 'seasonal_naive', 'linear_regression', etc.
- `model_type`: Required - 'naive', 'seasonal_naive', 'moving_average', 'regression'.
- `store_id`: Required - Store ID from /dimensions/stores
- `product_id`: Required - Product ID from /dimensions/products
- `start_date`: Required - Training data start (YYYY-MM-DD)
Expand Down
7 changes: 7 additions & 0 deletions app/features/jobs/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ async def _execute_train(
from app.features.forecasting.schemas import (
MovingAverageModelConfig,
NaiveModelConfig,
RegressionModelConfig,
SeasonalNaiveModelConfig,
)
from app.features.forecasting.service import ForecastingService
Expand Down Expand Up @@ -457,6 +458,12 @@ async def _execute_train(
elif model_type == "moving_average":
window_size = params.get("window_size", 7)
config = MovingAverageModelConfig(window_size=window_size)
elif model_type == "regression":
config = RegressionModelConfig(
max_iter=params.get("max_iter", 200),
learning_rate=params.get("learning_rate", 0.05),
max_depth=params.get("max_depth", 6),
)
else:
msg = f"Unsupported model_type: {model_type}"
raise ValueError(msg)
Expand Down
65 changes: 64 additions & 1 deletion app/features/jobs/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

import math
from datetime import date
from typing import Any, cast
from unittest.mock import AsyncMock, patch

import pytest
from sqlalchemy.ext.asyncio import AsyncSession

from app.features.backtesting.schemas import (
BacktestResponse,
Expand All @@ -16,7 +21,9 @@
SplitBoundary,
SplitConfig,
)
from app.features.jobs.service import _finite, _shape_backtest_result
from app.features.forecasting.schemas import RegressionModelConfig, TrainResponse
from app.features.forecasting.service import ForecastingService
from app.features.jobs.service import JobService, _finite, _shape_backtest_result


def _fold(idx: int, mae: float, smape: float, wape: float, bias: float) -> FoldResult:
Expand Down Expand Up @@ -158,3 +165,59 @@ def test_finite_coerces_non_finite_values() -> None:
assert _finite(math.nan) == 0.0
assert _finite(math.inf) == 0.0
assert _finite(-math.inf) == 0.0


# =============================================================================
# _execute_train regression-model support (#229)
# =============================================================================


def _fake_train_response(model_type: str) -> TrainResponse:
"""Build a TrainResponse stub for mocking ForecastingService.train_model."""
return TrainResponse(
store_id=1,
product_id=1,
model_type=model_type,
model_path="/data/artifacts/model_abc123def456.joblib",
config_hash="cfg-hash",
n_observations=400,
train_start_date=date(2024, 1, 1),
train_end_date=date(2024, 12, 31),
duration_ms=12.0,
)


_REGRESSION_PARAMS: dict[str, Any] = {
"model_type": "regression",
"store_id": 1,
"product_id": 1,
"start_date": "2024-01-01",
"end_date": "2024-12-31",
}


async def test_execute_train_builds_regression_config() -> None:
"""A train job with model_type='regression' builds a RegressionModelConfig (#229)."""
fake = _fake_train_response("regression")
with patch.object(
ForecastingService, "train_model", new=AsyncMock(return_value=fake)
) as mock_train:
result = await JobService()._execute_train(
db=cast(AsyncSession, AsyncMock()),
params=_REGRESSION_PARAMS,
)
assert mock_train.call_args is not None
config = mock_train.call_args.kwargs["config"]
assert isinstance(config, RegressionModelConfig)
assert result["model_type"] == "regression"
# run_id is parsed from the model_abc123def456.joblib artifact path.
assert result["run_id"] == "abc123def456"


async def test_execute_train_rejects_unsupported_model_type() -> None:
"""_execute_train still rejects a genuinely unsupported model_type (e.g. lightgbm)."""
with pytest.raises(ValueError, match="Unsupported model_type"):
await JobService()._execute_train(
db=cast(AsyncSession, AsyncMock()),
params={**_REGRESSION_PARAMS, "model_type": "lightgbm"},
)
59 changes: 59 additions & 0 deletions frontend/src/lib/scenario-utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { describe, expect, it } from 'vitest'
import {
assumptionDateErrors,
buildMultiSeries,
coverageLabel,
coverageVariant,
Expand Down Expand Up @@ -143,3 +144,61 @@ describe('methodLabel', () => {
expect(methodLabel('model_exogenous')).toBe('Model-driven')
})
})

describe('assumptionDateErrors', () => {
const NONE = {
priceEnabled: false,
priceStart: '',
priceEnd: '',
promoEnabled: false,
promoStart: '',
promoEnd: '',
}

it('reports no errors when nothing is enabled', () => {
expect(assumptionDateErrors(NONE).hasErrors).toBe(false)
})

it('flags both price dates when price is enabled and blank', () => {
const e = assumptionDateErrors({ ...NONE, priceEnabled: true })
expect(e.priceStart).toBe(true)
expect(e.priceEnd).toBe(true)
expect(e.hasErrors).toBe(true)
})

it('clears price errors once both dates are filled', () => {
const e = assumptionDateErrors({
...NONE,
priceEnabled: true,
priceStart: '2026-07-01',
priceEnd: '2026-07-14',
})
expect(e.hasErrors).toBe(false)
})

it('flags only the blank promotion date', () => {
const e = assumptionDateErrors({
...NONE,
promoEnabled: true,
promoStart: '2026-07-01',
promoEnd: '',
})
expect(e.promoStart).toBe(false)
expect(e.promoEnd).toBe(true)
expect(e.hasErrors).toBe(true)
})

it('isolates errors per assumption (price ok, promo blank)', () => {
const e = assumptionDateErrors({
priceEnabled: true,
priceStart: '2026-07-01',
priceEnd: '2026-07-14',
promoEnabled: true,
promoStart: '',
promoEnd: '',
})
expect(e.priceStart).toBe(false)
expect(e.promoStart).toBe(true)
expect(e.hasErrors).toBe(true)
})
})
38 changes: 38 additions & 0 deletions frontend/src/lib/scenario-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,41 @@ export function buildMultiSeries(comparison: MultiScenarioComparison): MultiSeri
export function methodLabel(method: 'heuristic' | 'model_exogenous'): string {
return method === 'model_exogenous' ? 'Model-driven' : 'Heuristic'
}

/** Form state for the date-bearing assumptions (price, promotion). */
export interface AssumptionDateState {
priceEnabled: boolean
priceStart: string
priceEnd: string
promoEnabled: boolean
promoStart: string
promoEnd: string
}

/** Which enabled assumption date inputs are still blank. */
export interface AssumptionDateErrors {
priceStart: boolean
priceEnd: boolean
promoStart: boolean
promoEnd: boolean
hasErrors: boolean
}

/**
* Flag every enabled Price/Promotion assumption whose From/To date is blank.
* The planner blocks Run/Save while `hasErrors` is true so the backend never
* receives an empty-string date (which fails Pydantic date validation → 422).
*/
export function assumptionDateErrors(state: AssumptionDateState): AssumptionDateErrors {
const priceStart = state.priceEnabled && !state.priceStart
const priceEnd = state.priceEnabled && !state.priceEnd
const promoStart = state.promoEnabled && !state.promoStart
const promoEnd = state.promoEnabled && !state.promoEnd
return {
priceStart,
priceEnd,
promoStart,
promoEnd,
hasErrors: priceStart || priceEnd || promoStart || promoEnd,
}
}
58 changes: 49 additions & 9 deletions frontend/src/pages/visualize/planner.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
import { downloadCsv, toCsv } from '@/lib/csv-export'
import { formatCurrency, formatNumber, getErrorMessage } from '@/lib/api'
import {
assumptionDateErrors,
buildMultiSeries,
coverageLabel,
coverageVariant,
Expand Down Expand Up @@ -73,8 +74,11 @@ export default function WhatIfPlannerPage() {
const [selectedJobId, setSelectedJobId] = useState('')
const [horizon, setHorizon] = useState(14)
const { data: job } = useJob(selectedJobId, !!selectedJobId)
// A predict job's params.run_id is the baseline model artifact key.
const baselineRunId = typeof job?.params?.run_id === 'string' ? job.params.run_id : null
// A completed `train` job stores result.run_id — the model-artifact key
// POST /scenarios/simulate resolves. (This is NOT a registry run id.)
// A `regression` baseline routes the simulate call down the model_exogenous
// re-forecast branch; other model types fall back to the heuristic factor.
const baselineRunId = typeof job?.result?.run_id === 'string' ? job.result.run_id : null

// -- Assumption form state ---------------------------------------------
const [priceEnabled, setPriceEnabled] = useState(false)
Expand All @@ -97,6 +101,19 @@ export default function WhatIfPlannerPage() {
const [lifecycleStage, setLifecycleStage] =
useState<(typeof LIFECYCLE_STAGES)[number]>('maturity')

// -- Derived validation ------------------------------------------------
// Enabling Price/Promotion without filling both dates would submit empty
// strings — Pydantic date validation rejects those with an RFC 7807 422.
// Gate Run/Save on this so the form can never produce that request (#228).
const dateErrors = assumptionDateErrors({
priceEnabled,
priceStart,
priceEnd,
promoEnabled,
promoStart,
promoEnd,
})

// -- Results / persistence state ---------------------------------------
const [simulated, setSimulated] = useState<ScenarioComparison | null>(null)
const [planName, setPlanName] = useState('')
Expand Down Expand Up @@ -152,7 +169,7 @@ export default function WhatIfPlannerPage() {
}

async function handleRun() {
if (!baselineRunId) return
if (!baselineRunId || dateErrors.hasErrors) return
setRunError(null)
setReloadId('')
try {
Expand All @@ -169,7 +186,7 @@ export default function WhatIfPlannerPage() {
}

async function handleSave() {
if (!baselineRunId || !planName.trim()) return
if (!baselineRunId || !planName.trim() || dateErrors.hasErrors) return
setRunError(null)
try {
await createScenario.mutateAsync({
Expand Down Expand Up @@ -245,12 +262,15 @@ export default function WhatIfPlannerPage() {
<CardHeader>
<CardTitle>1. Pick a baseline</CardTitle>
<CardDescription>
Choose a completed prediction job — its model is the baseline this scenario adjusts.
Choose a completed training job — its model is the baseline this scenario
adjusts. A regression baseline is genuinely re-forecast through the model
(model-driven); naive, seasonal-naive and moving-average baselines use a
heuristic adjustment factor.
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<JobPicker
jobType="predict"
jobType="train"
selectedJobId={selectedJobId}
onSelect={setSelectedJobId}
autoSelectLatest
Expand All @@ -274,7 +294,7 @@ export default function WhatIfPlannerPage() {
</div>
{selectedJobId && !baselineRunId && (
<p className="text-sm text-muted-foreground">
The selected job has no model artifact — pick a completed predict job.
The selected job has no model artifact — pick a completed train job.
</p>
)}
</CardContent>
Expand Down Expand Up @@ -317,6 +337,9 @@ export default function WhatIfPlannerPage() {
value={priceStart}
onChange={(event) => setPriceStart(event.target.value)}
/>
{dateErrors.priceStart && (
<p className="text-xs text-destructive">Required</p>
)}
</div>
<div className="space-y-1">
<span className="text-xs text-muted-foreground">To</span>
Expand All @@ -326,6 +349,9 @@ export default function WhatIfPlannerPage() {
value={priceEnd}
onChange={(event) => setPriceEnd(event.target.value)}
/>
{dateErrors.priceEnd && (
<p className="text-xs text-destructive">Required</p>
)}
</div>
</div>
)}
Expand Down Expand Up @@ -368,6 +394,9 @@ export default function WhatIfPlannerPage() {
value={promoStart}
onChange={(event) => setPromoStart(event.target.value)}
/>
{dateErrors.promoStart && (
<p className="text-xs text-destructive">Required</p>
)}
</div>
<div className="space-y-1">
<span className="text-xs text-muted-foreground">To</span>
Expand All @@ -377,6 +406,9 @@ export default function WhatIfPlannerPage() {
value={promoEnd}
onChange={(event) => setPromoEnd(event.target.value)}
/>
{dateErrors.promoEnd && (
<p className="text-xs text-destructive">Required</p>
)}
</div>
</div>
)}
Expand Down Expand Up @@ -462,7 +494,10 @@ export default function WhatIfPlannerPage() {
</div>

<div className="flex flex-wrap items-center gap-3 border-t pt-4">
<Button onClick={handleRun} disabled={!baselineRunId || simulate.isPending}>
<Button
onClick={handleRun}
disabled={!baselineRunId || simulate.isPending || dateErrors.hasErrors}
>
{simulate.isPending ? (
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
) : (
Expand Down Expand Up @@ -593,7 +628,12 @@ export default function WhatIfPlannerPage() {
</div>
<Button
onClick={handleSave}
disabled={!baselineRunId || !planName.trim() || createScenario.isPending}
disabled={
!baselineRunId ||
!planName.trim() ||
createScenario.isPending ||
dateErrors.hasErrors
}
>
{createScenario.isPending ? (
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
Expand Down