# Image Compression #

We can use clustering as a means to compress images. But it's going to work a little differently than what you are used to: Instead of guessing at the number of clusters, we purposely choose the k-value depending on how much compression we desire. The lower the k-value, the higher the compression but the more distorted the picture. The higher the k-value, the lower the compression but the picture will look closer to the original. We'll be using k-values of 256. I chose this number because we can represent 256 clusters with 8 bits.

The clustering algorithm will give us a palette of RGB colors that "represent" the overall image. So for the 256 cluster algorithm, we'll have 256 colors in our palette. These palette colors are actually the centroids from the algorithm. Once K-means chooses the color palette for us, we'll replace the original pixels with these representative values (e.g., the identity of their cluster).

We'll go ahead and use the SKLearn version of K-means as our clustering algorithm. If you'd like to substitute your own version of K-means, that's OK too although it might be slower.

In [2]:
import numpy as np
from sklearn.cluster import KMeans
from matplotlib.image import imread
import matplotlib.pyplot as plt

## Preprocessing ##

Load the image file from disk and reshape it from a 2D dataset of RGB pixels to a really long sequence of RGB pixels one after the other. Basically, we're going to tack the second row to the end of the first row, the third row to the end of the second row, the fourth row to the end of the third row, and so on for the entire image.

In [None]:
IMAGE_FILE = 'hershey.bmp'

# Read the file from disk
img = imread(IMAGE_FILE)
print(f"\nnumpy shape = {img.shape}")

(height, width, depth) = img.shape
print(f"Height = {height} pixels")
print(f"Width = {width} pixels")
print(f"Depth = {depth} bytes")

print("Pixels:")
print(img)

In [None]:
# Double check that we have the right image
plt.title(IMAGE_FILE)
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
# Convert image from HxWxD array to just a long list of pixel values
print(f"before reshape = {img.shape}")
X = img.reshape(width * height, depth)
print(f" after reshape = {X.shape}")

In [None]:
# Show that the rows are just placed one after the other
print("Before:")
print(img[0][0:3])
print()
print("After:")
print(X[0:3])

## Compress the Data ##

The K-means Algorithm will cluster the individual pixels and give us $k$ representative colors that are somewhat similar to the individual samples (pixel colors) in the cluster. The centroids will be our new (compressed) RGB values and the labels will identify each pixel and which compressed color represents it.

In [None]:
NUMBER_OF_COLORS = 256

# Runs the K-means algorithm multiple times and select the best cluster score
km = KMeans(n_clusters=NUMBER_OF_COLORS, n_init='auto')
km.fit(X)
print("Palette Identified")

In [None]:
# The new pixel values are just the cluster IDs
pixels = km.labels_.astype(np.uint8)
print(f"{len(pixels)} Compressed Pixels:  {pixels[0]} {pixels[1]} ... {pixels[-2]} {pixels[-1]}")

# We need to store the original RGB colors of each pixel
# This "color palette" will map cluster IDs back to the correct RGB colors
# Since the RGB values come from our K-means centroids, they are float values
# that must be converted to integers: round then truncate for accuracy
palette = np.round(km.cluster_centers_, 0).astype(np.uint8)
print(f"{len(palette)} Palette Colors: {palette[0]} {palette[1]} ... {palette[-2]} {palette[-1]}")

### Exploring the Compressed Data ###

The original data contained pixels that were defined by three 8-bit numbers (R, G, and B). We've used K-means to select 256 representative colors. This means we'll have taken 3 x 8-bit numbers (24 bits) and compressed it down to a single 8-bit number. That's approximately a compression ratio of $3 \rightarrow 1$.

We say *approximately* the compression ratio instead of *exactly* the compression ratio becuase in addition to storing the 8-bit pixel values, we also have to store the full RGB values for the 256 colors in the palette. But this factor is negligible.

In [None]:
# original first 5 pixels
print("Original Data")
print(img[0][0:5])

# First 5 pixels with their new reprentative
print("\nCompressed Data")
print(pixels[:5])

print("\nWhat is compressed color #{}?".format(pixels[0]))
print(palette[pixels[0]])

In [None]:
print("Here are all the colors:")
print(palette)

### Viewing the Color Palette ###

Just for fun, let's see what the 16 representative colors look like. We'll create an 8x8 plot where each plot shows a 100x100 solid block of color.

In [None]:
# Let's try creating a 100x100 block and see what it looks like
block = np.full([100, 100, 3], palette[0])
block[0][:5]

In [None]:
# That worked, so now let's create 32 of these blocks, one for each color in our palette
blocks = [np.full([100, 100, 3], color) for color in palette]
for i in range(5):
    print(f"Palette Color #{i}")
    print(f"{blocks[i][0][:3]}\n")

In [None]:
# And now we'll plot them to the screen
fig, ax = plt.subplots(8, 8, figsize=(8,8))
for i in range(ax.shape[0]):
    for j in range(ax.shape[1]):
        ax[i][j].axis('off')
        color_index = i*2+j
        ax[i][j].imshow(blocks[color_index])
plt.show()

## Saving/Loading the Compressed Image ##

To save the image we would need to store the palette, the sequence of pixels (ID#s from the palette), and a few other pieces of information.

Here is each piece of data, probably in the reverse order that it should appear in our file. 

In [None]:
# 1. We'll need to save the pixels
pixels[:1000]

In [None]:
# 2. We'll need to save the palette
palette

In [None]:
# 3. How about the dimensions of the picture?
width, height

In [None]:
# 4. Don't forget that we'll need to explicitely save the number of colors in the palette (256)
len(palette)

In [None]:
# 5. Some sort of header field that identifies this as a K-means Clustering Compressed Image
header = 'KMC:'
print(header)

## Displaying a Compressed Image ##

Remember, the compressed image is nothing more than the color palette, a sequence of pixel values, and the dimensions of the picture. But there isn't any code written to display images in this format. We'll need to expand/convert our data back to a form that can be used by standard image libraries on our computer.

To display one of our compressed images, we'll need to:
 1. Expand each pixel from its compressed representative value (e.g., 0-255) to the actual 24-bit RGB pixel value (e.g., RGB: 235, 195, 182).
 2. Reshape the data back to a 2D array of RGB pixels.

In [None]:
# This creates a new array that is the same size as pixel array but with pixel data instead of cluster IDs
# Each element in the array is the result of looking up the cluster ID value in the colors32 palette
print(f"Original shape of pixels sequence: {pixels.shape}")
print(f"Sample pixels: {pixels[:5]}")
compressed_img = palette[pixels]
print(f"New shape of the expanded pixels: {compressed_img.shape}")
print(f"Sample of expanded pixels: {compressed_img[:5]}")

In [None]:
compressed_img = compressed_img.reshape(height, width, depth)
print("Final shape after converting to 2D:", compressed_img.shape)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,10))

ax[0].set_title('Original Image')
ax[0].imshow(img)
ax[0].axis('off')

ax[1].set_title('256 Color Compressed Image')
ax[1].imshow(compressed_img)
ax[1].axis('off')

plt.tight_layout()
plt.show()