# Generating Synthetic Order Volume

## Introduction


## Import Required Libraries And Define Functions

In [15]:
import os
import sys

# Set XLA_FLAGS before JAX is imported
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

In [16]:
import blackjax
import blackjax.smc.resampling as resampling
from blackjax.smc import extend_params

In [17]:
import polars as pl
import numpy as np
import pandas as pd
from plotnine import ggplot, aes, geom_point, geom_line, labs, theme_minimal, theme_bw, scale_x_continuous, scale_x_discrete, scale_x_datetime

In [18]:
import os
import polars as pl
import jax.numpy as jnp
import jax.random as random

import numpyro


In [19]:
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer.autoguide import AutoNormal, AutoMultivariateNormal, AutoLaplaceApproximation
import patsy
import matplotlib.pyplot as plt
import arviz as az
from numpyro.infer import MCMC, NUTS, HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value
import arviz as az

## Model

### 1. Create the n_obsime Series

In [20]:
def periodic_rbf(x, mu, sigma):
    """
    Computes a periodic Gaussian radial basis function (RBF).
    
    Args:
        x: Scaled day-of-year values (range [0,1]).
        mu: Center of the Gaussian basis function.
        sigma: Controls the spread of the Gaussian.
    
    Returns:
        RBF values preserving periodicity.
    """
    periodic_distance = jnp.minimum(jnp.abs(x - mu), 1 - jnp.abs(x - mu))  # Cyclic distance
    return jnp.exp(- (periodic_distance ** 2) / (2 * sigma ** 2))

def compute_doy_basis(yday_fraction, sigma = 30/365.25, n_centers = 12):
    """
    Computes 12 periodic Gaussian basis functions for seasonal effects.
    
    Args:
        yday_fraction: Normalized day of the year (range [0,1]).
        yday_factor: Scaling factor for basis function width.
    
    Returns:
        A JAX array with 12 columns representing the 12 monthly basis functions.
    """
    # Define centers of Gaussian basis functions
    month_centers = jnp.linspace( 1/(2*n_centers), 1-1/(2*n_centers), n_centers)
    
    # Generate an array of shape (length of input, 12) with the RBF values
    doy_basis = jnp.stack([periodic_rbf(yday_fraction, mu, sigma) for mu in month_centers], axis=-1)

    # Subtract each row's mean to enforce sum-to-zero constraint
    doy_basis_centered = doy_basis - jnp.mean(doy_basis, axis=-1, keepdims=True)
    
    return doy_basis_centered

