In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
from sklearn.linear_model import ElasticNet, Ridge, Lasso
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error
import seaborn as sns
from scipy.stats import norm
from datetime import datetime, timedelta
import random
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

np.random.seed(42)

# ----- EXTENDED DATA GENERATION FUNCTIONS WITH SOCIAL & INFLUENCER CHANNELS -----

def generate_date_range(start_date='2022-01-01', periods=104):
    """Generate a series of dates for weekly data"""
    dates = pd.date_range(start=start_date, periods=periods, freq='W')
    return dates

def generate_baseline_sales(periods=104, baseline=100000, noise_level=0.05):
    """Generate baseline sales with some random noise"""
    noise = np.random.normal(0, noise_level, periods)
    baseline_trend = np.linspace(0, 0.3, periods)  # Slight upward trend
    seasonality = 0.2 * np.sin(np.linspace(0, 8*np.pi, periods))  # Seasonal pattern

    sales = baseline * (1 + noise + baseline_trend + seasonality)
    return sales

def generate_tv_spend(periods=104, budget=40000, noise_level=0.3):
    """Generate TV advertising spend with budget fluctuations"""
    base_spend = np.ones(periods) * budget
    # Add some campaign spikes
    campaign_periods = [13, 26, 52, 65, 78, 91]  # Campaigns every quarter
    for period in campaign_periods:
        base_spend[period-2:period+2] *= 2  # Double spend during campaigns

    # Add noise
    noise = np.random.normal(0, noise_level, periods)
    spend = base_spend * (1 + noise)

    # Ensure no negative spend
    spend = np.maximum(spend, 0)
    return spend

def generate_digital_spend(periods=104, budget=25000, noise_level=0.2):
    """Generate digital advertising spend"""
    base_spend = np.ones(periods) * budget
    # Digital tends to be more consistent but with occasional tests
    test_periods = [8, 22, 36, 50, 64, 78, 92]
    for period in test_periods:
        base_spend[period:period+2] *= 1.5  # 50% increase during test periods

    # Add noise
    noise = np.random.normal(0, noise_level, periods)
    spend = base_spend * (1 + noise)

    # Ensure no negative spend
    spend = np.maximum(spend, 0)
    return spend

def generate_radio_spend(periods=104, budget=15000, noise_level=0.4):
    """Generate radio advertising spend"""
    # Radio might be more seasonal
    seasonality = 0.3 * np.sin(np.linspace(0, 4*np.pi, periods))
    base_spend = budget * (1 + seasonality)

    # Add noise
    noise = np.random.normal(0, noise_level, periods)
    spend = base_spend * (1 + noise)

    # Ensure no negative spend
    spend = np.maximum(spend, 0)
    return spend

def generate_print_spend(periods=104, budget=10000, noise_level=0.5):
    """Generate print advertising spend"""
    # Print might be more sporadic
    base_spend = np.random.gamma(shape=1.5, scale=budget/1.5, size=periods)

    # Add some zero spend periods (no print ads)
    zero_indices = np.random.choice(periods, size=int(periods*0.2), replace=False)
    base_spend[zero_indices] = 0

    return base_spend

# ----- NEW SOCIAL MEDIA & INFLUENCER CHANNEL FUNCTIONS -----

