In [None]:
%matplotlib notebook

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact, interactive, fixed, interact_manual, GridspecLayout
import ipywidgets as widgets
from IPython.display import display

# Model Fitting

Suppose you have obtained some data / observations and are now wondering what to do with them. If you suspect the data has a particular model, you can optimise the parameters of your model such that they are best explained by the data.

The aim of this tutorial is to familiarise you with Bayesian model fitting. Like all good tutorials, we'll assume our data was generated from a Normal distribution. We'll then see

Like all good tutorials, we'll assume the data was generated from a Normal distribution. Ultimately what we want to estimate here is the mean and the variance of the distribution from which we obtained our data.

# Sample some data

Below, we'll sample some data from a 1-dimensional Normal distribution and plot their histogram. Feel free to play with the value of the mean, $\mu$, and the standard deviation, $\sigma$, and see how this changes the data.

Things to note/try:
* For large values of $N$, you should see the typical bell-curve shape of the Normal distribution.
* If you change $\mu$, the centre of the data will move (i.e. the centre value on the $x$ axis will move).
* If you increase $\sigma$, the spread of the data will increase (i.e. the range on the $x$ axis will increase).

In [None]:
# define some plotting functions
colours = {"ground_truth": "#4daf4a",
           "empirical": "#f781bf",
           "exact": "#377eb8",
           "approx": "#ff7f00"}
linestyles = {"ground_truth": "solid",
              "empirical": "dashdot",
              "exact": "dashed",
              "approx": "dotted"}

def get_contours(x, mu, std):
    return (1 / (std * np.sqrt(np.pi))) * np.exp(-0.5 * np.square((x - mu) / std))

def plot_contours(ax, x1, y1, c1, l1, ls1, x2=None, y2=None, c2=None, l2=None, ls2=None):
    ax.plot(x1, y1, color=c1, label=l1, linestyle=ls1)
    if x2:
        ax.plot(x2, y2, color=c2, label=l2, linestyle=ls2)

In [None]:
def update_contours(change):
    x = np.linspace(mu.value - 3*std.value, mu.value + 3*std.value, 1000)
    data_contour = get_contours(x, mu.value, std.value)
    emp_contour = get_contours(x, emp_mean, emp_std)
    ax10.clear()
    plot_contours(ax10, x, data_contour, colours["ground_truth"], "Ground Truth", linestyles["ground_truth"],
                  x, emp_contour, colours["approx"], "Empirical", linestyles["approx"])

emp_mean, emp_std = [np.mean(w.value), np.std(w.value)]
display(w);

fig10, ax10 = plt.subplots(num=10)
x = np.linspace(mu.value - 3*std.value, mu.value + 3*std.value, 1000)
data_contour = get_contours(x, mu.value, std.value)
emp_contour = get_contours(x, emp_mean, emp_std)
plot_contours(ax10, x, data_contour, colours["ground_truth"], "Ground Truth", linestyles["ground_truth"],
              x, emp_contour, colours["approx"], "Empirical", linestyles["approx"])
[w.children[n].observe(update_contours, 'value') for n in range(len(w.children))];

In [None]:
# calculate empirical mean and standard deviation from our observations
emp_mean, emp_std = [np.mean(w.value), np.std(w.value)]

def plot_histograms(ax, mean1, std1, mean2, std2,
                    label1="Empirical Approximation", label2="Ground Truth",
                    c1="#377eb8", c2="#4daf4a", meanstd1=False):
    ax.hist(np.random.normal(mean1, std1, [50000]), color=c1, alpha=0.3, bins=50, align='left', label=label1)
    ax.axvline(mean1, color='k', linestyle='solid')
    ax.hist(np.random.normal(mean2, std2, [50000]), color=c2, alpha=0.3, bins=50, align='left', label=label2)
    ax.axvline(mean2, color='g', linestyle='solid')
    ax.legend(loc='upper right');
    if meanstd1:
        ax.axvline(mean1-2*meanstd1, color='k', linestyle='--')
        ax.axvline(mean1+2*meanstd1, color='k', linestyle='--')

def update_empirical_histograms(change):
    ax1.clear()
    emp_mean, emp_std = [np.mean(w.value), np.std(w.value)]
    plot_histograms(ax1, emp_mean, emp_std, mu.value, std.value)

display(w);

fig1, ax1 = plt.subplots(num=1)
plot_histograms(ax1, emp_mean, emp_std, mu.value, std.value)
[w.children[n].observe(update_empirical_histograms, 'value') for n in range(len(w.children))];

In [None]:
def get_data(N, mu, std):
    w.value = np.random.normal(mu, std, [N])

