# Generate and save the data to be used for comparison

For an extensive comparison we're going to generate a dataset with 20 channels, 10 geos and 7 controls.

In [1]:
# Import the MMM data generator
from mmm_param_recovery.data_generator import (
    generate_mmm_dataset,
    get_preset_config,
    list_available_presets,
    customize_preset,
    MMMDataConfig,
    ChannelConfig,
    RegionConfig,
    TransformConfig,
    plot_channel_spend,
    plot_channel_contributions,
    plot_roas_comparison,
    plot_regional_comparison,
    plot_data_quality,
    calculate_roas_values,
    calculate_attribution_percentages
)

# Standard data science imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

# Display options
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

In [2]:
from typing import List

channels: List[ChannelConfig] = []

channel_patterns = [
    "seasonal",
    "linear_trend",
    "on_off"
]

# Generate 20 channels using pattern cycling and name formatting
for i in range(20):
    pattern = channel_patterns[i % len(channel_patterns)]
    pattern_name_part = pattern.replace("_", "-")
    name = f"x-{pattern_name_part}-{i+1}"

    if pattern == "seasonal":
        channel = ChannelConfig(
            name=name,
            pattern=pattern,
            base_spend=5000.0 + i * 100,
            seasonal_amplitude=0.3 + 0.01 * i,
            base_effectiveness=0.7 + 0.01 * (i % 5)
        )
    elif pattern == "linear_trend":
        channel = ChannelConfig(
            name=name,
            pattern=pattern,
            base_spend=3000.0 + i * 120,
            spend_trend=0.04 + 0.002 * (i % 7),
            base_effectiveness=0.6 + 0.01 * (i % 5)
        )
    elif pattern == "on_off":
        channel = ChannelConfig(
            name=name,
            pattern=pattern,
            base_spend=1500.0 + i * 90,
            activation_probability=0.5 + 0.02 * (i % 10),
            base_effectiveness=0.4 + 0.01 * (i % 6)
        )
    
    channels.append(channel)


In [3]:
region_names = [f"geo_{chr(97 + i)}" for i in range(10)]  # ['geo-a', 'geo-b', ..., 'geo-j']

regions = RegionConfig(
    n_regions=len(region_names),
    region_names=region_names,
    baseline_variation=0.1,
    channel_scale_variation=0.05,
    effectiveness_variation=0.08,
    transform_variation=0.03
)

In [4]:
control_variables = {
    "price": {
        "base_value": 10.0,
        "volatility": 0.1,
        "trend": 0.02,
        "seasonal_amplitude": 0.05
    },
    "promotion": {
        "base_value": 0.2,
        "volatility": 0.05,
        "trend": 0.0,
        "seasonal_amplitude": 0.1
    },
    "distribution": {
        "base_value": 1.0,
        "volatility": 0.02,
        "trend": 0.01,
        "seasonal_amplitude": 0.03
    },
    "product_quality": {
        "base_value": 0.8,
        "volatility": 0.01,
        "trend": 0.0,
        "seasonal_amplitude": 0.02
    },
    "macroeconomics": {
        "base_value": 5.0,
        "volatility": 0.2,
        "trend": -0.01,
        "seasonal_amplitude": 0.03
    },
    "weather": {
        "base_value": 0.5,
        "volatility": 0.3,
        "trend": 0.0,
        "seasonal_amplitude": 0.2
    },
    "competitor_price": {
        "base_value": 9.5,
        "volatility": 0.1,
        "trend": 0.015,
        "seasonal_amplitude": 0.04
    }
}


In [5]:
# Create a custom configuration with different channel patterns

custom_config = MMMDataConfig(
    n_periods=156, # 3 years
    channels= channels,
    regions = regions,
    transforms=TransformConfig(
        adstock_fun="geometric_adstock",
        adstock_kwargs=[{"alpha": 0.6}, {"alpha": 0.7}, {"alpha": 0.8}],
        saturation_fun="hill_function",
        saturation_kwargs=[{"slope": 1.0, "kappa": 2000.0}, {"slope": 1.0, "kappa": 2500.0}, {"slope": 1.0, "kappa": 3000.0}]
    ),
    #control_variables=control_variables,
    seed=42
) 

print("Generating data with custom configuration...")
custom_result = generate_mmm_dataset(custom_config)
custom_data = custom_result['data']
custom_ground_truth = custom_result['ground_truth']

print(f"\nCustom dataset generated!")
print(f"Shape: {custom_data.shape}")
print(f"Channels: {[col for col in custom_data.columns if col.startswith('x')]}")
print(f"Regions: {custom_data['geo'].unique()}")

Generating data with custom configuration...

Custom dataset generated!
Shape: (1560, 23)
Channels: ['x1_x-seasonal-1', 'x2_x-linear-trend-2', 'x3_x-on-off-3', 'x4_x-seasonal-4', 'x5_x-linear-trend-5', 'x6_x-on-off-6', 'x7_x-seasonal-7', 'x8_x-linear-trend-8', 'x9_x-on-off-9', 'x10_x-seasonal-10', 'x11_x-linear-trend-11', 'x12_x-on-off-12', 'x13_x-seasonal-13', 'x14_x-linear-trend-14', 'x15_x-on-off-15', 'x16_x-seasonal-16', 'x17_x-linear-trend-17', 'x18_x-on-off-18', 'x19_x-seasonal-19', 'x20_x-linear-trend-20']
Regions: ['geo_a' 'geo_b' 'geo_c' 'geo_d' 'geo_e' 'geo_f' 'geo_g' 'geo_h' 'geo_i'
 'geo_j']


In [6]:
custom_data.to_csv("mmm_input_data.csv")

In [7]:
custom_ground_truth['transformed_spend'].reset_index().to_csv("mmm_true_output_data.csv")