This notebook accompanies cl/723634204, which implements contribution priors for paid/organic/non-media treatments. It also re-writes much of the code for paid media ROI and mROI priors, so this also needs to be tested.

The tests in this notebook are meant to be converted to unit tests eventually.

# Imports

In [None]:
import importlib
import arviz as az
from colabtools import adhoc_import
import numpy as np
from absl.testing import parameterized
import tensorflow as tf
from google3.pyglib import gfile
# with adhoc_import.Google3Head():  # import from head
with adhoc_import.Google3CitcClient("fig", "lukmaz"):
  import meridian
  from meridian import data
  from meridian import model

  meridian.constants = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.constants"
  )
  data.input_data = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.data.input_data"
  )
  model.spec = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.model.spec"
  )
  model.prior_distribution = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.model.prior_distribution"
  )
  model.media = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.model.media"
  )
  model.prior_sampler = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.model.prior_sampler"
  )
  model.posterior_sampler = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.model.posterior_sampler"
  )
  model.model = importlib.import_module(
      "google3.ads.lift.mmm.modeling.research.contr_priors.model.model"
  )
  from meridian import constants
  from meridian.data import test_utils
  from meridian.data import input_data
  from meridian.model import media
  from meridian.model import prior_sampler
  from meridian.model import posterior_sampler
  from meridian.model import model
  from meridian.model import prior_distribution
  from meridian.model import spec
  from meridian.analysis import analyzer
  constants = importlib.reload(constants)
  test_utils = importlib.reload(test_utils)
  input_data = importlib.reload(input_data)
  media = importlib.reload(media)
  prior_sampler = importlib.reload(prior_sampler)
  posterior_sampler = importlib.reload(posterior_sampler)
  model = importlib.reload(model)
  spec = importlib.reload(spec)
  prior_distribution = importlib.reload(prior_distribution)
  analyzer = importlib.reload(analyzer)

In [None]:
# Ideally, we would call Meridian.sample_posterior() to test posterior draws.
# However, this is not practical due to the long computation time required.
# Instead we take prior samples using joint_dist_unpinned(), which is the same
# function used to construct the log-likelihood for posterior sampling - if
# these draws are accurate, then posterior draws should also be accurate.
def sample_joint_dist_unpinned_as_posterior(self, n_draws):
  prior_draws = self.posterior_sampler_callable._get_joint_dist_unpinned().sample([1, n_draws])._asdict()
  prior_draws = {
      k: v
      for k, v in prior_draws.items()
      if not k in constants.UNSAVED_PARAMETERS
  }
  # Create Arviz InferenceData for posterior draws.
  posterior_coords = self.create_inference_data_coords(1, n_draws)
  posterior_dims = self.create_inference_data_dims()
  infdata_posterior = az.convert_to_inference_data(
      prior_draws, coords=posterior_coords, dims=posterior_dims
  )

  self.inference_data.extend(infdata_posterior, join="right")

model.Meridian.sample_joint_dist_unpinned_as_posterior = sample_joint_dist_unpinned_as_posterior

In [None]:
def check_treatment_parameters(mmm, use_posterior, rtol=1e-3, atol=1e-3):
  infdata = (
      mmm.inference_data.posterior if use_posterior
      else mmm.inference_data.prior
  )
  # Calculate total_outcome from input_data instead of using mmm.total_outcome.
  total_outcome = np.sum(mmm.input_data.kpi * mmm.input_data.revenue_per_kpi)
  aaa = analyzer.Analyzer(mmm)
  paid_media_prior_type = mmm.model_spec.paid_media_prior_type
  n_m, n_rf, n_om, n_orf = (
      mmm.n_media_channels,
      mmm.n_rf_channels,
      mmm.n_organic_media_channels,
      mmm.n_organic_rf_channels,
  )
  incremental_outcome = aaa.incremental_outcome(
      include_non_paid_channels=True,
      use_posterior=use_posterior
  )
  ii_m = incremental_outcome[:, :, :n_m]
  ii_rf = incremental_outcome[:, :, n_m:(n_m+n_rf)]
  ii_om = incremental_outcome[:, :, (n_m+n_rf):(n_m+n_rf+n_om)]
  ii_orf = incremental_outcome[:, :, (n_m+n_rf+n_om):(n_m+n_rf+n_om+n_orf)]
  ii_n = incremental_outcome[:, :, (n_m+n_rf+n_om+n_orf):]
  calculated_om = ii_om / total_outcome
  calculated_orf = ii_orf / total_outcome
  calculated_n = ii_n / total_outcome
  param_om = infdata.contribution_om
  param_orf = infdata.contribution_orf
  param_n = infdata.contribution_n
  if paid_media_prior_type == "roi":
    param_m = infdata.roi_m
    param_rf = infdata.roi_rf
    if mmm.model_spec.roi_calibration_period is None:
      roi = aaa.roi(use_posterior=use_posterior)
      calculated_m = roi[:, :, :n_m]
    else:
      calculated_m = np.zeros_like(param_m)
      for i in range(n_m):
        times = mmm.model_spec.roi_calibration_period[:, i]
        ii = aaa.incremental_outcome(
            media_selected_times=times.tolist(),
            use_posterior=use_posterior,
        )
        spend = np.einsum(
            "gtm,t->m",
            mmm.input_data.media_spend,
            times[-mmm.n_times:],
        )
        calculated_m[:, :, i] = ii[:, :, i] / spend[i]
      calculated_m = tf.convert_to_tensor(calculated_m)
    if mmm.model_spec.rf_roi_calibration_period is None:
      roi = aaa.roi(use_posterior=use_posterior)
      calculated_rf = roi[:, :, n_m:]
    else:
      calculated_rf = np.zeros_like(param_rf)
      for i in range(n_rf):
        times = mmm.model_spec.rf_roi_calibration_period[:, i]
        ii = aaa.incremental_outcome(
            media_selected_times=times.tolist(),
            use_posterior=use_posterior,
        )
        spend = np.einsum(
            "gtm,t->m",
            mmm.input_data.rf_spend,
            times[-mmm.n_times:],
        )
        calculated_rf[:, :, i] = ii[:, :, n_m + i] / spend[i]
      calculated_rf = tf.convert_to_tensor(calculated_rf)
  elif paid_media_prior_type == "mroi":
    mroi = aaa.marginal_roi(use_posterior=use_posterior)
    calculated_m = mroi[:, :, :n_m]
    calculated_rf = mroi[:, :, n_m:]
    param_m = infdata.mroi_m
    param_rf = infdata.mroi_rf
  elif paid_media_prior_type == "contribution":
    calculated_m = ii_m / total_outcome
    calculated_rf = ii_rf / total_outcome
    param_m = infdata.contribution_m
    param_rf = infdata.contribution_rf
  tf.debugging.assert_near(calculated_m, param_m, rtol=rtol, atol=atol)
  tf.debugging.assert_near(calculated_rf, param_rf, rtol=rtol, atol=atol)
  tf.debugging.assert_near(calculated_om, param_om, rtol=rtol, atol=atol)
  tf.debugging.assert_near(calculated_orf, param_orf, rtol=rtol, atol=atol)
  tf.debugging.assert_near(calculated_n, param_n, rtol=rtol, atol=atol)