def display_data(ax, data):
    ax.hist(w.value, bins=10, color='g', align='left')
    plt.show()
    
def update_plot(change):
    ax0.clear()
    display_data(ax0, data=w.value)

# set up figure
fig0, ax0 = plt.subplots(num=0)

# get our widgets
N = widgets.IntSlider(value=50, min=5, max=500, step=5, description="N", continuous_update=False)
mu = widgets.IntSlider(value=0, min=-10, max=10, step=1, description="$\mu$", continuous_update=False)
std = widgets.FloatSlider(value=1., min=0.1, max=10, step=0.1, description="$\sigma$", continuous_update=False)

w = interactive(get_data, N=N, mu=mu, std=std);
display(w);

# display our data
display_data(ax0, w.value);

# watch for a change in our widgets. This is messy, is there a nicer way to do it?
[w.children[n].observe(update_plot, 'value') for n in range(len(w.children))];

# Empirical Mean and Standard Deviation

The first thing we might think to do is simply to take the mean and the standard deviation of our observations without trying to include prior knowledge about their values or estimate confidence in their values.

Below, we draw lots of samples from the ground truth distribution, i.e. $\mathcal{N} \left( \mu, \sigma^2 \right)$, and plot the ground truth histogram in green. This is overlaid on our learned distribution which is in gray.

If you play around with the values of the parameters above and rerun the cell below you should notice:
* Increasing the number of observations, $N$, drives our empirical approximation closer to the ground truth;
* Similarly, deacreasing the standard deviation, $\sigma$, of our distribution generally results in a better approximation;
* If we have a small number of observations, $N$, and a large standard deviation, $\sigma$, our empirical approximation can be quite poor!

In [4]:
# calculate empirical mean and standard deviation from our observations
emp_mean, emp_std = [np.mean(w.value), np.std(w.value)]

def plot_histograms(ax, mean1, std1, mean2, std2,
                    label1="Empirical Approximation", label2="Ground Truth",
                    c1="#377eb8", c2="#4daf4a", meanstd1=False):
    ax.hist(np.random.normal(mean1, std1, [50000]), color=c1, alpha=0.3, bins=50, align='left', label=label1)
    ax.axvline(mean1, color='k', linestyle='solid')
    ax.hist(np.random.normal(mean2, std2, [50000]), color=c2, alpha=0.3, bins=50, align='left', label=label2)
    ax.axvline(mean2, color='g', linestyle='solid')
    ax.legend(loc='upper right');
    if meanstd1:
        ax.axvline(mean1-2*meanstd1, color='k', linestyle='--')
        ax.axvline(mean1+2*meanstd1, color='k', linestyle='--')

def update_empirical_histograms(change):
    ax1.clear()
    emp_mean, emp_std = [np.mean(w.value), np.std(w.value)]
    plot_histograms(ax1, emp_mean, emp_std, mu.value, std.value)

display(w);

fig1, ax1 = plt.subplots(num=1)
plot_histograms(ax1, emp_mean, emp_std, mu.value, std.value)
[w.children[n].observe(update_empirical_histograms, 'value') for n in range(len(w.children))];

NameError: name 'w' is not defined

# Bayesian Inference

Suppose we also want to express our confidence in our parameter estimates, i.e. how confident are we in our estimate of the mean, $\mu$? The simple method above can't do this. It only gives us estimates of $\mu$ and $\sigma$. It also doesn't let us include prior knowledge about the ground truth distribution.

If you aren't familiar with probability or feel like you could be doing with a refresher, don't worry! The aim of this section is to give you some intuition into the different parts of a Bayesian model, not to be able to reproduce the code below. You should play around with the values used in the model via the sliders and watch how this changes the solution. This will hopefully make you more comfortable when reading about methods using Bayesian models - these pop up very frequently in medical image analysis!

## Set up our prior

Bayesian inference allows us to include any prior information we have about the distribution which generated our observations. This can be very helpful in two situations in particular:
* When we don't have many observations;
* When our observations are very noisy.

In the following model, we have 4 parameters. These are:
* $m_0$ = Our prior knowledge of the mean;
* $v_0$ = Our prior confidence in our prior mean;
* $\beta_0$ = Our prior on the mean of the variance;
* $\sigma_0$ = How confident we are in our prior on the variance.

Our uncertainty about our prior's mean is shown by the dashed lines. These have been drawn at $m_0 \pm2v_0$. The default setting of the prior where $v_0 = 1000$ reflects the fact that we are highly uncertain about our prior value on the mean.

In [None]:
def print_prior(m0, v0, beta_mean0, beta_var0):
    print(m0, v0, beta_mean0, beta_var0)