def generate_Social_Media_Spend(periods=104, budget=20000, noise_level=0.25):
    """Generate social media advertising spend with realistic patterns"""
    np.random.seed(50)  # Set a different seed for social media

    # Base spend with gradual increase over time (reflecting growing importance)
    growth_factor = np.linspace(0.8, 1.2, periods)
    base_spend = budget * growth_factor

    # Add campaign spikes (more frequent than TV)
    campaign_periods = [8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96]  # Monthly campaigns
    for period in campaign_periods:
        if period < periods:
            window = 2  # Shorter campaign windows
            start_idx = max(0, period-window//2)
            end_idx = min(periods, period+window//2)
            base_spend[start_idx:end_idx] *= 1.5  # 50% increase during campaigns

    # Add noise (more variable than traditional channels)
    noise = np.random.normal(0, noise_level, periods)
    spend = base_spend * (1 + noise)

    # Add occasional viral content pushes
    viral_periods = np.random.choice(periods, size=4, replace=False)
    for period in viral_periods:
        spend[period] *= 2.5  # Higher spend for viral content

    # Ensure no negative spend
    spend = np.maximum(spend, 0)
    return spend

def generate_influencer_spend(periods=104, budget=15000, noise_level=0.6):
    """Generate influencer marketing spend with distinct patterns"""
    np.random.seed(51)  # Set a different seed for influencer

    # Influencer marketing tends to be more sporadic with discrete campaigns
    base_spend = np.zeros(periods)

    # Add campaign periods (larger, less frequent than social)
    campaign_periods = [13, 26, 39, 52, 65, 78, 91]  # Quarterly big pushes
    for period in campaign_periods:
        if period < periods:
            # Campaign duration
            duration = np.random.choice([2, 3, 4])  # Variable campaign length
            for i in range(duration):
                if period+i < periods:
                    base_spend[period+i] = budget * (1.5 - 0.3*i)  # Decaying impact

    # Add some always-on influencer relationships
    always_on = budget * 0.2 * np.ones(periods)  # 20% of budget for ongoing relationships
    base_spend += always_on

    # Add randomness to represent different influencer costs and availability
    noise_multiplier = np.random.normal(1, noise_level, periods)
    spend = base_spend * noise_multiplier

    # Add occasional mega-influencer campaigns (much higher cost)
    mega_periods = np.random.choice(periods, size=2, replace=False)
    for period in mega_periods:
        spend[period] = budget * 3  # 3x normal budget for mega influencers

    # Ensure no negative spend
    spend = np.maximum(spend, 0)
    return spend

# ----- TRADITIONAL CHANNEL EFFECT CALCULATION FUNCTIONS -----

def calculate_tv_effect(spend, adstock_rate=0.7, saturation=0.7, base_effectiveness=1.5):
    """Calculate the effect of TV advertising with adstock and saturation"""
    # Apply adstock (lagged effect)
    effect = np.zeros(len(spend))
    effect[0] = spend[0]
    for i in range(1, len(spend)):
        effect[i] = spend[i] + adstock_rate * effect[i-1]

    # Apply diminishing returns (saturation)
    effect = base_effectiveness * np.power(effect, saturation)

    return effect

def calculate_digital_effect(spend, adstock_rate=0.3, saturation=0.8, base_effectiveness=2.0):
    """Calculate the effect of digital advertising"""
    # Digital typically has less carryover but higher immediate impact
    effect = np.zeros(len(spend))
    effect[0] = spend[0]
    for i in range(1, len(spend)):
        effect[i] = spend[i] + adstock_rate * effect[i-1]

    # Apply diminishing returns
    effect = base_effectiveness * np.power(effect, saturation)

    return effect

def calculate_radio_effect(spend, adstock_rate=0.5, saturation=0.6, base_effectiveness=1.2):
    """Calculate the effect of radio advertising"""
    effect = np.zeros(len(spend))
    effect[0] = spend[0]
    for i in range(1, len(spend)):
        effect[i] = spend[i] + adstock_rate * effect[i-1]

    # Apply diminishing returns
    effect = base_effectiveness * np.power(effect, saturation)

    return effect

def calculate_print_effect(spend, adstock_rate=0.4, saturation=0.5, base_effectiveness=1.0):
    """Calculate the effect of print advertising"""
    effect = np.zeros(len(spend))
    effect[0] = spend[0]
    for i in range(1, len(spend)):
        effect[i] = spend[i] + adstock_rate * effect[i-1]

    # Apply diminishing returns
    effect = base_effectiveness * np.power(effect, saturation)

    return effect

# ----- NEW SOCIAL MEDIA & INFLUENCER EFFECT CALCULATION FUNCTIONS -----

def calculate_social_media_effect(spend, adstock_rate=0.2, saturation=0.75, base_effectiveness=2.5,
                                 viral_coefficient=0.1, viral_decay=0.5):
    """
    Calculate the effect of social media with:
    1. Lower adstock (shorter memory) but higher base effectiveness
    2. Viral potential (content can organically spread)
    3. Network effects
    """
    # Initialize effect array
    effect = np.zeros(len(spend))

    # Initialize viral component array (to track organic spread)
    viral = np.zeros(len(spend))

    # Calculate first period normally
    effect[0] = spend[0]
    viral[0] = 0

    # Calculate remaining periods with viral effects
    for i in range(1, len(spend)):
        # Calculate viral component based on previous effect
        new_viral = viral_coefficient * effect[i-1]
        viral[i] = new_viral + viral_decay * viral[i-1]  # Viral effect decays over time

        # Basic adstock effect (shorter memory than traditional channels)
        adstock_effect = spend[i] + adstock_rate * effect[i-1]

        # Combine paid and viral effect
        effect[i] = adstock_effect + viral[i]

    # Network effect - social media impact increases when spend is consistently high
    # Higher spend creates more significant network effects
    rolling_spend = pd.Series(spend).rolling(window=4, min_periods=1).mean().values
    network_multiplier = 1 + (0.2 * (rolling_spend - rolling_spend.min()) /
                            (rolling_spend.max() - rolling_spend.min() + 1e-10))

    # Apply network multiplier
    effect = effect * network_multiplier

    # Apply diminishing returns (saturation)
    effect = base_effectiveness * np.power(effect, saturation)

    return effect

def calculate_influencer_effect(spend, adstock_rate=0.4, saturation=0.6, base_effectiveness=2.0,
                               authenticity_factor=0.2, max_authenticity_bonus=0.5):
    """
    Calculate the effect of influencer marketing with:
    1. Medium adstock (medium memory)
    2. Non-linear authenticity effects (less spend can sometimes be more effective)
    3. Diminishing returns as influencer saturation increases
    """
    # Initialize effect array
    effect = np.zeros(len(spend))
    effect[0] = spend[0]

    # Basic adstock calculation
    for i in range(1, len(spend)):
        effect[i] = spend[i] + adstock_rate * effect[i-1]

    # Calculate authenticity bonus
    # Influencer marketing shows a non-linear relationship: moderate spend
    # can have higher effectiveness per dollar than very high spend due to perceived authenticity
    spend_normalized = spend / np.max(spend + 1e-10)

    # Authenticity peaks at moderate spend levels (around 0.4-0.6 of max spend)
    authenticity_bonus = max_authenticity_bonus * (
        1 - np.abs(spend_normalized - 0.5) * 2
    ) ** 2

    # Apply authenticity effect
    authenticity_multiplier = 1 + authenticity_factor * authenticity_bonus
    effect = effect * authenticity_multiplier

    # When influencer marketing is consistent, it builds credibility over time
    rolling_avg = pd.Series(spend).rolling(window=8, min_periods=1).mean().values
    consistency_bonus = 0.2 * (1 - np.exp(-rolling_avg / np.mean(spend)))
    effect = effect * (1 + consistency_bonus)

    # Apply diminishing returns (saturation)
    effect = base_effectiveness * np.power(effect, saturation)

    return effect

# ----- OTHER EFFECT FUNCTIONS (SAME AS ORIGINAL) -----

def generate_price_effect(periods=104, base_price=50, price_elasticity=-1.5):
    """Generate price changes and their effect on sales"""
    # Generate price variations
    price_variations = np.random.normal(0, 0.05, periods)
    price = base_price * (1 + price_variations)

    # Calculate price index (relative to average)
    price_index = price / np.mean(price)

    # Calculate price effect on sales
    price_effect = np.power(price_index, price_elasticity)

    return price, price_effect

def generate_competitor_effect(periods=104, impact_factor=0.3):
    """Generate competitor activity effect on sales"""
    # Competitor activities might increase or decrease sales
    competitor_effect = np.random.normal(0, impact_factor, periods)
    # Make it more smooth with rolling average
    competitor_effect = pd.Series(competitor_effect).rolling(window=4, min_periods=1).mean().values

    # Convert to multiplicative effect (centered around 1)
    competitor_effect = 1 + competitor_effect

    return competitor_effect

def generate_holiday_effect(periods=104, start_date='2022-01-01'):
    """Generate holiday effects on sales"""
    dates = pd.date_range(start=start_date, periods=periods, freq='W')
    holiday_effect = np.ones(periods)

    # Define holidays (simplified for demonstration)
    for year in range(2022, 2025):
        # Black Friday (4th Thursday in November + following week)
        black_friday = pd.Timestamp(f'{year}-11-01') + pd.Timedelta(days=(24-pd.Timestamp(f'{year}-11-01').dayofweek))
        bf_week = black_friday.isocalendar()[1]
        cyber_week = bf_week + 1

        # Christmas
        christmas_week = pd.Timestamp(f'{year}-12-25').isocalendar()[1]

        # Summer holidays (July)
        summer_weeks = [pd.Timestamp(f'{year}-07-{day}').isocalendar()[1] for day in [1, 8, 15, 22]]

        # Apply effects
        for i, date in enumerate(dates):
            week = date.isocalendar()[1]
            year_match = date.year == year

            if year_match and week == bf_week:
                holiday_effect[i] *= 1.8  # Black Friday boost
            elif year_match and week == cyber_week:
                holiday_effect[i] *= 1.5  # Cyber Week boost
            elif year_match and week == christmas_week:
                holiday_effect[i] *= 1.7  # Christmas boost
            elif year_match and week in summer_weeks:
                holiday_effect[i] *= 1.2  # Summer boost

    return holiday_effect

def generate_weather_effect(periods=104):
    """Generate weather effects on sales"""
    # Simulate seasonal weather patterns with random variations
    seasonal_base = np.sin(np.linspace(0, 4*np.pi, periods))
    random_variations = np.random.normal(0, 0.2, periods)
    weather_pattern = seasonal_base + random_variations

    # Convert to multiplicative effect (centered around 1)
    # Assuming both positive and negative weather impacts
    weather_effect = 1 + 0.15 * weather_pattern

    return weather_effect

# ----- EXTENDED COMBINATION FUNCTION WITH SOCIAL & INFLUENCER CHANNELS -----

def combine_effects_and_generate_sales(baseline_sales, tv_effect, digital_effect,
                                       radio_effect, print_effect, social_media_effect,
                                       influencer_effect, price_effect,
                                       competitor_effect, holiday_effect, weather_effect,
                                       error_std=0.05, channel_interaction=True):
    """Combine all effects to generate final sales figures including new channels"""
    # New version with smaller coefficients to accommodate more channels
    combined_effect = (1 + 0.00015 * tv_effect) * \
                      (1 + 0.00025 * digital_effect) * \
                      (1 + 0.00015 * radio_effect) * \
                      (1 + 0.00008 * print_effect) * \
                      (1 + 0.0003 * social_media_effect) * \
                      (1 + 0.00022 * influencer_effect) * \
                      price_effect * competitor_effect * \
                      holiday_effect * weather_effect

    # Add interaction effects between digital, social, and influencer channels
    if channel_interaction:
        # Digital and Social Media have synergistic effects
        digital_social_interaction = 0.00004 * (digital_effect * social_media_effect) / np.mean(digital_effect * social_media_effect)

        # Influencer and Social have synergistic effects
        influencer_social_interaction = 0.00005 * (influencer_effect * social_media_effect) / np.mean(influencer_effect * social_media_effect)

        # Slight negative interaction between TV and digital/social (cannibalization)
        tv_digital_interaction = -0.00002 * (tv_effect * (digital_effect + social_media_effect)) / np.mean(tv_effect * (digital_effect + social_media_effect))

        # Apply interaction effects
        combined_effect = combined_effect * (1 + digital_social_interaction) * \
                                        (1 + influencer_social_interaction) * \
                                        (1 + tv_digital_interaction)

    # Generate final sales
    sales = baseline_sales * combined_effect

    # Add random error
    error = np.random.normal(0, error_std, len(sales))
    sales = sales * (1 + error)

    return sales

# ----- MAIN SIMULATION CODE WITH SOCIAL & INFLUENCER CHANNELS -----

def generate_mmm_data_extended(periods=104, include_interactions=True):
    """Generate a complete dataset for MMM analysis with social and influencer channels"""
    # Generate date range
    dates = generate_date_range(periods=periods)  # 2 years of weekly data

    # Generate baseline sales
    baseline_sales = generate_baseline_sales(periods=periods)

    # Generate media spend for traditional channels
    tv_spend = generate_tv_spend(periods=periods)
    digital_spend = generate_digital_spend(periods=periods)
    radio_spend = generate_radio_spend(periods=periods)
    print_spend = generate_print_spend(periods=periods)

    # Generate media spend for new channels
    Social_Media_Spend = generate_Social_Media_Spend(periods=periods)
    influencer_spend = generate_influencer_spend(periods=periods)

    # Calculate media effects for traditional channels
    tv_effect = calculate_tv_effect(tv_spend)
    digital_effect = calculate_digital_effect(digital_spend)
    radio_effect = calculate_radio_effect(radio_spend)
    print_effect = calculate_print_effect(print_spend)

    # Calculate media effects for new channels
    social_media_effect = calculate_social_media_effect(Social_Media_Spend)
    influencer_effect = calculate_influencer_effect(influencer_spend)

    # Generate other factors
    price, price_effect = generate_price_effect(periods=periods)
    competitor_effect = generate_competitor_effect(periods=periods)
    holiday_effect = generate_holiday_effect(periods=periods)
    weather_effect = generate_weather_effect(periods=periods)

    # Combine effects to generate sales
    sales = combine_effects_and_generate_sales(
        baseline_sales, tv_effect, digital_effect, radio_effect, print_effect,
        social_media_effect, influencer_effect, price_effect, competitor_effect,
        holiday_effect, weather_effect, channel_interaction=include_interactions
    )

    # Create DataFrame
    data = pd.DataFrame({
        'Date': dates,
        'Sales': sales,
        'TV_Spend': tv_spend,
        'Digital_Spend': digital_spend,
        'Radio_Spend': radio_spend,
        'Print_Spend': print_spend,
        'Social_Media_Spend': Social_Media_Spend,
        'Influencer_Spend': influencer_spend,
        'Price': price,
        'Holiday_Factor': holiday_effect,
        'Competitor_Activity': competitor_effect,
        'Weather_Factor': weather_effect,
        'Year': [d.year for d in dates],
        'Month': [d.month for d in dates],
        'Week': [d.isocalendar()[1] for d in dates]
    })

    # Add time variables
    data['WeekNum'] = np.arange(len(data))
    data['Sin_Week'] = np.sin(2 * np.pi * data['Week'] / 52)
    data['Cos_Week'] = np.cos(2 * np.pi * data['Week'] / 52)

    return data

# ----- FEATURE ENGINEERING FOR SOCIAL & INFLUENCER CHANNELS -----

# Modify the prepare_model_data_extended function to handle potential NaN values

def prepare_model_data_extended(data, adstock_params=None, saturation_params=None,
                              viral_params=None, authenticity_params=None):
    """Prepare data for modeling with extended parameters for social and influencer channels"""
    processed_data = data.copy()

    # Default adstock parameters if none provided
    if adstock_params is None:
        adstock_params = {
            'TV_Spend': 0.7,
            'Digital_Spend': 0.3,
            'Radio_Spend': 0.5,
            'Print_Spend': 0.4,
            'Social_Media_Spend': 0.2,
            'Influencer_Spend': 0.4
        }

    # Default saturation parameters if none provided
    if saturation_params is None:
        saturation_params = {
            'TV_Spend': 0.7,
            'Digital_Spend': 0.8,
            'Radio_Spend': 0.6,
            'Print_Spend': 0.5,
            'Social_Media_Spend': 0.75,
            'Influencer_Spend': 0.6
        }

    # Default viral parameters for social media if none provided
    if viral_params is None:
        viral_params = {
            'coefficient': 0.1,
            'decay': 0.5
        }

    # Default authenticity parameters for influencer if none provided
    if authenticity_params is None:
        authenticity_params = {
            'factor': 0.2,
            'max_bonus': 0.5
        }

    # Convert all spend columns to float to avoid dtype issues
    for col in ['TV_Spend', 'Digital_Spend', 'Radio_Spend', 'Print_Spend',
                'Social_Media_Spend', 'Influencer_Spend']:
        if col in processed_data.columns:
            processed_data[col] = processed_data[col].astype(float)

    # Process traditional channels with standard adstock & saturation
    for channel in ['TV_Spend', 'Digital_Spend', 'Radio_Spend', 'Print_Spend']:
        # Initialize adstocked spend column
        adstock_col = f"{channel}_Adstocked"
        processed_data[adstock_col] = 0.0  # Use float instead of int

        # Apply adstock transformation
        processed_data.loc[0, adstock_col] = processed_data.loc[0, channel]
        for i in range(1, len(processed_data)):
            processed_data.loc[i, adstock_col] = (
                processed_data.loc[i, channel] +
                adstock_params[channel] * processed_data.loc[i-1, adstock_col]
            )

        # Apply saturation (diminishing returns)
        processed_data[f"{channel}_Transformed"] = np.power(
            processed_data[adstock_col],
            saturation_params[channel]
        )

    # Process Social Media with viral effects
    channel = 'Social_Media_Spend'

    # Initialize adstocked spend column
    adstock_col = f"{channel}_Adstocked"
    processed_data[adstock_col] = 0.0  # Use float

    # Initialize viral component
    viral_col = f"{channel}_Viral"
    processed_data[viral_col] = 0.0  # Use float

    # Apply adstock transformation with viral component
    processed_data.loc[0, adstock_col] = processed_data.loc[0, channel]

    for i in range(1, len(processed_data)):
        # Calculate viral component
        new_viral = viral_params['coefficient'] * processed_data.loc[i-1, adstock_col]
        processed_data.loc[i, viral_col] = new_viral + viral_params['decay'] * processed_data.loc[i-1, viral_col]

        # Basic adstock effect
        basic_adstock = processed_data.loc[i, channel] + adstock_params[channel] * processed_data.loc[i-1, adstock_col]

        # Combine paid and viral effect
        processed_data.loc[i, adstock_col] = basic_adstock + processed_data.loc[i, viral_col]

    # Apply network effect with safeguards against division by zero
    rolling_spend = processed_data[channel].rolling(window=4, min_periods=1).mean()

    # Safely calculate the network multiplier
    spend_min = rolling_spend.min()
    spend_max = rolling_spend.max()

    # If max and min are nearly identical, set a default multiplier to avoid numerical issues
    if abs(spend_max - spend_min) < 1e-6:
        network_multiplier = np.ones(len(processed_data))
    else:
        # Use a larger epsilon and ensure we don't divide by zero
        epsilon = 1.0  # Larger epsilon to prevent numerical issues
        network_multiplier = 1 + (0.2 * (rolling_spend - spend_min) /
                                (spend_max - spend_min + epsilon))

    processed_data[f"{channel}_Network"] = processed_data[adstock_col] * network_multiplier

    # Apply saturation to the combined effect
    processed_data[f"{channel}_Transformed"] = np.power(
        processed_data[f"{channel}_Network"],
        saturation_params[channel]
    )

    # Process Influencer with authenticity effects
    channel = 'Influencer_Spend'

    # Initialize adstocked spend column
    adstock_col = f"{channel}_Adstocked"
    processed_data[adstock_col] = 0.0  # Use float

    # Apply adstock transformation
    processed_data.loc[0, adstock_col] = processed_data.loc[0, channel]
    for i in range(1, len(processed_data)):
        processed_data.loc[i, adstock_col] = (
            processed_data.loc[i, channel] +
            adstock_params[channel] * processed_data.loc[i-1, adstock_col]
        )

    # Calculate authenticity bonus with safeguards
    channel_max = processed_data[channel].max()

    # Avoid division by zero
    if channel_max == 0:
        spend_normalized = np.zeros(len(processed_data))
    else:
        spend_normalized = processed_data[channel] / (channel_max + 1e-6)

    # Authenticity peaks at moderate spend levels
    authenticity_bonus = authenticity_params['max_bonus'] * (
        1 - np.abs(spend_normalized - 0.5) * 2
    ) ** 2

    # Apply authenticity effect
    authenticity_multiplier = 1 + authenticity_params['factor'] * authenticity_bonus
    processed_data[f"{channel}_Authentic"] = processed_data[adstock_col] * authenticity_multiplier

    # Apply consistency bonus safely
    rolling_avg = processed_data[channel].rolling(window=8, min_periods=1).mean()

    # Avoid division by zero in mean calculation
    channel_mean = processed_data[channel].mean()
    if channel_mean == 0:
        consistency_bonus = np.zeros(len(processed_data))
    else:
        consistency_bonus = 0.2 * (1 - np.exp(-rolling_avg / (channel_mean + 1e-6)))

    processed_data[f"{channel}_Consistent"] = processed_data[f"{channel}_Authentic"] * (1 + consistency_bonus)

    # Apply saturation to the final effect
    processed_data[f"{channel}_Transformed"] = np.power(
        processed_data[f"{channel}_Consistent"],
        saturation_params[channel]
    )

    # Create interaction terms for digital, social and influencer
    processed_data['Digital_Social_Interaction'] = (
        processed_data['Digital_Spend_Transformed'] *
        processed_data['Social_Media_Spend_Transformed']
    ) / 1e6

    processed_data['Social_Influencer_Interaction'] = (
        processed_data['Social_Media_Spend_Transformed'] *
        processed_data['Influencer_Spend_Transformed']
    ) / 1e6

    processed_data['TV_Digital_Interaction'] = (
        processed_data['TV_Spend_Transformed'] *
        processed_data['Digital_Spend_Transformed']
    ) / 1e6

    # Create log transformed variables for all channels
    for channel in ['TV_Spend', 'Digital_Spend', 'Radio_Spend', 'Print_Spend',
                   'Social_Media_Spend', 'Influencer_Spend']:
        # Add small constant to avoid log(0)
        processed_data[f'Log_{channel}'] = np.log1p(processed_data[channel])

    # Create squared terms for price (non-linear effects)
    processed_data['Price_Squared'] = processed_data['Price'] ** 2

    # Final check for NaN values and fill them
    if processed_data.isna().any().any():
        # Print columns with NaN values for debugging
        nan_columns = processed_data.columns[processed_data.isna().any()].tolist()
        print(f"Warning: NaN values found in columns: {nan_columns}")

        # Fill NaN values with appropriate methods
        # For numeric columns, fill with median or 0
        numeric_cols = processed_data.select_dtypes(include=['number']).columns
        for col in numeric_cols:
            if processed_data[col].isna().any():
                if col.endswith('_Transformed') or col.endswith('_Adstocked'):
                    # For derived variables, use 0 as a safe value
                    processed_data[col] = processed_data[col].fillna(0)
                else:
                    # For other numeric columns, use median
                    processed_data[col] = processed_data[col].fillna(processed_data[col].median())

    return processed_data

# Add a fix for the train_model_extended function as well
def train_model_extended(data, model_type='ridge', test_size=26, feature_selection='transformed', **kwargs):
    """Train a marketing mix model with specified parameters for extended channel set"""
    # Prepare features
    features = prepare_features_extended(data, feature_selection)

    # Split data into training and testing
    train_data = data.iloc[:-test_size].copy()
    test_data = data.iloc[-test_size:].copy()

    # Prepare data
    X_train = train_data[features]
    y_train = train_data['Sales']
    X_test = test_data[features]
    y_test = test_data['Sales']

    # Check for NaN values before model training
    if X_train.isna().any().any() or y_train.isna().any():
        print("Warning: NaN values found in training data. Filling them...")
        # Fill NaN values in X_train
        X_train = X_train.fillna(X_train.median())
        # Fill NaN values in y_train
        y_train = y_train.fillna(y_train.median())

    if X_test.isna().any().any() or y_test.isna().any():
        print("Warning: NaN values found in test data. Filling them...")
        # Fill NaN values in X_test
        X_test = X_test.fillna(X_train.median())  # Use training median
        # Fill NaN values in y_test
        y_test = y_test.fillna(y_train.median())  # Use training median

    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Train model based on type
    if model_type == 'ols':
        # Add constant for statsmodels
        X_train_sm = sm.add_constant(X_train_scaled)
        X_test_sm = sm.add_constant(X_test_scaled)

        # Train model
        model = sm.OLS(y_train, X_train_sm).fit()

        # Make predictions
        train_pred = model.predict(X_train_sm)
        test_pred = model.predict(X_test_sm)

    elif model_type == 'ridge':
        alpha = kwargs.get('alpha', 1.0)
        model = Ridge(alpha=alpha)
        model.fit(X_train_scaled, y_train)

        train_pred = model.predict(X_train_scaled)
        test_pred = model.predict(X_test_scaled)

    elif model_type == 'lasso':
        alpha = kwargs.get('alpha', 0.1)
        model = Lasso(alpha=alpha)
        model.fit(X_train_scaled, y_train)

        train_pred = model.predict(X_train_scaled)
        test_pred = model.predict(X_test_scaled)

    elif model_type == 'elasticnet':
        alpha = kwargs.get('alpha', 0.1)
        l1_ratio = kwargs.get('l1_ratio', 0.5)
        model = ElasticNet(alpha=alpha, l1_ratio=l1_ratio)
        model.fit(X_train_scaled, y_train)

        train_pred = model.predict(X_train_scaled)
        test_pred = model.predict(X_test_scaled)

    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Calculate metrics
    train_r2 = r2_score(y_train, train_pred)
    test_r2 = r2_score(y_test, test_pred)
    train_rmse = np.sqrt(mean_squared_error(y_train, train_pred))
    test_rmse = np.sqrt(mean_squared_error(y_test, test_pred))

    results = {
        'model': model,
        'model_type': model_type,
        'train_pred': train_pred,
        'test_pred': test_pred,
        'train_r2': train_r2,
        'test_r2': test_r2,
        'train_rmse': train_rmse,
        'test_rmse': test_rmse,
        'features': features,
        'feature_selection': feature_selection,
        'scaler': scaler,
        'train_data': train_data,
        'test_data': test_data
    }

    return results

# ----- EXTENDED ANALYSIS FUNCTIONS FOR SOCIAL & INFLUENCER -----

def calculate_roi_extended(model_results, data):
    """Calculate ROI for each marketing channel including social & influencer"""
    model = model_results['model']
    model_type = model_results['model_type']
    features = model_results['features']
    feature_selection = model_results['feature_selection']
    scaler = model_results['scaler']

    # Get coefficients based on model type
    if model_type == 'ols':
        coef_dict = model.params.to_dict()
        if 'const' in coef_dict:
            del coef_dict['const']
    else:
        coef_dict = {}
        for i, feature in enumerate(features):
            coef_dict[feature] = model.coef_[i]

    # Calculate channel spend totals
    channel_spend = {
        'TV': data['TV_Spend'].sum(),
        'Digital': data['Digital_Spend'].sum(),
        'Radio': data['Radio_Spend'].sum(),
        'Print': data['Print_Spend'].sum(),
        'Social_Media': data['Social_Media_Spend'].sum(),
        'Influencer': data['Influencer_Spend'].sum()
    }

    # Calculate average sales
    avg_sales = data['Sales'].mean()

    # Calculate ROI for each channel
    roi_results = {}

    # Map feature prefixes to channel names
    channel_prefixes = {
        'TV': ['TV_Spend', 'TV_Spend_Adstocked', 'TV_Spend_Transformed', 'Log_TV_Spend'],
        'Digital': ['Digital_Spend', 'Digital_Spend_Adstocked', 'Digital_Spend_Transformed', 'Log_Digital_Spend'],
        'Radio': ['Radio_Spend', 'Radio_Spend_Adstocked', 'Radio_Spend_Transformed', 'Log_Radio_Spend'],
        'Print': ['Print_Spend', 'Print_Spend_Adstocked', 'Print_Spend_Transformed', 'Log_Print_Spend'],
        'Social_Media': ['Social_Media_Spend', 'Social_Media_Spend_Adstocked', 'Social_Media_Spend_Network',
                        'Social_Media_Spend_Transformed', 'Social_Media_Spend_Viral', 'Log_Social_Media_Spend'],
        'Influencer': ['Influencer_Spend', 'Influencer_Spend_Adstocked', 'Influencer_Spend_Authentic',
                      'Influencer_Spend_Consistent', 'Influencer_Spend_Transformed', 'Log_Influencer_Spend']
    }

    # For each channel, find the corresponding feature in the model
    for channel, prefixes in channel_prefixes.items():
        # Find the feature that corresponds to this channel
        channel_features = []
        for feature in features:
            for prefix in prefixes:
                if feature.startswith(prefix) and not any(inter in feature for inter in ['Interaction']):
                    channel_features.append(feature)

        # If no features found for this channel, skip
        if not channel_features:
            roi_results[channel] = 0
            continue

        # Get the coefficient
        channel_coef = 0
        for feature in channel_features:
            channel_coef += coef_dict.get(feature, 0)

        # Calculate ROI
        channel_key = channel if channel != 'Social_Media' and channel != 'Influencer' else channel.lower()
        channel_total_spend = channel_spend[channel]
        if channel_total_spend > 0:
            roi = (channel_coef * avg_sales * len(data)) / channel_total_spend
            roi_results[channel] = max(0, roi)  # Ensure non-negative ROI
        else:
            roi_results[channel] = 0

    # Add interaction effects to ROI
    # For Digital and Social Media interaction
    if 'Digital_Social_Interaction' in features:
        interaction_feature = 'Digital_Social_Interaction'
        interaction_coef = coef_dict.get(interaction_feature, 0)

        # Allocate interaction effect proportionally to Digital and Social
        digital_spend = channel_spend['Digital']
        social_spend = channel_spend['Social_Media']
        total_spend = digital_spend + social_spend

        if total_spend > 0:
            digital_share = digital_spend / total_spend
            social_share = social_spend / total_spend

            interaction_effect = interaction_coef * avg_sales * len(data)

            # Add to ROI
            if digital_spend > 0:
                roi_results['Digital'] += (interaction_effect * digital_share) / digital_spend

            if social_spend > 0:
                roi_results['Social_Media'] += (interaction_effect * social_share) / social_spend

    # For Social Media and Influencer interaction
    if 'Social_Influencer_Interaction' in features:
        interaction_feature = 'Social_Influencer_Interaction'
        interaction_coef = coef_dict.get(interaction_feature, 0)

        # Allocate interaction effect proportionally
        social_spend = channel_spend['Social_Media']
        influencer_spend = channel_spend['Influencer']
        total_spend = social_spend + influencer_spend

        if total_spend > 0:
            social_share = social_spend / total_spend
            influencer_share = influencer_spend / total_spend

            interaction_effect = interaction_coef * avg_sales * len(data)

            # Add to ROI
            if social_spend > 0:
                roi_results['Social_Media'] += (interaction_effect * social_share) / social_spend

            if influencer_spend > 0:
                roi_results['Influencer'] += (interaction_effect * influencer_share) / influencer_spend

    return roi_results

def decompose_sales_extended(model_results, data):
    """Decompose sales into contributions from different factors including social & influencer"""
    model = model_results['model']
    model_type = model_results['model_type']
    features = model_results['features']
    scaler = model_results['scaler']

    # Group features by category
    feature_groups = {
        'TV': [f for f in features if ('TV_Spend' in f) and not any(inter in f for inter in ['Interaction'])],
        'Digital': [f for f in features if ('Digital_Spend' in f) and not any(inter in f for inter in ['Interaction'])],
        'Radio': [f for f in features if 'Radio_Spend' in f],
        'Print': [f for f in features if 'Print_Spend' in f],
        'Social_Media': [f for f in features if ('Social_Media_Spend' in f) and not any(inter in f for inter in ['Interaction'])],
        'Influencer': [f for f in features if ('Influencer_Spend' in f) and not any(inter in f for inter in ['Interaction'])],
        'Price': [f for f in features if 'Price' in f],
        'Seasonality': ['Sin_Week', 'Cos_Week'],
        'External Factors': ['Holiday_Factor', 'Competitor_Activity', 'Weather_Factor'],
        'Interactions': [f for f in features if 'Interaction' in f]
    }

    # Prepare data for prediction
    X = data[features]
    X_scaled = scaler.transform(X)

    contrib_dict = {}

    # For statsmodels OLS
    if model_type == 'ols':
        X_with_const = sm.add_constant(X_scaled)

        # Calculate contribution for each feature group
        for group, group_features in feature_groups.items():
            group_contrib = np.zeros(len(data))
            for feature in group_features:
                if feature in model.params.index:
                    feature_idx = list(model.params.index).index(feature)
                    feature_contrib = X_with_const[:, feature_idx] * model.params[feature]
                    group_contrib += feature_contrib
            contrib_dict[group] = group_contrib

        # Add baseline/intercept
        if 'const' in model.params.index:
            contrib_dict['Baseline'] = np.ones(len(data)) * model.params['const']
        else:
            contrib_dict['Baseline'] = np.zeros(len(data))

    # For sklearn models (Ridge, Lasso, ElasticNet)
    else:
        # Calculate contribution for each feature group
        for group, group_features in feature_groups.items():
            group_contrib = np.zeros(len(data))
            for feature in group_features:
                if feature in features:
                    feature_idx = features.index(feature)
                    feature_contrib = X_scaled[:, feature_idx] * model.coef_[feature_idx]
                    group_contrib += feature_contrib
            contrib_dict[group] = group_contrib

        # Add baseline/intercept
        if hasattr(model, 'intercept_'):
            contrib_dict['Baseline'] = np.ones(len(data)) * model.intercept_
        else:
            contrib_dict['Baseline'] = np.zeros(len(data))

    return contrib_dict

# Fix for the KeyError in the simulate_budget_allocation_extended function

def simulate_budget_allocation_extended(model_results, data, budget_total, n_simulations=100, constraints=None):
    """Simulate different budget allocations to optimize sales with social & influencer channels"""
    model = model_results['model']
    model_type = model_results['model_type']
    features = model_results['features']
    feature_selection = model_results['feature_selection']
    scaler = model_results['scaler']

    # Original media spend
    original_spend = {
        'TV': data['TV_Spend'].sum(),
        'Digital': data['Digital_Spend'].sum(),
        'Radio': data['Radio_Spend'].sum(),
        'Print': data['Print_Spend'].sum(),
        'Social_Media': data['Social_Media_Spend'].sum(),
        'Influencer': data['Influencer_Spend'].sum()
    }

    total_original = sum(original_spend.values())

    # Define channels
    channels = ['TV', 'Digital', 'Radio', 'Print', 'Social_Media', 'Influencer']

    # Define corresponding column names in the DataFrame
    # This is the key fix - explicitly mapping channel names to column names
    column_map = {
    'TV': 'TV_Spend',
    'Digital': 'Digital_Spend',
    'Radio': 'Radio_Spend',
    'Print': 'Print_Spend',
    'Social_Media': 'Social_Media_Spend',
    'Influencer': 'Influencer_Spend'
}


    # Apply constraints if provided
    if constraints is None:
        constraints = {}

    min_pct = {channel: constraints.get(f'min_{channel.lower()}', 0.0) for channel in channels}
    max_pct = {channel: constraints.get(f'max_{channel.lower()}', 1.0) for channel in channels}

    simulation_results = []

    # Create random allocations
    for _ in range(n_simulations):
        retry_count = 0
        max_retries = 50

        while retry_count < max_retries:
            # Generate random weights
            weights = np.random.random(len(channels))
            weights = weights / weights.sum()

            # Check if weights satisfy constraints
            valid_allocation = True
            for i, channel in enumerate(channels):
                if weights[i] < min_pct[channel] or weights[i] > max_pct[channel]:
                    valid_allocation = False
                    break

            if valid_allocation:
                break

            retry_count += 1

        # If we couldn't find a valid allocation after max retries, use a simple approach
        if not valid_allocation:
            # Start with minimum allocations
            weights = np.array([min_pct[channel] for channel in channels])
            # Distribute remaining budget proportionally
            remaining = 1.0 - sum(weights)
            if remaining > 0:
                # Calculate the range between min and max for each channel
                ranges = np.array([max_pct[channel] - min_pct[channel] for channel in channels])
                # Normalize the ranges
                if sum(ranges) > 0:
                    normalized_ranges = ranges / sum(ranges)
                    # Distribute remaining budget
                    weights += normalized_ranges * remaining

        # Calculate new budget allocation
        allocation = {}
        for i, channel in enumerate(channels):
            allocation[channel] = budget_total * weights[i]

        # Use this allocation to predict sales
        new_data = data.copy()

        # Apply allocation - Using the column map to get the correct column names
        for channel in channels:
            channel_col = column_map[channel]  # Use the column map here
            channel_scaling = allocation[channel] / original_spend[channel] if original_spend[channel] > 0 else 0
            new_data[channel_col] = data[channel_col] * channel_scaling

        # Recalculate derived variables based on feature selection
        if feature_selection == 'transformed':
            # Process the data with our extended transformation function
            viral_params = {
                'coefficient': 0.1,
                'decay': 0.5
            }
            authenticity_params = {
                'factor': 0.2,
                'max_bonus': 0.5
            }

            new_data = prepare_model_data_extended(
                new_data,
                viral_params=viral_params,
                authenticity_params=authenticity_params
            )
        elif feature_selection == 'log':
            # Just recalculate log variables
            for channel in channels:
                channel_col = column_map[channel]  # Use the column map here
                new_data[f'Log_{channel_col}'] = np.log1p(new_data[channel_col])

        # Prepare features for prediction
        X_new = new_data[features]
        X_new_scaled = scaler.transform(X_new)

        # Predict sales
        if model_type == 'ols':
            X_new_with_const = sm.add_constant(X_new_scaled)
            predicted_sales = model.predict(X_new_with_const)
        else:
            predicted_sales = model.predict(X_new_scaled)

        total_sales = np.sum(predicted_sales)

        # Store results
        result = {
            'allocation': allocation,
            'total_sales': total_sales,
            'allocation_percentages': {k: v/budget_total*100 for k, v in allocation.items()}
        }

        simulation_results.append(result)

    # Sort results by total sales
    simulation_results.sort(key=lambda x: x['total_sales'], reverse=True)

    return simulation_results

# ----- EXTENDED VISUALIZATION FUNCTIONS FOR SOCIAL & INFLUENCER -----

def plot_media_spend_patterns_extended(data):
    """Create a Plotly figure showing media spend patterns over time for all channels"""
    fig = go.Figure()

    # Add traces for traditional channels
    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['TV_Spend'],
        name='TV',
        line=dict(width=2)
    ))

    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['Digital_Spend'],
        name='Digital',
        line=dict(width=2)
    ))

    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['Radio_Spend'],
        name='Radio',
        line=dict(width=2)
    ))

    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['Print_Spend'],
        name='Print',
        line=dict(width=2)
    ))

    # Add traces for new channels
    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['Social_Media_Spend'],
        name='Social Media',
        line=dict(width=2, dash='dot')
    ))

    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['Influencer_Spend'],
        name='Influencer',
        line=dict(width=2, dash='dot')
    ))

    # Update layout
    fig.update_layout(
        title='Media Spend Patterns Over Time - All Channels',
        xaxis_title='Date',
        yaxis_title='Spend ($)',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ),
        template="plotly_white"
    )

    return fig

