# Practice With K-Means

In this exercise we'll use k-means to cluster the pixels in an image, and display the image using only the resulting centroid colors. This process is called _vector quantization_, a method sometimes used for data compression.

Here we're not training a model for prediction. Instead, we're trying to group our data points (the pixels in the image) into clusters. Ideally, we might want the clusters to correspond to the different objects in the image (e.g. a cat or the floor), although we probably won't achieve that with just RGB features.

Working in pairs, complete each of the TO DO's listed in the notebook below. If you have time, go back and try the optional variations. If you get stuck, raise your hand and we will come around to help you.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from sklearn.cluster import KMeans
from PIL import Image

In [None]:
# Load the image from your local computer
img = np.array(Image.open("./example_images/Fuzz.jpg"))

### **TO DO**: Normalize pixel values to the range `[0,1]`

(hint: first change the data type, then divide by 255 to rescale values)

In [None]:
pixels = img.astype("float32") / 255

### **TO DO**: Display the image

(hint: use the `plt.imshow` function)

Also print the image dimensions (i.e. shape)

In [None]:
fix, ax = plt.subplots(1,1)
ax.imshow(pixels)
ax.set_title(f"Image with dimensions: {pixels.shape[1]} x {pixels.shape[0]}")

### **TO DO**: Reshape the image into a feature vector

To start we'll use the Red, Green, and Blue values for each pixel as the features.

Our resulting data matrix should have the shape `[num_samples, num_features]`, where each row is a sample (i.e. a pixel) and each column is a feature (i.e. a color).

In [None]:
num_samples = pixels.shape[0] * pixels.shape[1]
num_features = pixels.shape[2]
feature_matrix = np.reshape(pixels, [num_samples, num_features])

### _Optional Variation_:

Try using a different feature vector to describe the pixels. For example, convert the image to HSV (Hue, Saturation, Value) using the `matplotlib.colors.rgb_to_hsv` function (docs [here](https://matplotlib.org/stable/api/_as_gen/matplotlib.colors.rgb_to_hsv.html)). Does this improve the quality of your clusters?

In [None]:
feature_matrix = colors.rgb_to_hsv(feature_matrix)

### **TO DO**: Fit our k-means model to the data

When creating your model, specify the number of clusters (`n_clusters=3`), the initialization method (`init='random'`), and the number of times to re-run the k-means algorithm (`n_init=1`).

Then call the `.fit()` function using your RGB data. You can get the labels from `.labels_` and the centroids from `.cluster_centers_`.

The documentation for the k-means method can be found [here](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html).

In [None]:
n_clusters = 4
model = KMeans(n_clusters=n_clusters, init="random", n_init=1)
model.fit(feature_matrix)

labels = model.labels_
centroids = model.cluster_centers_

#labels = np.reshape(model.labels_, [pixels.shape[0], pixels.shape[1]])
#fix, ax = plt.subplots(1,1)
#ax.imshow(labels)

### _Optional Variation_:

Try changing the parameters of your k-means model, e.g. `n_clusters` or `init`. How does the quality of the clusters change?

### Display the resulting images

This code displays the resulting clusters for you, assuming you have already created the following variables:

`img`, `n_clusters`, `labels`, `centroids`

The left image shows a colorized version of the clusters, and the right image replaces each pixel with the value of its closest centroid.

In [None]:
# Pick the colors for the colorized image
cmap = plt.cm.plasma
centroid_colors = np.array([cmap(i)[:3] for i in np.linspace(0, 1, n_clusters)])

# Create arrays shaped the same as the original image, but with new colors
# corresponding to the labels from our k-means model
img_clusters = np.array([centroids[i] for i in labels]).reshape(img.shape)
img_cluster_labels = np.array([centroid_colors[i] for i in labels]).reshape(img.shape)

In [None]:
# Convert back from HSV to RGB if needed!
img_clusters = colors.hsv_to_rgb(img_clusters)

# Display the images side by side
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(6,3))
ax[0].imshow(img_cluster_labels)
ax[0].axis('off')
ax[1].imshow(img_clusters)
ax[1].axis('off')
plt.show()