def update_prior_histograms(change):
    ax2.clear()
    plot_histograms(ax2, m0.value, beta_mean0.value, mu.value, std.value,
                    label1="Prior distribution", label2="Data distribution")#,
#                     meanstd1=np.sqrt(v0.value*beta_mean0.value))
    
# create widgets for the prior values of our NormalGamma distribution
m0 = widgets.IntSlider(value=0, min=-10, max=10, step=1, description="$m_0$", continuous_update=False)
v0 = widgets.IntSlider(value=1000, min=1, max=1000, step=1, description="$v_0$", continuous_update=False)
beta_mean0 = widgets.IntSlider(value=1, min=1, max=100, step=1, description="Beta mean", continuous_update=False)
beta_var0 = widgets.IntSlider(value=1000, min=1, max=1000, step=1, description="Beta var", continuous_update=False)
prior_widget = interactive(print_prior, m0=m0, v0=v0, beta_mean0=beta_mean0, beta_var0=beta_var0)

fig2, ax2 = plt.subplots(num=2)
plot_histograms(ax2, m0.value, beta_mean0.value, mu.value, std.value,
                label1="Prior distribution", label2="Data distribution")#,
#                 meanstd1=np.sqrt(v0.value*beta_mean0.value))

grid = GridspecLayout(1, 2)
grid[0, 0] = w
grid[0, 1] = prior_widget
display(grid);

[w.children[n].observe(update_prior_histograms, 'value') for n in range(len(w.children))];
[prior_widget.children[n].observe(update_prior_histograms, 'value') for n in range(len(prior_widget.children))];

## Exact Solution

In a lot of cases, inferring the posterior distribution of our parameters is an intractable problem. In the case we're considering here however, we can find the exact solution!

We'll see later how we can still get reasonable approximations of the posterior in cases where our problem is intractable. Also, because we will have the exact solution, we can easily compare the approximated solution with the exact solution to see how good it is - handy!

In [None]:
def update_exact_histograms(change):
    ax4.clear()
    m, v, beta_mean, beta_var = exact_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)
    plot_histograms(ax4, m, np.sqrt(1/beta_mean), mu.value, std.value,
                    label1="Exact solution", label2="Data distribution")

def exact_solution(data, m0, v0, beta_mean0, beta_var0):
    # get number of observations
    n = len(data)
    
    # convert prior parameters into form expected in update equations
    l0 = 1 / v0
    a0 = 1 / (beta_var0*beta_mean0)
    b0 = 1 / (beta_var0*(beta_mean0**2))
    
    # get exact solution
    xbar = np.mean(data)
    s = np.mean(np.square(data-xbar))
    m = (l0*m0 + n*xbar) / (l0 + n)
    l = l0 + n
    a = a0 + n/2
    b = b0 + 0.5*n*(s + l0*((xbar-m0)**2)/(l0 + n))
    
    # convert parameters back to same form as in prior
    v = 1 / l
    beta_mean = a / b
    beta_var = (a**2) / b
    
    return m, v, beta_mean, beta_var

mE, vE, beta_meanE, beta_varE = exact_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)

fig4, ax4 = plt.subplots(num=4)
plot_histograms(ax4, mE, np.sqrt(1/beta_meanE), mu.value, std.value,
                label1="Exact solution", label2="Data distribution")

display(grid);

[w.children[n].observe(update_exact_histograms, 'value') for n in range(len(w.children))];
[prior_widget.children[n].observe(update_exact_histograms, 'value') for n in range(len(prior_widget.children))];

In [None]:
def update_exact_histograms(change):
    ax5.clear()
    m, v, beta_mean, beta_var = exact_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)
    plot_histograms(ax5, m, np.sqrt(1/beta_mean), mu.value, std.value,
                    label1="Exact solution", label2="Data distribution")

def exact_solution(data, m0, v0, beta_mean0, beta_var0):
    # get number of observations
    n = len(data)
    
    # convert prior parameters into form expected in update equations
    l0 = 1 / v0
    beta_var0 = 1/beta_var0
    a0 = beta_var0 / beta_mean0
    b0 = beta_var0 / (beta_mean0**2)
    
    # get exact solution
    xbar = np.mean(data)
    s = np.mean(np.square(data-xbar))
    m = (l0*m0 + n*xbar) / (l0 + n)
    l = l0 + n
    a = a0 + n/2
    b = b0 + 0.5*n*(s + l0*((xbar-m0)**2)/(l0 + n))
    
    # convert parameters back to same form as in prior
    v = 1 / l
    beta_mean = a / b
    beta_var = (a**2) / b
    
    return m, v, beta_mean, beta_var