In [None]:
def model_local_level_poisson(sales: jnp.array, log_price_centered: jnp.array, wday, yday_fraction: jnp.array, 
                              contrasts_sdif_t: jnp.array, contrasts_wday: jnp.array, contrasts_yday: jnp.array, 
                              downsampling_factor = 1):
    """
    """

    n_obs = len(sales)
    n_states = contrasts_sdif_t.shape[0]
 
    def sample_random_walk(contrasts_sdif_t, n_states):
        log_sigma = numpyro.sample("log_sigma", dist.Gumbel(0, 1))
        sigma = numpyro.deterministic("sigma", jnp.exp(log_sigma))
        log_state_mean = numpyro.sample("log_state_mean", dist.Normal(0, 5))
        log_state_delta = numpyro.sample( "log_state_delta", dist.Normal(0, 1), sample_shape=(n_states-1,))
        log_state_base = numpyro.deterministic("log_state_base", jnp.dot(contrasts_sdif_t, log_state_delta) * sigma + log_state_mean )
        return log_state_base

    def sample_downsampled_random_walk(contrasts_sdif_t, n_obs, n_states):
        log_state_base_downsampled = sample_random_walk(contrasts_sdif_t, n_states)
        
        idx_n_weight = jnp.array(range(0, n_obs))/downsampling_factor
        idx_1 = jnp.array( jnp.floor(idx_n_weight), dtype=int)
        idx_2 = jnp.array( jnp.ceil(idx_n_weight), dtype=int)
        weight_2 = idx_n_weight - idx_1

        state_before = log_state_base_downsampled[idx_1]
        state_after = log_state_base_downsampled[idx_2]
        return (1-weight_2)*state_before + weight_2*state_after
        
    def sample_wday_effect(contrasts_wday, wday):
        # Prior for day-of-the-week effects (6 coefficients)
        wday_coefficients = numpyro.sample("wday_coefficients", dist.Normal(0, 1), sample_shape=(6,))

        # Compute wday effect per observation (sum-to-zero constraint applied via contrasts)
        wday_effects = jnp.dot(contrasts_wday, wday_coefficients)
        return jnp.array([wday_effects[d - 1] for d in wday]) # to-do: just use an index vector instead of a loop

    def sample_yday_effect(contrasts_yday, yday_fraction):
        # Prior for yearly seasonality effects (12 coefficients)
        yday_coefficients = numpyro.sample("yday_coefficients", dist.Normal(0, 1), sample_shape=(12,))
        return jnp.dot(contrasts_yday, yday_coefficients)

    def sample_price_effect(log_price_centered):
        # Prior for price elasticity
        log_elasticity = numpyro.sample( "log_elasticity", dist.Normal(0, 1) )
        elasticity = numpyro.deterministic( "elasticity", -1 * jnp.exp( log_elasticity ))
        return log_price_centered * elasticity


    # Sample random walk    
    if n_obs == n_states:
        log_state_base = sample_random_walk(contrasts_sdif_t, n_states)
    else:
        log_state_base = sample_downsampled_random_walk(contrasts_sdif_t, n_obs, n_states)

    # Sample day-of-the-week effects
    wday_effect = sample_wday_effect(contrasts_wday, wday)

    # Sample day-of-the-year effects
    yday_effect = sample_yday_effect(contrasts_yday, yday_fraction)

    # Sample elasticity effect
    price_effect = sample_price_effect(log_price_centered)

    # Compute state
    state = numpyro.deterministic("state", jnp.exp(log_state_base + price_effect + yday_effect + wday_effect))

    # Compute log-likelihood for poisson emissions
    numpyro.sample("sales", dist.Poisson(rate=state), obs=sales)

In [22]:
def init_values(sales: jnp.array, log_price_centered: jnp.array, wday, yday_fraction: jnp.array):
    """
    """
    log_state_est = jnp.log(sales)
    log_state_mean_est = jnp.mean(log_state_est)
    log_state_delta_est = jnp.diff(log_state_est)
    log_state_delta_sd_est = jnp.std(log_state_delta_est)

    return {
        "log_sigma": jnp.log( log_state_delta_sd_est ),
        "log_state_mean": log_state_mean_est,
        "log_state_delta": log_state_delta_est,
        "wday_coefficients": jnp.array([0.0]*6),
        "yday_coefficients": jnp.array([0.0]*12),
        "log_elasticity": jnp.array([0.0])
    }

In [None]:
def prepare_model_arguments(sales: jnp.array, log_price: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1):
    """ 
    """    
    n_obs = len(sales)
    if downsampling_factor == 1:
        n_states = n_obs
    else:
        n_states = int( np.floor(n_obs/downsampling_factor) + 1 ) 
    
    # Define contrast matrix for random walk (T coefficients, sum-to-zero constraint)
    contrasts_sdif_t = patsy.contrasts.Diff().code_without_intercept(range(0, n_states)).matrix

    # Define contrast matrix for day-of-the-week effects (6 coefficients, sum-to-zero constraint)
    contrasts_wday = patsy.contrasts.Diff().code_without_intercept(range(0,7)).matrix  # 7 days → 6 contrasts

    # Compute yday effect per observation (sum-to-zero constraint applied via contrasts)
    contrasts_yday = compute_doy_basis(yday_fraction, sigma = 30/365.25, n_centers = 12)

    # Compute centered log price differences
    log_price_centered = log_price - jnp.mean(log_price)

    # Set up the model parameters
    model_arguments = {'sales': sales, 'log_price_centered': log_price_centered, 'wday': wday, 'yday_fraction': yday_fraction,
                       'downsampling_factor': downsampling_factor,
                       'contrasts_sdif_t': contrasts_sdif_t, 'contrasts_wday': contrasts_wday, 'contrasts_yday': contrasts_yday}
    
    # Prepare init values for parameters 
    init_params = init_values(sales, log_price_centered, wday, yday_fraction)

    return init_params, model_arguments

