In [None]:
from typing import Optional, Sequence
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import pandas as pd
import torch
from torch import Tensor
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal

In [None]:
plt.rcParams.update({
    'figure.titlesize': 12,
    'axes.titlesize':   10,
    'axes.labelsize':   10,
    'font.size':        8,
    'xtick.labelsize':  8,
    'ytick.labelsize':  8,
    'legend.fontsize':  8,
    'lines.linewidth':  1,
})

COLORS = ['red', 'blue', 'green', 'orange', 'purple',
          'brown', 'pink', 'gray', 'olive', 'cyan',
          'tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:purple',
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

<font size="+3">1-D data (Salary)</font>

Load & preview data

In [None]:
RESOLUTION = 500
NUM_CLASSES = 2
TIER_THRESHOLDS = [-0.5, 0.5]
NUM_EPOCHS = 500
BOUNDARIES = [[-3, 3]]

In [None]:
dataset = pd.read_csv('./data/Salary_Data.csv')
X_train = torch.tensor(dataset.iloc[:, 1].values, dtype=torch.float).unsqueeze(dim=1)
# Normalize data
X_mean = X_train.mean(dim=0)
X_std = X_train.std(dim=0)
X_train = (X_train - X_mean)/X_std

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 3), constrained_layout=True)
ax.set(
    xlim=BOUNDARIES[0],
    xlabel='Salary',
    title='Salary Distribution (Normalized)'
)
ax.hist(X_train.squeeze(dim=1), bins=50, range=(-3, 3))
ax.scatter(
    x=X_train.squeeze(dim=1),
    y=torch.zeros_like(X_train.squeeze(dim=1)),
    color='black'
)
pass

What if we have labels?  

Assuming:
+ Salary < -0.5 standard deviation are the poor
+ Salary > 0.5 standard deviation are the rich
+ Others: mixed

With labels, we can easily construct the distribution by calculating the mean and standard deviation for each class

In [None]:
tier = -torch.ones_like(X_train, dtype=torch.long)
tier[X_train<TIER_THRESHOLDS[0]] = 0
tier[X_train>=TIER_THRESHOLDS[1]] = 1
mid_tier = ((X_train>=TIER_THRESHOLDS[0]) & (X_train<TIER_THRESHOLDS[1])).nonzero().squeeze(dim=1)
tier[mid_tier] = torch.randint(low=0, high=NUM_CLASSES, size=tier[mid_tier].shape)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 3), constrained_layout=True)
ax.set(
    xlim=BOUNDARIES[0],
    xlabel='Salary',
    title='Salary Distribution (labeled data)'
)
x_plot = torch.linspace(*BOUNDARIES[0], steps=RESOLUTION).unsqueeze(dim=1)
for k in range(NUM_CLASSES):
    ax.scatter(
        x=X_train[tier==k],
        y=torch.zeros_like(X_train[tier==k]), 
        color=COLORS[k],
    )

    # continue # Comment out to see Gaussians
    gaussian = Normal(
        loc=X_train[tier==k].mean(dim=0),
        scale=X_train[tier==k].std(dim=0),
    )
    y_plot = gaussian.log_prob(value=x_plot).exp()
    ax.plot(x_plot, y_plot, color=COLORS[k])

Back to our Unsupervised problem  
(Gaussian Mixture sometimes cannot find solutions, if so rerun the block)

In [None]:
from sklearn.mixture import GaussianMixture

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 3), constrained_layout=True)
ax.set(
    xlim=BOUNDARIES[0],
    xlabel='Salary',
    title='Salary Distribution (Normalized) - Gaussian Mixture from sklearn'
)
ax.hist(X_train.squeeze(dim=1), bins=50, range=BOUNDARIES[0])

