<a href="https://colab.research.google.com/github/PGM-Lab/2023-probai-private/blob/main/python/Day2-BeforeLunch/notebooks/students_simple_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exercise

<center>
<img src="./students_simple_model.png" alt="Drawing" width="650">
</center>


### Imports

In [None]:
%pip install -q numpy scipy matplotlib
%matplotlib inline
import numpy as np
from scipy import special, stats
import matplotlib.pyplot as plt


### Startup: Define priors, and sample artificial training data

In [None]:
# Define priors
alpha_prior, beta_prior = 1E-2, 1E-2   # Parameters for the prior over gamma
mu_prior = 0 # A priori mean for mu
tau_prior = 1E-6  # A priori precision for mu

# Sample data
np.random.seed(123)
N = 100
correct_mean = 5
correct_precision = 1
x = np.random.normal(loc=correct_mean, scale=1./np.sqrt(correct_precision), size=N)

## Helper-routine: Make plot of density

In [None]:
#@title
def plot_density(posterior_mean_mu, posterior_prec_mu,
                   posterior_alpha_gamma, posterior_beta_gamma,
                   correct_mean, correct_precision):
    mu_range = np.linspace(-15,15, 500).astype(np.float32)
    precision_range = np.linspace(1E-2, 3, 500).astype(np.float32)
    mu_mesh, precision_mesh = np.meshgrid(mu_range, precision_range)
    variational_log_pdf = \
        stats.norm.logpdf(mu_mesh, loc=posterior_mean_mu, scale=1. / np.sqrt(posterior_prec_mu)) + \
        stats.gamma.logpdf(x=precision_mesh,
                           a=posterior_alpha_gamma,
                           scale=1. / posterior_beta_gamma)
    plt.figure()
    plt.contour(mu_mesh, precision_mesh, variational_log_pdf, 25)
    plt.plot(correct_mean, correct_precision, "bo")
    plt.title('Density over $(\\mu, \\tau)$. Blue dot: True parameters')
    plt.xlabel("Mean $\\mu$")
    plt.ylabel("Precision $\\tau$")

## Helper-routine: Calculate ELBO

In [None]:
#@title
def calculate_ELBO(data, tau, alpha, beta, nu_p, tau_p, alpha_p, beta_p):
    """
    Helper routine: Calculate ELBO. Data is the sampled x-values, anything without a _p relates to the prior,
    everything _with_ a _p relates to the variational posterior.
    Note that we have no nu without a _p; we are simplifying by forcing this to be zero a priori

    Note: This function obviously only works when the model is as in this code challenge,
    and is not a general solution.

    :param data: The sampled data
    :param tau: prior precision for mu, the mean for the data generation
    :param alpha: prior shape of dist for gamma, the precision  of the data generation
    :param beta: prior rate of dist for gamma, the precision  of the data generation
    :param nu_p: VB posterior mean for the distribution of mu - the mean of the data generation
    :param tau_p: VB posterior precision for the distribution of mu - the mean of the data generation
    :param alpha_p: VB posterior shape of dist for gamma, the precision  of the data generation
    :param beta_p: VB posterior shape of dist for gamma, the precision  of the data generation
    :return: the ELBO
    """

    # We calculate ELBO as E_q log p(x,z) - E_q log q(z)
    # log p(x,z) here is log p(mu) + log p(gamma) + \sum_i log p(x_i | mu, gamma)

    # E_q log p(mu)
    log_p = -.5 * np.log(2 * np.pi) + .5 * np.log(tau) - .5 * tau * (1 / tau_p + nu_p * nu_p)

    # E_q log p(gamma)
    log_p = log_p + alpha * np.log(beta) + \
            (alpha - 1) * (special.digamma(alpha_p) - np.log(beta_p)) - beta * alpha_p / beta_p

    # E_q log p(x_i|mu, gamma)
    for xi in data:
        log_p += -.5 * np.log(2 * np.pi) \
                 + .5 * (special.digamma(alpha_p) - np.log(beta_p)) \
                 - .5 * alpha_p / beta_p * (xi * xi - 2 * xi * nu_p + 1 / tau_p + nu_p * nu_p)

    # Entropy of mu (Gaussian)
    entropy = .5 * np.log(2 * np.pi * np.exp(1) / tau_p)
    entropy += alpha_p - np.log(beta_p) + special.gammaln(alpha_p) \
               + (1 - alpha_p) * special.digamma(alpha_p)

    return log_p + entropy


## Do the VB

The task is to implemente the variational updating equations appearing below.

<center>
<img src="./updating_equations.png" alt="Drawing" width="650">
</center>

In [None]:
# Initialization
alpha_q = alpha_prior
beta_q = beta_prior
nu_q = mu_prior
tau_q = tau_prior
previous_elbo = -np.inf

# Start iterating
print("\n" + 100 * "=" + "\n   VB iterations:\n" + 100 * "=")
for iteration in range(1000):
    # Update gamma distribution: q(\gamma)=Gamma(\alpha_q,\beta_q)
    alpha_q = 0 ## Code the updating equation
    beta_q = beta_prior + .5 * np.sum(x * x) - nu_q * np.sum(x) + .5 * N * (1. / tau_q + nu_q * nu_q)

    # Update Gaussian distribution: q(\mu)=N(\nu_q,\tau_q^{-1})
    expected_gamma = 0 ## Code the updating equation
    tau_q = 0.1 ## Code the updating equation
    nu_q = 0 ## Code the updating equation
    
    # Calculate Lower-bound
    current_elbo = calculate_ELBO(data=x, tau=tau_prior, alpha=alpha_prior, beta=beta_prior,
                                    nu_p=nu_q, tau_p=tau_q, alpha_p=alpha_q, beta_p=beta_q)
    
    print("{:2d}:  ELBO: {:12.7f}, alpha_q: {:6.3f}, beta_q: {:12.3f}, nu_q: {:6.3f}, tau_q: {:6.3f}".format(
        iteration + 1,  current_elbo, alpha_q, beta_q, nu_q, tau_q))
    
    # ELBO should always increase. Check, but be a bit lenient to avoid crash due to numerical instability
    if current_elbo < previous_elbo - 1E-6:
        raise ValueError("ELBO is decreasing. Something is wrong! Goodbye...")
    
    if iteration > 0 and np.abs((current_elbo - previous_elbo) / previous_elbo) < 1E-20:
        # Very little improvement. We are done.
        break
    
    # If we didn't break we need to run again. Update the value for "previous"
    previous_elbo = current_elbo
    

print("\n" + 100 * "=" + "\n   Result:\n" + 100 * "=")
print("E[mu] = {:5.3f} with data average {:5.3f} and prior mean {:5.3f}.".format(nu_q, np.mean(x), 0.))
print("E[gamma] = {:5.3f} with inverse of data covariance {:5.3f} and prior {:5.3f}.".format(
    alpha_q / beta_q, 1. / np.cov(x), alpha_prior / beta_prior))

### Plot of the Prior density

In [None]:
plot_density(mu_prior, tau_prior, alpha_prior, beta_prior, correct_mean, correct_precision)
plt.show()

### Plot of the Variational Posterior density

In [None]:
plot_density(nu_q, tau_q, alpha_q, beta_q, correct_mean, correct_precision)
plt.show()