## Sales Forecasting with Bayesian Inference

This notebook demonstrates the use of Bayesian inference for sales forecasting using various probabilistic programming techniques. We will use the `numpyro` library to define and fit our models, and `plotnine` for visualization.

## Introduction

Bayesian inference allows us to incorporate prior knowledge and quantify uncertainty in our predictions. This notebook will guide you through the process of building a Bayesian model for sales forecasting, fitting the model using Markov Chain Monte Carlo (MCMC) and Stochastic Variational Inference (SVI), and visualizing the results.

## Import Required Libraries And Define Functions

In [1]:
import os
import sys

# Set XLA_FLAGS before JAX is imported
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

In [2]:
import polars as pl
import pandas as pd
import numpy as np
from plotnine import ggplot, aes, geom_point, geom_line, labs, theme_minimal, theme_bw, scale_x_continuous, scale_x_discrete, scale_x_datetime
import patsy

import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
#from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer import MCMC, NUTS, MCMC, NUTS #, SVI, Trace_ELBO
import arviz as az




## 1. Model

In this section, we define the Bayesian model used for sales forecasting. The model incorporates various components such as random walk for the latent state, day-of-the-week effects, day-of-the-year effects, and price elasticity. The model is implemented using the `numpyro` library, which allows for efficient and scalable Bayesian inference.

### 1.1 Auxiliary Functions

These auxiliary functions are essential for data preprocessing and transformation:

- `periodic_rbf`: Computes a periodic Gaussian radial basis function (RBF).
- `compute_doy_basis`: Computes 12 periodic Gaussian basis functions for seasonal effects.
- `read_data`: Reads and preprocesses the sales data from a CSV file.
- `init_values`: Initializes values for the model parameters.

In [3]:
# Define a periodic Gaussian radial basis function (RBF)
def periodic_rbf(x, mu, sigma):
    """
    Computes a periodic Gaussian radial basis function (RBF).
    
    Args:
        x: Scaled day-of-year values (range [0,1]).
        mu: Center of the Gaussian basis function.
        sigma: Controls the spread of the Gaussian.
    
    Returns:
        RBF values preserving periodicity.
    """
    # compute cyclic distance to mu
    periodic_distance = jnp.minimum(jnp.abs(x - mu), 1 - jnp.abs(x - mu))
    # compute RBF value
    return jnp.exp(- (periodic_distance ** 2) / (2 * sigma ** 2))

def compute_doy_basis(yday_fraction, sigma = 30/365.25, n_centers = 12):
    """
    Computes 12 periodic Gaussian basis functions for seasonal effects.
    
    Args:
        yday_fraction: Normalized day of the year (range [0,1]).
        yday_factor: Scaling factor for basis function width.
    
    Returns:
        A JAX array with 12 columns representing the 12 monthly basis functions.
    """
    # Define centers of Gaussian basis functions
    month_centers = jnp.linspace( 1/(2*n_centers), 1-1/(2*n_centers), n_centers)
    
    # Generate an array of shape (length of input, 12) with the RBF values
    doy_basis = jnp.stack([periodic_rbf(yday_fraction, mu, sigma) for mu in month_centers], axis=-1)

    # Subtract each row's mean to enforce sum-to-zero constraint
    doy_basis_centered = doy_basis - jnp.mean(doy_basis, axis=-1, keepdims=True)
    
    return doy_basis_centered

def read_data(fname):
    """
    Reads and preprocesses the sales data from a CSV file.
    
    Args:
        fname: The filename of the CSV file containing the sales data.
    
    Returns:
        A dictionary with the following keys:
            - sales: An array of sales data.
            - log_price: An array of log-transformed prices.
            - wday: An array of day-of-the-week values.
            - yday_fraction: An array of normalized day-of-the-year values.
    """
    # Read the CSV file using polars
    df = pl.read_csv(fname)
    
    # Convert the 'date' column to date type
    df = df.with_columns(pl.col("date").str.to_date())

    # Extract sales, and log price data as a numpy arrays
    sales = df["sales"].to_numpy()
    log_price = df["log_price"].to_numpy()
    
    # Extract day-of-the-week values
    wday = df["date"].dt.weekday().to_numpy()
    
    # Extract day-of-the-year values
    yday = df["date"].dt.ordinal_day().to_numpy()
    
    # Determine if the year is a leap year
    is_leap_year = df["date"].dt.is_leap_year().to_numpy()
    
    # Normalize day-of-the-year values
    yday_fraction = yday / (365 + is_leap_year)
    
    # Return the preprocessed data as a dictionary
    return {
        "sales": sales,
        "log_price": log_price,
        "wday": wday,
        "yday_fraction": yday_fraction
    }

