#  Training Double Machine Learning (DML)

This notebook implements a **Double Machine Learning (DML) Forecaster** that follows the method from  
*Causal Forecasting for Pricing* (Schultz et al.). Instead of transformers, we use **LightGBM** models.

---

## 1. Overall Idea

We want to estimate how demand $q_t$ changes when we change the discount $d_t$, while controlling for  
a set of covariates $z_t$ (time features, product features, lags, etc.).

A naive regression of $q_t$ on $d_t$ is biased because discounts are **confounded** with these  
covariates (seasonality, life-cycle, etc.). The DML Forecaster fixes this by:

1. Learning what demand would be **without** using the current discount (`outcome model`).
2. Learning how the **discount policy** behaves given covariates (`treatment model`).
3. Learning an **elasticity function** $\psi(z_t)$ on top of those two nuisance models.
4. Using $\psi(z_t)$ to simulate demand under new, counterfactual discounts.

---

## 2. Two Nuisance Models

We first train two "nuisance" models:

- **Outcome model**:  
  $$
  \hat q_t(z) \approx \mathbb{E}[q_t \mid z_t]
  $$
  Predicts demand from covariates, *without* using the current discount as an input.

- **Treatment model**:  
  $$
  \hat d_t(z) \approx \mathbb{E}[d_t \mid z_t]
  $$
  Predicts the discount that the pricing policy would choose, given the same covariates.

These models are trained with LightGBM regressors. They capture all the "business as usual" structure in the data.

---

## 3. Sample-Splitting and Cross-Fitting

To avoid overfitting and to get approximately unbiased residuals, we use **sample splitting**:

1. Split items into two disjoint sets, e.g.
   - **Even** item IDs
   - **Odd** item IDs

2. Train nuisance models on one half and predict on the other:
   - Train on **even** items, predict $\hat q_t, \hat d_t$ for **odd** items.
   - Train on **odd** items, predict $\hat q_t, \hat d_t$ for **even** items.

This gives **cross-fitted predictions** $\hat q_t$ and $\hat d_t$ for every observation, where the model never sees the same item both in training and prediction.  
These cross-fitted nuisance predictions are then treated as "fixed" when estimating the effect.

---

## 4. Effect Model: Learning Elasticity $\psi(z)$

The key quantity we want is a **context-dependent elasticity** $\psi(z_t)$.

Starting from the multiplicative pricing formula:

$$
\tilde q_t(d_t) = \hat q_t(z_t)
\left( \frac{1 - d_t}{1 - \hat d_t(z_t)} \right)^{\psi(z_t)},
$$

we take logs and rearrange to get a **pseudo-regression target**:

$$
\log q_t - \log \hat q_t
\;\approx\;
\psi(z_t)\,
\Big[\,\log(1 - d_t) - \log(1 - \hat d_t)\,\Big].
$$

Define:

- **Numerator (response)**:
  $$
  y^{(\psi)}_t
  = \log q_t - \log \hat q_t
  $$
- **Denominator (design scalar)**:
  $$
  x^{(\psi)}_t
  = \log(1 - d_t) - \log(1 - \hat d_t)
  $$

Then, pointwise,

$$
\psi(z_t) \approx \frac{y^{(\psi)}_t}{x^{(\psi)}_t}.
$$

In practice:

1. We compute $y^{(\psi)}_t$ and $x^{(\psi)}_t$ using cross-fitted $\hat q_t$ and $\hat d_t$.
2. We drop rows where $|x^{(\psi)}_t|$ is too small (no meaningful price change ⇒ no elasticity signal).
3. We form **labels**  
   $$
   \tilde \psi_t = \frac{y^{(\psi)}_t}{x^{(\psi)}_t}
   $$
   and regress them on the covariates $z_t$ using another LightGBM regressor:
   $$
   \hat\psi(z_t) \approx \tilde \psi_t.
   $$

This third model is the **effect head**: it estimates how sensitive demand is to relative price changes, conditional on covariates.

---

## 5. Final Prediction

Once we have:

- $\hat q_t(z)$ – outcome model prediction
- $\hat d_t(z)$ – treatment model prediction
- $\hat\psi(z)$ – elasticity estimate

we can predict demand for any discount scenario $d^\star_t$:

$$
\tilde q_t(d^\star_t)
= \hat q_t(z_t)
\left(
  \frac{1 - d^\star_t}{1 - \hat d_t(z_t)}
\right)^{\hat\psi(z_t)}.
$$

