<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/meridian/blob/main/demo/Meridian_TFP_on_JAX_Pilot.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/meridian/blob/main/demo/Meridian_TFP_on_JAX_Pilot.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# **Meridian TFP-on-JAX Pilot**

This notebook demonstrates how to implement and use JAX/TFP-based versions of the Adstock and Hill functions, which are core components of Media Mix Models (MMMs). 

The Adstock function models the lagged carryover effects of media advertising, while the Hill function models the diminishing returns or saturation effect.

We will use the standard sample dataset provided by the Meridian library (`google-meridian`) to illustrate these concepts in a TFP-on-JAX context.

## Step 0: Install necessary libraries

In [None]:
# Install JAX, TFP nightly, and Meridian for data loading
!pip install -qU jax jaxlib
!pip install -qU tfp-nightly
!pip install -qU google-meridian

## Step 1: Imports

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import tensorflow_probability as tfp
tfp = tfp.experimental.substrates.jax
tfd = tfp.distributions

# Import Meridian's data loading utility
from meridian.data import load

## Step 2: Load the data

Define mappings from column names to variable types and channel names.

In [None]:
# These mappings are needed by the Meridian data loader
coord_to_columns = load.CoordToColumns(
    time='time',
    geo='geo',
    controls=['GQV', 'Competitor_Sales'],
    population='population',
    kpi='conversions',
    revenue_per_kpi='revenue_per_conversion',
    media=[
        'Channel0_impression',
        'Channel1_impression',
        'Channel2_impression',
        'Channel3_impression',
        'Channel4_impression',
    ],
    media_spend=[
        'Channel0_spend',
        'Channel1_spend',
        'Channel2_spend',
        'Channel3_spend',
        'Channel4_spend',
    ],
    organic_media=['Organic_channel0_impression'],
    non_media_treatments=['Promo'],
)

correct_media_to_channel = {
    'Channel0_impression': 'Channel_0',
    'Channel1_impression': 'Channel_1',
    'Channel2_impression': 'Channel_2',
    'Channel3_impression': 'Channel_3',
    'Channel4_impression': 'Channel_4',
}
correct_media_spend_to_channel = {
    'Channel0_spend': 'Channel_0',
    'Channel1_spend': 'Channel_1',
    'Channel2_spend': 'Channel_2',
    'Channel3_spend': 'Channel_3',
    'Channel4_spend': 'Channel_4',
}

In [None]:
# Load the sample data using Meridian's CsvDataLoader
loader = load.CsvDataLoader(
    csv_path="https://raw.githubusercontent.com/google/meridian/refs/heads/main/meridian/data/simulated_data/csv/geo_all_channels.csv",
    kpi_type='non_revenue',
    coord_to_columns=coord_to_columns,
    media_to_channel=correct_media_to_channel,
    media_spend_to_channel=correct_media_spend_to_channel,
)
data = loader.load()

# Display the structure of the loaded data (optional)
# print(data)

In [None]:
# Extract media impression data as a JAX array
# The loaded 'data' is an xarray Dataset. We access the 'media' DataArray.
media_data_jax = jnp.asarray(data['media'].values)

# Print the shape (expected: Time x Geo x Media Channel)
print("Shape of media data (Time x Geo x Media Channel):", media_data_jax.shape)

## Step 3: Define Adstock Function (TFP-on-JAX)

The Adstock transformation models the carryover effect of advertising, where the impact of media spend decays over time. We implement a geometric decay adstock using JAX's convolution capabilities.

