In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.utils import shuffle

In [None]:
n_colors = 7

In [None]:
# Load input image
input_image = cv2.imread("../data/example.jpeg")
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

# Transform image
input_image = np.array(input_image, dtype=np.float64) / 255
w, h, d = original_shape = tuple(input_image.shape)
assert d == 3
image_array = np.reshape(input_image, (w * h, d))

# Display image
plt.clf()
plt.axis("off")
plt.title("Original image (96,615 colors)")
plt.imshow(input_image)
plt.show()

In [None]:
# Fitting model on a small sub-sample of the data
image_array_sample = shuffle(image_array, random_state=0, n_samples=1_000)
kmeans = KMeans(n_clusters=n_colors, n_init="auto", random_state=0).fit(
    image_array_sample
)

# Get labels for all points
# Predicting color indices on the full image
labels = kmeans.predict(image_array)

# Display quantized image
plt.clf()
plt.axis("off")
plt.title(f"Quantized image ({n_colors} colors, K-Means)")
plt.imshow(kmeans.cluster_centers_[labels].reshape(w, h, -1))
plt.show()

In [None]:
# Display palette
fig, ax = plt.subplots()
palette_color_list = (kmeans.cluster_centers_ * 255).astype(int)
for i, color in enumerate(palette_color_list):
    hexa = '#%s' % ''.join(('%02x' % p for p in color))  # rgb to hex
    print(f"Color #{i + 1}: {hexa} ({100 * np.count_nonzero(labels == i) / labels.size:.2f}%)")
    ax.bar(i, 1, color=hexa)
ax.set_axis_off()
plt.title("Color palette")
plt.show()