gauss_mix = GaussianMixture(n_components=2, init_params='random')
gauss_mix.fit(X_train)
y_gm = torch.tensor(gauss_mix.predict(X_train)).unsqueeze(dim=1)
ax.plot(x_plot, torch.zeros_like(x_plot), color='black')
for k in range(NUM_CLASSES):
    ax.scatter(x=X_train[y_gm==k], y=torch.zeros_like(X_train[y_gm==k]), color=COLORS[k])

    # continue # Comment out to see Gaussians
    gaussian = Normal(
        loc=torch.tensor(gauss_mix.means_[k]).squeeze(),
        scale=torch.tensor(gauss_mix.covariances_[k]).squeeze().sqrt(),
    )
    y_plot = gaussian.log_prob(value=x_plot).exp()
    ax.plot(x_plot, y_plot, color=COLORS[k])

K-means clustering

**Inputs:** Data $X$, number of clusters $K$  
**Outputs:** Centroids $\{m^{(k)}\}_{k=1}^K$  
**Algorithm**:
1. Init centroids $\{m^{(k)}\}_{k=1}^K$ to random points in $X$
2. WHILE (not converge):
    1. Assignment step: Assign each observation to the nearest cluster (centroid)
        + Points belong to the $k$-th cluster: $S^{(k)} = \{x_p:\|x_p-m^{(k)}\|^2 \le \|x_p-m^{(j)}\|^2 \forall j, 1\le j\le K\}$
    2. Update step: Recalculate centroids with current points in the cluster
        + $m^{(k)} = \frac{1}{\|S^{(k)}\|} \sum_{x_j \in S^{(k)}}{x_j}$

In [None]:
class KMeansClassifier:
    def __init__(self, X_train:Tensor, K:int):
        self.X_train = X_train
        self.K = K

        self.centroids:Tensor = self.X_train[torch.randperm(self.X_train.shape[0])][range(self.K)]
    
    # Fix centroids, update labels
    def forward(self, X_train:Tensor) -> Tensor:
        distance = torch.zeros([X_train.shape[0], self.K])
        for k in torch.arange(self.K):
            distance[:, k] = torch.sqrt(torch.sum((self.centroids[k, :] - X_train)**2, dim=1))
        
        (_, yhat) = torch.min(distance, dim=1, keepdim=True)        
        return yhat

    # Fix labels, update centroids
    def backward(self, X_train:torch.Tensor, yhat:torch.Tensor):
        for k in torch.arange(self.K):
            self.centroids[k, :] = torch.mean(X_train[yhat.squeeze() == k, :], dim=0)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 3), constrained_layout=True)
fig.suptitle('Salary Distribution (Normalized) - K-means')
ax.set(xlim=BOUNDARIES[0])

ax.hist(X_train.squeeze(dim=1), bins=50, range=(-3, 3))
km = KMeansClassifier(X_train=X_train, K=2)
for j in range(NUM_EPOCHS):
    y_km = km.forward(X_train=X_train)
    km.backward(X_train=X_train, yhat=y_km)
y_km = km.forward(X_train)

for k in range(km.K):
    ax.scatter(x=X_train[y_km==k], y=torch.zeros_like(X_train[y_km==k]), color=COLORS[k])

    # gaussian = MultivariateNormal(loc=mus[k], covariance_matrix=sigmas[k])
    # y_plot = gaussian.log_prob(value=x_plot.unsqueeze(dim=1)).exp()
    # ax.plot(x_plot, pis[k]*y_plot*5, color=COLORS[k])
    ax.axvline(x=km.centroids[k], color=COLORS[k], label=f'Centroid {k}')

ax.step(x_plot.squeeze(dim=1), km.forward(x_plot).squeeze(dim=1), color='black', label='Class')
ax.legend()
pass

Expectation-Maximization for 1D Gaussian Mixture

**Inputs:** Data $X_{N\times d}$, number of Gaussians $K$, priors $\{\pi^{(k)}\}_{k=1}^K$  
**Outputs:** Gaussians $\{\theta^{(k)}\}_{k=1}^K$, where $\theta^{(k)} = \{\mu^{(k)}, \sigma^{(k)}\}$  
**Algorithm**:
1. Init Gaussians $\{\theta^{(k)}\}_{k=1}^K$
2. WHILE (not converge):
    1. Expectation step
    2. Maximization step

