In [1]:
# First would be to install lightweight_mmm
!pip install --upgrade git+https://github.com/google/lightweight_mmm.git
# !pip uninstall -y matplotlib
# !pip install matplotlib==3.1.3

[0mCollecting git+https://github.com/google/lightweight_mmm.git
  Cloning https://github.com/google/lightweight_mmm.git to /private/var/folders/zy/0yc78h6x5xq7w_xzp3jlfvzr0000gn/T/pip-req-build-0prkq5ca
  Running command git clone --filter=blob:none --quiet https://github.com/google/lightweight_mmm.git /private/var/folders/zy/0yc78h6x5xq7w_xzp3jlfvzr0000gn/T/pip-req-build-0prkq5ca
  Resolved https://github.com/google/lightweight_mmm.git to commit 4406aaa77bddc5b0d73d31c6cf4f2ace03f3ffda
  Preparing metadata (setup.py) ... [?25ldone
[0m

In [2]:
!pip install jax jaxlib

[0m

In [3]:
# Import jax.numpy and any other library we might need.
import jax.numpy as jnp
import numpyro
import pandas as pd

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

In [None]:
# Import the relevant modules of the library
from lightweight_mmm import lightweight_mmm
from lightweight_mmm import optimize_media
from lightweight_mmm import plot
from lightweight_mmm import preprocessing
from lightweight_mmm import utils

## Organising the data for modelling

In [None]:
csv="/content/bike_sales_data.csv"
df=pd.read_csv(csv) #, index_col=0)
df



In [None]:

media_data = df[['branded_search_spend', 'nonbranded_search_spend','facebook_spend', 'print_spend', 'ooh_spend','tv_spend', 'radio_spend']].to_numpy()
target = df[['sales']].to_numpy()
costs = df[['branded_search_spend', 'nonbranded_search_spend','facebook_spend', 'print_spend', 'ooh_spend','tv_spend', 'radio_spend']].sum().to_numpy()

In [None]:
media_data.shape

In [None]:
data_size = media_data.shape[0]

In [None]:
# Split and scale data.
split_point = data_size - 30
# Media data
media_data_train = media_data[:split_point, ...]
media_data_test = media_data[split_point:, ...]
# Target
target_train = target[:split_point].reshape(-1)

In [None]:
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

media_data_train = media_scaler.fit_transform(media_data_train)
target_train = target_scaler.fit_transform(target_train)
costs2 = cost_scaler.fit_transform(costs)

In [None]:
mmm = lightweight_mmm.LightweightMMM(model_name="carryover") #hill-adstock or adstock

In [None]:
number_warmup=100
number_samples=100


In [None]:

mmm.fit(
    media=media_data_train,
    media_prior=costs2,
    target=target_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    number_chains=1,
    )

In [None]:
mmm.print_summary()

In [None]:
plot.plot_media_channel_posteriors(media_mix_model=mmm)

In [None]:
plot.plot_model_fit(mmm, target_scaler=target_scaler)

In [None]:
# We have to scale the test media data if we have not done so before.
new_predictions = mmm.predict(media=media_scaler.transform(media_data_test))
new_predictions.shape

In [None]:
plot.plot_out_of_sample_model_fit(out_of_sample_predictions=new_predictions,
                                 out_of_sample_target=target_scaler.transform(target[split_point:].squeeze()))

### Media insights

In [None]:
media_contribution, roi_hat = mmm.get_posterior_metrics(target_scaler=target_scaler, cost_scaler=cost_scaler)

In [None]:
from matplotlib import pyplot as plt
import numpy as np

def custom_plot_media_baseline_contribution_area_plot(
        media_mix_model,
        target_scaler=None,
        channel_names=None,
        fig_size = (20, 7)):
      """Plots an area chart to visualize weekly media & baseline contribution.

      Args:
        media_mix_model: Media mix model.
        target_scaler: Scaler used for scaling the target.
        channel_names: Names of media channels.
        fig_size: Size of the figure to plot as used by matplotlib.

      Returns:
        Stacked area chart of weekly baseline & media contribution.
      """
      # Create media channels & baseline contribution dataframe.
      contribution_df = plot.create_media_baseline_contribution_df(
          media_mix_model=media_mix_model,
          target_scaler=target_scaler,
          channel_names=channel_names)
      contribution_df = contribution_df.clip(0)

      # Create contribution dataframe for the plot.
      contribution_columns = [
          col for col in contribution_df.columns if "contribution" in col
      ]
      contribution_df_for_plot = contribution_df.loc[:, contribution_columns]
      contribution_df_for_plot = contribution_df_for_plot[
          contribution_df_for_plot.columns[::-1]]
      period = np.arange(1, contribution_df_for_plot.shape[0] + 1)
      contribution_df_for_plot.loc[:, "period"] = period

      # Plot the stacked area chart.
      fig, ax = plt.subplots()
      contribution_df_for_plot.plot.area(
          x="period", stacked=True, figsize=fig_size, ax=ax)
      ax.set_title("Attribution Over Time", fontsize="x-large")
      ax.tick_params(axis="y")
      ax.set_ylabel("Baseline & Media Chanels Attribution")
      ax.set_xlabel("Period")
      ax.set_xlim(1, contribution_df_for_plot["period"].max())
      ax.set_xticks(contribution_df_for_plot["period"])
      ax.set_xticklabels(contribution_df_for_plot["period"])
      for tick in ax.get_xticklabels():
        tick.set_rotation(45)
      plt.close()
      return fig


In [None]:
custom_plot_media_baseline_contribution_area_plot(media_mix_model=mmm,
                                                target_scaler=target_scaler,
                                                fig_size=(30,10))

In [None]:
plot.plot_bars_media_metrics(metric=media_contribution, metric_name="Media Contribution Percentage")

In [None]:
plot.plot_bars_media_metrics(metric=roi_hat, metric_name="ROI hat")

In [None]:
plot.plot_response_curves(
    media_mix_model=mmm, target_scaler=target_scaler)

# Optimization

In [None]:
prices = jnp.ones(mmm.n_media_channels)

In [None]:
n_time_periods = 10
budget = jnp.sum(jnp.dot(prices, media_data.mean(axis=0)))* n_time_periods

In [None]:
# Run optimization with the parameters of choice.
solution, kpi_without_optim, previous_budget_allocation = optimize_media.find_optimal_budgets(
    n_time_periods=n_time_periods,
    media_mix_model=mmm,
    budget=budget,
    prices=prices,
    media_scaler=media_scaler,
    target_scaler=target_scaler)

In [None]:
# Obtain the optimal weekly allocation.
optimal_buget_allocation = prices * solution.x
optimal_buget_allocation

## We can plot the following:
1. Pre post optimization budget allocation comparison for each channel
2. Pre post optimization predicted target variable comparison

In [None]:
# Plot out pre post optimization budget allocation and predicted target variable comparison.
plot.plot_pre_post_budget_allocation_comparison(media_mix_model=mmm,
                                                kpi_with_optim=solution['fun'],
                                                kpi_without_optim=kpi_without_optim,
                                                optimal_buget_allocation=optimal_buget_allocation,
                                                previous_budget_allocation=previous_budget_allocation,
                                                figure_size=(10,10))