In [24]:
def run_svi(sales: jnp.array, log_price: jnp.array, wday, yday_fraction: jnp.array, num_samples=1_000, num_steps=10_000):
        """ """
        rng_key = random.PRNGKey(seed=42)

        n_obs = len(sales)
        
        # Prepare model arguments
        init_params, model_arguments = prepare_model_arguments(sales, log_price, wday, yday_fraction)    

        model = model_local_level_poisson
        guide = AutoNormal(model=model) # AutoLaplaceApproximation(model=model) # AutoNormal(model=model) # guide = AutoMultivariateNormal(model=model)
        optimizer = numpyro.optim.Adam(step_size=0.01)
        svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
        rng_key, rng_subkey = random.split(key=rng_key)

        svi_result = svi.run(rng_key = rng_subkey, num_steps = num_steps, **model_arguments, init_params = init_params)
        params = svi_result.params

        # get posterior samples (parameters)
        rng_key, rng_subkey = random.split(key=rng_key)
        predictive = Predictive(model=guide, params=params, num_samples=num_samples)
        posterior_parameters = predictive(rng_subkey, **model_arguments)

        # get posterior predictive (deterministics and likelihood)
        rng_key, rng_subkey = random.split(key=rng_key)
        predictive = Predictive(model=model, guide=guide, params=params, num_samples=num_samples)
        posterior_generated = predictive(rng_subkey, **model_arguments)

        return svi_result, posterior_parameters, posterior_generated

In [54]:
def run_nuts(sales: jnp.array, log_price: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1, n_chains = 1, num_warmup=1_000, num_samples=1_000):
    """ Runs NUTS MCMC inference on the model 
    """
    rng_key = random.PRNGKey(0)
    
    n_obs = len(sales)
    
    # Prepare model arguments
    init_params, model_arguments = prepare_model_arguments(sales = sales, log_price = log_price, wday = wday, yday_fraction = yday_fraction, downsampling_factor = downsampling_factor)

    rng_key, rng_key_ = random.split(rng_key)

    numpyro.set_host_device_count(n_chains)
    #jax.local_device_count()

    reparam_model = model_local_level_poisson # reparam(poisson_local_level, config={"log_state_delta": LocScaleReparam(0)})
    kernel = NUTS(reparam_model, step_size=0.01, max_tree_depth=8) #, dense_mass=True # max_tree_depth=12
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=n_chains)
    #kernel = AIES(poisson_local_level)
    #mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=n_chains, chain_method="vectorized")
    mcmc.run( rng_key_, **model_arguments #, init_params = init_params
             )

    samples = mcmc.get_samples()

    return {'log_state_base': jnp.mean(samples['log_state_base'], axis=0),
            'state': jnp.mean(samples['state'], axis=0),
            'mcmc': mcmc
           }  
  

In [26]:
df = pl.read_csv("sales_synthetic.csv")
df = df.with_columns( pl.col("date").str.to_date())

sales = df["sales"].to_numpy()
log_price = df["log_price"].to_numpy()
wday = df["date"].dt.weekday().to_numpy()
yday = df["date"].dt.ordinal_day().to_numpy()
is_leap_year = df["date"].dt.is_leap_year().to_numpy()
yday_fraction = yday / (365 + is_leap_year)

In [None]:
m = run_nuts(sales, log_price, wday, yday_fraction, n_chains = 4, downsampling_factor = 7)

sample: 100%|██████████| 2000/2000 [01:46<00:00, 18.79it/s, 255 steps of size 1.06e-02. acc. prob=0.90]