Special cases:

- **On-policy prediction** (using the actually observed discount $d_t$):  
  set $d^\star_t = d_t$.

- **Off-policy / counterfactual prediction** (e.g. "What if we always discounted 30%?"):  
  set $d^\star_t = 0.3$ for all rows, or pass any other scalar / vector of discounts.

This gives a **causally-aware forecaster**: it uses all the structure learned in the nuisance models,  
but adjusts demand as if we had chosen a different discount, using the learned elasticity $\hat\psi(z_t)$.

---

In [None]:
import numpy as np
import pandas as pd
from lightgbm import LGBMRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error

## Load Data

Load the synthetic pricing data generated in the previous notebook.

In [None]:
# Load from Databricks table
from seaborn._core.typing import default


catalog = "causal_forecasting"      # Change here
schema = "default"                  # Change here

table_name = f"{catalog}.{schema}.synthetic_data"
df = spark.table(table_name).toPandas()

print(f"Loaded {len(df):,} rows")
print(f"Shape: {df.shape}")
df.head()

## Train/Test Split

Split data along time dimension (weeks 0-39 for training, weeks 40-59 for testing).

In [None]:
train_df = df[df["week"] < 40].copy()
test_df = df[df["week"] >= 40].copy()

print(f"Training set: {len(train_df):,} rows")
print(f"Test set: {len(test_df):,} rows")

## Define Feature Columns

Covariate set Z_t (excludes current discount, but includes lag_discount).

In [None]:
feature_cols = [
    "week",
    "base_price",
    "category",
    "k_category",
    "season_type",
    "lag_demand",
    "lag_discount",
    "week_sin",
    "week_cos",
]