In [None]:
def adstock_jax(media, alpha, max_lag, n_times_output):
  """Applies geometric adstock transformation using JAX convolution.

  Args:
    media: JAX array with shape [..., n_geos, n_media_times, n_channels].
      Media values over time for different geos and channels.
    alpha: JAX array with shape [..., n_channels]. Adstock decay parameter for
      each channel.
    max_lag: Integer, the maximum duration of the adstock effect.
    n_times_output: Integer, the desired number of time points in the output.

  Returns:
    JAX array with shape [..., n_geos, n_times_output, n_channels] containing
    the adstocked media.
  """
  # --- Get input shapes --- 
  input_shape = jnp.shape(media)
  batch_shape = input_shape[:-3]
  n_geos = input_shape[-3]
  n_media_times = input_shape[-2]
  n_channels = input_shape[-1]
  alpha_shape = jnp.shape(alpha)
  if alpha_shape[-1] != n_channels:
      raise ValueError(
          f"Trailing dimension of alpha ({alpha_shape[-1]}) must match "
          f"trailing dimension of media ({n_channels})."
      )

  # --- Calculate window size and required history --- 
  window_size = max_lag + 1
  required_n_media_times = n_times_output + max_lag

  # --- Pad or slice media history --- 
  if n_media_times < required_n_media_times:
    # Pad with zeros at the beginning of the time dimension
    padding_amount = required_n_media_times - n_media_times
    # Paddings format: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...]
    # We need padding only before the time axis (-2)
    paddings = [(0, 0)] * (len(input_shape) - 2) + [(padding_amount, 0), (0, 0)]
    media_padded = jnp.pad(media, paddings, mode='constant', constant_values=0)
  elif n_media_times > required_n_media_times:
    # Slice to keep only the most recent required history
    start_index = n_media_times - required_n_media_times
    media_padded = jax.lax.dynamic_slice_in_dim(
        media, start_index, required_n_media_times, axis=-2)
  else:
    media_padded = media

  # --- Calculate adstock weights (convolution kernel) --- 
  # Range of lags [0, 1, ..., max_lag]
  l_range = jnp.arange(window_size, dtype=jnp.float32)
  # Geometric decay weights: alpha^lag. Shape: [..., n_channels, window_size]
  weights = jnp.power(jnp.expand_dims(alpha, -1), l_range)
  # Normalize weights to sum to 1 per channel
  normalization = jnp.sum(weights, axis=-1, keepdims=True)
  # Avoid division by zero if sum of weights is zero (e.g., alpha=0, max_lag>0)
  safe_normalization = jnp.where(normalization == 0, 1.0, normalization)
  normalized_weights = jnp.divide(weights, safe_normalization)
  # Reverse weights for convolution. Shape: [..., n_channels, window_size]
  kernel = normalized_weights[..., ::-1]

  # --- Prepare for convolution --- 
  # Transpose media to [..., n_channels, n_geos, required_n_media_times]
  # This groups data by channel first, suitable for vmap over channels.
  media_transposed = jnp.moveaxis(media_padded, -1, -3)

  # --- Define convolution for a single channel --- 
  def _convolve_channel(media_ch, kernel_ch):
    # media_ch shape: [..., n_geos, required_n_media_times]
    # kernel_ch shape: [..., window_size]
    
    # Reshape media for conv: [..., N=n_geos, W=req_times, C=1]
    media_ch_reshaped = jnp.expand_dims(media_ch, axis=-1)
    # Reshape kernel for conv: [..., W=window_size, I=1, O=1]
    kernel_ch_reshaped = jnp.expand_dims(kernel_ch, axis=(-1, -2))

    # Perform 1D convolution along the time axis for each geo
    # 'VALID' padding ensures output length is n_times_output
    dn = jax.lax.ConvDimensionNumbers(
        lhs_spec=(len(media_ch_reshaped.shape) - 3,) + (len(media_ch_reshaped.shape) - 2, len(media_ch_reshaped.shape) - 1),
        rhs_spec=(len(kernel_ch_reshaped.shape) - 3,) + (len(kernel_ch_reshaped.shape) - 2, len(kernel_ch_reshaped.shape) - 1),
        out_spec=(len(media_ch_reshaped.shape) - 3,) + (len(media_ch_reshaped.shape) - 2, len(media_ch_reshaped.shape) - 1)
    ) # Matches NWC, WIO, NWC based on input ranks
    
    convolved = jax.lax.conv_general_dilated(
        lhs=media_ch_reshaped,
        rhs=kernel_ch_reshaped,
        window_strides=(1,),
        padding='VALID',
        dimension_numbers=dn,
        feature_group_count=1
    )
    # Output shape: [..., n_geos, n_times_output, 1]
    # Squeeze the trailing channel dimension
    return jnp.squeeze(convolved, axis=-1)

  # --- Apply convolution per channel using vmap --- 
  # We map over the channel dimension, which is now at axis -3 for media
  # and axis -2 for the kernel.
  # The output will have the channel dimension at axis -3.
  adstocked_media_transposed = jax.vmap(
      _convolve_channel, 
      in_axes=(-3, -2), # Axis indices for media_transposed and kernel
      out_axes=-3       # Place the mapped axis back at -3 in the output
  )(media_transposed, kernel)
  # Output shape: [..., n_channels, n_geos, n_times_output]

  # --- Transpose back to original format --- 
  # Move channel axis from -3 back to -1
  adstocked_media = jnp.moveaxis(adstocked_media_transposed, -3, -1)
  # Final shape: [..., n_geos, n_times_output, n_channels]

  return adstocked_media