**Expectation step:**  
Construct Gaussian $\{\theta^{(k)}\}_{k=1}^K$  
Compute the responsibilities matrix $R_{N\times K}$, i.e., **class-specific weights** for each sample
$$R_{i, k} = \frac{\pi^{(k)} \phi_{\theta^{(k)}}(x_i)}
                  {\sum_{k'=1}^K {\pi^{(k')} \phi_{\theta^{(k')}}(x_i)}}$$

**Maximization step:**  
Update the weighted mean, weighted variance, and (optionally) priors for the Gaussians:
- Mean: $\mu^{(k)} = \frac{R_{:, k}\times X}{\sum{R_{:, k}}}$
- Variance: ${\sigma^{(k)}}^2 = \frac{R_{:, k}\times (X - \mu^{(k)})^2}{\sum{R_{:, k}}}$
- Prior: $\pi^{(k)} = \frac{R_{:, k}}{N}$

**Side note**:
1. $\mu^{(i)}$ are initialized to random points sampled from a standard Normal distribution
2. $\sigma^{(i)}$ are initialized to standard deviation of $X$
3. $\pi^{(i)}$ are initialized to $1/k$
4. $\pi^{(i)}$ may or may not be trainable

In [None]:
class ExpectationMaximization:
    """The Expectation-Maximization algorithm with Gaussian Mixture.

    Args:
    + `X_train`: Input data of shape [N * d].
    + `K`: Number of Gaussians.
    + `mus`: Mean of Gaussians, must be of shape [K * d]. Defaults to `None`,   \
        leave to initialize from a standard normal distribution.
    + `Sigmas`: Covariance matrix of Gaussians, must be of shape [K * d * d].   \
        Defaults to `None`, leave to initialize as the covariance matrix of     \
        `X_train` for each Gaussian.
    + `pis`: Prior of Gaussians, must be of shape [K]. Defaults to `None`, leave\
        to initialize as `1/K` for each Gaussian.
    + `trainable_mus`: Flag to set `mus` trainable. Defaults to `True`.
    + `trainable_Sigmas`: Flag to set `Sigmas` trainable. Defaults to `True`.
    + `trainable_pis`: Flag to set `pis` trainable. Defaults to `True`.
    """
    def __init__(
        self,
        X_train:Tensor,
        K:int,
        mus:Optional[Sequence[float]]=None,
        Sigmas:Optional[Sequence[float]]=None,
        pis:Optional[Sequence[float]]=None,
        trainable_mus:bool=True,
        trainable_Sigmas:bool=True,
        trainable_pis:bool=True,
    ):
        self.X_train = X_train
        self.d = self.X_train.shape[1]
        self.K = K
        self.trainable_mus = trainable_mus
        self.trainable_Sigmas = trainable_Sigmas
        self.trainable_pis = trainable_pis

        if mus is None:
            self.mus = torch.normal(mean=0, std=1, size=[self.K, self.d])
        else:
            self.mus = torch.tensor(mus)

        if Sigmas is None:
            self.Sigmas = self.X_train.t().cov()
            # Ensure Sigmas is of shape [k, d, d]
            if self.d == 1:
                self.Sigmas = self.Sigmas.unsqueeze(dim=-1).unsqueeze(dim=-1)
            self.Sigmas = self.Sigmas.tile(dims=[self.K, 1, 1])
        else:
            self.Sigmas = torch.tensor(Sigmas)

        if pis is None:
            self.pis = torch.ones(size=[self.K])/self.K
        else:
            self.pis = torch.tensor(pis)
            # Normalize: sum of priors is 1
            self.pis = self.pis/self.pis.sum(dim=0, keepdim=True)

        self.responsibility = torch.zeros(size=[self.X_train.shape[0], self.K])
        self.gaussians = [MultivariateNormal(loc=self.mus[k], covariance_matrix=self.Sigmas[k]) for k in range(self.K)]
    
    def expectation_step(self) -> Tensor:
        for k in range(self.K):
            self.responsibility[:, k] = self.pis[k]*(self.gaussians[k].log_prob(self.X_train).exp())
        # Normalize: sum of prob. for each data point is 1
        self.responsibility = self.responsibility/self.responsibility.sum(dim=1, keepdim=True)
        return self.responsibility

    def maximization_step(self) -> (Tensor, Tensor, Tensor):
        for k in range(self.K):
            mean = (self.responsibility[:, [k]]*self.X_train).sum(dim=0)/self.responsibility[:, [k]].sum(dim=0)
            covariance = (self.responsibility[:, [k]]*(self.X_train - mean)).t()@(self.X_train - mean) \
                         /self.responsibility[:, [k]].sum(dim=0)

            if self.trainable_mus == True:
                self.mus[k] = mean
            if self.trainable_Sigmas == True:
                self.Sigmas[k] = covariance
            if self.trainable_pis is True:
                prior = self.responsibility[:, [k]].mean(dim=0)
                self.pis[k] = prior
        
        # Update Gaussians
        self.gaussians = [MultivariateNormal(loc=self.mus[k], covariance_matrix=self.Sigmas[k]) for k in range(self.K)]
        return self.mus, self.Sigmas, self.pis
    
    def predict(self, input:Tensor) -> Tensor:
        pred = torch.zeros(size=[input.shape[0], self.K])

        for k in range(self.K):
            pred[:, k] = self.pis[k]*(self.gaussians[k].log_prob(input).exp())
        # Normalize: sum of prob. for each data point is 1
        pred = pred/pred.sum(dim=1, keepdim=True)
        return pred

    def __repr__(self) -> str:
        params = {
            'X_train':          self.X_train,
            'K':                self.K,
            'mus':              self.mus,
            'Sigmas':           self.Sigmas,
            'pis':              self.pis,
            'trainable_mus':    self.trainable_mus,
            'trainable_Sigmas': self.trainable_Sigmas,
            'trainable_pis':    self.trainable_pis,
        }
        return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in params.items() if v is not None])})"

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 6), constrained_layout=True, squeeze=False, sharex='all', sharey='all')
fig.suptitle('Salary Distribution (Normalized) - Expectation-Maximization algorithm')
ax[0, 0].set(xlim=BOUNDARIES[0])

