# Generate Synthetic Pricing Data

This notebook generates synthetic panel data for causal forecasting with pricing.

## Features

- Time-varying base demand (trend + seasonality)
- Constant item-level price elasticity
- Confounded pricing policy: discounts depend on latent demand
- Saves data to Databricks table for downstream analysis

In [None]:
import numpy as np
import pandas as pd

In [None]:
# Make sure that the catalog and the schema exist

catalog = "causal_forecasting"      # Change here
schema = "default"                  # Change here

_ = spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}") 
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}") 

## Synthetic Dataset Generator

The function below generates panel data with:
- Item-level characteristics (base_price, category, season_type)
- Time-varying demand with trend and seasonality
- Confounded treatment: discounts are higher when base demand is low
- Constant price elasticity per item

In [None]:
import numpy as np
import pandas as pd


def generate_synthetic_pricing_data(
    n_items: int = 200,
    n_weeks: int = 60,
    seed: int = 0,
) -> pd.DataFrame:
    """
    Synthetic panel data {i, t} inspired by the simulation in
    'Causal Forecasting for Pricing' (Schultz et al., 2024).

    Demand model:
        base_demand_{i,t} from Eq. (14):
            q_b(i,t) = (0.15 * tau_{i,t} + 0.25 * s_{i,t} + 1) * c_{i,t}
        observed demand (slightly modified):
            q_{i,t} ~ Poisson( q_b(i,t) * (p_{i,t} / p0_i)^{elasticity_i} )

    Where:
      - tau_{i,t} ~ N(t * gamma_i, sigma_tau_i^2)
      - s_{i,t} is a sine with period 30 and article-specific phase (season type)
      - c_{i,t} = 0.05 * a_{i,t}^2 + 0.25 * a_{i,t} + 0.5 * b_{i,t}
      - a_{i,t} = alpha_{d(i)} + eps_{i,t}, alpha_d ~ N(10, 3^2), eps_{i,t} ~ N(0,1)
      - b_{i,t} = beta_{k(i)} + psi_{i,t}, beta_k ~ N(300, 50^2), psi_{i,t} ~ N(0,5^2)

    Pricing policy:
      - discrete discounts in {0, 0.1, ..., 0.5}
      - higher discounts when current base_demand is low (confounding).

    Columns returned:
      item_id, week, demand, discount, base_price,
      category (d-category 0..44), 
      k_category (k-category, 0..14),
      season_type (0..5),
      price, lag_demand, lag_discount,
      week_sin, week_cos,
      elasticity_true, base_demand
    """
    rng = np.random.default_rng(seed)
    rows = []

    # ----- Synthetic "meta" parameters as in Appendix E -----
    # a_it categories: d(i) in {0,...,44}, alpha_d ~ N(10, 3^2)
    n_cat_a = 45
    alpha_d = rng.normal(loc=10.0, scale=3.0, size=n_cat_a)

    # b_it categories: k(i) in {0,...,14}, beta_k ~ N(300, 50^2)
    n_cat_b = 15
    beta_k = rng.normal(loc=300.0, scale=50.0, size=n_cat_b)

    # Seasonality groups (6 "season types" tied to k(i))
    n_season_types = 6
    # Assign each k to one of 6 groups (as in "subdivide k evenly into six subgroups")
    season_type_for_k = np.repeat(np.arange(n_season_types),
                                  repeats=int(np.ceil(n_cat_b / n_season_types)))[:n_cat_b]
    # Season shift (integer weeks) per season type, uniform over [-15, 15]
    season_shift_by_type = rng.integers(-15, 16, size=n_season_types)

    # Allowed discrete discount values d^(j) = 0.1 * j, j in {0,...,5}
    allowed_discounts = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])

    for i in range(n_items):
        item_id = i

        # Map article to synthetic categories d(i) and k(i)
        d_cat = rng.integers(0, n_cat_a)  # for a_it
        k_cat = rng.integers(0, n_cat_b)  # for b_it

        # Season type determined by k(i) subgroup
        season_type = season_type_for_k[k_cat]
        season_shift = season_shift_by_type[season_type]

        # Base (undiscounted) price p0 (we keep it simple, not exactly Eq. (21))
        base_price = rng.uniform(20.0, 100.0)

        # Trend parameters: gamma_i ~ U([-0.02, 0.02]), sigma_tau_i ~ U([0, 0.15])
        gamma_i = rng.uniform(-0.02, 0.02)
        sigma_tau_i = rng.uniform(0.0, 0.15)

        # Constant per-item elasticity (not in the original synthetic data generation, 
        # but aligned with the elasticity framework in Appendix B)
        elasticity = rng.uniform(-3.0, -0.5)

        # For lag features
        last_q = 0.0
        last_d = 0.0

        # Baseline base demand (for normalizing in the pricing policy)
        base_demand_baseline = None

        for t in range(n_weeks):
            # ---------- Article-specific component c_{i,t} (Eq. (15)-(19)) ----------
            # a_it = alpha_{d(i)} + eps_it, eps_it ~ N(0,1)
            eps_it = rng.normal(0.0, 1.0)
            a_it = alpha_d[d_cat] + eps_it

            # b_it = beta_{k(i)} + psi_it, psi_it ~ N(0, 5^2)
            psi_it = rng.normal(0.0, 5.0)
            b_it = beta_k[k_cat] + psi_it

            # c_it = 0.05 * a_it^2 + 0.25 * a_it + 0.5 * b_it (Eq. (15))
            c_it = 0.05 * (a_it ** 2) + 0.25 * a_it + 0.5 * b_it

            # ---------- Trend tau_{i,t} (Eq. (22)-(24)) ----------
            # tau_it ~ N(t * gamma_i, sigma_tau_i^2)
            tau_it = rng.normal(loc=t * gamma_i, scale=sigma_tau_i)

            # ---------- Seasonality s_{i,t} ----------
            # Sine with period 30 and article-dependent shift (as in Appendix E)
            s_it = np.sin(2.0 * np.pi * (t + season_shift) / 30.0)

            # ---------- Base demand q^{(b)}_{i,t} (Eq. (14)) ----------
            # q_b_it = (0.15 * tau_it + 0.25 * s_it + 1) * c_it
            base_demand = (0.15 * tau_it + 0.25 * s_it + 1.0) * c_it

            # Store baseline for pricing policy normalization (t = 0)
            if base_demand_baseline is None:
                base_demand_baseline = base_demand

            # ---------- Pricing policy (confounding, only discount when demand is far below baseline) ----------
            # norm_bd measures how far current base_demand is from the initial base_demand_baseline
            norm_bd = (base_demand - base_demand_baseline) / (
                0.5 * base_demand_baseline + 1e-6
            )

            # Define "far below baseline" as below -1.0 in normalized units
            #   norm_bd ≈ 0   → demand similar to baseline
            #   norm_bd > 0   → demand above baseline
            #   norm_bd < 0   → demand below baseline
            threshold = -0.5  # roughly: current base_demand < ~50% of baseline

            if norm_bd >= threshold:
                # Demand is not far below baseline → no discount
                disc_mean_raw = 0.0
            else:
                # Demand is far below baseline → introduce a positive discount
                # shortfall is how much further below the threshold we are
                shortfall = -(norm_bd - threshold)  # = -norm_bd - 0.5  when threshold = -0.5

                # Scale shortfall
                disc_mean_raw = np.clip(2.0 * shortfall, 0.0, 0.5)

            # Add small noise so we don't get perfectly deterministic discounts
            disc_mean = float(np.clip(disc_mean_raw, 0.0, 0.5))
            discount_continuous = float(
                np.clip(disc_mean + rng.normal(0.0, 0.03), 0.0, 0.5)
            )

            # Discretize to {0, 0.1, 0.2, 0.3, 0.4, 0.5}, as in the paper
            discount = float(
                allowed_discounts[
                    np.argmin(np.abs(allowed_discounts - discount_continuous))
                ]
            )

            # Price under this discount: p_t = p0 * (1 - d_t)
            price = base_price * (1.0 - discount)

            # ---------- Constant elasticity demand on top of q^{(b)} ----------
            # expected_q = q_b_it * (p_t / p0)^{elasticity_i}
            expected_q = base_demand * ((price / base_price) ** elasticity)
            expected_q = max(expected_q, 1e-3)

            # Poisson noise for integer demand
            demand = float(rng.poisson(expected_q))

            rows.append(
                {
                    "item_id": item_id,
                    "week": t,
                    "demand": demand,
                    "discount": discount,
                    "base_price": base_price,
                    # For interpretability / alignment with paper:
                    "category": int(d_cat),        # corresponds to d(i) in Eq. (16)-(17)
                    "k_category": int(k_cat),        # k(i)
                    "season_type": int(season_type),  # group that defines season shift
                    "elasticity_true": elasticity,  # our per-item epsilon (not in original DGP)
                    "base_demand": base_demand,    # q^{(b)}_{i,t}
                    "price": price,
                    "lag_demand": last_q,
                    "lag_discount": last_d,
                }
            )

            last_q = demand
            last_d = discount

    df = pd.DataFrame(rows)

    # Add simple periodic time features (proxy for positional encodings)
    df["week_sin"] = np.sin(2 * np.pi * df["week"] / 30.0)
    df["week_cos"] = np.cos(2 * np.pi * df["week"] / 30.0)

    return df


