In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.mixture import GaussianMixture
from sklearn.datasets import load_iris

In [None]:
#load the iris dataset

data = load_iris(as_frame = True)
df = data['data']
target_names = data['target_names']
df['species'] = [target_names[i] for i in data['target'].values]
df.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

Just like in the KMeans practice notebook, we are not going to standardize the variables because we know they are all measured on the same scale, but also just like with KMeans it is usually a good idea to standardize your variables before fitting a guassian mixture model.

Let's again start by plotting the variables that we are clustering on and coloring the points by species. Those are the clusters we will hope to learn with our Gaussian mixture model.

In [None]:
sns.pairplot(df, hue = 'species')
plt.show()

Now we will fit a guassian mixture model with 3 components (one for each species). Then we will get the cluster labels from our gmm, and recreate the plot from above but this time color the points by our cluster labels. Hopefully we find similar clusters as we get when we color by species.

In [None]:
#initialize the gaussian mixture model
gmm = GaussianMixture(n_components = 3)

#fit the model to the data
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].values
gmm.fit(X)

#define the cluster labels for each point. Note, we could also run predict_proba to get the probability of a point being in each cluster instead of the cluster label. We can't do that with KMeans
df['cluster_label'] = gmm.predict(X)

#This step is optional, but the cluster labels are integers, so if we did the seaborn pairplot with hue=cluster_label,
#it would assume a continuous hue. IF we want to replicate the results from our earlier plot when we colored by species, we need a discrete hue.
#We can trick seaborn into thinking the integers are discrete by changing their data types to strings.
df['cluster_label'] = df.cluster_label.astype(str)

#now we can see how good our KMeans clustering did by plotting the same pairplot we did before but using our label instead of the species
sns.pairplot(df, hue = 'cluster_label')
plt.show()

As we hoped, our GMM learned similar clusters as the ones we see when we color the points by species.

Since a GMM is a probability distribution, you can also simulate/generate new points, which we illustrate in the following cell.

In [None]:
#simulate 1000 new points
sim_irises, labels = gmm.sample(1000)
sim_df = pd.DataFrame(data = sim_irises, columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
sim_df['label'] = labels
sim_df.label = sim_df.label.astype(str)

sns.pairplot(sim_df, hue = 'label')
plt.show()