for i, trainable_pis in enumerate([False, True]):
    ax[i, 0].hist(X_train.squeeze(dim=1), bins=50, range=BOUNDARIES[0])
    em = ExpectationMaximization(X_train=X_train, K=4, trainable_pis=trainable_pis, pis=[1, 1, 1, 1])
    for j in range(500):
        r = em.expectation_step()
        mus, sigmas, pis = em.maximization_step()
    pred = em.predict(X_train)

    for k in range(em.K):
        ax[i, 0].scatter(x=X_train[pred.argmax(dim=1)==k], y=torch.zeros_like(X_train[pred.argmax(dim=1)==k]), color=COLORS[k])

        gaussian = MultivariateNormal(loc=mus[k], covariance_matrix=sigmas[k])
        y_plot = gaussian.log_prob(value=x_plot.unsqueeze(dim=1)).exp()
        ax[i, 0].plot(x_plot.squeeze(dim=1), pis[k]*y_plot.squeeze(dim=1)*5, color=COLORS[k], label=f'Gaussian {k}')
    
    ax[i, 0].step(x_plot, em.predict(x_plot).argmax(dim=1), color='black', label='Class')

    if trainable_pis == True:
        ax[i, 0].set(title=f'Trainable Prior ({em.pis})')
    else:
        ax[i, 0].set(title=f'Fixed Prior ({em.pis})')

ax[0, 0].legend()
pass

<font size="+3">2-D data (Gaussian clusters)</font>

