In [None]:
import numpy as np
import cv2
import os
import pandas as pd
from matplotlib import pyplot as plt

In [None]:
# Load in color database
colors_small = pd.read_csv("colors_small.csv")
color_array_small = colors_small[['R','G','B']].to_numpy()
color_names_small = colors_small.Name.to_numpy()

colors_medium = pd.read_csv("colors_medium.csv")
color_array_medium = colors_medium[['R','G','B']].to_numpy()
color_names_medium = colors_medium.Name.to_numpy()

colors_large = pd.read_csv("colors_large.csv")
color_array_large = colors_large[['R','G','B']].to_numpy()
color_names_large = colors_large.Name.to_numpy()

In [None]:
# Load Image
# img = cv2.imread("../img/eardrops.jpg")
# img = cv2.imread("../img/fake_flowers.jpeg")
# img = cv2.imread("../img/leaves.jpg")
img = cv2.imread("../img/lego.png")

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


# K Means clustering

# Convert to float32 array of N x 3, where each row is a pixel (R, G, B)
pixels = np.float32(img.reshape(-1, 3))

K = 6
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
flags = cv2.KMEANS_RANDOM_CENTERS

_, labels, palette = cv2.kmeans(pixels, K, None, criteria, 2*K, flags)
_, counts = np.unique(labels, return_counts=True)

average = img.mean(axis=0).mean(axis=0)
dominant = palette[np.argmax(counts)]


# Extract top colors
count_to_color = {}
for i in range(len(counts)):
    count_to_color[counts[i]] = palette[i]

colors_sorted = []
for key in sorted(count_to_color, reverse=True):
    colors_sorted.append(np.array(count_to_color[key]))
    
# Load in color database
colors = pd.read_csv("colors_large.csv")
color_array = colors[['R','G','B']].to_numpy()
color_names = colors.Name.to_numpy()

# Match top K clusters with color databse using euclidean distance
ans_small = []
ans_medium = []
ans_large = []

for color in colors_sorted:
    dist = np.linalg.norm(color - color_array_small, axis=1)
    color_idx = np.argmin(dist)
    ans_small.append(color_names_small[color_idx])
    
    dist = np.linalg.norm(color - color_array_medium, axis=1)
    color_idx = np.argmin(dist)
    ans_medium.append(color_names_medium[color_idx])
    
    dist = np.linalg.norm(color - color_array_large, axis=1)
    color_idx = np.argmin(dist)
    ans_large.append(color_names_large[color_idx])

# Visualize results
avg_patch = np.ones(shape=img.shape, dtype=np.uint8)*np.uint8(average)

indices = np.argsort(counts)[::-1]   
freqs = np.cumsum(np.hstack([[0], counts[indices]/float(counts.sum())]))
rows = np.int_(img.shape[0]*freqs)

dom_patch = np.zeros(shape=img.shape, dtype=np.uint8)
for i in range(len(rows) - 1):
    dom_patch[rows[i]:rows[i + 1], :, :] += np.uint8(palette[indices[i]])
    
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12,6))
ax0.imshow(img)
ax0.set_title('Image')
ax0.axis('off')
ax1.imshow(dom_patch)
ax1.set_title('Dominant colors')
ax1.axis('off')
plt.show(fig)

print("Small:", ans_small[0:3])
print("Medium:", ans_medium[0:3])
print("Large:", ans_large[0:3])