#def init_values(sales: jnp.array, log_price_centered: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1):
#    """
#    """
#    # to-do: implement downsampling_factor
#    log_state_est = jnp.log(sales)
#    log_state_mean_est = jnp.mean(log_state_est)
#    log_state_delta_est = jnp.diff(log_state_est)
#    log_state_delta_sd_est = jnp.std(log_state_delta_est)
#
#    return {
#        "log_sigma": jnp.log( log_state_delta_sd_est ),
#        "log_state_mean": log_state_mean_est,
#        "log_state_delta": log_state_delta_est,
#        "wday_coefficients": jnp.array([0.0]*6),
#        "yday_coefficients": jnp.array([0.0]*12),
#        "log_elasticity": jnp.array([0.0])
#    }

### 1.2 Model

#### 1.2.1 Key Features

- Sales are modeled using a **stochastic Poisson process**, where the expected rate $\lambda_t$ evolves over time.
- The **latent sales rate** follows a random walk, allowing it to drift nonstationarily.
- **Seasonal components** (day-of-the-week and annual patterns) adjust for structured demand variations.
- **Price elasticity** is explicitly modeled, ensuring sensitivity to pricing dynamics.
- The model is implemented in numpyro, enabling scalable Bayesian inference.


#### 1.2.1 Model Overview

We model the sales time series as a **stochastic process** where the underlying rate of sales evolves over time. This evolution follows a **random walk structure**, but with systematic adjustments for covariates such as price, day-of-the-week effects, and day-of-the-year effects. The rate of sales $\lambda_t$ on day $t$ is a function of captures *(i)* systematic covariate effects ($z_t$), *(ii)*
a global baseline ($\mu_\tau$), and *(iii)* the latent dynamic component ($\tau_t$).

$$
log~\lambda_t = z_t + \mu_\tau + \tau_t
$$

##### 1.2.1.1 Latent States Dynamics

The baseline sales level $\tau_t$ follows a **random walk**. Because all contrast matrices for structured effects are centered, $\mu_\tau + \tau_t$ can be interpreted as the average latent sales rate on $\tau_t$. 

$$
\tau_t = \tau_{t-1} + \delta_t, \quad \delta_t \sim \mathcal{N}(0, \sigma_\tau)
$$

with:

$$
\mu_\tau \sim \text{Exponential}(1), \quad \sigma_\tau \sim \mathcal{N}(1)
$$


##### 1.2.1.2 Structured Effects

We further accounted for systematic effects of *(i)* day of the week, *(ii)* day of the year, and *(iii)* price.

- For day of the week effects, we used a contrast matrix $\mathbf{C}_{\text{wday}}$ with sliding differences.
- For day of the year effects, we used a matrix of Gaussian radial basis functions $\mathbf{B}_{\text{yday}}$.
- Price elasticity is modelled using a centered log price 

Similarly, the day-of-the-year effects are modeled using a seasonality basis matrix $\mathbf{B}_{\text{yday}}$, which represents periodic seasonal patterns using Gaussian radial basis functions (RBFs).


- **Day-of-the-week effects**:

$$
  zw_t = \mathbf{C}_{\text{wday}} \cdot \beta_{\text{wday}}, \quad \beta_{\text{wday}} \sim \mathcal{N}(0, 1)
$$

- **Day-of-the-year effects**:

$$
  zy_t = \mathbf{B}_{\text{yday}} \cdot \beta_{\text{yday}}, \quad \beta_{\text{yday}} \sim \mathcal{N}(0, 1)