## Step 4: Define Hill Function (TFP-on-JAX)

The Hill function models the saturation effect or diminishing returns of advertising. As media exposure increases, its marginal effect typically decreases. The Hill function captures this non-linear relationship.

In [None]:
def hill_jax(media, ec, slope):
  """Applies the Hill transformation using JAX.

  Args:
    media: JAX array with shape [..., n_geos, n_times, n_channels].
      Adstocked media values.
    ec: JAX array broadcastable to [..., 1, 1, n_channels].
      The concentration parameter controlling the inflection point.
      Often has shape [..., n_channels].
    slope: JAX array broadcastable to [..., 1, 1, n_channels].
      The slope parameter controlling the steepness of the curve.
      Often has shape [..., n_channels].

  Returns:
    JAX array with the same shape as media, containing the transformed values.
  """
  # Validate shapes (optional, good practice)
  # media_shape = jnp.shape(media)
  # ec_shape = jnp.shape(ec)
  # slope_shape = jnp.shape(slope)
  # n_channels = media_shape[-1]
  # if ec_shape[-1] != n_channels or slope_shape[-1] != n_channels:
  #   raise ValueError("Trailing dimension of ec and slope must match media.")

  # Ensure ec and slope are broadcastable to media's channel dimension
  # JAX handles broadcasting automatically if ec/slope have shape [..., n_channels]

  # Calculate term: ec^slope
  # JAX will broadcast ec and slope to match the channel dimension
  ec_pow_slope = jnp.power(ec, slope)

  # Calculate term: media^slope
  # JAX will broadcast slope to match media's shape [..., G, T, C]
  media_pow_slope = jnp.power(media, slope)

  # Calculate denominator: media^slope + ec^slope
  denominator = media_pow_slope + ec_pow_slope

  # Calculate Hill transformation: media^slope / (media^slope + ec^slope)
  # Avoid division by zero if denominator is zero. Output 0 in that case.
  # This can happen if media=0 and ec=0.
  hill_media = jnp.where(
      denominator == 0,
      0.0,
      jnp.divide(media_pow_slope, denominator)
  )

  return hill_media

## Step 5: Demonstrate Usage

Now, let's define some example parameters for the Adstock and Hill functions. We'll derive the number of channels and time points from the `media_data_jax` we loaded earlier.

In [None]:
# Get dimensions from loaded media data
n_geos, n_times, n_channels = media_data_jax.shape
print(f"Detected {n_geos} geos, {n_times} time points, {n_channels} channels.")

# --- Adstock Parameters ---
max_lag = 13 # Example max lag (weeks)
n_times_output = n_times # Output same number of time points as input
# Example alpha values (decay rate) for each channel
alpha = jnp.linspace(0.1, 0.8, n_channels)
print(f"\nAdstock alpha (shape {alpha.shape}):\n{alpha}")
print(f"Adstock max_lag: {max_lag}")
print(f"Adstock n_times_output: {n_times_output}")

# --- Hill Parameters ---
# Example EC50 values (half-saturation point) for each channel
# Scaled by a factor, assuming media data represents impressions/spend
ec = jnp.linspace(0.1, 0.5, n_channels) * 1e5 
# Example slope values for each channel
slope = jnp.linspace(1.0, 3.0, n_channels)
print(f"\nHill EC50 (shape {ec.shape}):\n{ec}")
print(f"Hill slope (shape {slope.shape}):\n{slope}")

Apply the `adstock_jax` function to the raw media data.

In [None]:
adstocked_media = adstock_jax(media_data_jax, alpha, max_lag, n_times_output)
print("Shape of adstocked_media:", adstocked_media.shape)

Next, apply the `hill_jax` function to the *output* of the adstock transformation (`adstocked_media`) to model saturation effects.