mE, vE, beta_meanE, beta_varE = exact_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)

fig5, ax5 = plt.subplots(num=5)
plot_histograms(ax5, mE, np.sqrt(1/beta_meanE), mu.value, std.value,
                label1="Exact solution", label2="Data distribution")

display(grid);

[w.children[n].observe(update_exact_histograms, 'value') for n in range(len(w.children))];
[prior_widget.children[n].observe(update_exact_histograms, 'value') for n in range(len(prior_widget.children))];

# Variational Bayes

In [None]:
def update_vb_histograms(change):
    ax3.clear()
    m, v, beta_mean, beta_var = vb_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)
    plot_histograms(ax3, m, np.sqrt(1/beta_mean), mu.value, std.value,
                    label1="Approximated distribution", label2="Data distribution")

def vb_solution(data, m0, v0, beta_mean0, beta_var0, iterations=10):
    N = len(data)
    s1 = np.sum(data)
    s2 = np.sum(np.square(data))
    b0, c0 = (beta_var0/beta_mean0), (beta_mean0**2)/beta_var0
    m, v, b, c = m0, v0, b0, c0
    for i in range(iterations):
        m = (m0 + v0*b*c*s1) / (1 + N*v0*b*c)
        v = v0 / (1 + N*v0*b*c)
        X = s2 - 2*s1*m + N*(m**2 + v)
        b = 1 / (1/b0 + X/2)
        c = N/2 + c0
    beta_mean = b * c
    beta_var = (b**2) * c
    return m, v, beta_mean, beta_var

mVB, vVB, beta_meanVB, beta_varVB = vb_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)

fig3, ax3 = plt.subplots(num=3)
plot_histograms(ax3, mVB, np.sqrt(1/beta_meanVB), mu.value, std.value,
                label1="Approximated distribution", label2="Data distribution")

display(grid);

[w.children[n].observe(update_vb_histograms, 'value') for n in range(len(w.children))];
[prior_widget.children[n].observe(update_vb_histograms, 'value') for n in range(len(prior_widget.children))];

# Comparing our approximation to the exact solution

In [None]:
def get_contour(x, mean, std):
    return (1 / (std * np.sqrt(np.pi))) * np.exp(-0.5 * np.square((x - mean) / std))

def plot_contours(ax):
    meanE, standard_devE = [], []
    meanVB, standard_devVB = [], []
    # average over 100 runs
    for i in range(100):
        data = np.random.normal(mu.value, std.value, [N.value])

        # get exact and VB solutions
        mE, vE, beta_meanE, beta_varE = exact_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)
        mVB, vVB, beta_meanVB, beta_varVB = vb_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)
        stdE, stdVB = [1/np.sqrt(beta_mean) for beta_mean in (beta_meanE, beta_meanVB)]
        
        # append to list
        meanE.append(mE)
        standard_devE.append(stdE)
        meanVB.append(mVB)
        standard_devVB.append(stdVB)
        
    # take mean across all runs
    mE, stdE = [np.mean(s) for s in (meanE, standard_devE)]
    mVB, stdVB = [np.mean(s) for s in (meanVB, standard_devVB)]
    
    # get range of points on x-axis
    x = np.linspace(mu.value-3*std.value, mu.value+3*std.value, 1000)
    x = np.linspace(mE-3*stdE, mE+3*stdE, 1000)
    
    # get contours
    yE = get_contour(x, mE, stdE)
    yVB = get_contour(x, mVB, stdVB)
    
    # plot
    ax.plot(x, yE, color='b', label='Exact solution')
    ax.plot(x, yVB, color='b', ls='--', label='Approximate solution')
    
    ax.legend(loc='upper right')

def update_contour_plot(change):
    ax6.clear()
    plot_contours(ax6)

fig6, ax6 = plt.subplots(num=6)
plot_contours(ax6)

display(grid);

[w.children[n].observe(update_contour_plot, 'value') for n in range(len(w.children))];
[prior_widget.children[n].observe(update_contour_plot, 'value') for n in range(len(prior_widget.children))];

In [None]:
def new_exact_solution(data, m0, v0, beta_mean0, beta_var0):
    # get number of observations
    n = len(data)
    
    # convert prior parameters into form expected in update equations
    k0 = 1 / v0
    a0 = (beta_mean0**2) / beta_var0
    b0 = beta_mean0 / beta_var0
    
    # get exact solution
    xbar = np.mean(data)
    s = np.mean(np.square(data-xbar))
    mN = (k0*m0 + n*xbar) / (k0 + n)
    kN = k0 + n
    aN = a0 + n/2
    bN = b0 + 0.5*n*(s + k0*((xbar-m0)**2)/(k0 + n))
    
    # convert parameters back to same form as in prior
    vN = 1 / kN
    beta_mean = aN / bN
    beta_var = (aN**2) / bN
    
    return mN, vN, beta_mean, beta_var