$$

- **Price elasticity**:

$$
  ze_t = \text{log\_price\_centered} \cdot e, \quad \log(-e) \sim \mathcal{N^{+}}(0, 1)
$$

- **Sum of structural effects**:

$$
  z_t = zw_t + zy_t + ze_t
$$


##### 1.2.1.3 Emissions Model

Observed sales are assumed to follow a **Poisson distribution**, ensuring discrete, non-negative observations:

$$
S_t \sim \text{Poisson}(\lambda_t)
$$


In [None]:

def model_local_level_poisson(sales: jnp.array, log_price_centered: jnp.array, wday, yday_fraction: jnp.array, 
                              contrasts_sdif_t: jnp.array, contrasts_wday: jnp.array, contrasts_yday: jnp.array, 
                              downsampling_factor = 1):
    """
    """

    n_obs = len(sales)
    n_states = contrasts_sdif_t.shape[0]
 
    def sample_random_walk(contrasts_sdif_t, n_states):
        log_sigma = numpyro.sample("log_sigma", dist.Gumbel(0, 1))
        sigma = numpyro.deterministic("sigma", jnp.exp(log_sigma))
        log_state_mean = numpyro.sample("log_state_mean", dist.Normal(0, 5))
        log_state_delta = numpyro.sample( "log_state_delta", dist.Normal(0, 1), sample_shape=(n_states-1,))
        log_state_base = numpyro.deterministic("log_state_base", jnp.dot(contrasts_sdif_t, log_state_delta) * sigma + log_state_mean )
        return log_state_base

    def sample_downsampled_random_walk(contrasts_sdif_t, n_obs, n_states):
        log_state_base_downsampled = sample_random_walk(contrasts_sdif_t, n_states)
        
        idx_n_weight = jnp.array(range(0, n_obs))/downsampling_factor
        idx_1 = jnp.array( jnp.floor(idx_n_weight), dtype=int)
        idx_2 = jnp.array( jnp.ceil(idx_n_weight), dtype=int)
        weight_2 = idx_n_weight - idx_1

        state_before = log_state_base_downsampled[idx_1]
        state_after = log_state_base_downsampled[idx_2]
        return (1-weight_2)*state_before + weight_2*state_after
        
    def sample_wday_effect(contrasts_wday, wday):
        # Prior for day-of-the-week effects (6 coefficients)
        wday_coefficients = numpyro.sample("wday_coefficients", dist.Normal(0, 1), sample_shape=(6,))

        # Compute wday effect per observation (sum-to-zero constraint applied via contrasts)
        wday_effects = jnp.dot(contrasts_wday, wday_coefficients)
        return jnp.array([wday_effects[d - 1] for d in wday]) # to-do: just use an index vector instead of a loop

    def sample_yday_effect(contrasts_yday, yday_fraction):
        
        # Compute yday effect per observation (sum-to-zero constraint applied via contrasts)
        yday_rbf_width = numpyro.sample("yday_rbf_width", dist.Exponential(10))
        contrasts_yday = compute_doy_basis(yday_fraction, sigma = yday_rbf_width, n_centers = 12) # sigma = 30/365.25 # sigma = yday_rbf_width

        # Prior for yearly seasonality effects (12 coefficients)
        yday_coefficients = numpyro.sample("yday_coefficients", dist.Normal(0, 1), sample_shape=(12,))

        return jnp.dot(contrasts_yday, yday_coefficients)

    def sample_price_effect(log_price_centered):
        # Prior for price elasticity
        log_elasticity = numpyro.sample( "log_elasticity", dist.Normal(0, 2) ) # dist.TruncatedNormal(-1, 2, high = 0)
        elasticity = numpyro.deterministic("elasticity", -1 * log_elasticity)
        return log_price_centered * elasticity


    # Sample random walk    
    if n_obs == n_states:
        log_state_base = sample_random_walk(contrasts_sdif_t, n_states)
    else:
        log_state_base = sample_downsampled_random_walk(contrasts_sdif_t, n_obs, n_states)

    # Sample day-of-the-week effects
    wday_effect = sample_wday_effect(contrasts_wday, wday)

    # Sample day-of-the-year effects
    yday_effect = sample_yday_effect(contrasts_yday, yday_fraction)

    # Sample elasticity effect
    price_effect = sample_price_effect(log_price_centered)

    # Compute state
    state = numpyro.deterministic("state", jnp.exp(log_state_base  + yday_effect + wday_effect)) # + price_effect

    # Compute log-likelihood for poisson emissions
    numpyro.sample("sales", dist.Poisson(rate=state), obs=sales)

