In [None]:
import json
import httpx
from PIL import Image
from io import BytesIO
import numpy as np
from sklearn.cluster import KMeans
from skimage.color import rgb2lab, lab2rgb

In [None]:
with open("../data/image_urls.json", 'r') as f:
    image_urls = json.load(f)

In [None]:
def rgb_to_palette_visualisation(rgb_colours, size=50):
    pixels = np.concatenate([
        np.array([colour] * size**2).reshape(size,size,3) 
        for colour in rgb_colours
    ], axis=1)
    return Image.fromarray(pixels)

In [None]:
def get_5d_coordinates(image):
    rgb_colour_coords = np.array(image).reshape(-1, 3)
    lab_colour_coords = rgb2lab(rgb_colour_coords)
    spatial_coords = [
        [i/image.width, j/image.height]
        for i in range(image.width)
        for j in range(image.height)
    ]
    coords = np.concatenate([lab_colour_coords, spatial_coords], axis=1)
    return coords

def cluster(coords):
    clusterer = KMeans(n_clusters=6).fit(coords)
    return clusterer.cluster_centers_

def get_palette(image):
    coords = get_5d_coordinates(image)
    dominant_points = cluster(coords)
    lab_colour_centres = dominant_points[:, :3]
    colour_centres = (lab2rgb(lab_colour_centres)*255).astype(np.uint8)
    return colour_centres

In [None]:
image_url = np.random.choice(image_urls)
image = Image.open(BytesIO(
    httpx.get(image_url).content
))

image

In [None]:
palette = get_palette(image)

In [None]:
rgb_to_palette_visualisation(palette)

## try the same thing without the two spatial dimensions

In [None]:
only_colour_centres = clusterer.fit(
    np.array(image).reshape(-1,3)).cluster_centers_.astype(np.uint8)

In [None]:
rgb_to_palette_visualisation(only_colour_centres)