# üß† Mod√®le MMM Complet - Adstock & Saturation

Ce notebook impl√©mente le mod√®le MMM bay√©sien complet avec :
1. **Mod√®le simple** : R√©gression lin√©aire bay√©sienne
2. **+ Adstock** : Persistance temporelle de l'effet pub
3. **+ Saturation** : Rendements d√©croissants
4. **Comparaison** : Quel mod√®le explique le mieux les donn√©es ?

**Auteur** : Ivan  
**Projet** : MMM Bay√©sien - MSMIN5IN43

In [None]:
# Imports
import sys
sys.path.insert(0, '../src')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

from data.loader import load_csv_data, split_train_test
from data.preprocessing import prepare_mmm_data
from models.base_mmm import BayesianMMM
from inference.diagnostics import check_convergence

# Config
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 8)
%matplotlib inline

# Seed pour reproductibilit√©
np.random.seed(42)

## 1Ô∏è‚É£ Chargement et pr√©paration des donn√©es

In [None]:
# Charger les donn√©es
df = load_csv_data('../data/raw/sample_data.csv')

print(f"‚úì Dataset charg√© : {len(df)} p√©riodes")
print(f"‚úì Colonnes : {list(df.columns)}")
print(f"\nüìÖ P√©riode : {df['date'].min()} ‚Üí {df['date'].max()}")

df.head()

In [None]:
# Split train/test (80/20)
train_df, test_df = split_train_test(df, train_ratio=0.8)

print(f"Train : {len(train_df)} p√©riodes")
print(f"Test  : {len(test_df)} p√©riodes")

In [None]:
# Pr√©parer les donn√©es
train_prep, meta = prepare_mmm_data(train_df, normalize=True)
test_prep, _ = prepare_mmm_data(test_df, normalize=True)

# Extraire X et y
media_cols = ['media_1_spend', 'media_2_spend', 'media_3_spend']

X_train = train_prep[media_cols].values
y_train = train_prep['sales_log'].values

X_test = test_prep[media_cols].values
y_test = test_prep['sales_log'].values

print(f"‚úì X_train : {X_train.shape}")
print(f"‚úì y_train : {y_train.shape}")
print(f"‚úì X_test  : {X_test.shape}")
print(f"‚úì y_test  : {y_test.shape}")

## 2Ô∏è‚É£ Mod√®le 1 : Simple (sans transformations)

**√âquation** : `sales = Œ≤‚ÇÄ + Œ≤‚ÇÅ¬∑media‚ÇÅ + Œ≤‚ÇÇ¬∑media‚ÇÇ + Œ≤‚ÇÉ¬∑media‚ÇÉ + Œµ`

In [None]:
# Construire et entra√Æner
print("üî® Construction du mod√®le simple...")
mmm_simple = BayesianMMM(use_adstock=False, use_saturation=False)

print("‚è≥ Sampling MCMC (cela prend ~2-3 minutes)...")
trace_simple = mmm_simple.fit(
    X_train, y_train,
    draws=1000,
    tune=1000,
    chains=2,
    random_seed=42
)

print("\n‚úÖ Mod√®le simple entra√Æn√© !")

In [None]:
# Diagnostics de convergence
report_simple = check_convergence(trace_simple)

print("üìä Diagnostics de convergence\n")
print(f"Convergence : {'‚úÖ OUI' if report_simple['converged'] else '‚ùå NON'}")
print(f"R-hat max   : {report_simple['r_hat_max']:.4f}")
print(f"ESS bulk min: {report_simple['ess_bulk_min']:.0f}")
print(f"Divergences : {report_simple['n_divergences']}")

if report_simple['warnings']:
    print("\n‚ö†Ô∏è  Warnings:")
    for w in report_simple['warnings']:
        print(f"  - {w}")

