# Interactive MMM Visualizations with `plot_interactive`

This notebook demonstrates the **interactive plotting capabilities** of the
`plot_interactive` module in `pymc-marketing`. These Plotly-based visualizations
are designed for exploring MMM results interactively — hovering over data points,
zooming into time ranges, and faceting across custom dimensions like markets.

We use the **multidimensional MMM** example data (two geos: `geo_a` and `geo_b`,
two channels: `x1` and `x2`) and fit the model using the same setup as the
[MMM Multidimensional Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_multidimensional_example.html).

**What this notebook covers:**

- Posterior predictive checks (actual vs predicted)
- ROAS analysis at different time granularities
- Channel contributions over time
- Saturation curves (diminishing returns)
- Adstock curves (carryover effects)
- Auto-faceting across custom dimensions
- Advanced: Filtering and aggregating data with `MMMSummaryFactory` and `MMMPlotlyFactory`

In [None]:
import warnings

import numpy as np
import pandas as pd
import plotly.io as pio
from pymc_extras.prior import Prior

from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
from pymc_marketing.paths import data_dir
from pymc_marketing.special_priors import LaplacePrior, LogNormalPrior

warnings.filterwarnings("ignore", category=UserWarning)

pio.renderers.default = "notebook_connected"

seed: int = sum(map(ord, "plot_interactive"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

## Load Data & Fit Model

We load the same simulated multidimensional dataset used in the
[MMM Multidimensional Example](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_multidimensional_example.html).
The data has two geographies (`geo_a`, `geo_b`) and two channels (`x1`, `x2`).

For model fitting details (prior specification, adstock/saturation choices,
pooling strategies), please refer to the original notebook. Here we set up
and fit the model quickly so we can focus on **interactive visualizations**.

In [None]:
data_path = data_dir / "mmm_multidimensional_example.csv"
data_df = pd.read_csv(data_path, parse_dates=["date"])
data_df.head()

In [None]:
# --- Prior Specification ---
# Hierarchical beta (partially pooled across geos)
beta_prior = LogNormalPrior(
    mean=Prior("Gamma", mu=0.25, sigma=0.10, dims=("channel")),
    std=Prior("Exponential", scale=0.10, dims=("channel")),
    dims=("channel", "geo"),
    centered=False,
)

# Saturation: beta is hierarchical, lambda is fully pooled
saturation = LogisticSaturation(
    priors={
        "beta": beta_prior,
        "lam": Prior("Gamma", mu=0.5, sigma=0.25, dims=("channel")),
    }
)

# Adstock: unpooled (each geo x channel has its own alpha)
adstock = GeometricAdstock(
    priors={"alpha": Prior("Beta", alpha=2, beta=5, dims=("geo", "channel"))},
    l_max=8,
)

# Model config
model_config = {
    "intercept": Prior("Gamma", mu=0.5, sigma=0.25, dims="geo"),
    "gamma_control": Prior("Normal", mu=0, sigma=0.5, dims="control"),
    "gamma_fourier": LaplacePrior(
        mu=0,
        b=Prior("HalfNormal", sigma=0.2),
        dims=("geo", "fourier_mode"),
        centered=False,
    ),
    "likelihood": Prior(
        "TruncatedNormal",
        lower=0,
        sigma=Prior("HalfNormal", sigma=1.5),
        dims=("date", "geo"),
    ),
}

# --- Model Definition ---
mmm = MMM(
    date_column="date",
    target_column="y",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    dims=("geo",),
    scaling={
        "channel": {"method": "max", "dims": ()},
        "target": {"method": "max", "dims": ()},
    },
    adstock=adstock,
    saturation=saturation,
    yearly_seasonality=2,
    model_config=model_config,
)

# --- Fit ---
x_train = data_df.drop(columns=["y"])
y_train = data_df["y"]

mmm.fit(
    X=x_train,
    y=y_train,
    chains=4,
    target_accept=0.95,
    random_seed=rng,
)

# --- Add original scale deterministic variables ---
mmm.build_model(X=x_train, y=y_train)
mmm.add_original_scale_contribution_variable(
    var=[
        "channel_contribution",
        "control_contribution",
        "intercept_contribution",
        "yearly_seasonality_contribution",
        "y",
    ]
)

# --- Posterior predictive ---
mmm.sample_posterior_predictive(X=x_train, random_seed=rng)

---

## 1. Posterior Predictive: How Well Does the Model Fit?

The posterior predictive plot shows the model's predictions (with uncertainty)
against the observed data. This is the first thing to check after fitting —
do the predictions track the actual sales?

The interactive plot lets you **hover** over any point to see exact values,
and **zoom** into specific time periods.

In [None]:
mmm.plot_interactive.posterior_predictive()

---

## 2. ROAS Analysis: Which Channels Give the Best Return?

ROAS (Return on Ad Spend) is one of the most important metrics for marketers.
The `plot_interactive` module makes it easy to slice and dice ROAS across
different time granularities and dimensions.

### Q: How did the ROAS of each channel change year after year?

In [None]:
mmm.plot_interactive.roas(
    frequency="yearly",
    color="date",
    x="channel",
)

### Q: Within each year, which channel performed better in each geo?

By swapping `x` and `color`, we get a different perspective — now the x-axis
shows time and the color differentiates channels.

In [None]:
mmm.plot_interactive.roas(
    frequency="yearly",
    color="channel",
    x="date",
)

### Q: Looking over all the data, which channel performed better in each geo?

Using `frequency="all_time"` aggregates everything into a single time period,
giving us the overall ROAS per channel per geo.

In [None]:
mmm.plot_interactive.roas(frequency="all_time")

---

## 3. Channel Contributions: What Drives Sales?

The contributions plot shows how much each channel contributes to total sales
over time. It supports the same `frequency`, `color`, and `x` parameters as
`roas()` above, so you can slice contributions in exactly the same way —
by time granularity, channel, or geography.

### Q: What is the overall contribution of each channel across all geos?

Error bars show the 94% HDI (Highest Density Interval).

In [None]:
mmm.plot_interactive.contributions(frequency="all_time")

### Q: How did channel contributions change year by year?

Just like the ROAS examples, setting `frequency="yearly"`, `color="channel"`,
and `x="date"` gives us a time-series view colored by channel.

In [None]:
mmm.plot_interactive.contributions(
    frequency="yearly",
    color="channel",
    x="date",
)

### Q: How did the contributions of control variables change year by year?

On top of channel contributions, we can also plot the contributions of control, seasonality, and baseline variables.

In [None]:
mmm.plot_interactive.contributions(
    component="control",
    frequency="yearly",
    color="control",
    x="date",
    hdi_prob=None,
)

---

## 4. Saturation Curves: Where Are Diminishing Returns?

Saturation curves show how the response (contribution) changes as spend
increases. These are essential for understanding where additional spend
will have diminishing returns.

### Q: Show me the saturation curves in original scale

By default, the x-axis is in original scale (e.g., dollars of spend).
Each channel gets its own line, faceted by geo.

In [None]:
mmm.plot_interactive.saturation_curves()

### Q: Now show me with uncertainty (HDI bands)

Adding `hdi_prob=0.9` draws shaded bands showing the 90% HDI around
each curve — capturing posterior uncertainty in the saturation parameters.

In [None]:
mmm.plot_interactive.saturation_curves(hdi_prob=0.9)

---

## 5. Adstock Curves: How Long Do Effects Last?

Adstock (carryover) curves show how the effect of a marketing impulse
decays over time. A slow decay means the channel has long-lasting effects;
a fast decay means the effect is short-lived.

### Q: How do the decay curves look?

In [None]:
mmm.plot_interactive.adstock_curves(hdi_prob=None)

### Q: Show me adstock curves with uncertainty

Adding HDI bands helps us understand how confident we are about the
carryover duration for each channel.

In [None]:
mmm.plot_interactive.adstock_curves(hdi_prob=0.9)

---

## 6. Auto-faceting Across Custom Dimensions

When your model includes **custom dimensions** (like `geo` in our example),
`plot_interactive` **automatically creates subplots** (facets) for each
dimension value. You've already seen this in all the plots above — each
subplot corresponds to a different geography (`geo_a` and `geo_b`).

This behavior is controlled by the `auto_facet` parameter, which is
**enabled by default**. You can also control the faceting layout with
`facet_col` and `facet_row`:

- `facet_col` — creates side-by-side columns (one per dimension value)
- `facet_row` — stacks subplots vertically (one per dimension value)
- `single_dim_facet` — controls the default direction (`"col"` or `"row"`)

### Q: Compare saturation curves across geos stacked vertically

Using `facet_row="geo"` overrides the default column layout and stacks
the saturation curves vertically instead.

In [None]:
mmm.plot_interactive.saturation_curves(
    facet_row="geo",
)

### Q: What if I want to have all lines in the same plot?

Setting `auto_facet=False` disables the automatic subplot creation.
For line-based plots (like saturation and adstock curves), the custom
dimension is then shown using **line dash styles** instead of separate
subplots — all curves appear on a single plot, differentiated by dashing.

In [None]:
mmm.plot_interactive.saturation_curves(
    auto_facet=False,
)

### Q: Show adstock curves with geos as columns instead of the default

By passing `single_dim_facet="col"`, the single custom dimension (`geo`)
is faceted as columns rather than the default rows for this plot type.

In [None]:
mmm.plot_interactive.adstock_curves(
    single_dim_facet="col",
)

---

## 7. Advanced: Filtering & Aggregating with the Data Layer

> **This is an advanced section.** The examples above cover the most common
> use cases through `mmm.plot_interactive`. This section shows how to work
> directly with the underlying components for custom data slicing.

Under the hood, `mmm.plot_interactive` is powered by two key classes:

- **`MMMSummaryFactory`** — Takes a data wrapper (from `mmm.data`) and the
  fitted model, and computes summary statistics like contributions, ROAS,
  and posterior predictive values. It handles HDI computation, time
  aggregation, and proper scaling. You can think of it as the "data engine"
  that turns raw InferenceData into plottable DataFrames.

- **`MMMPlotlyFactory`** — Takes an `MMMSummaryFactory` and provides all
  the interactive plotting methods (`posterior_predictive()`, `roas()`,
  `contributions()`, `saturation_curves()`, `adstock_curves()`). It reads
  from the summary factory and creates Plotly figures.

When you call `mmm.plot_interactive`, it automatically creates both of
these using your full dataset. But you can also create them **manually**
with filtered or aggregated data — this is the key to custom views.

The workflow is:

1. **Transform the data** using `mmm.data.filter_dims()`, `mmm.data.filter_dates()`,
   or `mmm.data.aggregate_dims()`
2. **Create a new `MMMSummaryFactory`** with the transformed data
3. **Create a new `MMMPlotlyFactory`** with that summary
4. **Plot** using the factory's methods

### Q: Aggregating over geos, how did the ROAS of each channel change year after year?

Here we aggregate both geos into a single "all_geos" label, then plot ROAS.
This collapses the geo dimension so we get a single aggregated view.

In [None]:
from pymc_marketing.mmm.plot_interactive import MMMPlotlyFactory
from pymc_marketing.mmm.summary import MMMSummaryFactory

# Aggregate geo_a and geo_b into "all_geos"
agg_data = mmm.data.aggregate_dims(
    dim="geo", values=["geo_a", "geo_b"], new_label="all_geos"
)
agg_summary = MMMSummaryFactory(agg_data, mmm)
agg_factory = MMMPlotlyFactory(summary=agg_summary)

agg_factory.roas(
    frequency="yearly",
    color="channel",
    x="date",
)

### Q: Filtering to only one geo, what was the yearly ROAS?

You can filter the data to a single geography and create a dedicated factory.

In [None]:
filtered_data_geo_a = mmm.data.filter_dims(geo="geo_a")
filtered_summary_geo_a = MMMSummaryFactory(
    filtered_data_geo_a, mmm, validate_data=False
)
filtered_factory_geo_a = MMMPlotlyFactory(summary=filtered_summary_geo_a)

filtered_factory_geo_a.roas(
    frequency="yearly",
    color="channel",
    x="date",
    title="ROAS for geo_a (yearly)",
)

### Q: How did the ROAS change quarter after quarter starting 2024?

Filter by date range and then view quarterly ROAS.

In [None]:
filtered_data_2024 = mmm.data.filter_dates(start_date="2024-01-01")
filtered_summary_2024 = MMMSummaryFactory(filtered_data_2024, mmm)
filtered_factory_2024 = MMMPlotlyFactory(summary=filtered_summary_2024)

filtered_factory_2024.roas(
    frequency="quarterly",
    color="channel",
    x="date",
    hdi_prob=None,
    title="ROAS from 2024 onwards (quarterly)",
)

---

## Summary

The `plot_interactive` module provides a rich set of interactive visualizations for
exploring MMM results:

| Method | What It Shows | Key Parameters |
|--------|--------------|----------------|
| `posterior_predictive()` | Actual vs predicted with HDI band | `hdi_prob`, `frequency` |
| `contributions()` | Channel/control/seasonality contributions | `component`, `frequency`, `color`, `x` |
| `roas()` | Return on Ad Spend | `frequency`, `color`, `x` |
| `saturation_curves()` | Diminishing returns curves | `hdi_prob`, `max_value`, `original_scale` |
| `adstock_curves()` | Carryover effect curves | `hdi_prob`, `amount` |

**Key features:**
- **Auto-faceting**: Custom dimensions (e.g., geo) automatically create subplots
- **Facet control**: Use `facet_col`, `facet_row`, and `auto_facet` to customize layout
- **Filtering**: Use `mmm.data.filter_dims()` or `filter_dates()` to focus on subsets
- **Aggregating**: Use `mmm.data.aggregate_dims()` to combine dimensions
- **Error bars**: Control with `hdi_prob` (set to `None` to remove)
- **Customizable**: All Plotly Express kwargs (title, height, width, colors, etc.) are supported