In [None]:
hill_media = hill_jax(adstocked_media, ec, slope)
print("Shape of hill_media:", hill_media.shape)

Display a small slice of the final transformed data (first 2 geos, first 5 time points, first channel).

In [None]:
print("Slice of hill_media[:2, :5, 0]:\n", hill_media[:2, :5, 0])

Before defining the model, we also need to extract the target Key Performance Indicator (KPI) data that the model will try to predict.

In [None]:
# Extract KPI data (e.g., conversions) as a JAX array
# The loaded 'data' is an xarray Dataset. We access the 'kpi' DataArray.
target_kpi_jax = jnp.asarray(data['kpi'].values)

# Print the shape (expected: Time x Geo)
print("Shape of target KPI data (Time x Geo):", target_kpi_jax.shape)

## Step 6: Define a Simple TFP-on-JAX Model

Now, let's define a simple Media Mix Model using TFP-on-JAX's `JointDistributionCoroutine`. This model will incorporate the `adstock_jax` and `hill_jax` functions we defined earlier.

For simplicity in this demonstration, we will:
1. Use **fixed** parameters for the Adstock (`alpha`, `max_lag`) and Hill (`ec`, `slope`) transformations, using the example values defined in the previous section.
2. Define **priors** only for the intercept, the media channel coefficients (`beta_media`), and the observation noise (`sigma`).
3. Assume a simple linear relationship between the transformed media contributions and the mean predicted KPI (`mu`).
4. Use a Normal distribution for the likelihood of observing the actual KPI data given the predicted mean `mu` and noise `sigma`.

In [None]:
def create_mmm_model_fixed_transforms(media_data, fixed_alpha, fixed_ec, fixed_slope, fixed_max_lag, fixed_n_times_output):
  """Creates a TFP JointDistributionCoroutine MMM with fixed transforms."""
  
  n_geos, n_times, n_channels = media_data.shape

  @tfd.JointDistributionCoroutine
  def model():
    # --- Priors --- 
    # Noise level
    sigma = yield tfd.HalfCauchy(loc=0., scale=5., name='sigma')
    # Baseline KPI
    intercept = yield tfd.Normal(loc=0., scale=5., name='intercept')
    # Coefficients for each media channel's contribution
    beta_media = yield tfd.Sample(tfd.Normal(loc=0., scale=1.), 
                                  sample_shape=[n_channels], 
                                  name='beta_media')
    
    # --- Transformations (using fixed parameters) --- 
    adstocked_media = adstock_jax(media_data, fixed_alpha, fixed_max_lag, fixed_n_times_output)
    transformed_media = hill_jax(adstocked_media, fixed_ec, fixed_slope)
    # transformed_media shape: [n_geos, n_times, n_channels]
    
    # --- Mean prediction (mu) --- 
    # intercept shape: [] (scalar)
    # beta_media shape: [n_channels]
    # Broadcasting: 
    # transformed_media * beta_media -> [n_geos, n_times, n_channels]
    # jnp.sum(..., axis=-1) -> [n_geos, n_times]
    # intercept + ... -> [n_geos, n_times]
    mu = intercept + jnp.sum(transformed_media * beta_media, axis=-1)

    # --- Likelihood --- 
    # sigma shape: [] (scalar)
    # Need to expand sigma to match mu shape for Normal distribution
    # Likelihood for KPI given mu and sigma
    # Use Independent to declare Geo and Time as batch dimensions
    kpi = yield tfd.Independent(tfd.Normal(loc=mu, scale=jnp.expand_dims(sigma, axis=(-1, -2))), 
                              reinterpreted_batch_ndims=2, 
                              name='kpi')
  
  return model

# --- Instantiate the model --- 
# Use the JAX media data and the fixed parameters defined earlier
mmm_model_fixed_jax = create_mmm_model_fixed_transforms(
    media_data=media_data_jax,
    fixed_alpha=alpha, # From cell 'define_demo_params'
    fixed_ec=ec,       # From cell 'define_demo_params'
    fixed_slope=slope, # From cell 'define_demo_params'
    fixed_max_lag=max_lag, # From cell 'define_demo_params'
    fixed_n_times_output=n_times_output # From cell 'define_demo_params'
)

print("Instantiated TFP-on-JAX Model (fixed transforms):\n", mmm_model_fixed_jax)

