# Generate and save the data to be used for comparison

For an extensive comparison we're going to generate a dataset with 20 channels, 10 geos and 7 controls.

In [1]:
# Import the MMM data generator
from mmm_param_recovery.data_generator import (
    generate_mmm_dataset,
    get_preset_config,
    list_available_presets,
    customize_preset,
    MMMDataConfig,
    ChannelConfig,
    RegionConfig,
    TransformConfig,
    plot_channel_spend,
    plot_channel_contributions,
    plot_roas_comparison,
    plot_regional_comparison,
    plot_data_quality,
    calculate_roas_values,
    calculate_attribution_percentages
)

# Standard data science imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

# Display options
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

In [2]:
from typing import List

channels: List[ChannelConfig] = []

channel_patterns = [
    "seasonal",
    "linear_trend",
    "on_off"
]

# Generate 20 channels using pattern cycling and name formatting
for i in range(20):
    pattern = channel_patterns[i % len(channel_patterns)]
    pattern_name_part = pattern.replace("_", "-")
    name = f"x-{pattern_name_part}-{i+1}"

    if pattern == "seasonal":
        channel = ChannelConfig(
            name=name,
            pattern=pattern,
            base_spend=5000.0 + i * 100,
            seasonal_amplitude=0.3 + 0.01 * i,
            base_effectiveness=0.7 + 0.01 * (i % 5)
        )
    elif pattern == "linear_trend":
        channel = ChannelConfig(
            name=name,
            pattern=pattern,
            base_spend=3000.0 + i * 120,
            spend_trend=0.04 + 0.002 * (i % 7),
            base_effectiveness=0.6 + 0.01 * (i % 5)
        )
    elif pattern == "on_off":
        channel = ChannelConfig(
            name=name,
            pattern=pattern,
            base_spend=1500.0 + i * 90,
            activation_probability=0.5 + 0.02 * (i % 10),
            base_effectiveness=0.4 + 0.01 * (i % 6)
        )
    
    channels.append(channel)


In [3]:
region_names = [f"geo_{chr(97 + i)}" for i in range(10)]  # ['geo-a', 'geo-b', ..., 'geo-j']

regions = RegionConfig(
    n_regions=len(region_names),
    region_names=region_names,
    baseline_variation=0.1,
    channel_scale_variation=0.05,
    effectiveness_variation=0.08,
    transform_variation=0.03
)

In [4]:
control_variables = {
    "price": {
        "base_value": 10.0,
        "volatility": 0.1,
        "trend": 0.02,
        "seasonal_amplitude": 0.05
    },
    "promotion": {
        "base_value": 0.2,
        "volatility": 0.05,
        "trend": 0.0,
        "seasonal_amplitude": 0.1
    },
    "distribution": {
        "base_value": 1.0,
        "volatility": 0.02,
        "trend": 0.01,
        "seasonal_amplitude": 0.03
    },
    "product_quality": {
        "base_value": 0.8,
        "volatility": 0.01,
        "trend": 0.0,
        "seasonal_amplitude": 0.02
    },
    "macroeconomics": {
        "base_value": 5.0,
        "volatility": 0.2,
        "trend": -0.01,
        "seasonal_amplitude": 0.03
    },
    "weather": {
        "base_value": 0.5,
        "volatility": 0.3,
        "trend": 0.0,
        "seasonal_amplitude": 0.2
    },
    "competitor_price": {
        "base_value": 9.5,
        "volatility": 0.1,
        "trend": 0.015,
        "seasonal_amplitude": 0.04
    }
}


In [5]:
# Create a custom configuration with different channel patterns

custom_config = MMMDataConfig(
    n_periods=156, # 3 years
    channels= channels,
    regions = regions,
    transforms=TransformConfig(
        adstock_fun="geometric_adstock",
        adstock_kwargs=[{"alpha": 0.6}, {"alpha": 0.7}, {"alpha": 0.8}],
        saturation_fun="hill_function",
        saturation_kwargs=[{"slope": 1.0, "kappa": 2000.0}, {"slope": 1.0, "kappa": 2500.0}, {"slope": 1.0, "kappa": 3000.0}]
    ),
    #control_variables=control_variables,
    seed=42
) 

print("Generating data with custom configuration...")
custom_result = generate_mmm_dataset(custom_config)
custom_data = custom_result['data']
custom_ground_truth = custom_result['ground_truth']

print(f"\nCustom dataset generated!")
print(f"Shape: {custom_data.shape}")
print(f"Channels: {[col for col in custom_data.columns if col.startswith('x')]}")
print(f"Regions: {custom_data['geo'].unique()}")