In [None]:
class DMLForecasterLGBM:
    """
    LGBM implementation of the DML Forecaster (Section 3, Eq. (6))
    from 'Causal Forecasting for Pricing' (Schultz et al., 2024).

    Structure:
      - outcome nuisance model:   q_hat(z)
      - treatment nuisance model: d_hat(z)
      - effect model:             psi(z) = elasticity

    Training:
      1) Split items into even / odd sets (sample splitting).
      2) Train nuisance models on even, predict on odd; then
         train nuisance models on odd, predict on even
         -> cross-fitted q_hat_cf, d_hat_cf for all rows.
      3) Build elasticity labels from Eq. (6) in log form:
           log(q) - log(q_hat_cf)
           --------------------------------------------  ≈ psi(z)
           log(1 - d) - log(1 - d_hat_cf)
      4) Fit effect LGBM on z -> psi(z).

    Prediction:
      - Use both nuisance models (even/odd) on all rows and
        average their predictions:
          q_hat = (q_hat_even + q_hat_odd) / 2
          d_hat = (d_hat_even + d_hat_odd) / 2
      - Compute psi(z) from effect model.
      - Apply Eq. (6) for desired discount scenario d*:
          q_pred = q_hat * ((1 - d*) / (1 - d_hat)) ** psi(z)
    """

    def __init__(
        self,
        outcome_params=None,
        treatment_params=None,
        effect_params=None,
        random_state: int = 0,
        min_delta_log_price: float = 1e-3,
    ):
        # Default LGBM params (you can override via *_params)
        base_params = {
            "n_estimators": 200,
            "learning_rate": 0.05,
            "max_depth": -1,
            "random_state": random_state,
        }
        def _merge(user):
            p = base_params.copy()
            if user is not None:
                p.update(user)
            return p

        self.outcome_params = _merge(outcome_params)
        self.treatment_params = _merge(treatment_params)
        self.effect_params = _merge(effect_params)

        self.random_state = random_state
        self.min_delta_log_price = min_delta_log_price

        # Will be populated in fit()
        self.outcome_model_even_ = None
        self.outcome_model_odd_ = None
        self.treatment_model_even_ = None
        self.treatment_model_odd_ = None
        self.effect_model_ = None

        self.feature_cols_ = None
        self.treatment_col_ = None
        self.outcome_col_ = None
        self.item_col_ = None

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------
    def _split_even_odd(self, df: pd.DataFrame):
        # Split by item_id parity, like even/odd batches in the paper.
        items = df[self.item_col_].unique()
        items_even = set([i for i in items if int(i) % 2 == 0])
        mask_even = df[self.item_col_].isin(items_even)
        mask_odd = ~mask_even
        return mask_even.values, mask_odd.values

    # ------------------------------------------------------------------
    # Fit
    # ------------------------------------------------------------------
    def fit(
        self,
        df: pd.DataFrame,
        feature_cols,
        treatment_col: str = "discount",
        outcome_col: str = "demand",
        item_col: str = "item_id",
    ):
        """
        Fit nuisance and effect models following the DML Forecaster
        approach (sample splitting + elasticity head).

        Parameters
        ----------
        df : pd.DataFrame
            Training data with columns outcome_col, treatment_col, item_col
            and feature_cols (z_t).
        feature_cols : list of str
            Covariates z_t (no current discount; past discount OK).
        """
        self.feature_cols_ = list(feature_cols)
        self.treatment_col_ = treatment_col
        self.outcome_col_ = outcome_col
        self.item_col_ = item_col

        # Ensure no NaNs for training
        df = df.copy()
        df = df.dropna(subset=self.feature_cols_ + [treatment_col, outcome_col, item_col])

        X_all = df[self.feature_cols_].values
        d_all = df[treatment_col].values.astype(float)
        y_all = df[outcome_col].values.astype(float)

        # 1) Split into even/odd item sets
        mask_even, mask_odd = self._split_even_odd(df)
        X_even, X_odd = X_all[mask_even], X_all[mask_odd]
        y_even, y_odd = y_all[mask_even], y_all[mask_odd]
        d_even, d_odd = d_all[mask_even], d_all[mask_odd]

        # 2) Train nuisance models on even, predict on odd; and vice versa
        # Outcome models
        self.outcome_model_even_ = LGBMRegressor(**self.outcome_params)
        self.outcome_model_even_.fit(X_even, y_even)

        self.outcome_model_odd_ = LGBMRegressor(**self.outcome_params)
        self.outcome_model_odd_.fit(X_odd, y_odd)

        # Treatment models
        self.treatment_model_even_ = LGBMRegressor(**self.treatment_params)
        self.treatment_model_even_.fit(X_even, d_even)

        self.treatment_model_odd_ = LGBMRegressor(**self.treatment_params)
        self.treatment_model_odd_.fit(X_odd, d_odd)

        # Cross-fitted nuisance predictions: q_hat_cf, d_hat_cf
        q_hat_cf = np.zeros_like(y_all, dtype=float)
        d_hat_cf = np.zeros_like(d_all, dtype=float)

        # For odd rows, use models trained on even
        q_hat_cf[mask_odd] = self.outcome_model_even_.predict(X_odd)
        d_hat_cf[mask_odd] = self.treatment_model_even_.predict(X_odd)

        # For even rows, use models trained on odd
        q_hat_cf[mask_even] = self.outcome_model_odd_.predict(X_even)
        d_hat_cf[mask_even] = self.treatment_model_odd_.predict(X_even)

        # 3) Build effect training data: elasticity labels from Eq. (6) in log form
        eps_q = 1e-3
        eps_d = 1e-6

        # log q and log q_hat
        log_q = np.log(np.clip(y_all, eps_q, None))
        log_q_hat = np.log(np.clip(q_hat_cf, eps_q, None))
        num = log_q - log_q_hat

        # log(1-d) and log(1-d_hat)
        log_1_minus_d = np.log(np.clip(1.0 - d_all, eps_d, None))
        log_1_minus_d_hat = np.log(np.clip(1.0 - d_hat_cf, eps_d, None))
        den = log_1_minus_d - log_1_minus_d_hat

        # Drop rows where denominator is ~0 (no price difference => no elasticity info)
        mask_effect = np.abs(den) > self.min_delta_log_price
        if mask_effect.sum() == 0:
            raise RuntimeError("No rows with sufficient discount variation for effect model training.")

        psi_labels = num[mask_effect] / den[mask_effect]
        X_effect = X_all[mask_effect]

        # 4) Fit effect model psi(z)
        self.effect_model_ = LGBMRegressor(**self.effect_params)
        self.effect_model_.fit(X_effect, psi_labels)

        return self

    # ------------------------------------------------------------------
    # Prediction
    # ------------------------------------------------------------------
    def predict(
        self,
        df: pd.DataFrame,
        discount_scenario=None,
    ) -> np.ndarray:
        """
        Predict demand under a given discount scenario using Eq. (6).

        Parameters
        ----------
        df : pd.DataFrame
            Data with feature_cols, treatment_col, outcome_col, item_col.
        discount_scenario : float, array-like, or None
            - None: use logged discount from df[treatment_col_].
            - scalar: use this discount for all rows.
            - array-like: per-row discounts.

        Returns
        -------
        np.ndarray
            Predicted demand for each row.
        """
        if any(m is None for m in [
            self.outcome_model_even_,
            self.outcome_model_odd_,
            self.treatment_model_even_,
            self.treatment_model_odd_,
            self.effect_model_,
        ]):
            raise RuntimeError("Model not fitted. Call .fit() first.")

        df = df.copy()
        df = df.dropna(subset=self.feature_cols_)

        X = df[self.feature_cols_].values

        # Nuisance predictions: average even/odd models (forecast-style)
        q_hat_even = self.outcome_model_even_.predict(X)
        q_hat_odd = self.outcome_model_odd_.predict(X)
        q_hat = 0.5 * (q_hat_even + q_hat_odd)

        d_hat_even = self.treatment_model_even_.predict(X)
        d_hat_odd = self.treatment_model_odd_.predict(X)
        d_hat = 0.5 * (d_hat_even + d_hat_odd)

        # Effect model: psi(z)
        psi_hat = self.effect_model_.predict(X)

        # Desired discount scenario
        if discount_scenario is None:
            d_new = df[self.treatment_col_].values.astype(float)
        else:
            if np.isscalar(discount_scenario):
                d_new = np.full(len(df), float(discount_scenario), dtype=float)
            else:
                d_new = np.asarray(discount_scenario, dtype=float)
                if d_new.shape[0] != len(df):
                    raise ValueError("discount_scenario has wrong length.")

        eps_d = 1e-6
        # Eq. (6): q_hat * ((1 - d_new) / (1 - d_hat)) ** psi_hat
        ratio = np.clip(1.0 - d_new, eps_d, None) / np.clip(1.0 - d_hat, eps_d, None)
        q_pred = np.clip(q_hat, 1e-3, None) * (ratio ** psi_hat)

        return q_pred


