Skip to content

End-to-end retail sales forecasting with scikit-learn, MLflow, and Databricks Model Serving. Includes synthetic data generation, model training/deployment, and batch inference examples.

Notifications You must be signed in to change notification settings

samayamadh-db/serverless-batch-inference

Repository files navigation

Retail Sales Forecasting with Databricks Model Serving

End-to-end time series forecasting solution using scikit-learn, MLflow, and Databricks Model Serving.

Overview

This repository contains three Databricks notebooks that demonstrate:

  1. Generating synthetic retail sales data
  2. Training and deploying a forecasting model
  3. Running batch predictions at scale

Files

1. 00_generate_retail_sales_data

Generates realistic synthetic retail sales data with temporal patterns.

Features:

  • Multi-store, multi-product transactions
  • Seasonal effects (weekends, holidays, summer)
  • Growth trends over time
  • Saves to Databricks Unity Catalog

Usage:

# Update catalog and schema names
catalog = "your_catalog"
schema = "your_schema"

# Run notebook to generate ~200K+ transaction records

Output: retail_sales_forecast table in Unity Catalog

2. 01_train_and_deploy_forecasting_model

Trains a scikit-learn LinearRegression model and deploys to Model Serving.

Features:

  • Time-based feature engineering
  • Train/test split with evaluation metrics
  • MLflow model registration in Unity Catalog
  • Automatic deployment to Model Serving endpoint
  • Model aliasing for production management

Requirements:

  • Pro or Serverless SQL warehouse
  • Unity Catalog enabled workspace

Usage:

# Configure your Unity Catalog
catalog = "your_catalog"
schema = "your_schema"
model_name = f"{catalog}.{schema}.retail_forecast_model"
endpoint_name = "retail-forecast-endpoint"

# Run notebook - it will:
# 1. Train the model
# 2. Register to Unity Catalog
# 3. Deploy to Model Serving

Output:

  • Registered model in Unity Catalog
  • Live Model Serving endpoint

3. 02_batch_inference_ai_query

Generates batch forecasts using the MLflow Deployments SDK.

Features:

  • Single and batch predictions
  • Multiple forecast horizons (7, 30, 90 days)
  • Multi-store processing
  • Automatic table creation and storage
  • Visualization-ready outputs

Usage:

from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")

# Single prediction
result = client.predict(
    endpoint="retail-forecast-endpoint",
    inputs={"dataframe_records": [{"periods": 30}]}
)

# Batch processing
# Run the notebook for automated batch forecasting

Output Tables:

  • daily_forecasts - Single store forecasts
  • batch_forecasts - Multi-store forecasts
  • scenario_forecasts - Multiple time horizons

Architecture

Data Generation → Model Training → Model Serving → Batch Predictions
     ↓                 ↓                ↓                ↓
  Unity Catalog    MLflow Registry   REST API        Tables

Getting Started

Prerequisites

  • Databricks workspace (AWS, Azure, or GCP)
  • Unity Catalog enabled
  • Pro or Serverless SQL warehouse
  • Cluster with ML runtime (13.3 LTS ML or higher)

Installation

  1. Clone this repository to your Databricks workspace

  2. Update configuration variables in each notebook:

    • catalog: Your Unity Catalog name
    • schema: Your schema name
    • endpoint_name: Your model serving endpoint name
  3. Run notebooks in order:

    1. data_generation.py
    2. model_training_deployment.py
    3. batch_predictions.py
    

Model Details

Algorithm: scikit-learn LinearRegression

Features:

  • Day index (days since start)
  • Day of week (0-6)
  • Month (1-12)

Output:

[
  {
    "date": "2025-01-01",
    "forecast": 8500.50,
    "forecast_lower": 7650.45,
    "forecast_upper": 9350.55
  }
]

API Usage

Model Serving Endpoint

Input:

{
  "dataframe_records": [
    {"periods": 30}  # Number of days to forecast
  ]
}

Output:

[
  {"date": "2025-01-01", "forecast": 8500.50, ...},
  {"date": "2025-01-02", "forecast": 8524.56, ...},
  ...
]

Python SDK

from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")
predictions = client.predict(
    endpoint="retail-forecast-endpoint",
    inputs={"dataframe_records": [{"periods": 30}]}
)

Tables Schema

retail_sales_forecast (input data)

Column Type Description
date DATE Transaction date
store_id STRING Store identifier
region STRING Store region
final_sales DOUBLE Sales amount

daily_forecasts (predictions)

Column Type Description
forecast_timestamp TIMESTAMP When forecast was made
store_id STRING Store identifier
date STRING Forecast date
forecast DOUBLE Predicted sales
forecast_lower DOUBLE Lower bound (90% CI)
forecast_upper DOUBLE Upper bound (90% CI)

Performance

  • Training time: ~30 seconds (single store)
  • Prediction latency: ~200ms for 30-day forecast
  • Batch throughput: ~1000 forecasts/minute
  • Model size: <5 MB

Limitations

  • Model trained on single store (S001) - extend for multi-store
  • Simple linear model - consider more complex algorithms for production
  • No hyperparameter tuning implemented
  • 10% confidence intervals are approximations

Troubleshooting

Model serving endpoint not ready:

  • Wait 1-2 minutes after deployment
  • Check endpoint status in Model Serving UI

Permission errors:

  • Ensure you have USE CATALOG and USE SCHEMA privileges
  • Verify CREATE MODEL privileges on the schema

Prediction errors:

  • Verify endpoint is in "Ready" state
  • Check input format matches expected schema

License

MIT

Contributing

This is a reference implementation for learning purposes. Feel free to extend and customize for your use case.

Support

For issues or questions:

About

End-to-end retail sales forecasting with scikit-learn, MLflow, and Databricks Model Serving. Includes synthetic data generation, model training/deployment, and batch inference examples.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published