## Generate Data

Generate 200 items over 60 weeks of data.

In [None]:
df = generate_synthetic_pricing_data(
    n_items=2000,
    n_weeks=60,
    seed=42,
)

print(f"Generated {len(df):,} rows")
print(f"Shape: {df.shape}")
df.head()

## Visualize Sample Time Series

Let's visualize the time series for 5 randomly sampled items to understand the data patterns.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set modern seaborn style
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.1)

# Sample 10 random items
np.random.seed(42)
sample_items = np.random.choice(df['item_id'].unique(), size=5, replace=False)

# Create figure with enhanced styling
fig, axes = plt.subplots(2, 1, figsize=(16, 11))

# Use colorblind-friendly palette
colors = sns.color_palette("husl", 10)

# ===== Plot 1: Demand Over Time =====
ax1 = axes[0]
for idx, item_id in enumerate(sample_items):
    item_data = df[df['item_id'] == item_id].sort_values('week')
    ax1.plot(item_data['week'], item_data['demand'], 
             label=f'Item {item_id}', 
             color=colors[idx], 
             linewidth=2.5, 
             alpha=0.85,
             marker='o',
             markersize=4,
             markevery=5)

ax1.set_xlabel('Week', fontsize=13, fontweight='bold')
ax1.set_ylabel('Demand (units)', fontsize=13, fontweight='bold')
ax1.set_title('Demand Over Time - Sample of 5 Items', 
              fontsize=15, fontweight='bold', pad=15)
