# SKU Dataset

## Environment Variables

In [None]:
%matplotlib widget

%reload_ext autoreload
%autoreload

%set_env XLA_PYTHON_CLIENT_PREALLOCATE false
%set_env CUDA_VISIBLE_DEVICES 0

## Imports

In [4]:
import keras

import numpy as np
import polars as pl

import matplotlib.pyplot as plt
import seaborn as sns

import jax
from jax import random
from jax import numpy as jnp
from jax.scipy import optimize

from loss_functions import MAAPE, SMASPE, MASPE
from sku_models import arima_run, additive_hw_run

from functools import partial

sns.set_theme(
    context="paper",
    style="whitegrid",
    font="Roboto",
    font_scale=2
)

KEY = random.key(0)
EPS = keras.backend.epsilon()


def fft_com(x):
  x_fft = np.abs(np.fft.rfft(x))
  return np.sum(x_fft * np.linspace(0, 1, len(x_fft))) / np.sum(x_fft[1:])


# Keeping only time-series with at least some intermittence
DATASET = pl.read_csv("./data/sku.csv").select(
    [
        "Scode",
        "Pcode",
        pl.selectors.matches(r"Wk\d+")
    ]
).with_columns(
    pl.concat_arr(pl.col(r"^Wk\d+$")).alias("ts")
).drop(r"^Wk\d+$").with_columns(
    intermittence=pl.col("ts").map_elements(
        fft_com,
        returns_scalar=True,
        return_dtype=pl.Float64
    )
).sort(["Scode", "Pcode"]).to_dicts()

X = jnp.stack(
    [
        jnp.asarray(row["ts"])
        for row in DATASET
    ],
    dtype=float
)

## Visualization

In [None]:
n_vis = 5
n_rows = len(DATASET)
step = n_rows // n_vis

fig = plt.figure(figsize=(12, 6))
for row in DATASET[::step]:
  label = f"{row["Scode"]}: {row["Pcode"]}"
  plt.plot(
      np.asarray(row["ts"]),
      label=label
  )
plt.ylabel("Units Sold")
plt.xlabel("Week")
plt.legend(loc="upper right")
fig.tight_layout(pad=2)
plt.autoscale(tight=True)
fig.savefig("./plots/sku_timeseries.svg", transparent=True)
plt.close(fig)

## Benchmarks

### Global Definitions

In [11]:
SPLIT_IDX = 95
X_TRAIN, X_TEST = X[:, :SPLIT_IDX], X[:, SPLIT_IDX:]
BOUNDS = [
    jnp.min(X_TRAIN, axis=-1, keepdims=True),
    jnp.max(X_TRAIN, axis=-1, keepdims=True)
]
DELTA = BOUNDS[1] - BOUNDS[0]

losses = {
    "MAE": keras.losses.MeanAbsoluteError(),
    "MSE": keras.losses.MeanSquaredError(),
    "MAAPE": MAAPE(),
    "MASPE": MASPE(),
    "SMASPE (tight)": SMASPE(y_minus=BOUNDS[0], y_plus=BOUNDS[1]),
    "SMASPE (loose)": SMASPE(
        y_minus=BOUNDS[0] - 0.05 * DELTA,
        y_plus=BOUNDS[1] + 0.05 * DELTA
    ),
}