def new_vb_solution(data, m0, v0, beta_mean0, beta_var0, iterations=10):
    N = len(data)
    s1 = np.sum(data)
    s2 = np.sum(np.square(data))
    xbar = np.mean(data)
    a0, b0 = (beta_mean0**2)/beta_var0, (beta_mean0/beta_var0)
    k0 = 1 / v0
    mN = (k0*m0 + N*xbar) / (k0 + N)
    aN = a0 + ((N + 1) / 2)
    m, k, a, b = mN, k0, aN, b0
    for i in range(iterations):
        k = (k0 + N) * a / b
        b = b0 + 0.5*((k0 + N) * ((1/k) + (m**2)) - 2*m*(k0*m0 + s1) + s2 + k0*(m0**2))
    v = 1 / k
    beta_mean = a / b
    beta_var = (a**2) / b
    return m, v, beta_mean, beta_var

def get_contour(x, mean, std):
    return (1 / (std * np.sqrt(np.pi))) * np.exp(-0.5 * np.square((x - mean) / std))

def plot_contours(ax):
    meanE, standard_devE = [], []
    meanVB, standard_devVB = [], []
    # average over 100 runs
    for i in range(100):
        data = np.random.normal(mu.value, std.value, [N.value])

        # get exact and VB solutions
        mE, vE, beta_meanE, beta_varE = new_exact_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value)
        mVB, vVB, beta_meanVB, beta_varVB = new_vb_solution(w.value, m0.value, v0.value, beta_mean0.value, beta_var0.value, 10)
        stdE, stdVB = [1/np.sqrt(beta_mean) for beta_mean in (beta_meanE, beta_meanVB)]
        
        # append to list
        meanE.append(mE)
        standard_devE.append(stdE)
        meanVB.append(mVB)
        standard_devVB.append(stdVB)
        
    # take mean across all runs
    mE, stdE = [np.mean(s) for s in (meanE, standard_devE)]
    mVB, stdVB = [np.mean(s) for s in (meanVB, standard_devVB)]
    
    # get range of points on x-axis
    x = np.linspace(mu.value-3*std.value, mu.value+3*std.value, 1000)
    x = np.linspace(mE-3*stdE, mE+3*stdE, 1000)
    
    # get contours
    yE = get_contour(x, mE, stdE)
    yVB = get_contour(x, mVB, stdVB)
    
    # plot
    ax.plot(x, yE, color='b', label='Exact solution')
    ax.plot(x, yVB, color='b', ls='--', label='Approximate solution')
    
    ax.legend(loc='upper right')

def update_contour_plot(change):
    ax7.clear()
    plot_contours(ax7)

fig7, ax7 = plt.subplots(num=7)
plot_contours(ax7)

display(grid);

[w.children[n].observe(update_contour_plot, 'value') for n in range(len(w.children))];
[prior_widget.children[n].observe(update_contour_plot, 'value') for n in range(len(prior_widget.children))];

In [None]:
# this cell has the MLaPP VB solution
a0 = (beta_mean0.value**2) / beta_var0.value
b0 = beta_mean0.value / beta_var0.value

k0 = 1 / v0.value

xbar = np.mean(w.value)

mN = (k0*m0.value + N.value*xbar) / (k0 + N.value)
aN = a0 + (N.value+1)/2

# emu = mN

# m, k, a, b = mN, k0, aN, b0
# for i in range(10):
#     k = (k0 + N.value) * a / b
#     emusquare = (1/k) + (m**2)
#     b = b0 + k0 * (emusquare + (m0.value**2) - 2*emu*m0.value) + 0.5 * np.sum(np.square(w.value) + emusquare - 2*emu*w.value)

m, k, a, b = mN, k0, aN, b0
s1 = np.sum(w.value)
s2 = np.sum(np.square(w.value))
for i in range(10):
    k = (k0 + N.value) * a / b
    b = b0 + 0.5 * ((k0 + N.value)*((1/k) + (m**2)) - 2 * m * (k0*m0.value + s1) + s2 + k0*(m0.value**2))

# print(m, k, a, b)
print(m, 1/k, a/b, (a**2)/b)
gamma_mean = a / b
print(f"Gamma mean = {gamma_mean}")
standard_deviation = np.sqrt(1/gamma_mean)
print(f"Standard deviation = {standard_deviation}")