In [1]:
import numpy as np
import matplotlib.pyplot as plt

import math
import torch
from torch.distributions import Normal, MultivariateNormal
from torch import matmul

## Univariate Normal Distribution

In [2]:
from torch.distributions import Normal

# Set the parameters of the distribution
mu = torch.tensor([0.0], dtype=torch.float)
sigma = torch.tensor([5.0], dtype=torch.float)

# Instantiate the univariate normal distribution
uvn_dist = Normal(mu, sigma)

In [3]:
# Instantiate single point test dataset
X = torch.tensor([0.0], dtype=torch.float)

# Function to evaluate log prob using math formula
def raw_eval(X, mu, sigma):
    K = 1 / (math.sqrt(2 * math.pi) * sigma)
    E = math.exp( -1 * (X - mu) ** 2 * (1 / (2 * sigma ** 2)))
    return torch.log(K * E)

# Evaluate log-prob using PyTorch distributions function call
log_prob = uvn_dist.log_prob(X)
print("Log Prob: {:.3f}".format(log_prob[0]))

# Evaluate log-prob using formula
raw_eval_log_prob = raw_eval(X, mu, sigma)
print("Raw eval Log Prob: {:.3f}".format(raw_eval_log_prob[0]))

assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)

Log Prob: -2.528
Raw eval Log Prob: -2.528


In [4]:
# Number of samples to draw
num_samples = 100000

# Draw samples
samples = uvn_dist.sample([num_samples])

In [5]:
# The mean obtained from the samples
sample_mean = samples.mean()
print("Sample Mean: {:.3f}".format(sample_mean))

# The mean of the distribution from Pytorch
dist_mean = uvn_dist.mean
print("Dist Mean: {:.3f}".format(dist_mean[0]))

# As expected, the two means approximately match
assert torch.isclose(sample_mean, dist_mean, atol=0.3)

# The variance obtained from the samples
sample_var = uvn_dist.sample([num_samples]).var()
print("Sample Variance: {:.3f}".format(sample_var))

# The variance of the distribution from Pytorch
dist_var = uvn_dist.variance
print("Dist Variance: {:.3f}".format(dist_var[0]))

# As expected, the two variances approximately match
assert torch.isclose(sample_var, dist_var, atol=0.3)

Sample Mean: 0.022
Dist Mean: 0.000
Sample Variance: 24.858
Dist Variance: 25.000


## Interactive Visualization

Here we allow the user to set different values for the mean and variance of a univariate normal distribution and visualise the resulting distribution. 
Specifically, notice that changing the mean does not change the shape of the distribution. It just varies where the distribution peaks. Changing the variance causes the distribution to either become more diffuse / peaked.

Note: In order to run this section, please download the notebook. Interactive snippets do not work online. 

In [6]:
%matplotlib widget
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

fig, ax = plt.subplots()

ax.set_title("Univariate Normal Distribution")
ax.set_ylabel("P(X)")
ax.set_xlabel("X")


@interact
def plot_univariate_normal(mu=(-40, 40, 0.5), sigma=(4, 30, 0.5)):
    x = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000)
    [l.remove() for l in ax.lines]
    uvn_dist = Normal(mu, sigma)
    pdf = uvn_dist.log_prob(torch.from_numpy(x)).exp()
    ax.set_xlim(-75, 75)
    ax.set_ylim(0, 0.1)
    ax.plot(x, pdf, 'green')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=0.0, description='mu', max=40.0, min=-40.0, step=0.5), FloatSlider(val…

## Multivariate Normal Distribution

In [7]:
from torch.distributions import MultivariateNormal

# Set the parameters of the distribution
mu = torch.tensor([0.0, 0.0], dtype=torch.float)
C = torch.tensor([[5.0, 0.0], [0.0, 5.0]], dtype=torch.float)

# Instantiate the multivariate normal distribution
mvn_dist = MultivariateNormal(mu, C)

In [8]:
# Instantiate single point test dataset
X = torch.tensor([0.0, 0.0], dtype=torch.float)

# Function to evaluate log prob using math formula
def raw_eval(X, mu, C):
    K = (1 / (2 * math.pi * torch.sqrt(C.det())))
    X_minus_mu = (X - mu).reshape(-1, 1)
    E1 = torch.matmul(X_minus_mu.T, C.inverse())
    E = torch.exp(-1 / 2. * torch.matmul(E1, X_minus_mu))
    return torch.log(K * E)

# Evaluate log-prob using PyTorch distributions function call
log_prob = mvn_dist.log_prob(X)
print("Log Prob: {:.3f}".format(log_prob))

# Evaluate log-prob using formula
raw_eval_log_prob = raw_eval(X, mu, C)
print("Raw eval Log Prob: {:.3f}".format(raw_eval_log_prob[0][0]))

assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)

Log Prob: -3.447
Raw eval Log Prob: -3.447


In [9]:
# Number of samples to draw
num_samples = 100000

# Draw samples
samples = mvn_dist.sample([num_samples])

In [10]:
# The mean obtained from the samples
sample_mean = samples.mean(axis=0)
print("Sample Mean: {}".format(sample_mean))

# The mean of the distribution from Pytorch
dist_mean = mvn_dist.mean
print("Dist Mean: {}".format(dist_mean))