### 1.3 Model Fitting Logic

- Functions to fit the model

In [5]:
def prepare_model_arguments(sales: jnp.array, log_price: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1):
    """ 
    """    
    n_obs = len(sales)
    if downsampling_factor == 1:
        n_states = n_obs
    else:
        n_states = int( np.floor(n_obs/downsampling_factor) + 1 ) 
    
    # Define contrast matrix for random walk (T coefficients, sum-to-zero constraint)
    contrasts_sdif_t = patsy.contrasts.Diff().code_without_intercept(range(0, n_states)).matrix

    # Define contrast matrix for day-of-the-week effects (6 coefficients, sum-to-zero constraint)
    contrasts_wday = patsy.contrasts.Diff().code_without_intercept(range(0,7)).matrix  # 7 days → 6 contrasts

    # Compute yday effect per observation (sum-to-zero constraint applied via contrasts)
    contrasts_yday = compute_doy_basis(yday_fraction, sigma = 30/365.25, n_centers = 12)

    # Compute centered log price differences
    log_price_centered = log_price - jnp.mean(log_price)

    # Set up the model parameters
    model_arguments = {'sales': sales, 'log_price_centered': log_price_centered, 'wday': wday, 'yday_fraction': yday_fraction,
                       'downsampling_factor': downsampling_factor,
                       'contrasts_sdif_t': contrasts_sdif_t, 'contrasts_wday': contrasts_wday, 'contrasts_yday': contrasts_yday}
    
    # # Prepare init values for parameters 
    # init_params = init_values(sales, log_price_centered, wday, yday_fraction)

    #return init_params, model_arguments
    return model_arguments

In [6]:
def run_nuts(sales: jnp.array, log_price: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1, n_chains = 1, num_warmup=1_000, num_samples=1_000):
    """ Runs NUTS MCMC inference on the model 
    """
    rng_key = random.PRNGKey(0)
    
    n_obs = len(sales)
    
    # Prepare model arguments
    model_arguments = prepare_model_arguments(sales = sales, log_price = log_price, wday = wday, yday_fraction = yday_fraction, downsampling_factor = downsampling_factor)

    rng_key, rng_key_ = random.split(rng_key)

    numpyro.set_host_device_count(n_chains)

    reparam_model = model_local_level_poisson
    kernel = NUTS(reparam_model, step_size=0.01, max_tree_depth=8)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=n_chains)
    mcmc.run( rng_key_, **model_arguments) #, init_params = init_params

    return mcmc  

## 2. Fit the Model 

In [7]:
data = read_data("sales_synthetic.csv")

In [16]:
m_fit = run_nuts(data['sales'], data['log_price'], data['wday'], data['yday_fraction'], downsampling_factor = 7, n_chains = 4, num_warmup=1_000, num_samples=1_000)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [17]:
summary = az.summary(m_fit)
summary.loc[['sigma', 'elasticity', 'yday_rbf_width']]

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
sigma,0.065,0.006,0.053,0.077,0.001,0.0,23.0,124.0,1.13
elasticity,-0.114,0.349,-0.761,0.526,0.067,0.017,27.0,113.0,1.11
yday_rbf_width,0.472,0.292,0.16,0.971,0.097,0.02,8.0,42.0,1.48


In [None]:
x = pd.DataFrame({ 'date': df["date"].to_numpy(), 'sales': df["sales"].to_numpy(), 'state': m['state'] })
ggplot(x, aes(x='date', y='sales')) + geom_point() + geom_line(aes(y='state'), color = "red") + theme_bw()