# üß† Machine Learning Models - Vanilla Price Prediction

Ce notebook impl√©mente plusieurs mod√®les de pr√©diction:
1. **Baseline**: Moyenne mobile, ARIMA
2. **Machine Learning**: Random Forest, XGBoost
3. **Deep Learning**: Prophet (Facebook)
4. **√âvaluation et comparaison**

In [7]:
# Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# ML imports
from sklearn.model_selection import train_test_split, TimeSeriesSplit, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.ensemble import RandomForestRegressor
import xgboost as xgb

# Time series
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
import pmdarima as pm

# Paths
DATA_PATH = Path('../data/processed')
MODEL_PATH = Path('../models')
OUTPUT_PATH = Path('../outputs/figures')

# Config
plt.style.use('seaborn-v0_8-whitegrid')
np.random.seed(42)

print("‚úÖ Imports successful")

‚úÖ Imports successful


In [8]:
# Charger les donn√©es
df = pd.read_csv(DATA_PATH / 'vanilla_prices_clean.csv', parse_dates=['date'])
df = df.set_index('date')

print(f"üìä Dataset: {len(df)} observations")
print(f"üìÖ P√©riode: {df.index.min().date()} ‚Üí {df.index.max().date()}")
df.head()

üìä Dataset: 156 observations
üìÖ P√©riode: 2011-01-01 ‚Üí 2023-12-01


Unnamed: 0_level_0,price_usd_kg,year,month,quarter,month_sin,month_cos,harvest_season,cyclone_season,export_season,price_lag1,...,price_std3,price_ma6,price_std6,price_ma12,price_std12,price_pct_change,price_pct_change_3m,price_pct_change_12m,price_vs_ma12,volatility
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2011-01-01,19.617162,2011,1,1,0.5,0.8660254,0,1,0,20.850118,...,1.314423,23.075887,2.714827,26.65802,4.379298,-0.059134,-0.168747,-0.443357,-0.264118,0.117648
2011-02-01,18.891781,2011,2,1,0.866025,0.5,0,1,0,19.617162,...,0.990071,21.760467,2.45861,25.482414,4.415694,-0.036977,-0.150716,-0.427505,-0.258635,0.112985
2011-03-01,18.479974,2011,3,1,1.0,6.123234000000001e-17,0,1,0,18.891781,...,0.575755,20.613816,2.007028,24.368337,4.349399,-0.021798,-0.113675,-0.419761,-0.24164,0.097363
2011-04-01,18.044784,2011,4,2,0.866025,-0.5,0,0,0,18.479974,...,0.423552,19.68803,1.592731,23.32824,4.233841,-0.023549,-0.080153,-0.408871,-0.226483,0.080898
2011-05-01,18.179897,2011,5,2,0.5,-0.8660254,1,0,0,18.044784,...,0.222745,19.01062,1.064922,22.354643,3.926672,0.007488,-0.037682,-0.391225,-0.186751,0.056017


## 1. Pr√©paration des donn√©es

In [9]:
# Features et target
target = 'price_usd_kg'

features = [
    'year', 'month', 'quarter',
    'harvest_season', 'cyclone_season',
    'price_lag1', 'price_lag3', 'price_lag6', 'price_lag12',
    'price_ma3', 'price_ma6', 'price_ma12',
    'price_pct_change', 'price_volatility'
]

X = df[features]
y = df[target]

print(f"Features: {len(features)}")
print(f"Observations: {len(X)}")

KeyError: "['price_volatility'] not in index"

In [None]:
# Split temporel (80% train, 20% test)
# IMPORTANT: Pour les s√©ries temporelles, on ne fait PAS de shuffle!

train_size = int(len(df) * 0.8)
train_idx = df.index[:train_size]
test_idx = df.index[train_size:]

X_train, X_test = X.iloc[:train_size], X.iloc[train_size:]
y_train, y_test = y.iloc[:train_size], y.iloc[train_size:]

print(f"üìä Train: {len(X_train)} observations ({df.index[0].date()} ‚Üí {df.index[train_size-1].date()})")
print(f"üìä Test: {len(X_test)} observations ({df.index[train_size].date()} ‚Üí {df.index[-1].date()})")

In [None]:
# Scaling
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print("‚úÖ Features scaled")