# As expected, the two means approximately match
assert torch.allclose(sample_mean, dist_mean, atol=1e-1)

# The variance obtained from the samples
sample_var = mvn_dist.sample([num_samples]).var(axis=0)
print("Sample Variance: {}".format(sample_var))

# The variance of the distribution from Pytorch
dist_var = mvn_dist.variance
print("Dist Variance: {}".format(dist_var))

# As expected, the two variances approximately match
assert torch.allclose(sample_var, dist_var, atol=1e-1)

Sample Mean: tensor([-0.0001, -0.0004])
Dist Mean: tensor([0., 0.])
Sample Variance: tensor([4.9463, 5.0443])
Dist Variance: tensor([5., 5.])


## Interactive Visualization

Here we allow the user to set different values for the means and covariance matrix of a 2D Normal distribution and visualise the resulting distribution. 

Specifically, notice that changing the mean does not change the shape of the distribution. It just varies where the distribution peaks. Changing $\mu_{0}$ shifts the center along the X axis. Similarly changing $\mu_{1}$ shifts the center along the Y-axis

While providing values for the covariance matrix, we should ensure that the matrix is not singular.

Note: In order to run this section, please download the notebook. Interactive snippets do not work online. 

In [11]:
from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting 
from matplotlib import cm

fig_1, ax_1 = plt.subplots(nrows=1, ncols=1)
ax_1.set_title("Bivariate Normal Distribution")
ax_1 = fig_1.gca(projection='3d')


@interact
def plot_2d_normal(
    mu_0=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),
    mu_1=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),
    sigma_00=widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0),
    sigma_01=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),
    sigma_11 =widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0)):
    
    def _reset_plot(ax):
        ax.clear()
        ax.set_ylabel("Y")
        ax.set_xlabel("X")
        ax.set_zlabel("P(X,Y)")
        
    X = np.linspace(-10, 10, 1000)
    Y = np.linspace(-10, 10, 1000)
    X, Y = np.meshgrid(X, Y)
    XY = np.stack((X, Y), axis=2)
    mu = np.array([mu_0, mu_1])
    sigma_10 = sigma_01 # Covariance matrix is symmetric

    C = np.array([[sigma_00, sigma_01], [sigma_10, sigma_11]])
    try:
        mvn_dist = MultivariateNormal(torch.from_numpy(mu), torch.from_numpy(C))
        Z = mvn_dist.log_prob(torch.from_numpy(XY)).exp().numpy()
        # Plot the surface.
        _reset_plot(ax_1)
        ax_1.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)
    except RuntimeError:
        print("Error!: Covariance matrix cannot be singular!")
        ax_1.clear()
    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  


interactive(children=(FloatSlider(value=0.0, description='mu_0', max=5.0, min=-5.0, step=0.25), FloatSlider(va…

### Contour plots obtained from Bivariate Normal Distributions

For any Bivariate normal distribution, the plot of the surface $p (x, y)$ against $(x, y)$ looks like a bell in 3D space. The shape of the bell’s base, on the $(x, y)$ plane, is governed by the 2x2 matrix $\sum$

If $\sum$ is a diagonal matrix with equal diagonal elements, the bell is symmetric in all directions,
its base is circular

If $\sum$ is a diagonal matrix with unequal diagonal elements, the base of the bell is elliptical.
The axes of the ellipse are aligned with coordinate axes.

For general $\sum$ matrix the base of the bell is elliptical. The axes of the ellipse are not necessarily
aligned with coordinate axes.

Observe the following as you interact with the visualization
- When $\mu_{0}$ increases, the base of the bell shifts along the X-axis.
- When $\mu_{1}$ increases, the base of the bell shifts along the Y-axis.
- When $\sigma_{00}$ increases, the spread along the X-axis increases.
- When $\sigma_{11}$ increases, the spread along the Y-axis increases.

Note: In order to run this section, please download the notebook. Interactive snippets do not work online. 

In [12]:
fig_2, ax_2 = plt.subplots(nrows=1, ncols=1)


@interact
def plot_2d_normal_contour(
    mu_0=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),
    mu_1=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),
    sigma_00=widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0),
    sigma_01=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),
    sigma_11 =widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0)):

    def _reset_plot(ax):
        ax.clear()
        ax_2.set_title("Base of the Bivariate Normal Distribution")
        
    X = np.linspace(-10, 10, 1000)
    Y = np.linspace(-10, 10, 1000)
    X, Y = np.meshgrid(X, Y)
    XY = np.stack((X, Y), axis=2)
    mu = np.array([mu_0, mu_1])
    sigma_10 = sigma_01 # Covariance matrix is symmetrical

    C = np.array([[sigma_00, sigma_01], [sigma_10, sigma_11]])
    try:
        mvn_dist = MultivariateNormal(torch.from_numpy(mu), torch.from_numpy(C))
        Z = mvn_dist.log_prob(torch.from_numpy(XY)).exp()
        _reset_plot(ax_2)
        ax_2.contour(Z)
    except RuntimeError:
        print("Error!: The covariance matrix must not be singular")
        ax_2.clear()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=0.0, description='mu_0', max=5.0, min=-5.0, step=0.25), FloatSlider(va…