def plot_roi_comparison_extended(roi_results, title='Channel ROI Comparison'):
    """Create a Plotly figure showing ROI comparison across channels including social & influencer"""
    channels = list(roi_results.keys())
    roi_values = [roi_results[ch] for ch in channels]

    # Define colors based on ROI value
    colors = []
    for roi in roi_values:
        if roi > 2.0:
            colors.append('#1a9850')  # Dark green for very high ROI
        elif roi > 1.5:
            colors.append('#66bd63')  # Green for high ROI
        elif roi > 1.0:
            colors.append('#a6d96a')  # Light green for good ROI
        elif roi > 0.5:
            colors.append('#fee08b')  # Yellow for moderate ROI
        else:
            colors.append('#fdae61')  # Orange for low ROI

    # Create figure
    fig = go.Figure()

    # Add bar chart
    fig.add_trace(go.Bar(
        x=channels,
        y=roi_values,
        marker_color=colors,
        text=[f"{roi:.2f}" for roi in roi_values],
        textposition='outside'
    ))

    # Add horizontal line at ROI = 1.0
    fig.add_shape(
        type='line',
        x0=-0.5,
        y0=1.0,
        x1=len(channels)-0.5,
        y1=1.0,
        line=dict(
            color='red',
            width=2,
            dash='dash'
        )
    )

    # Add annotation for break-even line
    fig.add_annotation(
        x=channels[-1],
        y=1.0,
        text="Break-even (ROI=1.0)",
        showarrow=False,
        yshift=5,
        font=dict(color='red')
    )

    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title='Channel',
        yaxis_title='ROI (Return on Ad Spend)',
        template="plotly_white",
        yaxis=dict(
            range=[0, max(max(roi_values) * 1.1, 1.2)]  # Add some headroom for text
        )
    )

    return fig