def process_sku_results(results: dict[str, any], suffix: str):
  # Printed Results
  for loss_name, result in results.items():
    row = f"{loss_name} & {result["median_epochs"]}"
    for _, metric_fn in losses.items():
      row += f" & {metric_fn(
          X_TEST,
          result["predictions"][:, -X_TEST.shape[-1]:]
      ):.3}"
    print(row + r" \\")

  # Test Error Matrix
  test_err_matrix = jnp.stack([
      (X_TEST - results[k]["predictions"][:, -X_TEST.shape[-1]:]).flatten()
      for k in results
  ], -1)

  # Boxplots
  boxplot_fig = plt.figure(figsize=(12, 6))
  plt.boxplot(
      test_err_matrix,
      tick_labels=results.keys(),
  )
  plt.xlabel("Training Loss Function")
  plt.ylabel(r"Error Distribution ($y - \hat{y}$)")
  plt.yscale("symlog", linthresh=1., linscale=1.)
  boxplot_fig.tight_layout(pad=2)
  boxplot_fig.savefig(
      f"./plots/sku_err_boxplot_{suffix}.svg",
      transparent=True
  )
  plt.close(boxplot_fig)

  boxplot_fig = plt.figure(figsize=(12, 6))
  plt.boxplot(
      test_err_matrix / jnp.clip(X_TEST.flatten(), EPS, jnp.inf)[:, None],
      tick_labels=results.keys(),
  )
  plt.xlabel("Training Loss Function")
  plt.ylabel(r"Relative Error Distribution ($\bar{\delta}$)")
  plt.yscale("symlog", linthresh=.1, linscale=1.)
  boxplot_fig.tight_layout(pad=2)
  boxplot_fig.savefig(
      f"./plots/sku_delta_boxplot_{suffix}.svg",
      transparent=True
  )
  plt.close(boxplot_fig)

  # Histograms
  hist_fig, hist_axs = plt.subplots(
      nrows=2,
      ncols=1,
      sharex=True,
      sharey=True,
      figsize=(12, 12)
  )
  err_bins = np.logspace(
      np.log10(5E-6),
      np.log10(5E11),
      10
  )
  ref_bins = np.pad(
      np.logspace(0, np.log10(np.max(X_TEST)), 9),
      (1, 0),
      constant_values=0.
  )
  for idx, loss_name in enumerate(["MSE", "SMASPE (loose)"]):
    _, _, _, mpbl = hist_axs[idx].hist2d(
        x=X_TEST.flatten(),
        y=np.abs(test_err_matrix[:, list(results.keys()).index(loss_name)]),
        bins=[ref_bins, err_bins],
        cmap="viridis",
        norm="log"
    )
    hist_axs[idx].set_ylabel(
        r"$\left| y - \hat{y}_\text{" + loss_name + r"} \right|$"
    )
    hist_axs[idx].set_xscale("symlog", linthresh=1., linscale=.5)
    hist_axs[idx].set_yscale("log")
    hist_axs[idx].set_ylim((1e-6, 1e12))
    cbar = plt.colorbar(mpbl, ax=hist_axs[idx])
    cbar.ax.set_ylabel("Frequency")
  hist_axs[-1].set_xlabel(r"$y$")
  hist_fig.tight_layout(pad=2)
  hist_fig.savefig(
      f"./plots/sku_err_histogram_{suffix}.svg",
      transparent=True
  )
  plt.close(hist_fig)

  # Forecasts
  pred_tensor = jnp.stack([results[k]["predictions"] for k in results])
  eff_length = pred_tensor.shape[-1]
  product_abs_err = jnp.sum(
      jnp.abs(
          test_err_matrix.transpose()[-2:].reshape([-1, *X_TEST.shape])
      ),
      (0, -1)
  )

  week_range = np.asarray(range(eff_length)) + (X.shape[-1] - eff_length)

  best_idx, worst_idx = [
      np.argmin(product_abs_err),
      np.argmax(product_abs_err)
  ]
  best_ref, worst_ref = [
      X[best_idx, -eff_length:],
      X[worst_idx, -eff_length:]
  ]

  ex_fig, ex_axs = plt.subplots(
      nrows=2,
      ncols=1,
      sharex=True,
      figsize=(12, 12)
  )

  ex_axs[0].plot(
      week_range,
      best_ref,
      ":",
      label="Reference"
  )
  ex_axs[0].set_ylim([10, 1E3])
  ex_axs[0].axvline(
      SPLIT_IDX + 0.5,
      linestyle=":", color="red"
  )

  ex_axs[1].plot(
      week_range,
      worst_ref,
      ":",
      label="Reference"
  )
  ex_axs[1].set_ylim([0, 1E4])
  ex_axs[1].axvline(
      SPLIT_IDX + 0.5,
      linestyle="--", color="red"
  )

  for loss_name in results:
    ex_axs[0].plot(
        week_range,
        results[loss_name]["predictions"][best_idx],
        label=loss_name
    )
    ex_axs[1].plot(
        week_range,
        results[loss_name]["predictions"][worst_idx],
        label=loss_name
    )

  ex_axs[0].legend(loc="upper left")

  ex_axs[0].set_yscale("symlog", linthresh=1., linscale=1.)
  ex_axs[1].set_yscale("symlog", linthresh=1., linscale=1.)
  ex_axs[1].set_xlabel("Week")

  ex_axs[0].set_xlim(left=20, right=week_range[-1])
  ex_axs[1].set_xlim(left=20, right=week_range[-1])

  ex_fig.tight_layout(pad=2)
  ex_fig.savefig(
      f"./plots/sku_forecasts_{suffix}.svg",
      transparent=True
  )
  plt.close(ex_fig)

