In [12]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from matplotlib import animation
from IPython.display import HTML

In [31]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib import animation
from IPython.display import HTML

np.random.seed(1)

# data
true_means = np.array([[0, 0], [4, 4], [-4, 4]])
true_covs = np.array([[[0.8, 0.2], [0.2, 0.5]], [[0.6, -0.2], [-0.2, 0.6]], [[0.5, 0.0], [0.0, 0.9]]])
true_weights = np.array([0.4, 0.35, 0.25])

n = 400
components = np.random.choice(len(true_weights), size=n, p=true_weights)
X = np.vstack([np.random.multivariate_normal(true_means[k], true_covs[k], size=1) for k in components])
X += 0.02 * np.random.randn(*X.shape)

def multivariate_gaussian(x, mean, cov):
    D = mean.size
    x = np.atleast_2d(x)
    det = np.linalg.det(cov)
    if det <= 0:
        cov = cov + 1e-6 * np.eye(D)
        det = np.linalg.det(cov)
    inv = np.linalg.inv(cov)
    diff = x - mean
    exponent = -0.5 * np.sum(diff @ inv * diff, axis=1)
    denom = np.sqrt((2 * np.pi) ** D * det)
    return np.exp(exponent) / denom

def cov_ellipse_params(cov, nsig=2.0):
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    vals, vecs = vals[order], vecs[:, order]
    width, height = 2 * nsig * np.sqrt(vals)
    angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0]))
    return width, height, angle

# EM
K = 3
D = X.shape[1]
rng = np.random.RandomState(2)
indices = rng.choice(n, K, replace=False)
means = X[indices].copy()
covs = np.array([np.cov(X.T) + 0.5 * np.eye(D) for _ in range(K)])
weights = np.ones(K) / K

def em_step(X, means, covs, weights):
    n = X.shape[0]
    K = means.shape[0]
    r = np.zeros((n, K))
    for k in range(K):
        r[:, k] = weights[k] * multivariate_gaussian(X, means[k], covs[k])
    r_sum = r.sum(axis=1, keepdims=True)
    r_sum[r_sum == 0] = 1e-16
    r /= r_sum
    Nk = r.sum(axis=0)
    weights_new = Nk / n
    means_new = (r.T @ X) / Nk[:, None]
    covs_new = np.zeros_like(covs)
    for k in range(K):
        diff = X - means_new[k]
        covs_new[k] = (r[:, k][:, None] * diff).T @ diff / Nk[k]
        covs_new[k] += 1e-6 * np.eye(D)
    return means_new, covs_new, weights_new, r

iterations = 40
history = {'means': [], 'covs': [], 'weights': [], 'r': []}

for it in range(iterations):
    history['means'].append(means.copy())
    history['covs'].append(covs.copy())
    history['weights'].append(weights.copy())
    _, _, _, r = em_step(X, means, covs, weights)
    history['r'].append(r.copy())
    means, covs, weights, _ = em_step(X, means, covs, weights)

history['means'].append(means.copy())
history['covs'].append(covs.copy())
history['weights'].append(weights.copy())
_, _, _, r = em_step(X, means, covs, weights)
history['r'].append(r.copy())

fig, ax = plt.subplots(figsize=(6,6))
ax.set_xlim(X[:,0].min() - 2, X[:,0].max() + 2)
ax.set_ylim(X[:,1].min() - 2, X[:,1].max() + 2)
ax.set_title("GMM fitting with EM: iteration 0")
scat = ax.scatter(X[:,0], X[:,1], s=20, alpha=0.8)

colors = plt.cm.tab10(np.arange(K))

ellipses = [patches.Ellipse((0,0), 0, 0, angle=0, linewidth=2, fill=False) for _ in range(K)]
for e, c in zip(ellipses, colors):
    e.set_edgecolor(c)
    ax.add_patch(e)

mean_plots = [ax.scatter([], [], marker='x', s=100, linewidths=3) for _ in range(K)]
legend_plots = [patches.Patch(color=colors[k], label=f'comp {k}') for k in range(K)]
ax.legend(handles=legend_plots, loc='upper right')

def init():
    scat.set_offsets(X)
    initial_means = history['means'][0]
    for idx, mp in enumerate(mean_plots):
        mp.set_offsets([initial_means[idx]])
    for e in ellipses:
        e.set_visible(False)
    ax.set_title("GMM fitting with EM: iteration 0")
    return [scat] + mean_plots + ellipses

def update(frame):
    it = frame
    means = history['means'][it]
    covs = history['covs'][it]
    r = history['r'][it]
    composite_colors = (r @ colors[:, :3])
    composite_colors = composite_colors / composite_colors.max(axis=0)
    scat.set_color(np.clip(composite_colors, 0, 1))
    for k, mp in enumerate(mean_plots):
        mp.set_offsets([means[k]])
        mp.set_edgecolor(colors[k])
    for k, e in enumerate(ellipses):
        w, h, ang = cov_ellipse_params(covs[k], nsig=2.0)
        e.width, e.height, e.angle = w, h, ang
        e.center = (means[k,0], means[k,1])
        e.set_visible(True)
        e.set_alpha(0.9)
    ax.set_title(f"GMM fitting with EM: iteration {it}")
    return [scat] + mean_plots + ellipses

anim = animation.FuncAnimation(fig, update, frames=len(history['means']), init_func=init,
                               interval=400, blit=True, repeat=False)
plt.close(fig)
HTML(anim.to_jshtml())


In [32]:
save_path = "../../assets/images/clustering/gaussian_mix_expectation_maximization.gif"

from matplotlib.animation import PillowWriter
anim.save(save_path, writer=PillowWriter(fps=3))