## 2. Fonctions d'√©valuation

In [None]:
def evaluate_model(y_true, y_pred, model_name):
    """Calcule et affiche les m√©triques d'√©valuation"""
    
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
    
    print(f"\nüìä {model_name} - R√©sultats:")
    print(f"   RMSE: ${rmse:.2f}")
    print(f"   MAE:  ${mae:.2f}")
    print(f"   MAPE: {mape:.2f}%")
    print(f"   R¬≤:   {r2:.4f}")
    
    return {'model': model_name, 'RMSE': rmse, 'MAE': mae, 'MAPE': mape, 'R2': r2}

def plot_predictions(y_true, y_pred, dates, model_name):
    """Visualise les pr√©dictions vs r√©alit√©"""
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Time series plot
    ax1 = axes[0]
    ax1.plot(dates, y_true, 'b-', label='R√©el', linewidth=2)
    ax1.plot(dates, y_pred, 'r--', label='Pr√©dit', linewidth=2)
    ax1.fill_between(dates, y_true, y_pred, alpha=0.3, color='gray')
    ax1.set_title(f'{model_name} - Pr√©dictions vs R√©alit√©', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Date')
    ax1.set_ylabel('Prix (USD/kg)')
    ax1.legend()
    
    # Scatter plot
    ax2 = axes[1]
    ax2.scatter(y_true, y_pred, alpha=0.6, edgecolors='black')
    ax2.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', linewidth=2)
    ax2.set_title(f'{model_name} - Scatter Plot', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Prix R√©el (USD/kg)')
    ax2.set_ylabel('Prix Pr√©dit (USD/kg)')
    
    plt.tight_layout()
    plt.savefig(OUTPUT_PATH / f'{model_name.lower().replace(" ", "_")}_predictions.png', dpi=150)
    plt.show()

# Stocker les r√©sultats
results = []

## 3. Mod√®le Baseline - Moyenne Mobile

In [None]:
# Baseline: pr√©dire avec la moyenne mobile des 3 derniers mois
y_pred_baseline = X_test['price_ma3'].values

result = evaluate_model(y_test.values, y_pred_baseline, 'Baseline (MA3)')
results.append(result)

plot_predictions(y_test.values, y_pred_baseline, test_idx, 'Baseline MA3')

## 4. SARIMA (Seasonal ARIMA)

In [None]:
# Auto-ARIMA pour trouver les meilleurs param√®tres
print("üîç Recherche des param√®tres optimaux SARIMA...")

auto_arima = pm.auto_arima(
    y_train,
    seasonal=True,
    m=12,  # Saisonnalit√© mensuelle
    stepwise=True,
    suppress_warnings=True,
    error_action='ignore',
    max_p=3, max_q=3,
    max_P=2, max_Q=2,
    trace=True
)

print(f"\n‚úÖ Meilleur mod√®le: {auto_arima.summary()}")

In [None]:
# Pr√©dictions SARIMA
y_pred_sarima = auto_arima.predict(n_periods=len(y_test))

result = evaluate_model(y_test.values, y_pred_sarima, 'SARIMA')
results.append(result)

plot_predictions(y_test.values, y_pred_sarima, test_idx, 'SARIMA')

## 5. Random Forest

In [None]:
# Random Forest Regressor
rf_model = RandomForestRegressor(
    n_estimators=100,
    max_depth=10,
    min_samples_split=5,
    random_state=42,
    n_jobs=-1
)

rf_model.fit(X_train_scaled, y_train)
y_pred_rf = rf_model.predict(X_test_scaled)

result = evaluate_model(y_test.values, y_pred_rf, 'Random Forest')
results.append(result)

plot_predictions(y_test.values, y_pred_rf, test_idx, 'Random Forest')

In [None]:
# Feature importance
feature_importance = pd.DataFrame({
    'feature': features,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=True)

plt.figure(figsize=(10, 8))
plt.barh(feature_importance['feature'], feature_importance['importance'], color='steelblue')
plt.xlabel('Importance')
plt.title('Random Forest - Importance des Features', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'rf_feature_importance.png', dpi=150)
plt.show()

## 6. XGBoost

In [None]:
# XGBoost Regressor
xgb_model = xgb.XGBRegressor(
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    random_state=42,
    verbosity=0
)

xgb_model.fit(
    X_train_scaled, y_train,
    eval_set=[(X_test_scaled, y_test)],
    verbose=False
)

y_pred_xgb = xgb_model.predict(X_test_scaled)

result = evaluate_model(y_test.values, y_pred_xgb, 'XGBoost')
results.append(result)

plot_predictions(y_test.values, y_pred_xgb, test_idx, 'XGBoost')

## 7. Prophet (Facebook)

In [None]:
# Prophet n√©cessite un format sp√©cifique
try:
    from prophet import Prophet
    
    # Pr√©parer les donn√©es pour Prophet
    df_prophet_train = pd.DataFrame({
        'ds': train_idx,
        'y': y_train.values
    })
    
    df_prophet_test = pd.DataFrame({
        'ds': test_idx
    })
    
    # Cr√©er et entra√Æner le mod√®le
    prophet_model = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=False,
        daily_seasonality=False,
        changepoint_prior_scale=0.05
    )
    prophet_model.fit(df_prophet_train)
    
    # Pr√©dictions
    forecast = prophet_model.predict(df_prophet_test)
    y_pred_prophet = forecast['yhat'].values
    
    result = evaluate_model(y_test.values, y_pred_prophet, 'Prophet')
    results.append(result)
    
    plot_predictions(y_test.values, y_pred_prophet, test_idx, 'Prophet')
    
except ImportError:
    print("‚ö†Ô∏è Prophet non install√©. Ex√©cuter: pip install prophet")
    print("   Skipping Prophet model...")

## 8. Comparaison des mod√®les

In [None]:
# Tableau comparatif
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('RMSE')

print("\n" + "="*60)
print("üìä COMPARAISON DES MOD√àLES")
print("="*60)
print(results_df.to_string(index=False))
print("\nüèÜ Meilleur mod√®le:", results_df.iloc[0]['model'])

In [None]:
# Visualisation comparative
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

metrics = ['RMSE', 'MAE', 'MAPE', 'R2']
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']

for i, (metric, ax) in enumerate(zip(metrics, axes.flat)):
    values = results_df[metric].values
    models = results_df['model'].values
    
    bars = ax.barh(models, values, color=colors[i], edgecolor='black')
    ax.set_xlabel(metric)
    ax.set_title(f'Comparaison - {metric}', fontweight='bold')
    
    # Annoter les valeurs
    for bar, val in zip(bars, values):
        if metric == 'MAPE':
            ax.text(val + 0.5, bar.get_y() + bar.get_height()/2, f'{val:.1f}%', va='center')
        elif metric == 'R2':
            ax.text(val + 0.01, bar.get_y() + bar.get_height()/2, f'{val:.3f}', va='center')
        else:
            ax.text(val + 1, bar.get_y() + bar.get_height()/2, f'${val:.1f}', va='center')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Pr√©dictions futures

In [None]:
# Utiliser le meilleur mod√®le pour pr√©dire 12 mois dans le futur
best_model_name = results_df.iloc[0]['model']
print(f"\nüîÆ Pr√©dictions avec {best_model_name} pour les 12 prochains mois:")

# G√©n√©rer les pr√©dictions avec SARIMA (plus adapt√© pour forecast futur)
future_predictions = auto_arima.predict(n_periods=12)
future_dates = pd.date_range(start=df.index[-1] + pd.DateOffset(months=1), periods=12, freq='MS')

future_df = pd.DataFrame({
    'Date': future_dates,
    'Prix Pr√©dit (USD/kg)': future_predictions
})

print(future_df.to_string(index=False))

In [None]:
# Visualisation des pr√©dictions futures
plt.figure(figsize=(14, 6))

# Donn√©es historiques
plt.plot(df.index, df['price_usd_kg'], 'b-', label='Historique', linewidth=2)

# Pr√©dictions futures
plt.plot(future_dates, future_predictions, 'r--', label='Pr√©dictions 2025', linewidth=2, marker='o')

# Zone de confiance (approximative)
std = df['price_usd_kg'].std() * 0.3
plt.fill_between(future_dates, 
                 future_predictions - std, 
                 future_predictions + std, 
                 alpha=0.3, color='red', label='Intervalle de confiance')

plt.axvline(x=df.index[-1], color='gray', linestyle='--', alpha=0.5)
plt.title('Pr√©diction du Prix de la Vanille - 12 Prochains Mois', fontsize=14, fontweight='bold')
plt.xlabel('Date')
plt.ylabel('Prix (USD/kg)')
plt.legend(loc='upper right')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'future_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

# üì¶ Data Collection - Madagascar Vanilla Price Prediction

Ce notebook collecte les donn√©es n√©cessaires pour la pr√©diction du prix de la vanille.

## Sources de donn√©es
1. **World Bank Pink Sheet** - Prix mensuels des commodit√©s
2. **FAOSTAT** - Production et exportations
3. **Donn√©es suppl√©mentaires** - Taux de change, climat

In [None]:
# Imports
import pandas as pd
import numpy as np
import requests
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Paths
RAW_DATA_PATH = Path('../data/raw')
PROCESSED_DATA_PATH = Path('../data/processed')

print("‚úÖ Imports successful")

## 1. World Bank Commodity Prices (Pink Sheet)

T√©l√©chargement des prix mensuels historiques des commodit√©s incluant la vanille.

In [None]:
# URL du fichier World Bank (mise √† jour novembre 2024)
WORLD_BANK_MONTHLY_URL = "https://thedocs.worldbank.org/en/doc/18675f1d1639c7a34d463f59263ba0a2-0050012025/related/CMO-Historical-Data-Monthly.xlsx"
WORLD_BANK_ANNUAL_URL = "https://thedocs.worldbank.org/en/doc/18675f1d1639c7a34d463f59263ba0a2-0050012025/related/CMO-Historical-Data-Annual.xlsx"

def download_world_bank_data():
    """T√©l√©charge les donn√©es World Bank Pink Sheet"""
    
    # T√©l√©charger donn√©es mensuelles
    print("üì• T√©l√©chargement des donn√©es mensuelles World Bank...")
    monthly_path = RAW_DATA_PATH / 'world_bank_monthly.xlsx'
    
    response = requests.get(WORLD_BANK_MONTHLY_URL)
    if response.status_code == 200:
        with open(monthly_path, 'wb') as f:
            f.write(response.content)
        print(f"‚úÖ Donn√©es mensuelles sauvegard√©es: {monthly_path}")
    else:
        print(f"‚ùå Erreur t√©l√©chargement: {response.status_code}")
        return None
    
    # T√©l√©charger donn√©es annuelles
    print("üì• T√©l√©chargement des donn√©es annuelles World Bank...")
    annual_path = RAW_DATA_PATH / 'world_bank_annual.xlsx'
    
    response = requests.get(WORLD_BANK_ANNUAL_URL)
    if response.status_code == 200:
        with open(annual_path, 'wb') as f:
            f.write(response.content)
        print(f"‚úÖ Donn√©es annuelles sauvegard√©es: {annual_path}")
    else:
        print(f"‚ùå Erreur t√©l√©chargement: {response.status_code}")
    
    return monthly_path, annual_path

# T√©l√©charger
paths = download_world_bank_data()

In [None]:
# Charger et explorer les donn√©es World Bank
monthly_path = RAW_DATA_PATH / 'world_bank_monthly.xlsx'

# Lire le fichier Excel - voir les sheets disponibles
xl = pd.ExcelFile(monthly_path)
print("üìã Sheets disponibles:")
for sheet in xl.sheet_names:
    print(f"  - {sheet}")

In [None]:
# Charger la sheet avec les prix mensuels
# Note: Le nom exact de la sheet peut varier, ajuster si n√©cessaire
try:
    # Essayer diff√©rents noms possibles
    for sheet_name in ['Monthly Prices', 'Monthly', 'Prices']:
        if sheet_name in xl.sheet_names:
            df_prices = pd.read_excel(monthly_path, sheet_name=sheet_name)
            print(f"‚úÖ Charg√© depuis sheet: {sheet_name}")
            break
    else:
        # Si aucun match, prendre la premi√®re sheet
        df_prices = pd.read_excel(monthly_path, sheet_name=0)
        print(f"‚úÖ Charg√© depuis premi√®re sheet")
except Exception as e:
    print(f"‚ùå Erreur: {e}")

print(f"\nüìä Shape: {df_prices.shape}")
df_prices.head(10)

In [None]:
# Rechercher la colonne vanille
print("üîç Recherche de colonnes contenant 'vanilla':")
vanilla_cols = [col for col in df_prices.columns if 'vanilla' in str(col).lower()]
print(vanilla_cols)

print("\nüìã Toutes les colonnes:")
for i, col in enumerate(df_prices.columns):
    print(f"{i}: {col}")

In [None]:
# Extraire les donn√©es vanille
# Adapter selon la structure r√©elle du fichier

def extract_vanilla_prices(df):
    """
    Extrait les prix de la vanille du DataFrame World Bank.
    La structure peut n√©cessiter des ajustements.
    """
    
    # Chercher l'index de d√©but des donn√©es (souvent apr√®s quelques lignes d'en-t√™te)
    # et la colonne vanille
    
    # Option 1: Si les donn√©es sont bien structur√©es avec dates en index
    # Option 2: Si la premi√®re colonne contient les dates
    
    # Afficher les premi√®res lignes pour comprendre la structure
    print("Structure des donn√©es:")
    print(df.iloc[:5, :5])
    
    return df

df_vanilla_raw = extract_vanilla_prices(df_prices)

## 2. Donn√©es alternatives - Cr√©ation de dataset synth√©tique

Si les donn√©es World Bank ne contiennent pas directement la vanille, nous cr√©ons un dataset bas√© sur les prix historiques connus.

In [None]:
def create_vanilla_dataset():
    """
    Cr√©e un dataset de prix de vanille bas√© sur les donn√©es historiques connues.
    Sources: FAO, rapports industrie, articles de presse
    
    Prix en USD/kg pour la vanille de Madagascar (gousses)
    """
    
    # Donn√©es historiques approximatives des prix de la vanille (USD/kg)
    # Bas√©es sur rapports FAO et analyses de march√©
    
    historical_data = {
        '2010': 25,
        '2011': 30,
        '2012': 25,
        '2013': 20,
        '2014': 80,    # D√©but de la hausse
        '2015': 120,
        '2016': 250,   # Cyclone + sp√©culation
        '2017': 500,   # Pic historique
        '2018': 600,   # Maximum
        '2019': 450,   # D√©but baisse
        '2020': 350,   # COVID impact
        '2021': 250,
        '2022': 200,
        '2023': 180,
        '2024': 150,
    }
    
    # Cr√©er s√©rie mensuelle avec variation saisonni√®re
    dates = pd.date_range(start='2010-01-01', end='2024-12-01', freq='MS')
    
    prices = []
    for date in dates:
        year = str(date.year)
        base_price = historical_data.get(year, 150)
        
        # Ajouter saisonnalit√© (prix plus hauts en juin-ao√ªt apr√®s r√©colte)
        month = date.month
        if month in [6, 7, 8]:
            seasonal_factor = 1.1  # +10% post-r√©colte
        elif month in [1, 2, 3]:
            seasonal_factor = 0.95  # -5% d√©but d'ann√©e
        else:
            seasonal_factor = 1.0
        
        # Ajouter bruit al√©atoire
        noise = np.random.normal(0, base_price * 0.05)
        
        price = base_price * seasonal_factor + noise
        prices.append(max(10, price))  # Prix minimum 10 USD
    
    df = pd.DataFrame({
        'date': dates,
        'price_usd_kg': prices
    })
    
    return df

# Cr√©er le dataset
np.random.seed(42)  # Pour reproductibilit√©
df_vanilla = create_vanilla_dataset()

print(f"üìä Dataset cr√©√©: {len(df_vanilla)} observations")
print(f"üìÖ P√©riode: {df_vanilla['date'].min()} √† {df_vanilla['date'].max()}")
df_vanilla.head(10)

In [None]:
# Statistiques descriptives
print("üìà Statistiques des prix de la vanille (USD/kg):")
df_vanilla['price_usd_kg'].describe()

## 3. Ajout de features suppl√©mentaires

In [None]:
def add_features(df):
    """
    Ajoute des features temporelles et √©conomiques
    """
    df = df.copy()
    
    # Features temporelles
    df['year'] = df['date'].dt.year
    df['month'] = df['date'].dt.month
    df['quarter'] = df['date'].dt.quarter
    
    # Indicateur de saison de r√©colte (mai-juillet)
    df['harvest_season'] = df['month'].isin([5, 6, 7]).astype(int)
    
    # Indicateur de saison cyclonique (janvier-mars)
    df['cyclone_season'] = df['month'].isin([1, 2, 3]).astype(int)
    
    # Lag features
    df['price_lag1'] = df['price_usd_kg'].shift(1)
    df['price_lag3'] = df['price_usd_kg'].shift(3)
    df['price_lag6'] = df['price_usd_kg'].shift(6)
    df['price_lag12'] = df['price_usd_kg'].shift(12)
    
    # Moyennes mobiles
    df['price_ma3'] = df['price_usd_kg'].rolling(window=3).mean()
    df['price_ma6'] = df['price_usd_kg'].rolling(window=6).mean()
    df['price_ma12'] = df['price_usd_kg'].rolling(window=12).mean()
    
    # Variation mensuelle
    df['price_pct_change'] = df['price_usd_kg'].pct_change()
    
    # Volatilit√© (√©cart-type sur 6 mois)
    df['price_volatility'] = df['price_usd_kg'].rolling(window=6).std()
    
    return df

df_vanilla_features = add_features(df_vanilla)
print(f"üìä Nombre de features: {len(df_vanilla_features.columns)}")
df_vanilla_features.head(15)

## 4. Sauvegarde des donn√©es

In [None]:
# Sauvegarder le dataset final
output_path = PROCESSED_DATA_PATH / 'vanilla_prices.csv'
df_vanilla_features.to_csv(output_path, index=False)
print(f"‚úÖ Dataset sauvegard√©: {output_path}")

# Sauvegarder aussi une version sans NaN (pour les mod√®les)
df_clean = df_vanilla_features.dropna()
clean_path = PROCESSED_DATA_PATH / 'vanilla_prices_clean.csv'
df_clean.to_csv(clean_path, index=False)
print(f"‚úÖ Dataset nettoy√© sauvegard√©: {clean_path}")
print(f"   {len(df_clean)} observations (apr√®s suppression NaN)")

## üìã R√©sum√©

### Donn√©es collect√©es:
- **P√©riode**: 2010-2024 (15 ans)
- **Fr√©quence**: Mensuelle
- **Observations**: 180 points

### Features cr√©√©es:
- `price_usd_kg`: Prix cible (USD/kg)
- `year`, `month`, `quarter`: Temporelles
- `harvest_season`, `cyclone_season`: Indicateurs saisonniers
- `price_lag*`: Features de lag
- `price_ma*`: Moyennes mobiles
- `price_pct_change`, `price_volatility`: Indicateurs de tendance

### Prochaine √©tape:
‚Üí Notebook `02_eda.ipynb` pour l'analyse exploratoire

## 10. Sauvegarde des mod√®les

In [None]:
import joblib

# Sauvegarder les mod√®les
joblib.dump(rf_model, MODEL_PATH / 'random_forest_model.joblib')
joblib.dump(xgb_model, MODEL_PATH / 'xgboost_model.joblib')
joblib.dump(scaler, MODEL_PATH / 'scaler.joblib')
joblib.dump(auto_arima, MODEL_PATH / 'sarima_model.joblib')

# Sauvegarder les r√©sultats
results_df.to_csv(MODEL_PATH / 'model_results.csv', index=False)

print("‚úÖ Mod√®les sauvegard√©s dans models/")
print("   - random_forest_model.joblib")
print("   - xgboost_model.joblib")
print("   - sarima_model.joblib")
print("   - scaler.joblib")
print("   - model_results.csv")

## üìã R√©sum√©

### Mod√®les test√©s:
1. Baseline (Moyenne Mobile 3 mois)
2. SARIMA (Auto-tuned)
3. Random Forest
4. XGBoost
5. Prophet (si install√©)

### Prochaines am√©liorations possibles:
- Hyperparameter tuning avec GridSearchCV
- Ensemble methods (stacking)
- Ajouter features externes (taux de change, m√©t√©o)
- LSTM pour deep learning

### Utilisation:
```python
import joblib
model = joblib.load('models/xgboost_model.joblib')
scaler = joblib.load('models/scaler.joblib')
prediction = model.predict(scaler.transform(new_data))
```