ax1.legend(loc='best', ncol=2, framealpha=0.95, 
           edgecolor='gray', fancybox=True, shadow=True)
ax1.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# ===== Plot 2: Discount Over Time (Step Plot) =====
ax2 = axes[1]
for idx, item_id in enumerate(sample_items):
    item_data = df[df['item_id'] == item_id].sort_values('week')
    ax2.step(item_data['week'], item_data['discount'], 
             label=f'Item {item_id}', 
             color=colors[idx], 
             linewidth=2.5, 
             alpha=0.85,
             where='post')

ax2.set_xlabel('Week', fontsize=13, fontweight='bold')
ax2.set_ylabel('Discount Rate', fontsize=13, fontweight='bold')
ax2.set_title('Discount Rate Over Time - Discrete Levels [0, 0.1, 0.2, 0.3, 0.4, 0.5]', 
              fontsize=15, fontweight='bold', pad=15)
ax2.legend(loc='best', ncol=2, framealpha=0.95, 
           edgecolor='gray', fancybox=True, shadow=True)
ax2.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

# Add horizontal reference lines for discount levels
for discount_level in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
    ax2.axhline(y=discount_level, color='gray', linestyle=':', 
                alpha=0.3, linewidth=1, zorder=0)

# Overall figure title
fig.suptitle('Synthetic Pricing Data: Time Series Visualization', 
             fontsize=17, fontweight='bold', y=0.995)

plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.show()

## Data Summary

In [None]:
df.describe()

## Save to Databricks Table

Save the generated data to a Databricks table for use in downstream training.

In [None]:
# Convert to Spark DataFrame and save
spark_df = spark.createDataFrame(df)

# Save to table (adjust catalog/schema as needed)
table_name = f"{catalog}.{schema}.synthetic_data"
spark_df.write.mode("overwrite").saveAsTable(table_name)

print(f"Data saved to table: {table_name}")