### Step 6.1: Prior Predictive Sampling

Before fitting the model to the actual data (which involves MCMC sampling from the posterior), it's crucial to perform a **prior predictive check**. This involves sampling from the model using only the prior distributions.

The goal is to see if the priors we've defined generate plausible data *before* we let the model see the real `target_kpi_jax`. If the prior predictive samples look completely unreasonable (e.g., predicting wildly different scales, shapes, or distributions than expected for KPI), it suggests our priors might be poorly chosen or the model structure might be misspecified. For example, we might expect KPI values to be generally positive and unimodal.

We will:
1. Sample from the `mmm_model_fixed_jax` using its `.sample()` method.
2. Extract the sampled KPIs (`kpi` variable from the model).
3. Visualize the distribution of these prior predictive KPIs for a specific data point.

In [None]:
# Import plotting library
import matplotlib.pyplot as plt

# Set number of prior samples
num_prior_samples = 500

# Create a JAX PRNG key
key = jax.random.PRNGKey(123)

# Sample from the model (prior predictive)
# For JDCoroutine, use value=None to sample all variables defined in the model
prior_samples = mmm_model_fixed_jax.sample(value=None, seed=key, sample_shape=num_prior_samples)

# Extract the prior predictive samples for the KPI
prior_kpi_samples = prior_samples['kpi']

# Print the shape of the prior predictive KPI samples
# Expected shape: [num_prior_samples, n_geos, n_times]
print("Shape of prior predictive KPI samples:", prior_kpi_samples.shape)

Let's visualize the distribution of these prior predictive KPI samples for a single geo (index 0) at a single time point (index 10).

Check the histogram: Does the range of predicted KPI values seem plausible based on domain knowledge? Is the shape roughly what you might expect, or does it reveal potential issues with the priors (e.g., unintended multi-modality)?

In [None]:
# Select samples for Geo 0, Time 10
geo_idx = 0
time_idx = 10
samples_point = prior_kpi_samples[:, geo_idx, time_idx]

# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(samples_point, bins=30, density=True)
plt.title(f'Prior Predictive Distribution for KPI (Geo {geo_idx}, Time {time_idx})')
plt.xlabel('Predicted KPI Value')
plt.ylabel('Density')
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

# Also print basic stats
print(f"\nPrior Predictive KPI Stats (Geo {geo_idx}, Time {time_idx}):")
print(f"Mean: {jnp.mean(samples_point):.2f}")
print(f"Std Dev: {jnp.std(samples_point):.2f}")
print(f"Min: {jnp.min(samples_point):.2f}")
print(f"Max: {jnp.max(samples_point):.2f}")

### Step 6.2: Posterior Sampling (MCMC)

Now that we have defined the model and checked the priors, we can perform **posterior inference** using Markov Chain Monte Carlo (MCMC). The goal is to sample from the posterior distribution `p(parameters | observed_kpi)`, which tells us the plausible values for our parameters (`sigma`, `intercept`, `beta_media`) given the data.

We need:
1.  **Target Log Probability Function:** A function that takes the parameters we want to sample (`sigma`, `intercept`, `beta_media`) as input and returns the log probability of the model *given the observed KPI data*. This function effectively calculates `log p(parameters) + log p(observed_kpi | parameters)` by calling `model.log_prob()` and supplying (or "pinning") the actual `target_kpi_jax` to the `kpi` argument of `log_prob`.
2.  **MCMC Sampler:** We'll use the No-U-Turn Sampler (NUTS), specifically the `tfp.experimental.mcmc.windowed_adaptive_nuts` kernel. NUTS is an efficient gradient-based MCMC algorithm. The `windowed_adaptive` variant automatically tunes the sampler's step size during an initial adaptation/burn-in phase.
3.  **Initial State:** Starting values for the parameters (`sigma`, `intercept`, `beta_media`). We often initialize multiple chains starting from different points, typically drawn from the prior distribution, to help assess convergence.
4.  **Sampling Loop:** Run the NUTS kernel for a specified number of steps. This typically includes:
    *   *Burn-in (or Warmup/Adaptation):* An initial period where the sampler adapts and converges towards the high-density region of the posterior. These samples are usually discarded.
    *   *Sampling (or Results):* The period after burn-in where we collect samples to approximate the posterior distribution.