In [None]:
NUM_CLUSTERS = 3
NUM_EPOCHS = 100
BOUNDARIES = [[-3, 3], [-3, 3]]
# For PLOT_STEP = 0.05: RESOLUTION = (MAX - MIN BOUNDARIES)//PLOT_STEP + 1 = 121 --> 121*121 pixel pcolormesh
# Smaller PLOT_STEP = more accurate plot, longer inference time
PLOT_STEP = 0.05

In [None]:
def set_axes_equal(ax):
    '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc.. This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().

    Self-note: Thanks stranger on the Internet.
    '''
    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])
    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
    ax.set_box_aspect((1, 1, 1))

For EM with 2D Gaussian Mixture, we parameterized the Gaussians with covariance matrices $\{{\Sigma^{(k)}}\}_{k=1}^K$ instead of standard deviation $\{{\sigma^{(k)}}\}_{k=1}^K$.

Thus, in the Maximization step, we follow [this formula](https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Weighted_sample_covariance) to compute the weighted covariance matrix. (Each sample has a specific weight to the overall covariance - in this case the weight is embedded in the responsibility matrix $R$).
- Weighted covariance matrix: $\Sigma^{(k)} = \frac{\left(R_{:, k}\times(X - \mu^{(k)})\right)^T(X - \mu^{(k)})}{\sum{R_{:, k}}}$


In [None]:
def get_clusters_2D(
    num_clusters:int=2,
    radius:float=1,
    cluster_scale:float=0.2,
    num_examples:int=600,
) -> Sequence[Tensor]:
    """Generates 2D Gaussian clusters evenly spaced around the origin.

    Args:
    + `num_clusters`: Number of clusters. Defaults to `2`.
    + `radius`: Distance of cluster to origin. Defaults to `1`.
    + `cluster_scale`: Scale of cluster, i.e., the covariance matrix of each    \
        cluster is `cluster_scale` times identity matrix. Defaults to `0.2`.
    + `num_examples`: Number of examples. Defaults to `600`.

    Returns:
    + Shuffled inputs and labels of data points the clusters.
    """    
    pi = torch.acos(torch.zeros(1))* 2
    # mu and Sigma for Gaussian distributions
    mus = torch.cat(
        [
            radius*torch.cos(2*pi*torch.arange(num_clusters)/num_clusters).unsqueeze(dim = 1),
            radius*torch.sin(2*pi*torch.arange(num_clusters)/num_clusters).unsqueeze(dim = 1),
        ],
        dim = 1,
    )
    Sigmas = (cluster_scale*torch.eye(2)).unsqueeze(dim=0).tile(dims=[num_clusters, 1, 1])

    # Pre-allocate x and y
    examples_per_cluster = num_examples//num_clusters
    x = torch.empty(size=[0, 2])
    y = torch.empty(size=[0, 1])
    for k in range(num_clusters):
        # Sample x from Gaussian distributions
        new_x = MultivariateNormal(loc=mus[k], covariance_matrix=Sigmas[k]).sample([examples_per_cluster])
        new_y = torch.tensor([[k]]*examples_per_cluster)
        x = torch.cat([x, new_x], dim = 0)
        y = torch.cat([y, new_y], dim = 0)
    # Shuffle data
    shuffle = torch.randperm(x.shape[0])
    return x[shuffle], y[shuffle]

In [None]:
# Task 1: Classification for 2D clusters
X_train, y_train = get_clusters_2D(num_clusters=NUM_CLUSTERS, radius=1, cluster_scale=0.3, num_examples=600)

# Plot all centroids and examples
fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)
fig.suptitle(f'Gaussian clusters')
ax.set(xlim=BOUNDARIES[0], ylim=BOUNDARIES[1])

for k in range(NUM_CLUSTERS):
    ax.scatter(
        X_train[y_train.squeeze(dim=1)==k, 0],
        X_train[y_train.squeeze(dim=1)==k, 1],
        s=3, color=COLORS[k], label=f'Cluster {k}'
    )
ax.legend()
pass

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)
fig.suptitle(f'Gaussian clusters')
ax.set(xlim=BOUNDARIES[0], ylim=BOUNDARIES[1])
ax.scatter(X_train[:, 0], X_train[:, 1], s=3)
pass

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)
fig.suptitle(f'Gaussian clusters - K-means')
ax.set(xlim=BOUNDARIES[0], ylim=BOUNDARIES[1])
# Visualize regions of each class
# ptp_X = X_train.max(dim = 0)[0] - X_train.min(dim = 0)[0]
# x_plot = torch.arange(X_train[:, 0].min() - 0.2*ptp_X[0], X_train[:, 0].max() + 0.2*ptp_X[1], PLOT_STEP)
# y_plot = torch.arange(X_train[:, 1].min() - 0.2*ptp_X[0], X_train[:, 1].max() + 0.2*ptp_X[1], PLOT_STEP)
x0_plot = torch.arange(start=BOUNDARIES[0][0], end=BOUNDARIES[0][1]+1e-8, step=PLOT_STEP)#.unsqueeze(dim=1)
x1_plot = torch.arange(start=BOUNDARIES[1][0], end=BOUNDARIES[1][1]+1e-8, step=PLOT_STEP)#.unsqueeze(dim=1)
x0_grid, x1_grid = torch.meshgrid([x0_plot, x1_plot])
X_plot = torch.cat([x0_grid.flatten().unsqueeze(dim=1), x1_grid.flatten().unsqueeze(dim=1)], dim = 1)

km = KMeansClassifier(X_train=X_train, K=3)
for j in range(NUM_EPOCHS):
    y_km = km.forward(X_train=X_train)
    km.backward(X_train=X_train, yhat=y_km)
y_km = km.forward(X_train)

for k in range(km.K):
    # Training data
    ax.scatter(X_train[(y_km==k).squeeze(dim=1), 0], X_train[(y_km==k).squeeze(dim=1), 1],
                color=COLORS[k], alpha=0.7, s=3, zorder=100, label=f'Cluster {k}')
    ax.scatter(km.centroids[k][0], km.centroids[k][1],
               color=COLORS[k], alpha=1, s=200, marker='x', linewidth=4, zorder=100)

y_plot = km.forward(X_plot).reshape([x0_plot.shape[0], x1_plot.shape[0]])
ax.pcolormesh(x0_grid.numpy(), x1_grid.numpy(), y_plot.numpy(), cmap=ListedColormap(COLORS[0:km.K]), alpha=0.3, shading = 'auto')
ax.legend()
pass

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8, 4), constrained_layout=True, squeeze=False, sharex='all', sharey='all')
fig.suptitle(f'Gaussian clusters - Expectation-Maximization algorithm')
ax[0, 0].set(xlim=BOUNDARIES[0], ylim=BOUNDARIES[1])
# Visualize regions of each class
x0_plot = torch.arange(start=BOUNDARIES[0][0], end=BOUNDARIES[0][1]+1e-8, step=PLOT_STEP)#.unsqueeze(dim=1)
x1_plot = torch.arange(start=BOUNDARIES[1][0], end=BOUNDARIES[1][1]+1e-8, step=PLOT_STEP)#.unsqueeze(dim=1)
x1, x2 = torch.meshgrid([x0_plot, x1_plot])
X_plot = torch.cat([x1.flatten().unsqueeze(dim=1), x2.flatten().unsqueeze(dim=1)], dim=1)

for i, trainable_pis in enumerate([True, False]):
    em = ExpectationMaximization(
        X_train=X_train, K=3, trainable_pis=trainable_pis,
        pis=[4, 1, 1],
    )
    for j in range(NUM_EPOCHS):
        r = em.expectation_step()
        mus, sigmas, pis = em.maximization_step()
    y_em = em.predict(X_train)

    for k in range(em.K):
        # Training data
        ax[0, i].scatter(X_train[(y_em.argmax(dim=1)==k), 0], X_train[(y_em.argmax(dim=1)==k), 1],
                    color = COLORS[k], alpha = 0.7, s = 3, zorder = 100, label=f'Gaussian {k}')
        ax[0, i].scatter(em.mus[k][0], em.mus[k][1],
                color = COLORS[k], alpha = 1, s = 200, marker='x', linewidth=4, zorder = 100)
        
        # gaussian = MultivariateNormal(loc=mus[k], covariance_matrix=sigmas[k])
        # contour = gaussian.pdf(X_plot)
        # y_plot = gaussian.log_prob(value=x_plot.unsqueeze(dim=1)).exp()
        # ax[i, 0].plot(x_plot.squeeze(dim=1), pis[k]*y_plot.squeeze(dim=1)*5, color=COLORS[k], label=f'Gaussian {k}')
        
        # ax[i, 0].step(x_plot, em.predict(x_plot).argmax(dim=1), color='black', label='Class')

    y_plot = em.predict(X_plot).argmax(dim=1).reshape([x0_plot.shape[0], x1_plot.shape[0]])
    ax[0, i].pcolormesh(x0_grid.numpy(), x1_grid.numpy(), y_plot.numpy(), cmap = ListedColormap(COLORS[0:em.K]), alpha = 0.3, shading = 'auto')

    if trainable_pis == True:
        ax[0, i].set(title=f'Trainable Prior ({em.pis})')
    else:
        ax[0, i].set(title=f'Fixed Prior ({em.pis})')

ax[0, 0].legend()
pass

In [None]:
%matplotlib qt

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8, 4), subplot_kw={"projection": "3d"}, constrained_layout=True, squeeze=False)
for i, trainable_pis in enumerate([True, False]):
    em = ExpectationMaximization(
        X_train=X_train, K=3, trainable_pis=trainable_pis,
        # mus=torch.rand(size=(4, 2)),
        pis=[4, 1, 1],
    )
    for j in range(NUM_EPOCHS):
        r = em.expectation_step()
        mus, sigmas, pis = em.maximization_step()
    y_em = em.predict(X_train)

    for k in range(em.K):
        # Training data
        ax[0, i].scatter(
            X_train[(y_em.argmax(dim=1)==k), 0], X_train[(y_em.argmax(dim=1)==k), 1],
            color=COLORS[k], alpha=0.7, s=3, zorder=100,
        )
        ax[0, i].scatter(
            em.mus[k][0], em.mus[k][1],
            color=COLORS[k], alpha=1, s=200, marker='x', linewidth=4, zorder = 100
        )
        ax[0, i].plot_surface(
            x0_grid.numpy(), x1_grid.numpy(), em.pis[k]*em.gaussians[k].log_prob(X_plot).exp().reshape(shape=[x0_plot.shape[0], x1_plot.shape[0]]).numpy()*5,
            # color=COLORS[k],
            cmap=COLORS[k][0].capitalize() + COLORS[k][1:]+'s',#ListedColormap('white', [COLORS[k]]),
            alpha=0.3, label=f'Gassian {k}',
        )
        ax[0, i].contour3D(
            x0_grid.numpy(), x1_grid.numpy(), em.pis[k]*em.gaussians[k].log_prob(X_plot).exp().reshape(shape=[x0_plot.shape[0], x1_plot.shape[0]]).numpy()*5,
            levels=5, cmap=ListedColormap(COLORS[k], [COLORS[k]]), label=f'Gaussian {k}'
        )

    set_axes_equal(ax[0, i])
    if trainable_pis == True:
        ax[0, i].set(title=f'Trainable Prior ({em.pis})')
    else:
        ax[0, i].set(title=f'Fixed Prior ({em.pis})')

ax[0, 0].shareview(ax[0, 1])
ax[0, 0].legend()
fig.show()

In [None]:
# %matplotlib inline