def plot_channel_response_curves_extended(data, model_results, max_multiplier=2.0):
    """Create Plotly figures showing response curves for all channels including social & influencer"""
    model = model_results['model']
    model_type = model_results['model_type']
    features = model_results['features']
    feature_selection = model_results['feature_selection']
    scaler = model_results['scaler']

    # All channels including new ones
    channels = ['TV', 'Digital', 'Radio', 'Print', 'Social_Media', 'Influencer']

    # Define column mapping for consistent reference
    column_map = {
        'TV': 'TV_Spend',
        'Digital': 'Digital_Spend',
        'Radio': 'Radio_Spend',
        'Print': 'Print_Spend',
        'Social_Media': 'Social_Media_Spend',
        'Influencer': 'Influencer_Spend'
    }

    # Create a subplot with 3x2 grid
    fig = make_subplots(
        rows=3,
        cols=2,
        subplot_titles=[f"{channel} Response Curve" for channel in channels]
    )

    for i, channel in enumerate(channels):
        row = (i // 2) + 1
        col = (i % 2) + 1

        # Get the appropriate spend column name using the column map
        spend_col = column_map[channel]

        # Create range of spend multipliers
        multipliers = np.linspace(0, max_multiplier, 20)
        predicted_sales = []

        for mult in multipliers:
            # Create a copy of the data with modified spend
            test_data = data.copy()

            # Modify spend for this channel only
            test_data[spend_col] = data[spend_col] * mult

            # Recalculate derived variables based on feature selection
            if feature_selection == 'transformed':
                # Process the data with our extended transformation function
                viral_params = {
                    'coefficient': 0.1,
                    'decay': 0.5
                }
                authenticity_params = {
                    'factor': 0.2,
                    'max_bonus': 0.5
                }

                test_data = prepare_model_data_extended(
                    test_data,
                    viral_params=viral_params,
                    authenticity_params=authenticity_params
                )
            elif feature_selection == 'log':
                # Just recalculate log variables
                test_data[f'Log_{spend_col}'] = np.log1p(test_data[spend_col])

            # Prepare features for prediction
            X_test = test_data[features]
            X_test_scaled = scaler.transform(X_test)

            # Make prediction
            if model_type == 'ols':
                X_test_with_const = sm.add_constant(X_test_scaled)
                y_pred = model.predict(X_test_with_const)
            else:
                y_pred = model.predict(X_test_scaled)

            predicted_sales.append(np.sum(y_pred))

        # Convert to percentage changes
        base_sales = predicted_sales[0]  # Sales with zero spend
        pct_change = [(s - base_sales) / base_sales * 100 if base_sales > 0 else 0 for s in predicted_sales]

        # Calculate average original spend
        avg_spend = data[spend_col].mean()
        total_spend = [avg_spend * m * len(data) for m in multipliers]

        # Add line trace for response curve
        fig.add_trace(
            go.Scatter(
                x=total_spend,
                y=pct_change,
                mode='lines',
                name=f"{channel} Response",
                line=dict(color='blue', width=2),
                showlegend=False
            ),
            row=row,
            col=col
        )

        # Add point for current spend level
        current_idx = 10  # Assuming 20 points and 2.0 max multiplier, 1.0 is at index 10
        fig.add_trace(
            go.Scatter(
                x=[total_spend[current_idx]],
                y=[pct_change[current_idx]],
                mode='markers',
                marker=dict(color='red', size=10),
                name=f"Current {channel} Spend",
                text=f"Current Spend: ${total_spend[current_idx]/1e3:.0f}k",
                showlegend=False
            ),
            row=row,
            col=col
        )

        # Calculate and plot optimal point (where marginal returns start diminishing significantly)
        # Use second derivative approach
        y = np.array(pct_change)
        grad = np.gradient(np.gradient(y))
        # Find where the acceleration drops below threshold
        threshold = 0.1 * np.min(grad) if np.min(grad) < 0 else -0.01  # 10% of min gradient
        optimal_idx = np.where(grad < threshold)[0]
        if len(optimal_idx) > 0:
            optimal_idx = optimal_idx[0]
            if optimal_idx > 0 and optimal_idx < len(multipliers) - 1:  # Ensure we're not picking the endpoints
                fig.add_trace(
                    go.Scatter(
                        x=[total_spend[optimal_idx]],
                        y=[pct_change[optimal_idx]],
                        mode='markers',
                        marker=dict(color='green', size=10),
                        name=f"Optimal {channel} Spend",
                        text=f"Optimal Spend: ${total_spend[optimal_idx]/1e3:.0f}k",
                        showlegend=False
                    ),
                    row=row,
                    col=col
                )

    # Update layout
    fig.update_layout(
        title='Channel Response Curves - Sales Lift vs Spend',
        template="plotly_white",
        height=800,
        width=1000
    )

    # Update x and y-axis titles for all subplots
    for i in range(1, 4):  # rows
        for j in range(1, 3):  # columns
            fig.update_xaxes(title_text="Total Spend ($)", row=i, col=j)
            fig.update_yaxes(title_text="Sales Lift (%)", row=i, col=j)

    return fig

def plot_sales_decomposition_extended(contrib_dict, data, title='Sales Decomposition'):
    """Create a Plotly figure showing sales decomposition over time including social & influencer"""
    fig = go.Figure()

    # Get all contribution components
    components = list(contrib_dict.keys())

    # Create stacked area chart
    for component in components:
        if component != 'Baseline':  # Add everything except baseline first
            fig.add_trace(go.Scatter(
                x=data['Date'],
                y=contrib_dict[component],
                mode='lines',
                stackgroup='one',
                name=component,
                hoverinfo='x+y+name'
            ))

    # Add baseline as a line beneath everything
    if 'Baseline' in components:
        fig.add_trace(go.Scatter(
            x=data['Date'],
            y=contrib_dict['Baseline'],
            mode='lines',
            name='Baseline',
            line=dict(color='black', width=2),
            hoverinfo='x+y+name'
        ))

    # Add actual sales
    fig.add_trace(go.Scatter(
        x=data['Date'],
        y=data['Sales'],
        mode='lines',
        name='Actual Sales',
        line=dict(color='red', width=2),
        hoverinfo='x+y+name'
    ))

    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title='Date',
        yaxis_title='Sales Contribution',
        template="plotly_white",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    return fig

def plot_budget_optimization_results(simulation_results, top_n=5):
    """Create a Plotly figure showing top budget allocation strategies"""
    # Get top N results
    top_results = simulation_results[:top_n]

    # Extract channels
    channels = list(top_results[0]['allocation_percentages'].keys())

    # Create figure
    fig = go.Figure()

    # Add current strategy for comparison (assuming first simulation is current)
    current_allocations = list(top_results[0]['allocation_percentages'].values())
    current_sales = top_results[0]['total_sales']

    bar_width = 0.15

    # Add a bar for each top result
    for i, result in enumerate(top_results):
        allocations = [result['allocation_percentages'][ch] for ch in channels]
        sales_increase = (result['total_sales'] - current_sales) / current_sales * 100

        fig.add_trace(go.Bar(
            x=channels,
            y=allocations,
            name=f"Strategy {i+1}: +{sales_increase:.1f}%",
            width=bar_width,
            offset=(i - len(top_results)/2 + 0.5) * bar_width
        ))

    # Update layout
    fig.update_layout(
        title='Top Budget Allocation Strategies',
        xaxis_title='Channel',
        yaxis_title='Budget Allocation (%)',
        barmode='group',
        template="plotly_white",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    return fig

# ----- MAIN EXECUTION FUNCTION -----

def run_mmm_analysis(periods=104, train_model=True, visualize=True):
    """Run the complete MMM analysis pipeline with social and influencer channels"""
    print("Generating MMM data...")
    data = generate_mmm_data_extended(periods=periods)

    if train_model:
        print("Processing data for modeling...")
        processed_data = prepare_model_data_extended(data)

        print("Training model...")
        model_results = train_model_extended(
            processed_data,
            model_type='ridge',
            test_size=26,
            feature_selection='transformed',
            alpha=100.0
        )

        print(f"Model Training Results:")
        print(f"Training R²: {model_results['train_r2']:.4f}")
        print(f"Test R²: {model_results['test_r2']:.4f}")
        print(f"Training RMSE: {model_results['train_rmse']:.2f}")
        print(f"Test RMSE: {model_results['test_rmse']:.2f}")

        # Calculate ROI
        print("Calculating ROI...")
        roi_results = calculate_roi_extended(model_results, processed_data)
        for channel, roi in roi_results.items():
            print(f"{channel} ROI: {roi:.2f}")

        # Decompose sales
        print("Decomposing sales...")
        contrib_dict = decompose_sales_extended(model_results, processed_data)

        # Simulate budget allocation
        print("Simulating budget allocations...")
        total_budget = (data['TV_Spend'].sum() + data['Digital_Spend'].sum() +
                        data['Radio_Spend'].sum() + data['Print_Spend'].sum() +
                        data['Social_Media_Spend'].sum() + data['Influencer_Spend'].sum())

        simulation_results = simulate_budget_allocation_extended(
            model_results,
            processed_data,
            budget_total=total_budget,
            n_simulations=100,
            constraints={
                'min_tv': 0.1,
                'max_tv': 0.5,
                'min_digital': 0.15,
                'max_digital': 0.6,
                'min_social_media': 0.1,
                'max_social_media': 0.4,
                'min_influencer': 0.05,
                'max_influencer': 0.3
            }
        )

        print("Top budget allocation strategy:")
        best_result = simulation_results[0]
        for channel, allocation in best_result['allocation_percentages'].items():
            print(f"{channel}: {allocation:.1f}%")
        print(f"Predicted Sales: {best_result['total_sales']:.2f}")

        if visualize:
            print("Creating visualizations...")
            # Create figures
            spend_fig = plot_media_spend_patterns_extended(data)
            roi_fig = plot_roi_comparison_extended(roi_results)
            response_curves_fig = plot_channel_response_curves_extended(processed_data, model_results)
            decomp_fig = plot_sales_decomposition_extended(contrib_dict, processed_data)
            budget_fig = plot_budget_optimization_results(simulation_results)

            # Show figures
            print("Visualization complete. Figures available in variables:")
            print("- spend_fig: Media spend patterns")
            print("- roi_fig: ROI comparison")
            print("- response_curves_fig: Channel response curves")
            print("- decomp_fig: Sales decomposition")
            print("- budget_fig: Budget optimization")

            # Return everything
            return {
                'data': data,
                'processed_data': processed_data,
                'model_results': model_results,
                'roi_results': roi_results,
                'contrib_dict': contrib_dict,
                'simulation_results': simulation_results,
                'figures': {
                    'spend_fig': spend_fig,
                    'roi_fig': roi_fig,
                    'response_curves_fig': response_curves_fig,
                    'decomp_fig': decomp_fig,
                    'budget_fig': budget_fig
                }
            }

        # Return results without figures
        return {
            'data': data,
            'processed_data': processed_data,
            'model_results': model_results,
            'roi_results': roi_results,
            'contrib_dict': contrib_dict,
            'simulation_results': simulation_results
        }

    # Just return the data if not training model
    return {'data': data}

# Example usage
if __name__ == "__main__":
    # Run the complete analysis
    results = run_mmm_analysis(periods=104, train_model=True, visualize=True)

    # To display a particular visualization (in a Jupyter notebook)
    results['figures']['roi_fig'].show()

    # To save figures to files
    for name, fig in results['figures'].items():fig.write_html(f"{name}.html")

    # Extract insights from the results
    data = results['data']
    model_results = results['model_results']
    roi_results = results['roi_results']
    simulation_results = results['simulation_results']

    # Print summary report
    print("\n===== MARKETING MIX MODEL SUMMARY REPORT =====")
    print("\nMODEL PERFORMANCE")
    print(f"Training R²: {model_results['train_r2']:.4f}")
    print(f"Test R²: {model_results['test_r2']:.4f}")

    print("\nCHANNEL ROI")
    channels_by_roi = sorted(roi_results.items(), key=lambda x: x[1], reverse=True)
    for channel, roi in channels_by_roi:
        print(f"{channel}: {roi:.2f}")

    print("\nBUDGET OPTIMIZATION")
    best_allocation = simulation_results[0]['allocation_percentages']
    current_allocation = {
        'TV': data['TV_Spend'].sum(),
        'Digital': data['Digital_Spend'].sum(),
        'Radio': data['Radio_Spend'].sum(),
        'Print': data['Print_Spend'].sum(),
        'Social_Media': data['Social_Media_Spend'].sum(),
        'Influencer': data['Influencer_Spend'].sum()
    }
    total_spend = sum(current_allocation.values())
    current_allocation = {k: v/total_spend*100 for k, v in current_allocation.items()}

    print("Current allocation vs. Optimized allocation:")
    for channel in sorted(best_allocation.keys()):
        print(f"{channel}: {current_allocation[channel]:.1f}% -> {best_allocation[channel]:.1f}% " +
              f"({'↑' if best_allocation[channel] > current_allocation[channel] else '↓'})")

    print("\nKEY INSIGHTS:")
    # Identify highest ROI channel
    best_channel = max(roi_results.items(), key=lambda x: x[1])[0]
    print(f"1. {best_channel} has the highest ROI at {roi_results[best_channel]:.2f}")

    # Compare digital and traditional channels
    digital_roi = roi_results['Digital']
    social_roi = roi_results['Social_Media']
    influencer_roi = roi_results['Influencer']
    tv_roi = roi_results['TV']
    print(f"2. Digital channels (Digital: {digital_roi:.2f}, Social: {social_roi:.2f}, " +
          f"Influencer: {influencer_roi:.2f}) vs. TV: {tv_roi:.2f}")

    # Suggest channel shifts based on optimal allocation
    increases = []
    decreases = []
    for channel in best_allocation:
        diff = best_allocation[channel] - current_allocation[channel]
        if diff > 5:
            increases.append(f"{channel} (+{diff:.1f}%)")
        elif diff < -5:
            decreases.append(f"{channel} ({diff:.1f}%)")

    if increases:
        print(f"3. Recommended budget increases: {', '.join(increases)}")
    if decreases:
        print(f"4. Recommended budget decreases: {', '.join(decreases)}")


 # Modify the prepare_model_data_extended function to handle potential NaN values
# Add this code to your script to ensure figures are saved to files

def save_mmm_visualizations(results, output_dir="mmm_visualizations"):
    """
    Save all MMM visualizations to HTML files in the specified directory

    Parameters:
    -----------
    results : dict
        The results dictionary returned by run_mmm_analysis
    output_dir : str
        Directory to save the visualizations (will be created if it doesn't exist)
    """
    import os
    from IPython.display import display, HTML

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Check if figures are in the results
    if 'figures' not in results:
        print("No figures found in results dictionary.")
        return

    # Save each figure to a file
    for name, fig in results['figures'].items():
        file_path = os.path.join(output_dir, f"{name}.html")
        try:
            fig.write_html(file_path)
            print(f"Saved {name} to {file_path}")

            # For Jupyter notebooks, display a link to the file
            try:
                display(HTML(f'<a href="{file_path}" target="_blank">View {name}</a>'))
            except:
                pass
        except Exception as e:
            print(f"Error saving {name}: {str(e)}")

    print(f"\nAll figures saved to {output_dir} directory.")
    return

# If you're having issues with figures not being generated, try this function to
# manually create and save the key visualizations
def create_and_save_visualizations(data, model_results, roi_results, simulation_results, contrib_dict=None):
    """
    Manually create and save visualizations from MMM results

    Parameters:
    -----------
    data : pandas DataFrame
        The original or processed data
    model_results : dict
        Model results dictionary from train_model_extended
    roi_results : dict
        ROI results dictionary from calculate_roi_extended
    simulation_results : list
        Budget simulation results from simulate_budget_allocation_extended
    contrib_dict : dict, optional
        Sales decomposition dictionary from decompose_sales_extended
    """
    import os
    output_dir = "mmm_visualizations"
    os.makedirs(output_dir, exist_ok=True)

    # 1. Media spend patterns
    print("Creating media spend patterns visualization...")
    try:
        spend_fig = plot_media_spend_patterns_extended(data)
        spend_fig.write_html(os.path.join(output_dir, "spend_fig.html"))
        print("Saved media spend patterns visualization.")
    except Exception as e:
        print(f"Error creating media spend patterns visualization: {str(e)}")

    # 2. ROI comparison
    print("Creating ROI comparison visualization...")
    try:
        roi_fig = plot_roi_comparison_extended(roi_results)
        roi_fig.write_html(os.path.join(output_dir, "roi_fig.html"))
        print("Saved ROI comparison visualization.")
    except Exception as e:
        print(f"Error creating ROI comparison visualization: {str(e)}")

    # 3. Channel response curves - this may be slow
    print("Creating channel response curves visualization (this may take a moment)...")
    try:
        response_curves_fig = plot_channel_response_curves_extended(data, model_results)
        response_curves_fig.write_html(os.path.join(output_dir, "response_curves_fig.html"))
        print("Saved channel response curves visualization.")
    except Exception as e:
        print(f"Error creating channel response curves visualization: {str(e)}")

    # 4. Sales decomposition - only if contrib_dict is provided
    if contrib_dict is not None:
        print("Creating sales decomposition visualization...")
        try:
            decomp_fig = plot_sales_decomposition_extended(contrib_dict, data)
            decomp_fig.write_html(os.path.join(output_dir, "decomp_fig.html"))
            print("Saved sales decomposition visualization.")
        except Exception as e:
            print(f"Error creating sales decomposition visualization: {str(e)}")

    # 5. Budget optimization
    print("Creating budget optimization visualization...")
    try:
        budget_fig = plot_budget_optimization_results(simulation_results)
        budget_fig.write_html(os.path.join(output_dir, "budget_fig.html"))
        print("Saved budget optimization visualization.")
    except Exception as e:
        print(f"Error creating budget optimization visualization: {str(e)}")

    print(f"\nAll available visualizations have been saved to {output_dir} directory.")

# Example usage (add this to your code):
"""
# After running your analysis
results = run_mmm_analysis(periods=104, train_model=True, visualize=True)

# Save all visualizations
save_mmm_visualizations(results)

# Alternatively, create and save visualizations manually
create_and_save_visualizations(
    results['data'],
    results['model_results'],
    results['roi_results'],
    results['simulation_results'],
    results.get('contrib_dict')
)
"""
print("\n===== END OF REPORT =====")

Generating MMM data...
Processing data for modeling...
Training model...
Model Training Results:
Training R²: 0.6847
Test R²: 0.2708
Training RMSE: 673437.97
Test RMSE: 1232405.80
Calculating ROI...
TV ROI: 25063439.92
Digital ROI: 25279680.66
Radio ROI: 46769705.70
Print ROI: 0.00
Social_Media ROI: 60104995.34
Influencer ROI: 73617303.93
Decomposing sales...
Simulating budget allocations...
Top budget allocation strategy:
TV: 10.8%
Digital: 28.9%
Radio: 11.8%
Print: 4.3%
Social_Media: 35.5%
Influencer: 8.7%
Predicted Sales: 435945722.21
Creating visualizations...
Visualization complete. Figures available in variables:
- spend_fig: Media spend patterns
- roi_fig: ROI comparison
- response_curves_fig: Channel response curves
- decomp_fig: Sales decomposition
- budget_fig: Budget optimization



===== MARKETING MIX MODEL SUMMARY REPORT =====

MODEL PERFORMANCE
Training R²: 0.6847
Test R²: 0.2708

CHANNEL ROI
Influencer: 73617303.93
Social_Media: 60104995.34
Radio: 46769705.70
Digital: 25279680.66
TV: 25063439.92
Print: 0.00

BUDGET OPTIMIZATION
Current allocation vs. Optimized allocation:
Digital: 20.7% -> 28.9% (↑)
Influencer: 5.4% -> 8.7% (↑)
Print: 5.8% -> 4.3% (↓)
Radio: 11.8% -> 11.8% (↓)
Social_Media: 18.1% -> 35.5% (↑)
TV: 38.2% -> 10.8% (↓)

KEY INSIGHTS:
1. Influencer has the highest ROI at 73617303.93
2. Digital channels (Digital: 25279680.66, Social: 60104995.34, Influencer: 73617303.93) vs. TV: 25063439.92
3. Recommended budget increases: Digital (+8.2%), Social_Media (+17.3%)
4. Recommended budget decreases: TV (-27.4%)

===== END OF REPORT =====


"\n# After running your analysis\nresults = run_mmm_analysis(periods=104, train_model=True, visualize=True)\n\n# Save all visualizations\nsave_mmm_visualizations(results)\n\n# Alternatively, create and save visualizations manually\ncreate_and_save_visualizations(\n    results['data'], \n    results['model_results'], \n    results['roi_results'], \n    results['simulation_results'],\n    results.get('contrib_dict')\n)\n"