## Train DML Forecaster

Instantiate and train the DML forecaster with 2-fold cross-fitting.

In [None]:
forecaster = DMLForecasterLGBM(
    outcome_params={"n_estimators": 200, "learning_rate": 0.05, "max_depth": -1},
    treatment_params={"n_estimators": 200, "learning_rate": 0.05, "max_depth": -1},
    effect_params={"n_estimators": 200, "learning_rate": 0.05, "max_depth": -1},
    n_folds=2,
    random_state=0,
)

forecaster.fit(
    train_df,
    feature_cols=feature_cols,
    treatment_col="discount",
    outcome_col="demand",
    item_col="item_id",
)

print("DML forecaster trained successfully!")

## Evaluate: On-Policy Predictions

Evaluate the model using the actual discounts in the test set.

In [None]:
y_true = test_df["demand"].values
y_pred_on = forecaster.predict(test_df)

mae_on = mean_absolute_error(y_true, y_pred_on)
mse_on = mean_squared_error(y_true, y_pred_on)
rmse_on = np.sqrt(mse_on)

print("On-policy performance (using logged discounts):")
print(f"  MAE = {mae_on:.3f}")
print(f"  RMSE = {rmse_on:.3f}")
print(f"  MSE = {mse_on:.3f}")

## Evaluate: Off-Policy Predictions

Evaluate counterfactual scenarios with different discount levels.

In [None]:
for disc in [0.1, 0.3, 0.5]:
    y_pred_off = forecaster.predict(test_df, discount_scenario=disc)
    mae_off = mean_absolute_error(y_true, y_pred_off)
    mse_off = mean_squared_error(y_true, y_pred_off)
    rmse_off = np.sqrt(mse_off)
    print(f"\nOff-policy performance (counterfactual discount = {disc:.1f}):")
    print(f"  MAE vs actual realized demand = {mae_off:.3f}")
    print(f"  RMSE vs actual realized demand = {rmse_off:.3f}")
    print(f"  MSE vs actual realized demand = {mse_off:.3f}")

## Demand Sensitivity Analysis

Inspect how average predicted demand changes with different discount levels.

In [None]:
print("Average predicted demand at different discount levels:\n")
for disc in [0.0, 0.2, 0.4, 0.6]:
    y_pred = forecaster.predict(test_df, discount_scenario=disc)
    print(f"  Discount {disc:.1f}: {y_pred.mean():.2f}")