Generating data with custom configuration...

Custom dataset generated!
Shape: (1560, 23)
Channels: ['x1_x-seasonal-1', 'x2_x-linear-trend-2', 'x3_x-on-off-3', 'x4_x-seasonal-4', 'x5_x-linear-trend-5', 'x6_x-on-off-6', 'x7_x-seasonal-7', 'x8_x-linear-trend-8', 'x9_x-on-off-9', 'x10_x-seasonal-10', 'x11_x-linear-trend-11', 'x12_x-on-off-12', 'x13_x-seasonal-13', 'x14_x-linear-trend-14', 'x15_x-on-off-15', 'x16_x-seasonal-16', 'x17_x-linear-trend-17', 'x18_x-on-off-18', 'x19_x-seasonal-19', 'x20_x-linear-trend-20']
Regions: ['geo_a' 'geo_b' 'geo_c' 'geo_d' 'geo_e' 'geo_f' 'geo_g' 'geo_h' 'geo_i'
 'geo_j']


In [6]:
custom_data.to_csv("mmm_input_data.csv")

In [7]:
custom_ground_truth['transformed_spend'].reset_index().to_csv("mmm_true_output_data.csv")

# Loading in new data presets

### Small business data

In [2]:
# Generate multi-region data
small_business_config = get_preset_config('small_business')
small_business_result = generate_mmm_dataset(small_business_config)
small_business_data = small_business_result['data']

print(f"Small Business dataset shape: {small_business_data.shape}")
print(f"Regions: {small_business_data['geo'].unique()}")
print(f"Date range: {small_business_data['date'].min()} to {small_business_data['date'].max()}")

Small Business dataset shape: (104, 11)
Regions: ['Local']
Date range: 2020-01-05 00:00:00 to 2021-12-26 00:00:00


In [3]:
small_business_data

Unnamed: 0,date,geo,x1_Search-Ads,x2_Social-Media,x3_Local-Ads,x4_Email,c1,c1_effect,c2,c2_effect,y
0,2020-01-05,Local,99.395498,491.929897,0.000000,0.000000,0.0,0.0,0.00000,0.000000,10249.752599
1,2020-01-12,Local,101.019857,524.915682,0.000000,0.000000,0.0,0.0,0.00000,0.000000,10359.393106
2,2020-01-19,Local,100.533384,535.040384,0.000000,0.000000,0.0,0.0,0.00000,0.000000,10907.838319
3,2020-01-26,Local,101.522134,564.481943,0.000000,0.000000,0.0,0.0,0.00000,0.000000,11387.561845
4,2020-02-02,Local,98.691781,550.372063,0.000000,0.000000,0.0,0.0,0.00000,0.000000,11519.351206
...,...,...,...,...,...,...,...,...,...,...,...
99,2021-11-28,Local,101.427500,479.898108,519.528262,0.000000,0.0,0.0,0.00000,0.000000,12563.357847
100,2021-12-05,Local,100.910636,473.467682,533.183783,96.156181,0.0,0.0,0.00000,0.000000,16160.363544
101,2021-12-12,Local,98.896190,452.086372,0.000000,0.000000,0.0,0.0,0.00000,0.000000,12704.575524
102,2021-12-19,Local,104.338930,521.616932,0.000000,0.000000,0.0,0.0,0.00000,0.000000,11962.477025


### Medium business data

In [6]:
medium_business_config = get_preset_config('medium_business')
medium_business_result = generate_mmm_dataset(medium_business_config)
medium_business_data = medium_business_result['data']

print(f"Medium Business dataset shape: {medium_business_data.shape}")
print(f"Regions: {medium_business_data['geo'].unique()}")
print(f"Date range: {medium_business_data['date'].min()} to {medium_business_data['date'].max()}")

Medium Business dataset shape: (262, 13)
Regions: ['geo_a' 'geo_b']
Date range: 2020-01-05 00:00:00 to 2022-07-03 00:00:00


In [7]:
medium_business_data

