# MMM Calibration with Geo-Level Lift Tests

## Introduction

This notebook demonstrates how to calibrate a multidimensional MMM using lift test results from a geo-level experiment. By incorporating experimental lift measurements, we can achieve better parameter recovery--both reduced bias and increased precision--compared to fitting the MMM alone.

We follow the same pattern as the [national-level lift test notebook](mmm_lift_test.ipynb), generating data directly from the model to ensure perfect consistency.

### Notebook Overview

This is a long technical document, so here's a roadmap of what we'll cover:

**Setup (Sections 1-2)**
- We simulate a marketing dataset with **8 geographic regions** and **2 media channels**
- The two channels are highly correlated (~0.99), making them difficult to separate--a common real-world challenge

**The Experiment (Section 3)**
- We run a geo-level lift test on **channel 1 only**
- **4 treated geos** receive an incremental spend increase; **4 control geos** remain unchanged
- This produces **4 lift test measurements**--one per treated geo--each estimating the causal effect of channel 1

**Analysis (Sections 4-6)**
- We fit the MMM twice: once without calibration (baseline) and once with lift test calibration
- We compare parameter recovery, showing that calibration reduces bias and increases precision
- Key result: the calibrated model's saturation curves more closely match the true curves

### When to Use Geo-Level Lift Tests

Geo-level lift tests are appropriate when you can:
- **Target specific geographic regions**: Digital campaigns with location targeting, regional TV markets
- **Hold out control regions**: Some geos receive the treatment while others serve as controls
- **Measure regional outcomes**: Sales, conversions, or other metrics at the geographic level

### The Practical Workflow: From Experiment to Calibration

In practice, running a geo-level lift test and using it to calibrate an MMM involves several steps:

1. **Design and run the experiment**: Increase (or decrease) spend on one channel in a subset of geographic regions (treated geos), while keeping spend unchanged in the remaining regions (control geos). Run the experiment for several weeks.