### Non-Seasonal ARIMA

Each model is trained as a single sequence to preserve MA coherence. However, since we are studying the one-step-ahead case, the real values of the signal are used during the iterative forecast.

In [None]:
P, D, Q = 4, 1, 4


def arima_helper(
        z_train: jnp.ndarray,
        z_test: jnp.ndarray,
        loss: keras.Loss) -> tuple[jnp.ndarray, int]:
  opt_result = optimize.minimize(
      lambda w: loss(
          z_train[P + D:],
          arima_run(w, z_train, [P, D, Q], len(z_train) - P - D)
      ),
      jnp.zeros(1 + P + Q + 1),
      method="BFGS",
      tol=EPS
  )

  pred = arima_run(
      opt_result.x,
      jnp.concatenate(
          [
              z_train,
              z_test
          ]
      ),
      [P, D, Q],
      len(z_train) + len(z_test) - (P + D)
  )
  return pred, opt_result.nit


results_arima = {}

for loss_name, loss_fn in losses.items():
  if loss_name == "MAE":
    continue
  print(f"Training with {loss_name}...", flush=True)

  if isinstance(loss_fn, SMASPE):
    predictions, iterations = jax.vmap(
        lambda z_train, z_test, ym, yp: arima_helper(
            z_train=z_train,
            z_test=z_test,
            loss=SMASPE(ym, yp)
        )
    )(
        X_TRAIN,
        X_TEST,
        loss_fn.y_minus[:, 0],
        loss_fn.y_plus[:, 0]
    )
  else:
    predictions, iterations = jax.vmap(
        partial(arima_helper, loss=loss_fn)
    )(z_train=X_TRAIN, z_test=X_TEST)

  results_arima[loss_name] = {
      "predictions": jnp.clip(predictions, 0, jnp.inf),
      "median_epochs": jnp.median(iterations)
  }

In [None]:
process_sku_results(results_arima, "arima")

### Holt-Winters (Triple Exponential Smoothing)

In [None]:
M = 4


def hw_helper(
        z_train: jnp.ndarray,
        z_test: jnp.ndarray,
        loss: keras.Loss) -> tuple[jnp.ndarray, int]:
  opt_result = optimize.minimize(
      lambda w: loss(
          z_train[3 * M + 1:],
          additive_hw_run(w, z_train, M, len(z_train) - (3 * M + 1))
      ),
      0.5 * jnp.ones(3),
      method="BFGS",
      tol=EPS
  )

  pred = additive_hw_run(
      opt_result.x,
      jnp.concatenate(
          [
              z_train,
              z_test
          ]
      ),
      M,
      len(z_train) + len(z_test) - (3 * M + 1)
  )

  return pred, opt_result.nit


results_hw = {}

for loss_name, loss_fn in losses.items():
  if loss_name == "MAE":
    continue
  print(f"Training with {loss_name}...", flush=True)

  if isinstance(loss_fn, SMASPE):
    predictions, iterations = jax.vmap(
        lambda z_train, z_test, ym, yp: hw_helper(
            z_train=z_train,
            z_test=z_test,
            loss=SMASPE(ym, yp)
        )
    )(
        X_TRAIN,
        X_TEST,
        loss_fn.y_minus[:, 0],
        loss_fn.y_plus[:, 0]
    )
  else:
    predictions, iterations = jax.vmap(
        partial(hw_helper, loss=loss_fn)
    )(z_train=X_TRAIN, z_test=X_TEST)

  results_hw[loss_name] = {
      "predictions": jnp.clip(predictions, 0, jnp.inf),
      "median_epochs": jnp.median(iterations)
  }

In [None]:
process_sku_results(results_hw, "hw")