5.  **JIT Compilation:** Use `@jax.jit` to compile the MCMC sampling loop. This is crucial for performance, especially with complex models or large datasets, as it optimizes the computation graph.

The output `states` from `tfp.experimental.mcmc.run_kernel` (when using NUTS on a target function with multiple inputs) will be a list where each element corresponds to the samples for one input parameter of the `target_log_prob_fn`, ordered positionally.

In [None]:
# Define the target log probability function
# This function takes parameters as input and returns the model's log probability
# evaluated with the *observed* KPI data pinned.
def target_log_prob_fn(sigma, intercept, beta_media):
  return mmm_model_fixed_jax.log_prob({
      'sigma': sigma,
      'intercept': intercept,
      'beta_media': beta_media,
      'kpi': target_kpi_jax # Pin observed data here!
  })

# Example: Evaluate log prob at a point (e.g., prior sample)
example_params = mmm_model_fixed_jax.sample(value=None, seed=jax.random.PRNGKey(0), sample_shape=()) 
example_log_prob = target_log_prob_fn(**{k: example_params[k] for k in ['sigma', 'intercept', 'beta_media']})
print(f"Example Log Prob at prior sample: {example_log_prob}")

In [None]:
# Import MCMC utilities
from tfp.experimental import mcmc
import time

# MCMC settings
num_burnin_steps = 500
num_results = 1000
num_chains = 4 # For parallel chains

# --- Initial state --- 
# Get initial states by sampling from the prior (one for each chain)
key_init, key_mcmc = jax.random.split(jax.random.PRNGKey(42))
initial_state_params = mmm_model_fixed_jax.sample(value=None, seed=key_init, sample_shape=[num_chains])
# Keep only the parameters we are sampling (sigma, intercept, beta_media)
initial_state = [
    initial_state_params['sigma'],
    initial_state_params['intercept'],
    initial_state_params['beta_media']
]
print(f"Initial state shapes: {[s.shape for s in initial_state]}")

# --- MCMC Kernel --- 
# Create the Windowed Adaptive NUTS kernel
#Adaptation steps should be a fraction of burn-in
num_adaptation_steps = int(num_burnin_steps * 0.8)
adaptive_kernel = mcmc.windowed_adaptive_nuts(
    target_log_prob_fn=target_log_prob_fn, 
    num_adaptation_steps=num_adaptation_steps
) 

# --- JIT-compiled MCMC runner --- 
@jax.jit
def run_mcmc_chain(key, initial_state):
  # Run the kernel
  # `run_kernel` returns (states, kernel_results)
  states, kernel_results = mcmc.run_kernel(
      kernel=adaptive_kernel,
      num_steps=num_burnin_steps + num_results,
      current_state=initial_state,
      seed=key
  )
  return states, kernel_results

# --- Run MCMC --- 
print(f"Running {num_chains} NUTS chains for {num_burnin_steps} burn-in + {num_results} results...")
start_time = time.time()
# Use vmap to run chains in parallel
keys_mcmc = jax.random.split(key_mcmc, num_chains)
states, kernel_results = jax.vmap(run_mcmc_chain)(keys_mcmc, initial_state)
end_time = time.time()
print(f"MCMC finished in {end_time - start_time:.2f} seconds.")

# states is a list [sigma_samples, intercept_samples, beta_media_samples]
# Each element has shape [num_chains, num_steps, ...parameter_dims...]
print(f"\nRaw output shapes (chains, steps, ...): {[s.shape for s in states]}")

Extract the posterior samples after discarding the burn-in steps.

In [None]:
# Discard burn-in steps and combine chains
# states is a list: [sigma_samples, intercept_samples, beta_media_samples]
# Each item has shape [num_chains, num_steps, ...]
# We want shape [num_chains * num_results, ...]

posterior_samples = {
    'sigma': states[0][:, num_burnin_steps:].reshape(-1, *states[0].shape[2:]),
    'intercept': states[1][:, num_burnin_steps:].reshape(-1, *states[1].shape[2:]),
    'beta_media': states[2][:, num_burnin_steps:].reshape(-1, *states[2].shape[2:]),
}

print("Posterior sample shapes (merged_samples, ...):")
for name, samples in posterior_samples.items():
  print(f"  {name}: {samples.shape}")