2. **Analyze the experiment with synthetic control**: Use a method like [CausalPy's multi-cell GeoLift](https://causalpy.readthedocs.io/en/latest/notebooks/multi_cell_geolift.html) to estimate the causal effect. Synthetic control constructs a weighted combination of control geos that matches each treated geo's pre-experiment trend, then measures the post-experiment divergence as the treatment effect.

3. **Extract lift test measurements**: From the CausalPy analysis, obtain for each treated geo:
   - x: the baseline spend level during the test period
   - delta_x: the incremental spend change applied
   - delta_y: the estimated causal lift in outcomes
   - sigma: the uncertainty (standard error) of the lift estimate

4. **Calibrate the MMM**: Pass these lift test measurements to the MMM's add_lift_test_measurements() method to constrain the saturation curve parameters during model fitting.

**Scope of this notebook**: This notebook focuses on step 4--demonstrating how lift test measurements improve MMM parameter recovery. We simulate the lift test results directly rather than running a full CausalPy analysis, since this is a PyMC-Marketing tutorial. For the experiment analysis step, see the [CausalPy documentation](https://causalpy.readthedocs.io/en/latest/notebooks/multi_cell_geolift.html).

### Why Lift Tests Work: Pinning Down the Saturation Curve

It's worth understanding *why* lift tests improve parameter estimation--the benefit goes beyond simply adding more variation to your data.

**The problem with observational data alone**

When fitting an MMM to observational data, you observe spend (X) and outcomes (Y) moving together. But correlation isn't causation:
- Maybe sales rise when spend rises because the spend *caused* more sales
- Or maybe both rose because of an external factor (seasonality, promotions, economic conditions)
- With highly correlated channels, the model can't tell which channel is actually driving sales--many different parameter combinations fit the data equally well

**What a lift test provides: causal ground truth**

A lift test is an *experiment*. By randomly assigning some geos to receive increased spend while others serve as controls, the difference in outcomes between groups is a **causal effect**, not just a correlation. The lift test tells you:

> "When we increased channel 1 spend from x to x + delta_x, the true causal lift in sales was delta_y"

This is a point on the saturation curve that we *know* to be correct, not just inferred from correlational patterns.

**How calibration uses this information**

When we add lift test measurements to the MMM, we're adding a constraint that says: the model's saturation curve must pass through (or near) the experimentally measured point. Mathematically:

$$\text{saturation}(x + \Delta x) - \text{saturation}(x) \approx \Delta y$$

This pins down the saturation curve at the operating point where the experiment was run, dramatically reducing the set of feasible (lambda, beta) parameter values.

**The key distinction**

| Aspect | More Spend Variation | Lift Test Calibration |
|--------|----------------------|----------------------|
| **Provides** | More data points | Causal ground truth |
| **Helps with** | Precision (narrower posteriors) | Accuracy (reduced bias) |
| **Addresses** | Uncertainty in parameters | Bias in parameters |

In short: more data variation helps you estimate parameters more precisely, but a lift test helps you estimate the *right* parameters.

### Key Design Principles

1. **Normalized data**: Channel spend is normalized to [0, 1] range for consistent saturation behavior
2. **Model-based data generation**: We use `pm.do` and `pm.draw` to generate synthetic data from the model itself, ensuring the data perfectly matches model assumptions
3. **Consistent lift tests**: Lift test measurements are calculated using the same saturation function the model uses

## Prepare Notebook

In [None]:
from functools import partial

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr
from pymc_extras.prior import Prior

from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
from pymc_marketing.mmm.transformers import logistic_saturation

# warnings.filterwarnings("ignore", category=UserWarning)

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = (12, 7)
plt.rcParams["figure.dpi"] = 100
az.style.use("arviz-darkgrid")

In [None]:
# Set random seed for reproducibility
seed = sum(map(ord, "Geo lift tests for MMM calibration"))
rng = np.random.default_rng(seed)
print(f"Random seed: {seed}")

## Generate Synthetic Data

Our data generation follows a two-step process that ensures perfect consistency between the synthetic data and model assumptions:

### Step 1: Manually Simulate Channel Spend (X)

We manually create the **input data** — the marketing spend decisions:
- `channel_1`, `channel_2`: Normalized spend values in [0, 1] range
- `date`: Time dimension (weekly data)
- `geo`: Geographic dimension

This represents **business decisions** that are external to the model. We intentionally create highly correlated channels to simulate the identification problem that lift tests help solve.

### Step 2: Generate Outcomes from the Model (y)

We generate the **target variable** (sales/conversions) using `pm.do` and `pm.draw`:
1. Build the MMM with the spend data X
2. Use `pm.do` to fix model parameters to known "true" values
3. Use `pm.draw` to sample y from the model

This approach ensures that y is generated **exactly as the model expects**, including:
- Adstock transformations applied correctly
- Saturation curves computed consistently
- Noise structure matching the model's likelihood
- Internal scaling handled properly

**Why not generate y manually?** Manual generation (e.g., `y = intercept + beta * saturation(adstock(x)) + noise`) can introduce subtle mismatches with how the model actually computes predictions, leading to poor parameter recovery.

---

### Create Normalized Channel Spend Data

Following the national-level notebook, we normalize spend data to [0, 1] range. This ensures:
- Saturation parameters (lam) have intuitive values (e.g., 5-15)
- The model's internal scaling doesn't create mismatches
- Lift test measurements are consistent with model assumptions

In [None]:
# Define dimensions
n_dates = 104  # 2 years of weekly data
geos = [f"geo_{i:02d}" for i in range(8)]  # 8 geos for tractable computation
n_geos = len(geos)
channels = ["channel_1", "channel_2"]
n_channels = len(channels)

# Create date range
dates = pd.date_range(start="2022-01-03", periods=n_dates, freq="W-MON")

print("Data dimensions:")
print(f"  Dates: {n_dates} weeks ({dates[0].date()} to {dates[-1].date()})")
print(f"  Geos: {n_geos}")
print(f"  Channels: {n_channels}")

In [None]:
# Generate normalized spend data (0-1 range)
# Channel 1 and 2 are highly correlated (the identification problem)
rows = []

for geo in geos:
    # Generate base spend pattern (shared across channels for correlation)
    base_spend = pm.draw(
        pm.Uniform.dist(lower=0.2, upper=1.0, size=n_dates), random_seed=rng
    )
    base_spend = base_spend / base_spend.max()  # Normalize to max=1

    # Add geo-specific scaling
    geo_scale = 0.7 + 0.3 * (geos.index(geo) / (n_geos - 1))  # 0.7 to 1.0

    for i, date in enumerate(dates):
        # Channel 1: base pattern with small noise
        ch1 = base_spend[i] * geo_scale + rng.normal(0, 0.02)
        ch1 = np.clip(ch1, 0.1, 1.0)

        # Channel 2: highly correlated with channel 1 (shifted slightly)
        ch2 = base_spend[i] * geo_scale * 0.95 + rng.normal(0, 0.02)
        ch2 = np.clip(ch2, 0.1, 1.0)

        rows.append(
            {
                "date": date,
                "geo": geo,
                "channel_1": ch1,
                "channel_2": ch2,
            }
        )

df = pd.DataFrame(rows)

# Verify normalization
print("Channel spend ranges (should be ~0-1):")
for ch in channels:
    print(f"  {ch}: [{df[ch].min():.3f}, {df[ch].max():.3f}]")

# Check correlation
corr = df[["channel_1", "channel_2"]].corr().iloc[0, 1]
print(
    f"\nChannel correlation: {corr:.3f} (high correlation = identification challenge)"
)

In [None]:
# Visualize spend patterns
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Spend over time for sample geos
ax = axes[0]
sample_geos = [geos[0], geos[4], geos[7]]
for geo in sample_geos:
    geo_data = df[df["geo"] == geo]
    ax.plot(geo_data["date"], geo_data["channel_1"], label=f"{geo} - ch1", alpha=0.7)
ax.set_title("Channel 1 Spend Over Time (Sample Geos)")
ax.set_xlabel("Date")
ax.set_ylabel("Normalized Spend")
ax.legend()

# Plot 2: Channel correlation
ax = axes[1]
ax.scatter(df["channel_1"], df["channel_2"], alpha=0.3, s=10)
ax.set_title(f"Channel 1 vs Channel 2 (Correlation: {corr:.3f})")
ax.set_xlabel("Channel 1 Spend")
ax.set_ylabel("Channel 2 Spend")
ax.set_xlim(0, 1.1)
ax.set_ylim(0, 1.1)

plt.tight_layout()
plt.show()

### Define True Parameters

We define parameters that work well with normalized [0, 1] spend data:
- **Saturation lam**: Values around 5-15 (half-saturation at x ≈ ln(3)/lam ≈ 0.1-0.2)
- **Saturation beta**: Contribution coefficients ~0.3-0.8
- **Adstock alpha**: Carryover effects ~0.3-0.7

Channel 1 and 2 have similar parameters, making them hard to separate without lift tests.

In [None]:
# Define true parameters for data generation
# These will be used with pm.do to fix the model parameters

# Lam values appropriate for normalized [0,1] inputs
# Higher lam = saturates faster (half-saturation at ~ln(3)/lam)
true_lam_c1 = 8.0  # Half-saturation at ~0.14
true_lam_c2 = 6.0  # Half-saturation at ~0.18 (similar to c1)

# Beta values (contribution coefficients)
true_beta_c1 = 0.6
true_beta_c2 = 0.5  # Similar to c1 (hard to separate)

# Adstock alpha (carryover)
true_alpha_c1 = 0.5
true_alpha_c2 = 0.5

# Intercept per geo (base level)
true_intercept = np.array([0.3 + 0.05 * i for i in range(n_geos)])

# Create parameter arrays matching model structure
# Note: lam and beta have dims (geo, channel) in multidimensional MMM
true_lam = np.array([[true_lam_c1, true_lam_c2] for _ in range(n_geos)])
true_beta = np.array([[true_beta_c1, true_beta_c2] for _ in range(n_geos)])

true_params = {
    "adstock_alpha": np.array([true_alpha_c1, true_alpha_c2]),
    "saturation_lam": true_lam,
    "saturation_beta": true_beta,
    "intercept_contribution": true_intercept,
    "y_sigma": np.full(n_geos, 0.15),  # Noise level per geo
}

print("True parameters:")
print(f"  Adstock alpha: {true_params['adstock_alpha']}")
print(f"  Saturation lam (channel 1): {true_lam_c1}, (channel 2): {true_lam_c2}")
print(f"  Saturation beta (channel 1): {true_beta_c1}, (channel 2): {true_beta_c2}")
print(f"  Intercept range: [{true_intercept.min():.2f}, {true_intercept.max():.2f}]")

### Generate Target Variable Using the Model

Following the national-level notebook, we generate the target variable `y` directly from the model using `pm.do` to fix the true parameters. This ensures **perfect consistency** between data generation and model assumptions.

In [None]:
# Initialize placeholder y
df["y"] = np.ones(len(df))

# Prepare data for model
X = df[["date", "geo", "channel_1", "channel_2"]].copy()
y = df["y"]

print(f"Data shapes: X={X.shape}, y={y.shape}")

In [None]:
# Define priors appropriate for normalized [0,1] inputs
adstock_priors = {
    "alpha": Prior("Beta", alpha=2, beta=2, dims="channel"),  # Centered around 0.5
}

saturation_priors = {
    # Lam prior: Gamma with mean ~8, appropriate for normalized inputs
    # Half-saturation at ~ln(3)/8 ≈ 0.14 for mean value
    "lam": Prior("Gamma", alpha=4, beta=0.5, dims=("geo", "channel")),
    # Beta prior: moderate contribution
    "beta": Prior("HalfNormal", sigma=1, dims=("geo", "channel")),
}

print("Priors defined:")
print("  Adstock alpha: Beta(2, 2) - mean=0.5")
print("  Saturation lam: Gamma(4, 0.5) - mean=8")
print("  Saturation beta: HalfNormal(1) - mean≈0.8")

In [None]:
# Build a temporary model to generate data
mmm_temp = MMM(
    date_column="date",
    channel_columns=channels,
    adstock=GeometricAdstock(priors=adstock_priors, l_max=8),
    saturation=LogisticSaturation(priors=saturation_priors),
    dims=("geo",),
)

mmm_temp.build_model(X, y)
print("Temporary model built for data generation")
print(f"Model coords: {list(mmm_temp.model.coords.keys())}")

In [None]:
# Generate y from the model with fixed true parameters
# This ensures perfect consistency between data and model assumptions
fixed_model = pm.do(mmm_temp.model, true_params)
y_drawn = pm.draw(fixed_model["y"], random_seed=rng)

# y_drawn has shape (n_dates, n_geos) with dims ("date", "geo")
# Convert to DataFrame format matching our row order
y_xr = xr.DataArray(
    y_drawn,
    dims=["date", "geo"],
    coords={"date": dates, "geo": geos},
)
y_df = y_xr.to_dataframe(name="y").reset_index()

# Merge back to our DataFrame
df = df.drop(columns=["y"]).merge(y_df, on=["date", "geo"])

# Clean up temporary model
del mmm_temp.model

print("Target variable generated from model:")
print(f"  Shape: {y_drawn.shape} (date, geo)")
print(f"  Mean: {df['y'].mean():.3f}")
print(f"  Std: {df['y'].std():.3f}")
print(f"  Range: [{df['y'].min():.3f}, {df['y'].max():.3f}]")

In [None]:
# Visualize generated data
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: y over time for sample geos
ax = axes[0]
for geo in sample_geos:
    geo_data = df[df["geo"] == geo]
    ax.plot(geo_data["date"], geo_data["y"], label=geo, alpha=0.7)
ax.set_title("Target Variable (y) Over Time")
ax.set_xlabel("Date")
ax.set_ylabel("y")
ax.legend()

# Plot 2: Distribution of y by geo
ax = axes[1]
df.boxplot(column="y", by="geo", ax=ax)
ax.set_title("Distribution of y by Geo")
ax.set_xlabel("Geo")
ax.set_ylabel("y")
plt.suptitle("")

plt.tight_layout()
plt.show()

## Fit MMM Without Calibration

First, let's fit a standard MMM without lift test calibration to establish a baseline.

In [None]:
# Prepare data
X = df[["date", "geo", "channel_1", "channel_2"]].copy()
y = df["y"]

# Initialize MMM (same structure as data generation)
mmm_uncalibrated = MMM(
    date_column="date",
    channel_columns=channels,
    adstock=GeometricAdstock(priors=adstock_priors, l_max=8),
    saturation=LogisticSaturation(priors=saturation_priors),
    dims=("geo",),
)

mmm_uncalibrated.build_model(X, y)
print("Uncalibrated model built")

In [None]:
# Fit the model using nutpie for faster sampling
fit_kwargs = {
    "tune": 1000,
    "draws": 1000,
    "chains": 4,
    "random_seed": rng,
    "nuts_sampler": "nutpie",
}

idata_uncalibrated = mmm_uncalibrated.fit(X, y, **fit_kwargs)
print("\nUncalibrated model fitted")

In [None]:
# Visualize parameter estimates vs true values
posterior = idata_uncalibrated.posterior

# We'll focus on geo_idx=0 since parameters are the same across geos
geo_idx = 0

# Parameter names
param_names = [
    "lam (Ch1)",
    "lam (Ch2)",
    "beta (Ch1)",
    "beta (Ch2)",
    "alpha (Ch1)",
    "alpha (Ch2)",
]

# True values
true_values = [
    true_lam[0, 0],
    true_lam[0, 1],
    true_beta[0, 0],
    true_beta[0, 1],
    true_params["adstock_alpha"][0],
    true_params["adstock_alpha"][1],
]

# Extract posterior samples
lam_c1 = posterior["saturation_lam"][:, :, geo_idx, 0].values.flatten()
lam_c2 = posterior["saturation_lam"][:, :, geo_idx, 1].values.flatten()
beta_c1 = posterior["saturation_beta"][:, :, geo_idx, 0].values.flatten()
beta_c2 = posterior["saturation_beta"][:, :, geo_idx, 1].values.flatten()
alpha_c1 = posterior["adstock_alpha"][:, :, 0].values.flatten()
alpha_c2 = posterior["adstock_alpha"][:, :, 1].values.flatten()

samples = [lam_c1, lam_c2, beta_c1, beta_c2, alpha_c1, alpha_c2]

# Compute means and HDI
est_means = [np.mean(s) for s in samples]
hdi_low = [np.percentile(s, 3) for s in samples]
hdi_high = [np.percentile(s, 97) for s in samples]

# Create figure
fig, ax = plt.subplots(figsize=(12, 6))

x_pos = np.arange(len(param_names))
bar_width = 0.35

# True values as bars
ax.bar(
    x_pos - bar_width / 2,
    true_values,
    bar_width,
    label="True",
    color="red",
    alpha=0.7,
)

# Estimated values as bars with error bars for HDI
errors = [
    [est_means[i] - hdi_low[i] for i in range(len(est_means))],
    [hdi_high[i] - est_means[i] for i in range(len(est_means))],
]
ax.bar(
    x_pos + bar_width / 2,
    est_means,
    bar_width,
    label="Uncalibrated (94% HDI)",
    color="C0",
    alpha=0.7,
    yerr=errors,
    capsize=5,
)

ax.set_xlabel("Parameter")
ax.set_ylabel("Value")
ax.set_title("Uncalibrated Model: Parameter Estimates vs True Values")
ax.set_xticks(x_pos)
ax.set_xticklabels(param_names)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

plt.tight_layout()
plt.show()

## Create Lift Test Measurements

Now we simulate a geo-level lift test on **channel 1** in selected treatment geos. 

**Key principle**: The lift test uses the **same saturation function** that the model uses, ensuring consistency. We calculate:
- `x`: baseline (normalized) spend during test period
- `delta_x`: incremental spend change
- `delta_y`: lift = saturation(x + delta_x) - saturation(x)

In [None]:
# Define the saturation function matching the model
def saturation_function(x, lam, beta):
    """Compute saturation contribution (same as model uses)."""
    return (beta * logistic_saturation(x, lam)).eval()


# Create partial functions for each channel with true parameters
c1_curve_fn = partial(saturation_function, lam=true_lam_c1, beta=true_beta_c1)
c2_curve_fn = partial(saturation_function, lam=true_lam_c2, beta=true_beta_c2)

# Visualize the true saturation curves
xx = np.linspace(0, 1.2, 100)
c1_curve = c1_curve_fn(xx)
c2_curve = c2_curve_fn(xx)

plt.figure(figsize=(10, 6))
plt.plot(
    xx,
    c1_curve,
    label=f"Channel 1 (lam={true_lam_c1}, beta={true_beta_c1})",
    linewidth=2,
)
plt.plot(
    xx,
    c2_curve,
    label=f"Channel 2 (lam={true_lam_c2}, beta={true_beta_c2})",
    linewidth=2,
)
plt.xlabel("Normalized Spend (x)")
plt.ylabel("Contribution")
plt.title("True Saturation Curves (what we want to recover)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Define lift test setup
treated_geos = [geos[i] for i in [0, 2, 4, 6]]  # 4 treated geos
control_geos = [geos[i] for i in [1, 3, 5, 7]]  # 4 control geos
test_channel = "channel_1"

print("Lift Test Setup:")
print(f"  Test channel: {test_channel}")
print(f"  Treated geos: {treated_geos}")
print(f"  Control geos: {control_geos}")

In [None]:
def create_lift_test(geo: str, x: float, delta_x: float, sigma: float) -> dict:
    """
    Create a lift test measurement using the true saturation curve.

    This directly uses the saturation function, ensuring consistency
    with what add_lift_test_measurements() expects.
    """
    delta_y = c1_curve_fn(x + delta_x) - c1_curve_fn(x)

    return {
        "channel": test_channel,
        "geo": geo,
        "x": x,
        "delta_x": delta_x,
        "delta_y": float(delta_y),
        "sigma": sigma,
    }


# Create lift tests at different points on the saturation curve
lift_test_results = []

for geo in treated_geos:
    # Get typical spend level for this geo (from data)
    geo_data = df[df["geo"] == geo]
    x_typical = geo_data[test_channel].mean()

    # Create lift test: increase spend by 0.1 (on 0-1 scale)
    delta_x = 0.1
    sigma = 0.02  # Measurement uncertainty

    lift_test = create_lift_test(geo, x_typical, delta_x, sigma)
    lift_test_results.append(lift_test)

    print(
        f"{geo}: x={x_typical:.3f}, delta_x={delta_x}, delta_y={lift_test['delta_y']:.4f}"
    )

df_lift_test = pd.DataFrame(lift_test_results)
print("\nLift Test DataFrame:")
df_lift_test

In [None]:
# Visualize lift tests on the saturation curve
fig, ax = plt.subplots(figsize=(10, 6))

# Plot true saturation curve
ax.plot(xx, c1_curve, "--", color="C0", linewidth=2, label="True saturation curve")

# Plot lift test triangles
for _, row in df_lift_test.iterrows():
    x = row["x"]
    delta_x = row["delta_x"]
    delta_y = row["delta_y"]
    y_base = c1_curve_fn(x)

    # Draw triangle showing lift
    ax.plot([x, x + delta_x], [y_base, y_base], "k-", alpha=0.5)
    ax.plot([x + delta_x, x + delta_x], [y_base, y_base + delta_y], "k-", alpha=0.5)
    ax.scatter([x], [y_base], color="C0", s=80, zorder=5)
    ax.scatter(
        [x + delta_x],
        [y_base + delta_y],
        color="C2",
        s=80,
        zorder=5,
        marker="^",
        label=f"{row['geo']}" if _ == 0 else "",
    )

ax.set_xlabel("Normalized Spend (x)")
ax.set_ylabel("Contribution")
ax.set_title(f"Lift Tests on {test_channel} Saturation Curve")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Fit MMM With Lift Test Calibration

Now we fit a new MMM and add the lift test measurements to calibrate it.

In [None]:
# Initialize calibrated MMM with same priors
mmm_calibrated = MMM(
    date_column="date",
    channel_columns=channels,
    adstock=GeometricAdstock(priors=adstock_priors, l_max=8),
    saturation=LogisticSaturation(priors=saturation_priors),
    dims=("geo",),
)

mmm_calibrated.build_model(X, y)
print("Calibrated model built")

In [None]:
# Add lift test measurements
mmm_calibrated.add_lift_test_measurements(df_lift_test)
print(f"Added {len(df_lift_test)} lift test measurements")
print(f"Lift tests cover geos: {df_lift_test['geo'].unique().tolist()}")

In [None]:
# Fit the calibrated model
idata_calibrated = mmm_calibrated.fit(X, y, **fit_kwargs)
print("\nCalibrated model fitted")

## Compare Results: Calibrated vs Uncalibrated

Let's compare parameter recovery between the two models.

In [None]:
# Extract posteriors for comparison
posterior_uncal = idata_uncalibrated.posterior
posterior_cal = idata_calibrated.posterior

# Plot posterior comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Focus on a treated geo
focus_geo_idx = 0  # geo_00
focus_geo = geos[focus_geo_idx]

# Saturation lam - Channel 1
ax = axes[0, 0]
samples_uncal = posterior_uncal["saturation_lam"][
    :, :, focus_geo_idx, 0
].values.flatten()
samples_cal = posterior_cal["saturation_lam"][:, :, focus_geo_idx, 0].values.flatten()
ax.hist(samples_uncal, bins=50, alpha=0.5, label="Uncalibrated", color="C0")
ax.hist(samples_cal, bins=50, alpha=0.5, label="Calibrated", color="C2")
ax.axvline(true_lam_c1, color="red", linestyle="--", linewidth=2, label="True value")
ax.set_title(f"Saturation lam - Channel 1 ({focus_geo})")
ax.set_xlabel("lam")
ax.legend()

# Saturation beta - Channel 1
ax = axes[0, 1]
samples_uncal = posterior_uncal["saturation_beta"][
    :, :, focus_geo_idx, 0
].values.flatten()
samples_cal = posterior_cal["saturation_beta"][:, :, focus_geo_idx, 0].values.flatten()
ax.hist(samples_uncal, bins=50, alpha=0.5, label="Uncalibrated", color="C0")
ax.hist(samples_cal, bins=50, alpha=0.5, label="Calibrated", color="C2")
ax.axvline(true_beta_c1, color="red", linestyle="--", linewidth=2, label="True value")
ax.set_title(f"Saturation beta - Channel 1 ({focus_geo})")
ax.set_xlabel("beta")
ax.legend()

# Saturation lam - Channel 2 (NOT tested, should be similar)
ax = axes[1, 0]
samples_uncal = posterior_uncal["saturation_lam"][
    :, :, focus_geo_idx, 1
].values.flatten()
samples_cal = posterior_cal["saturation_lam"][:, :, focus_geo_idx, 1].values.flatten()
ax.hist(samples_uncal, bins=50, alpha=0.5, label="Uncalibrated", color="C0")
ax.hist(samples_cal, bins=50, alpha=0.5, label="Calibrated", color="C2")
ax.axvline(true_lam_c2, color="red", linestyle="--", linewidth=2, label="True value")
ax.set_title(f"Saturation lam - Channel 2 ({focus_geo})")
ax.set_xlabel("lam")
ax.legend()

# Saturation beta - Channel 2
ax = axes[1, 1]
samples_uncal = posterior_uncal["saturation_beta"][
    :, :, focus_geo_idx, 1
].values.flatten()
samples_cal = posterior_cal["saturation_beta"][:, :, focus_geo_idx, 1].values.flatten()
ax.hist(samples_uncal, bins=50, alpha=0.5, label="Uncalibrated", color="C0")
ax.hist(samples_cal, bins=50, alpha=0.5, label="Calibrated", color="C2")
ax.axvline(true_beta_c2, color="red", linestyle="--", linewidth=2, label="True value")
ax.set_title(f"Saturation beta - Channel 2 ({focus_geo})")
ax.set_xlabel("beta")
ax.legend()

plt.suptitle("Posterior Distributions: Uncalibrated vs Calibrated", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Visualize HDI width comparison (narrower = better precision)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

params_to_check = [
    ("saturation_lam", 0, "lam (Ch1)"),
    ("saturation_beta", 0, "beta (Ch1)"),
    ("saturation_lam", 1, "lam (Ch2)"),
    ("saturation_beta", 1, "beta (Ch2)"),
]

hdi_widths_uncal = []
hdi_widths_cal = []
improvements = []
param_labels = []

for param_name, ch_idx, display_name in params_to_check:
    # Average across geos
    samples_uncal = posterior_uncal[param_name][:, :, :, ch_idx].values.flatten()
    samples_cal = posterior_cal[param_name][:, :, :, ch_idx].values.flatten()

    hdi_uncal = az.hdi(samples_uncal, hdi_prob=0.94)
    hdi_cal = az.hdi(samples_cal, hdi_prob=0.94)

    width_uncal = hdi_uncal[1] - hdi_uncal[0]
    width_cal = hdi_cal[1] - hdi_cal[0]
    improvement = (width_uncal - width_cal) / width_uncal * 100

    hdi_widths_uncal.append(width_uncal)
    hdi_widths_cal.append(width_cal)
    improvements.append(improvement)
    param_labels.append(display_name)

# Plot HDI widths
x_pos = np.arange(len(param_labels))
bar_width = 0.35

ax = axes[0]
ax.bar(
    x_pos - bar_width / 2,
    hdi_widths_uncal,
    bar_width,
    label="Uncalibrated",
    color="C0",
    alpha=0.8,
)
ax.bar(
    x_pos + bar_width / 2,
    hdi_widths_cal,
    bar_width,
    label="Calibrated",
    color="C2",
    alpha=0.8,
)
ax.set_xlabel("Parameter")
ax.set_ylabel("94% HDI Width")
ax.set_title("Posterior Uncertainty (Narrower = More Precise)")
ax.set_xticks(x_pos)
ax.set_xticklabels(param_labels)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

# Plot improvement percentage
ax = axes[1]
colors = ["C2" if imp > 0 else "C3" for imp in improvements]
ax.bar(x_pos, improvements, color=colors, alpha=0.8)
ax.axhline(y=0, color="black", linestyle="-", linewidth=0.5)
ax.set_xlabel("Parameter")
ax.set_ylabel("HDI Width Reduction (%)")
ax.set_title("Precision Improvement from Calibration")
ax.set_xticks(x_pos)
ax.set_xticklabels(param_labels)
ax.grid(True, alpha=0.3, axis="y")

# Add value labels on bars
for i, imp in enumerate(improvements):
    va = "bottom" if imp >= 0 else "top"
    ax.text(i, imp, f"{imp:.1f}%", ha="center", va=va, fontsize=10)

plt.tight_layout()
plt.show()

### Saturation Curve Recovery

A key benefit of lift test calibration is better recovery of the saturation curves. Let's compare the true curves with the inferred curves from both models.

In [None]:
# Plot true vs inferred saturation curves with HDI bands
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# x values for plotting curves
x_plot = np.linspace(0, 1.2, 100)

# Get posterior samples for saturation parameters (focus on first geo)
geo_idx = 0


def compute_curve_hdi(lam_samples, beta_samples, x_vals, hdi_prob=0.94):
    """Compute HDI bands for saturation curves."""
    n_samples = len(lam_samples)
    curves = np.zeros((n_samples, len(x_vals)))
    for i in range(n_samples):
        curves[i, :] = (
            beta_samples[i]
            * (1 - np.exp(-lam_samples[i] * x_vals))
            / (1 + np.exp(-lam_samples[i] * x_vals))
        )
    # Compute mean and HDI
    mean_curve = np.mean(curves, axis=0)
    lower = np.percentile(curves, (1 - hdi_prob) / 2 * 100, axis=0)
    upper = np.percentile(curves, (1 + hdi_prob) / 2 * 100, axis=0)
    return mean_curve, lower, upper


for col, (ch_idx, ch_name) in enumerate([(0, "channel_1"), (1, "channel_2")]):
    # True parameters
    true_lam = true_params["saturation_lam"][geo_idx, ch_idx]
    true_beta = true_params["saturation_beta"][geo_idx, ch_idx]
    true_curve = (
        true_beta * (1 - np.exp(-true_lam * x_plot)) / (1 + np.exp(-true_lam * x_plot))
    )

    # Uncalibrated model
    ax = axes[0, col]
    lam_samples = posterior_uncal["saturation_lam"][
        :, :, geo_idx, ch_idx
    ].values.flatten()
    beta_samples = posterior_uncal["saturation_beta"][
        :, :, geo_idx, ch_idx
    ].values.flatten()

    # Compute HDI bands
    mean_uncal, lower_uncal, upper_uncal = compute_curve_hdi(
        lam_samples, beta_samples, x_plot
    )

    ax.fill_between(
        x_plot, lower_uncal, upper_uncal, alpha=0.3, color="C0", label="94% HDI"
    )
    ax.plot(x_plot, mean_uncal, color="C0", linewidth=2, label="Posterior mean")
    ax.plot(x_plot, true_curve, "r--", linewidth=2, label="True curve")
    ax.set_title(f"Uncalibrated - {ch_name}")
    ax.set_xlabel("Normalized Spend")
    ax.set_ylabel("Contribution")
    ax.legend()
    ax.set_ylim(0, 1.0)

    # Calibrated model
    ax = axes[1, col]
    lam_samples = posterior_cal["saturation_lam"][
        :, :, geo_idx, ch_idx
    ].values.flatten()
    beta_samples = posterior_cal["saturation_beta"][
        :, :, geo_idx, ch_idx
    ].values.flatten()

    # Compute HDI bands
    mean_cal, lower_cal, upper_cal = compute_curve_hdi(
        lam_samples, beta_samples, x_plot
    )

    ax.fill_between(
        x_plot, lower_cal, upper_cal, alpha=0.3, color="C2", label="94% HDI"
    )
    ax.plot(x_plot, mean_cal, color="C2", linewidth=2, label="Posterior mean")
    ax.plot(x_plot, true_curve, "r--", linewidth=2, label="True curve")
    ax.set_title(f"Calibrated - {ch_name}")
    ax.set_xlabel("Normalized Spend")
    ax.set_ylabel("Contribution")
    ax.legend()
    ax.set_ylim(0, 1.0)

plt.suptitle(
    f"Saturation Curve Recovery ({geos[geo_idx]})\nShaded: 94% HDI, Dashed: True curve",
    fontsize=12,
)
plt.tight_layout()
plt.show()

### Saturation Parameter Comparison

To see the effect of calibration more clearly, let's compare the true, uncalibrated, and calibrated estimates for the saturation parameters side by side.

In [None]:
# Bar chart comparing true, uncalibrated, and calibrated saturation parameters
# Split into two panels: lambda (left) and beta (right) since they have different scales
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Focus on geo_idx = 0 for clarity
geo_idx = 0
bar_width = 0.25
channel_labels = ["Channel 1", "Channel 2"]
x_pos = np.arange(len(channel_labels))


def get_mean_and_hdi(samples, hdi_prob=0.94):
    """Extract mean and HDI bounds from posterior samples."""
    flat = samples.values.flatten()
    mean = np.mean(flat)
    hdi = az.hdi(flat, hdi_prob=hdi_prob)
    return mean, hdi[0], hdi[1]


# --- Lambda parameters (left panel) ---
ax = axes[0]

# True values
true_lam = [
    true_params["saturation_lam"][geo_idx, 0],
    true_params["saturation_lam"][geo_idx, 1],
]

# Uncalibrated: mean and HDI
uncal_lam_stats = [
    get_mean_and_hdi(posterior_uncal["saturation_lam"][:, :, geo_idx, i])
    for i in range(2)
]
uncal_lam_means = [s[0] for s in uncal_lam_stats]
uncal_lam_err = [
    [s[0] - s[1] for s in uncal_lam_stats],  # lower error
    [s[2] - s[0] for s in uncal_lam_stats],  # upper error
]

# Calibrated: mean and HDI
cal_lam_stats = [
    get_mean_and_hdi(posterior_cal["saturation_lam"][:, :, geo_idx, i])
    for i in range(2)
]
cal_lam_means = [s[0] for s in cal_lam_stats]
cal_lam_err = [
    [s[0] - s[1] for s in cal_lam_stats],
    [s[2] - s[0] for s in cal_lam_stats],
]

# Plot bars with error bars
ax.bar(x_pos - bar_width, true_lam, bar_width, label="True", color="red", alpha=0.8)
ax.bar(
    x_pos,
    uncal_lam_means,
    bar_width,
    label="Uncalibrated",
    color="C0",
    alpha=0.8,
    yerr=uncal_lam_err,
    capsize=5,
)
ax.bar(
    x_pos + bar_width,
    cal_lam_means,
    bar_width,
    label="Calibrated",
    color="C2",
    alpha=0.8,
    yerr=cal_lam_err,
    capsize=5,
)
ax.set_xlabel("Channel")
ax.set_ylabel("Lambda (λ)")
ax.set_title(f"Saturation Lambda Estimates ({geos[geo_idx]})")
ax.set_xticks(x_pos)
ax.set_xticklabels(channel_labels)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

# --- Beta parameters (right panel) ---
ax = axes[1]

# True values
true_beta = [
    true_params["saturation_beta"][geo_idx, 0],
    true_params["saturation_beta"][geo_idx, 1],
]

# Uncalibrated: mean and HDI
uncal_beta_stats = [
    get_mean_and_hdi(posterior_uncal["saturation_beta"][:, :, geo_idx, i])
    for i in range(2)
]
uncal_beta_means = [s[0] for s in uncal_beta_stats]
uncal_beta_err = [
    [s[0] - s[1] for s in uncal_beta_stats],
    [s[2] - s[0] for s in uncal_beta_stats],
]

# Calibrated: mean and HDI
cal_beta_stats = [
    get_mean_and_hdi(posterior_cal["saturation_beta"][:, :, geo_idx, i])
    for i in range(2)
]
cal_beta_means = [s[0] for s in cal_beta_stats]
cal_beta_err = [
    [s[0] - s[1] for s in cal_beta_stats],
    [s[2] - s[0] for s in cal_beta_stats],
]

# Plot bars with error bars
ax.bar(x_pos - bar_width, true_beta, bar_width, label="True", color="red", alpha=0.8)
ax.bar(
    x_pos,
    uncal_beta_means,
    bar_width,
    label="Uncalibrated",
    color="C0",
    alpha=0.8,
    yerr=uncal_beta_err,
    capsize=5,
)
ax.bar(
    x_pos + bar_width,
    cal_beta_means,
    bar_width,
    label="Calibrated",
    color="C2",
    alpha=0.8,
    yerr=cal_beta_err,
    capsize=5,
)
ax.set_xlabel("Channel")
ax.set_ylabel("Beta (β)")
ax.set_title(f"Saturation Beta Estimates ({geos[geo_idx]})")
ax.set_xticks(x_pos)
ax.set_xticklabels(channel_labels)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

plt.suptitle("Saturation Parameter Recovery: True vs Estimated (94% HDI)", fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
# Generate posterior predictive and compare
mmm_uncalibrated.sample_posterior_predictive(X, extend_idata=True, random_seed=rng)
mmm_calibrated.sample_posterior_predictive(X, extend_idata=True, random_seed=rng)

print("Posterior predictive samples generated")

In [None]:
# Plot predicted vs actual for a sample geo
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Get predictions (in scaled space) and scale back to original space
# The model fits to scaled y, so predictions need to be multiplied by target_scale
target_scale_uncal = mmm_uncalibrated.get_scales_as_xarray()["target_scale"]
target_scale_cal = mmm_calibrated.get_scales_as_xarray()["target_scale"]

y_pred_uncal = (
    mmm_uncalibrated.idata.posterior_predictive["y"].mean(dim=["chain", "draw"])
    * target_scale_uncal
)
y_pred_cal = (
    mmm_calibrated.idata.posterior_predictive["y"].mean(dim=["chain", "draw"])
    * target_scale_cal
)

# Plot for focus geo
geo_mask = df["geo"] == focus_geo
geo_dates = df.loc[geo_mask, "date"]
y_actual = df.loc[geo_mask, "y"].values

# Uncalibrated
ax = axes[0]
y_pred_geo = y_pred_uncal.sel(geo=focus_geo).values
ax.plot(geo_dates, y_actual, "k-", label="Actual", linewidth=2)
ax.plot(geo_dates, y_pred_geo, "--", color="C0", label="Predicted", linewidth=2)
ax.set_title(f"Uncalibrated Model - {focus_geo}")
ax.set_xlabel("Date")
ax.set_ylabel("y")
ax.legend()

# Calibrated
ax = axes[1]
y_pred_geo = y_pred_cal.sel(geo=focus_geo).values
ax.plot(geo_dates, y_actual, "k-", label="Actual", linewidth=2)
ax.plot(geo_dates, y_pred_geo, "--", color="C2", label="Predicted", linewidth=2)
ax.set_title(f"Calibrated Model - {focus_geo}")
ax.set_xlabel("Date")
ax.set_ylabel("y")
ax.legend()

plt.suptitle("Predicted vs Actual (Original Scale)")
plt.tight_layout()
plt.show()

## Conclusion

This notebook demonstrated how to calibrate a multidimensional MMM using geo-level lift tests:

1. **Key Design Principles**:
   - Normalize channel data to [0, 1] range for consistent saturation behavior
   - Generate data from the model itself using `pm.do`/`pm.draw` to ensure perfect consistency
   - Calculate lift tests using the same saturation function the model uses

2. **The Problem**: Without lift tests, highly correlated channels are hard to separate

3. **The Solution**: Lift test measurements constrain the saturation curve parameters

4. **Results**: Calibrated models show:
   - Narrower posterior distributions (higher precision)
   - Parameter estimates closer to true values (better accuracy)
   - Better separation between channel effects

### Practical Application

In practice:
1. **Conduct geo-level experiments** using synthetic control methods
2. **Analyze with CausalPy** to get lift estimates (`delta_y`, `sigma`)
3. **Format as DataFrame** with columns: `[channel, geo, x, delta_x, delta_y, sigma]`
4. **Add to MMM**: `mmm.add_lift_test_measurements(df_lift_test)`

### References

- [CausalPy Multi-Cell GeoLift](https://causalpy.readthedocs.io/en/latest/notebooks/multi_cell_geolift.html)
- [PyMC-Marketing National-Level Lift Tests](mmm_lift_test.ipynb)
- [PyMC-Marketing Multidimensional MMM](mmm_multidimensional_example.ipynb)