Unnamed: 0,date,geo,x1_Search-Ads,x2_Search-Ads-Brand,x3_Video,x4_Social-Media,x5_Display-Ads,x6_Email,c1,c1_effect,c2,c2_effect,y
0,2020-01-05,geo_a,298.186495,94.707645,0.000000,520.089275,0.000000,0.000000,0.0,0.0,0.0,0.0,26203.697665
1,2020-01-12,geo_a,303.049471,95.476134,0.000000,554.963255,0.000000,0.000000,0.0,0.0,0.0,0.0,27722.809941
2,2020-01-19,geo_a,301.579950,95.246118,0.000000,565.667522,0.000000,0.000000,0.0,0.0,0.0,0.0,29162.113591
3,2020-01-26,geo_a,304.536100,95.713940,0.000000,596.794394,0.000000,0.000000,0.0,0.0,0.0,0.0,30262.555472
4,2020-02-02,geo_a,296.034939,94.375178,0.000000,581.876827,0.000000,0.000000,0.0,0.0,0.0,0.0,30535.839437
...,...,...,...,...,...,...,...,...,...,...,...,...,...
257,2022-06-05,geo_b,297.666189,97.337038,520.227398,918.824621,552.248343,0.000000,0.0,0.0,0.0,0.0,44947.033707
258,2022-06-12,geo_b,295.398563,96.968942,510.276290,906.564091,482.509612,0.000000,0.0,0.0,0.0,0.0,43462.820653
259,2022-06-19,geo_b,315.446983,100.239935,598.168590,987.215962,484.493143,97.607004,0.0,0.0,0.0,0.0,50736.953077
260,2022-06-26,geo_b,297.267635,97.277106,518.453455,900.848103,0.000000,0.000000,0.0,0.0,0.0,0.0,39152.540480


### Large business data

In [8]:
large_business_config = get_preset_config('large_business')
large_business_result = generate_mmm_dataset(large_business_config)
large_business_data = large_business_result['data']

print(f"large Business dataset shape: {large_business_data.shape}")
print(f"Regions: {large_business_data['geo'].unique()}")
print(f"Date range: {large_business_data['date'].min()} to {large_business_data['date'].max()}")

large Business dataset shape: (655, 19)
Regions: ['geo_a' 'geo_b' 'geo_c' 'geo_d' 'geo_e']
Date range: 2020-01-05 00:00:00 to 2022-07-03 00:00:00


In [9]:
large_business_data

Unnamed: 0,date,geo,x1_Search-Ads,x2_Search-Ads-Brand,x3_Video,x4_Video-2,x5_Social-Media,x6_Social-Media-2,x7_Display-Ads,x8_Influencer,c1,c1_effect,c2,c2_effect,c3,c3_effect,c4,c4_effect,y
0,2020-01-05,geo_a,298.186495,94.707645,0.000000,0.000000,1022.607527,929.050047,0.0,0.0,0.0,0.0,0.000000,0.000000,769.837536,647.748437,0.000000,0.000000,3.097005e+06
1,2020-01-12,geo_a,303.049471,95.476134,0.000000,0.000000,1134.842089,1012.804238,0.0,0.0,0.0,0.0,0.000000,0.000000,2361.315067,1986.832380,17.636674,1740.034362,3.209898e+06
2,2020-01-19,geo_a,301.579950,95.246118,0.000000,0.000000,1144.148232,994.843215,0.0,0.0,0.0,0.0,975.395044,187.719255,1878.223701,1580.354827,22.636704,2233.337360,3.322502e+06
3,2020-01-26,geo_a,304.536100,95.713940,0.000000,0.000000,1232.576990,1052.687378,0.0,0.0,0.0,0.0,920.335606,177.122813,2121.733024,1785.245828,0.000000,0.000000,3.429922e+06
4,2020-02-02,geo_a,296.034939,94.375178,0.000000,0.000000,1130.066975,923.984984,0.0,0.0,0.0,0.0,880.025087,169.364869,989.376286,832.470376,0.000000,0.000000,3.413038e+06
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
650,2022-06-05,geo_e,322.331511,103.068409,542.464029,1095.662995,1852.906238,1542.295921,0.0,0.0,0.0,0.0,0.000000,0.000000,2038.959972,1606.049686,35.377541,3283.669341,4.962196e+06
651,2022-06-12,geo_e,313.730317,101.681671,507.418604,961.383008,1715.920254,1410.896283,0.0,0.0,0.0,0.0,1062.839245,222.925608,2901.834955,2285.719770,11.904205,1104.923359,4.982507e+06
652,2022-06-19,geo_e,310.869957,101.221634,495.758633,916.706686,1661.638885,1370.779529,0.0,0.0,0.0,0.0,908.888318,190.635114,1064.202957,838.252270,0.000000,0.000000,4.862331e+06
653,2022-06-26,geo_e,305.175106,100.304045,472.552283,827.789280,1558.406703,1281.280124,0.0,0.0,0.0,0.0,957.062237,200.739370,0.000000,0.000000,0.000000,0.000000,4.706779e+06