In [None]:
# R√©sum√© des param√®tres
print("üìà Param√®tres estim√©s (mod√®le simple)\n")
summary_simple = mmm_simple.summary()
print(summary_simple[['mean', 'sd', 'hdi_3%', 'hdi_97%']].round(3))

In [None]:
# Pr√©dictions
y_pred_simple = mmm_simple.predict(X_test)

mae_simple = np.mean(np.abs(y_test - y_pred_simple))
rmse_simple = np.sqrt(np.mean((y_test - y_pred_simple) ** 2))

print(f"üìä Performance sur test set (mod√®le simple)")
print(f"MAE  : {mae_simple:.4f}")
print(f"RMSE : {rmse_simple:.4f}")

## 3Ô∏è‚É£ Mod√®le 2 : Avec Adstock

**Ajout** : Effet persistant de la publicit√©

**Transformation** : `x_adstocked[t] = x[t] + Œ±¬∑x_adstocked[t-1]`

In [None]:
# Construire et entra√Æner
print("üî® Construction du mod√®le avec adstock...")
mmm_adstock = BayesianMMM(use_adstock=True, use_saturation=False)

print("‚è≥ Sampling MCMC (cela prend ~3-4 minutes)...")
trace_adstock = mmm_adstock.fit(
    X_train, y_train,
    draws=1000,
    tune=1000,
    chains=2,
    random_seed=42
)

print("\n‚úÖ Mod√®le adstock entra√Æn√© !")

In [None]:
# Diagnostics
report_adstock = check_convergence(trace_adstock)

print("üìä Diagnostics (mod√®le adstock)\n")
print(f"Convergence : {'‚úÖ OUI' if report_adstock['converged'] else '‚ùå NON'}")
print(f"R-hat max   : {report_adstock['r_hat_max']:.4f}")
print(f"ESS bulk min: {report_adstock['ess_bulk_min']:.0f}")

In [None]:
# Param√®tres alpha (taux de r√©tention)
alpha_mean = trace_adstock.posterior['alpha'].mean(dim=['chain', 'draw']).values

print("üìà Taux de r√©tention (alpha) par canal\n")
for i, alpha in enumerate(alpha_mean):
    print(f"Canal {i+1} : Œ± = {alpha:.3f} ‚Üí effet persiste ~{int(1/(1-alpha))} p√©riodes")

In [None]:
# Performance
y_pred_adstock = mmm_adstock.predict(X_test)

mae_adstock = np.mean(np.abs(y_test - y_pred_adstock))
rmse_adstock = np.sqrt(np.mean((y_test - y_pred_adstock) ** 2))

print(f"üìä Performance sur test set (mod√®le adstock)")
print(f"MAE  : {mae_adstock:.4f}")
print(f"RMSE : {rmse_adstock:.4f}")

## 4Ô∏è‚É£ Mod√®le 3 : Complet (Adstock + Saturation)

**Ajout** : Rendements d√©croissants

**Transformation** : `x_saturated = x^s / (k^s + x^s)`

In [None]:
# Construire et entra√Æner
print("üî® Construction du mod√®le complet...")
mmm_full = BayesianMMM(use_adstock=True, use_saturation=True)

print("‚è≥ Sampling MCMC (cela prend ~4-5 minutes)...")
trace_full = mmm_full.fit(
    X_train, y_train,
    draws=1000,
    tune=1000,
    chains=2,
    random_seed=42
)

print("\n‚úÖ Mod√®le complet entra√Æn√© !")

In [None]:
# Diagnostics
report_full = check_convergence(trace_full)

print("üìä Diagnostics (mod√®le complet)\n")
print(f"Convergence : {'‚úÖ OUI' if report_full['converged'] else '‚ùå NON'}")
print(f"R-hat max   : {report_full['r_hat_max']:.4f}")
print(f"ESS bulk min: {report_full['ess_bulk_min']:.0f}")

In [None]:
# Param√®tres de saturation
k_mean = trace_full.posterior['half_saturation'].mean(dim=['chain', 'draw']).values
s_mean = trace_full.posterior['slope'].mean(dim=['chain', 'draw']).values

