In [None]:
import numpy as np
from sklearn.cluster import KMeans
import cv2
import matplotlib.pyplot as plt

In [None]:
def get_color_palette(img, n_colors):
    # Calculate the color palette of the image
    img_array = np.array(img)
    img_array = img_array.reshape((img_array.shape[0] * img_array.shape[1], 3))
    kmeans = KMeans(n_clusters=n_colors, random_state=0, n_init="auto").fit(img_array)
    color_palette = kmeans.cluster_centers_
    return color_palette

In [None]:
def replace_colors(img, palette_a, palette_b):
    # Replace the colors of the image according to the palettes
    img_array = np.array(img)
    img_shape = img_array.shape
    img_array = img_array.reshape((img_shape[0] * img_shape[1], 3))
    index_a = np.argmin(np.linalg.norm(img_array[:, None, :] - palette_a[None, :, :], axis=-1), axis=1)
    new_colors = palette_b[index_a]
    new_img_array = new_colors.reshape((img_shape[0] * img_shape[1], 3))
    new_img_array = np.uint8(np.round(new_img_array))
    new_img = new_img_array.reshape((img_shape[0], img_shape[1], 3))
    return new_img

In [None]:
# Load input image
img = cv2.imread("../data/example.jpeg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Display the input image
plt.clf()
plt.axis("off")
plt.title(f"Input image")
plt.imshow(img)
plt.show()

In [None]:
# Calculate the color palette of the image
palette_a = get_color_palette(img, 3)

# Define output color palette
palette_b = np.array([
    [8, 16, 63],
    [193, 199, 196],
    [142, 132, 107],
    [154, 19, 15],
    [82, 86, 69],
    [17, 101, 48],
    [8, 8, 7],
])
palette_b = palette_b[:, [2, 1, 0]]

# Replace the colors of the input image according to the output palette
new_img = replace_colors(img, palette_a, palette_b)

# Display the input and output images
plt.clf()
plt.axis("off")
plt.title(f"Output image")
plt.imshow(new_img)
plt.show()