### Step 6.3: Analyze MCMC Results

After running MCMC, we need to analyze the results to understand the posterior distribution and check if the sampler behaved well.

Key checks include:
1.  **Summary Statistics:** Calculate metrics like mean, median, and standard deviation for each parameter's posterior samples. This gives us point estimates (like the posterior mean) and a measure of uncertainty (like the posterior standard deviation).
2.  **Trace Plots:** Visualize the sampled values for each parameter across all MCMC steps (including burn-in) for each chain. Well-behaved chains should:
    *   Look like stationary "fuzzy caterpillars" after the burn-in period, indicating the sampler is exploring a stable distribution.
    *   Show good mixing, meaning the chains explore the parameter space effectively without getting stuck.
    *   Have different chains overlapping in the same region, suggesting they all converged to the same posterior distribution.
3.  **Diagnostics:**
    *   **R-hat (Potential Scale Reduction Factor):** Compares the variance within each chain to the variance between chains. Values close to 1.0 (ideally **< 1.05**, sometimes relaxed to < 1.1) suggest that all chains have converged to the same distribution.
    *   **Effective Sample Size (ESS):** Estimates the number of *effectively independent* samples obtained, accounting for autocorrelation within the chains. Higher ESS values are better, indicating more efficient exploration of the posterior and more reliable estimates. Low ESS might suggest needing longer sampling runs or sampler tuning.

In [None]:
# Calculate and print summary statistics for posterior samples
print("Posterior Summary Statistics:")
for name, samples in posterior_samples.items():
  print(f"\nParameter: {name}")
  print(f"  Mean:   {jnp.mean(samples, axis=0)}")
  print(f"  Median: {jnp.median(samples, axis=0)}")
  print(f"  Std Dev:{jnp.std(samples, axis=0)}")

In [None]:
# Plot trace plots for key parameters
# Note: We use the 'states' variable which contains samples *before* burn-in removal and chain merging
# states = [sigma_samples, intercept_samples, beta_media_samples]
# Shapes: [num_chains, num_steps, ...]

param_names = ['sigma', 'intercept', 'beta_media']
num_params_to_plot = 3 # Plot sigma, intercept, and beta_media[0]

fig, axes = plt.subplots(num_params_to_plot, 1, figsize=(10, 2 * num_params_to_plot), sharex=True)

for i in range(num_params_to_plot):
    param_idx = i
    param_name = param_names[param_idx]
    # Select the correct samples from the 'states' list
    samples_all_chains = states[param_idx]
    # If beta_media, select the first component
    if param_name == 'beta_media':
        samples_all_chains = samples_all_chains[:, :, 0]
        plot_title = f'Trace Plot: {param_name}[0]'
    else:
        plot_title = f'Trace Plot: {param_name}'
        
    ax = axes[i]
    # Plot samples for each chain
    for chain_idx in range(num_chains):
        ax.plot(samples_all_chains[chain_idx], alpha=0.7)
    # Add vertical line for burn-in
    ax.axvline(num_burnin_steps, color='red', linestyle='--', label=f'Burn-in ({num_burnin_steps} steps)')
    ax.set_title(plot_title)
    ax.set_ylabel('Sample Value')
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.6)

axes[-1].set_xlabel('MCMC Step')
fig.tight_layout()
plt.show()

In [None]:
# Calculate MCMC diagnostics (R-hat, ESS)
from tfp import mcmc

# Use samples *after* burn-in but *before* merging chains
# states_after_burnin shape: [param_list, num_chains, num_results, ...]
states_after_burnin = [
    s[:, num_burnin_steps:] for s in states
]

print("MCMC Diagnostics (calculated on post-burn-in samples):")
for i, name in enumerate(param_names):
    samples = states_after_burnin[i]
    try:
        # R-hat (Potential Scale Reduction)
        rhat = mcmc.potential_scale_reduction(samples)
        # ESS (Effective Sample Size)
        ess = mcmc.effective_sample_size(samples)
        
        print(f"\nParameter: {name}")
        print(f"  R-hat: {rhat}")
        print(f"  ESS:   {ess}")
    except Exception as e:
        print(f"\nCould not compute diagnostics for {name}: {e}")
        # This might happen if shapes are unexpected or samples have issues (e.g., all NaNs)