In [None]:
class TestTreatmentParameterAccuracy(parameterized.TestCase):
  @parameterized.product(
      n_channels_per_treatment=[1, 2],
      paid_media_prior_type=["roi", "mroi", "contribution"],
      roi_calibration_times=[None, [5, 6, 7]],
      rf_roi_calibration_times=[None, [5, 6, 7]],
  )
  def test_treatment_parameter_accuracy(
      self,
      n_channels_per_treatment,
      paid_media_prior_type,
      roi_calibration_times,
      rf_roi_calibration_times,
  ):
    input_data = test_utils.sample_input_data_non_revenue_revenue_per_kpi(
        n_geos=3,
        n_times=10,
        n_media_times=15,
        n_controls=1,
        n_media_channels=n_channels_per_treatment,
        n_rf_channels=n_channels_per_treatment,
        n_organic_media_channels=n_channels_per_treatment,
        n_organic_rf_channels=n_channels_per_treatment,
        n_non_media_channels=n_channels_per_treatment,
        seed=1,
    )

    # Scale each channel's spend to be between 4-6% of total revenue. (Otherwise spend
    # values can be so small that they cause numerical inaccuracies with ROI priors.)
    total_outcome = np.sum(
        input_data.kpi.values * input_data.revenue_per_kpi.values
    )
    total_spend_m = np.sum(input_data.media_spend.values, (0, 1))
    total_spend_rf = np.sum(input_data.rf_spend.values, (0, 1))
    n_m = len(input_data.media_channel)
    n_rf = len(input_data.rf_channel)
    media_pcts = np.linspace(0.04, 0.06, n_m)
    input_data.media_spend *= media_pcts * total_outcome / total_spend_m
    rf_pcts = np.linspace(0.04, 0.06, n_rf)
    input_data.rf_spend *= rf_pcts * total_outcome / total_spend_rf

    # Set `roi_calibration_period` and assert error if paid_media_prior_type != "roi".
    if roi_calibration_times is None:
      roi_calibration_period = None
    else:
      n_media_times = len(input_data.media_time)
      n_media_channels = len(input_data.media_channel)
      roi_calibration_period = np.full([n_media_times, n_media_channels], False)
      for time in roi_calibration_times:
        roi_calibration_period[time, :] = True
      if paid_media_prior_type != "roi":
        with self.assertRaisesRegex(
            ValueError, "The `roi_calibration_period`"
        ):
          spec.ModelSpec(
              paid_media_prior_type=paid_media_prior_type,
              roi_calibration_period=roi_calibration_period,
          )
        return

    # Set `rf_roi_calibration_period` and assert error if paid_media_prior_type != "roi".
    if rf_roi_calibration_times is None:
      rf_roi_calibration_period = None
    else:
      n_media_times = len(input_data.media_time)
      n_rf_channels = len(input_data.rf_channel)
      rf_roi_calibration_period = np.full([n_media_times, n_rf_channels], False)
      for time in rf_roi_calibration_times:
        rf_roi_calibration_period[time, :] = True
      if paid_media_prior_type != "roi":
        with self.assertRaisesRegex(
            ValueError, "The `rf_roi_calibration_period`"
        ):
          spec.ModelSpec(
              paid_media_prior_type=paid_media_prior_type,
              rf_roi_calibration_period=rf_roi_calibration_period,
          )
        return

    model_spec = spec.ModelSpec(
        paid_media_prior_type=paid_media_prior_type,
        roi_calibration_period=roi_calibration_period,
        rf_roi_calibration_period=rf_roi_calibration_period,
    )
    mmm = model.Meridian(input_data=input_data, model_spec=model_spec)
    mmm.sample_prior(100)
    mmm.sample_joint_dist_unpinned_as_posterior(100)
    mmm.inference_data

    check_treatment_parameters(mmm, use_posterior=False)
    check_treatment_parameters(mmm, use_posterior=True)

In [None]:
test_treatment_parameter_accuracy = TestTreatmentParameterAccuracy()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy0()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy1()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy2()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy3()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy4()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy5()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy6()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy7()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy8()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy9()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy10()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy11()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy12()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy13()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy14()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy15()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy16()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy17()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy18()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy19()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy20()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy21()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy22()
test_treatment_parameter_accuracy.test_treatment_parameter_accuracy23()