In [58]:
m['mcmc'].print_summary()


                           mean       std    median      5.0%     95.0%     n_eff     r_hat
       log_elasticity     -0.72      0.58     -0.63     -1.56      0.21    325.87      1.00
            log_sigma     -2.68      0.10     -2.68     -2.84     -2.51    270.53      1.00
   log_state_delta[0]      0.43      0.93      0.43     -1.13      1.82    537.92      1.00
   log_state_delta[1]      0.32      0.86      0.33     -0.98      1.79    647.03      1.00
   log_state_delta[2]      0.44      0.90      0.45     -0.91      2.03    381.45      1.00
   log_state_delta[3]     -0.04      0.94     -0.03     -1.58      1.51    427.28      1.00
   log_state_delta[4]     -0.59      0.83     -0.61     -1.92      0.72    620.58      1.00
   log_state_delta[5]      0.33      0.84      0.31     -1.10      1.68    521.19      1.00
   log_state_delta[6]      0.69      0.90      0.69     -0.67      2.18    645.59      1.00
   log_state_delta[7]      0.68      0.89      0.68     -0.64      2.18    513.

In [62]:
import arviz as az

summary = az.summary(m['mcmc'])
print(summary.loc[['sigma']])



        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
sigma  0.069  0.007   0.056    0.082        0.0      0.0     269.0     438.0   

       r_hat  
sigma    NaN  


In [None]:
import inspect
print(inspect.getsource(m['mcmc'].print_summary))

In [None]:
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

import numpy as np
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal


In [None]:
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

In [None]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal

import blackjax
import blackjax.smc.resampling as resampling
from blackjax.smc import extend_params

from numpyro.infer.util import initialize_model

In [None]:
init_params, model_arguments = prepare_model_arguments(sales, log_price, wday, yday_fraction)    

rng_key, init_key = jax.random.split(rng_key)
init_params, potential_fn_gen, *_ = initialize_model(
    init_key,
    model_local_level_poisson,
    model_kwargs=model_arguments,
    dynamic_args=True,
)

In [None]:
logdensity_fn = lambda position: -potential_fn_gen(**model_arguments)(position)
initial_position = init_params.z

In [None]:
import blackjax

num_warmup = 2000

adapt = blackjax.window_adaptation(
    blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8
)
rng_key, warmup_key = jax.random.split(rng_key)
(last_state, parameters), _ = adapt.run(warmup_key, initial_position, num_warmup)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, (
        infos.acceptance_rate,
        infos.is_divergent,
        infos.num_integration_steps,
    )

In [None]:
num_sample = 1000
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, num_sample)
#_ = states.position["mu"].block_until_ready()

In [None]:
states

In [None]:
svi_result, posterior_parameters, posterior_generated = run_svi(sales, log_price, wday, yday_fraction)

In [None]:
idata_parameters = az.from_dict(
     posterior = { k: np.expand_dims(a=np.asarray(v), axis=0) for k, v in posterior_parameters.items() },
 )
idata_generated = az.from_dict(
     posterior={ k: np.expand_dims(a=np.asarray(v), axis=0) for k, v in posterior_generated.items() },
 )

In [None]:
#az.summary(data=idata_generated, var_names=["sigma"], round_to=3)
az.summary(data=idata_parameters, var_names=["log_sigma", "log_price_coefficient", "yday_coefficients", "wday_coefficients"], round_to=3)

In [None]:
az.summary(data=idata_generated, var_names=["state"], round_to=3)

In [None]:
state = jnp.mean(posterior_generated['state'], axis=0).tolist()
df = df.with_columns([ pl.Series("state", state) ])
df

In [None]:
posterior_generated

In [None]:
x = pd.DataFrame({ 'date': df["date"].to_numpy(), 'sales': df["sales"].to_numpy(), 'state': m['state'] })
ggplot(x, aes(x='date', y='sales')) + geom_point() + geom_line(aes(y='state'), color = "red") + theme_bw()