In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import scipy

In [None]:
plt.rcParams['figure.figsize'] = [10., 10.]
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14 
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['legend.fontsize'] = 14

# Expectation Maximization (EM) in 1d

here we draw a few samples from two normal distributions, and then fit back two normal distributions via the EM algorithm

In [None]:
g1 = {'loc': 0, 'scale': 2}
g2 = {'loc': 2, 'scale': 1}

In [None]:
x1 = stats.norm(**g1).rvs(20)
x2 = stats.norm(**g2).rvs(20)

In [None]:
def E(x, c1, c2, e1, e2):
    ''' Expectation Step '''
    p1 = stats.norm(**c1).pdf(x)
    p2 = stats.norm(**c2).pdf(x)
    
    e1[:] = p1 / (p1 + p2)
    e2[:] = p2 / (p1 + p2)

In [None]:
def M(x, c1, c2, e1, e2):
    ''' Maximization Step '''
    c1['loc'] = np.average(x, weights=e1)
    c2['loc'] = np.average(x, weights=e2)
    
    c1['scale'] = np.sqrt(np.average(np.square(x-c1['loc']), weights=e1))
    c2['scale'] = np.sqrt(np.average(np.square(x-c2['loc']), weights=e2))

In [None]:
def plot(x, c1, c2, e1, e2):
    xspace = np.linspace(-5, 5, 1000)
    plt.plot(xspace, stats.norm(**c1).pdf(xspace), c='b')
    plt.plot(xspace, stats.norm(**c2).pdf(xspace), c='r')
    
    plt.plot(xspace, stats.norm(**g1).pdf(xspace), c='k', ls='--', alpha=0.5)
    plt.plot(xspace, stats.norm(**g2).pdf(xspace), c='k', ls='--', alpha=0.5)
    
    plt.scatter(x, np.zeros_like(x), s=100, c=e1, cmap='RdBu', edgecolor='k')

In [None]:
x = np.concatenate([x1, x2])

In [None]:
#random clusters init:

c1 = {}
c2 = {}

c1['loc'] = np.random.rand()
c2['loc'] = np.random.rand()
c1['scale'] = np.random.rand()
c2['scale'] = np.random.rand()

In [None]:
e1 = np.random.rand(x.size)
e2 = np.random.rand(x.size)

In [None]:
E(x, c1, c2, e1, e2)

This is how it looks like after random initialization

In [None]:
plot(x, c1, c2, e1, e2)

After 3 iterations

In [None]:
for i in range(3):
    E(x, c1, c2, e1, e2)
    M(x, c1, c2, e1, e2)

plot(x, c1, c2, e1, e2)

After aother 3 iterations

In [None]:
for i in range(3):
    E(x, c1, c2, e1, e2)
    M(x, c1, c2, e1, e2)

plot(x, c1, c2, e1, e2)

And aother 3 iterations

In [None]:
for i in range(3):
    E(x, c1, c2, e1, e2)
    M(x, c1, c2, e1, e2)

plot(x, c1, c2, e1, e2)