print("üìà Param√®tres de saturation\n")
for i in range(len(k_mean)):
    print(f"Canal {i+1} : k={k_mean[i]:.2f}, s={s_mean[i]:.2f}")

In [None]:
# Performance
y_pred_full = mmm_full.predict(X_test)

mae_full = np.mean(np.abs(y_test - y_pred_full))
rmse_full = np.sqrt(np.mean((y_test - y_pred_full) ** 2))

print(f"üìä Performance sur test set (mod√®le complet)")
print(f"MAE  : {mae_full:.4f}")
print(f"RMSE : {rmse_full:.4f}")

## 5Ô∏è‚É£ Comparaison des mod√®les

In [None]:
# Tableau comparatif
comparison = pd.DataFrame({
    'Mod√®le': ['Simple', 'Adstock', 'Complet'],
    'MAE': [mae_simple, mae_adstock, mae_full],
    'RMSE': [rmse_simple, rmse_adstock, rmse_full],
    'Convergence': [
        '‚úÖ' if report_simple['converged'] else '‚ùå',
        '‚úÖ' if report_adstock['converged'] else '‚ùå',
        '‚úÖ' if report_full['converged'] else '‚ùå'
    ]
})

print("\nüìä COMPARAISON DES MOD√àLES\n")
print(comparison.to_string(index=False))

# Meilleur mod√®le
best_idx = comparison['MAE'].idxmin()
print(f"\nüèÜ Meilleur mod√®le : {comparison.loc[best_idx, 'Mod√®le']}")

In [None]:
# Visualisation des pr√©dictions
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

models = [
    ('Simple', y_pred_simple, mae_simple),
    ('Adstock', y_pred_adstock, mae_adstock),
    ('Complet', y_pred_full, mae_full)
]

for ax, (name, y_pred, mae) in zip(axes, models):
    ax.scatter(y_test, y_pred, alpha=0.6)
    ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')
    ax.set_xlabel('Ventes r√©elles (log)', fontsize=11)
    ax.set_ylabel('Ventes pr√©dites (log)', fontsize=11)
    ax.set_title(f'{name}\nMAE: {mae:.4f}', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6Ô∏è‚É£ Analyse du mod√®le complet

In [None]:
# Trace plots (convergence visuelle)
az.plot_trace(trace_full, var_names=['beta_media', 'alpha', 'half_saturation'])
plt.tight_layout()
plt.show()

In [None]:
# Distributions a posteriori
az.plot_posterior(trace_full, var_names=['beta_media', 'alpha'])
plt.tight_layout()
plt.show()

In [None]:
# Contributions par canal
contributions = mmm_full.get_channel_contributions()
contributions['pct'] = (contributions['total_contribution'] / 
                        contributions['total_contribution'].sum() * 100)

print("\nüí∞ CONTRIBUTIONS PAR CANAL\n")
print(contributions.round(2))

# Visualisation
plt.figure(figsize=(10, 6))
plt.bar(contributions['channel'], contributions['pct'])
plt.xlabel('Canal', fontsize=12)
plt.ylabel('Contribution (%)', fontsize=12)
plt.title('Contribution de chaque canal aux ventes', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, axis='y')
plt.show()

## 7Ô∏è‚É£ Conclusions

### ‚úÖ R√©sultats cl√©s

**Performance :**
- Mod√®le simple : baseline
- Mod√®le adstock : capture la persistance temporelle
- Mod√®le complet : capture adstock + saturation

**Insights business :**
1. **Adstock** : Effet pub persiste X semaines
2. **Saturation** : Rendements d√©croissants apr√®s Y‚Ç¨ de d√©penses
3. **Attribution** : Contributions relatives par canal

**Prochaines √©tapes :**
1. Optimisation budg√©taire
2. Sc√©narios what-if
3. Validation crois√©e
4. D√©ploiement