In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display

The multivariate normal distribution is a generalisation of the univariate normal distribution to higher dimensions. It is used to describe a set of correlated random variables each of which clusters around a mean value. It has two parameters: a k-dimensional mean vector and a k x k positive definite covariance matrix.

This notebook is designed to enable the interactive visualisation of 2D normal distributions for different covariance matrices. The interactive visualisation at the bottom of the notebook only displays the random sample if the covariance matrix is positive definite.  

In [None]:
mean = np.array([0, 0])
cov = np.array([[4, -2], [-2, 2]])
sample = np.random.multivariate_normal(mean, cov, 100)
plt.scatter(sample[:, 0], sample[:, 1])

In [None]:
fig, ax = plt.subplots()
ax.grid(True)
plt.ion()

def new_widget(text, value):
    return widgets.FloatText(
        value = value,
        description = text)

mean_1 = new_widget('Mean 1: ', 0)
mean_2 = new_widget('Mean 2: ', 0)
cov_11 = new_widget('Cov 1, 1: ', 4)
cov_12 = new_widget('Cov 1, 2: ', -2)
cov_22 = new_widget('Cov 2, 2: ', 2)

layout = widgets.VBox([
    widgets.HBox([mean_1, mean_2]),
    widgets.HBox([cov_11, cov_12]),
    widgets.HBox([cov_12, cov_22])
])

def on_change(change):
    global fig, ax
    
    # Deleting previous figure
    ax.clear()
    
    # Plotting itself
    # But only if the covariance matrix is definite positive
    mean = np.array([float(wid[0].value), float(wid[1].value)])
    cov = np.array([[float(wid[2].value), float(wid[3].value)], 
                    [float(wid[3].value), float(wid[4].value)]])
    if np.all(np.linalg.eigvals(cov) > 0):
        sample = np.random.multivariate_normal(mean, cov, 100)
        paths = ax.scatter(sample[:, 0], sample[:, 1])
        lim_y = np.max(np.abs(ax.get_ylim()))
        lim_x = np.max(np.abs(ax.get_xlim()))
        lim = max(lim_x, lim_y)
        ax.set_xlim([-lim, lim])
        ax.set_ylim([-lim, lim])

wid = [mean_1, mean_2, cov_11, cov_12, cov_22]
[w.observe(on_change) for w in wid]
on_change(None)

display(layout)