# Obtaining image colour palettes
If we're going to search the collection by colour palette, we first need to be able to extract palettes from images. How the search happens (and whether it involves the palettes at all) is almost irrelevant at this stage. We know that colour palettes are going to be an integral part of whatever the final product of this project becomes, so we need them to be able to extract them.

Our goal in this first notebook is just to extract a few of the dominant colours from a given image. There are a few plausible approaches to this task, but rather than explore all of them through implementation, I've read up on them in detail and decided on what I think is the most appropriate method before starting.

### The process
The basic process is as follows: 
- Turn the image into a numpy array of shape $(h, w, 3)$, where $h$ and $w$ are the height and width of the image, and the 3 denotes the dimensions of our colour space <sup id="a1">[1](#f1)</sup> 
- Reshape the array, going from 3D to 2D, with rows representing individual pixels and columns representing the three colour channels, ie. an array of shape $(h\times w, 3)$
- Treat those three channels as axes in a 3D coordinate space, and each row as a point within that space. The distance between points should then represent the difference between colours: close points are similar colours, distant points are very different colours.
- Use k-means clustering to obtain 5 distinct groups of pixels in colour space. 
- Return the center of each cluster. This set of 5 points in colour space make up the palette for the given image.

In [None]:
import os
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
from sklearn.cluster import KMeans

Let's start by loading in an example image

In [None]:
path_to_images = '../data/small_images/'
image_id = np.random.choice(os.listdir(path_to_images))
image = Image.open(path_to_images + image_id)

if len(np.array(image).shape) != 3:
    image = Image.fromarray(np.stack((image,)*3, -1))

print(image_id.replace('.jpg', ''))

image

Now we'll resize the image to reduce the overall number of pixels. Using a large input array will give us more fidelity to the original image and a lower chance of missing colour details, but will significantly increase the computational cost when performing the k-means clustering. We'll also transform the pixels from RGB to LAB space.

In [None]:
image_size = 100
image = image.resize((image_size, image_size), 
                     resample=Image.BILINEAR)
lab_image = rgb2lab(np.array(image))

Now that our image is in numpy array format, it's easy to get it into the $(h \times w, 3)$ shape we need to treat pixels as individual points in colour space

In [None]:
lab_pixels = lab_image.reshape(-1, 3)

We can fit 5 clusters to our pixel data using `sklearn`'s k-means implementation

In [None]:
n_clusters = 5
cluster = (KMeans(n_clusters=n_clusters).fit(lab_pixels))

The coordinates of each cluster center are an integral part of [how k-means clustering works](https://en.wikipedia.org/wiki/K-means_clustering), so getting hold of them is also super easy

In [None]:
cluster.cluster_centers_

Now that we have these coordinates in colour space, we just need to build an array to display them neatly. All we're doing below is building a solid block of colour for each colour in our palette (an $(n \times n \times 3)$ array where $n$ is the size of the colour block and each 3-vector is filled with our palette colour, transformed back into RGB space). Those blocks can then be stacked together to form a palette!

In [None]:
palette = (np.hstack([(lab2rgb(np.array(colour.tolist() * image_size * image_size)
                               .reshape(image_size, image_size, 3)) * 255)
                      .astype(np.uint8)
                      for colour in cluster.cluster_centers_])
           .reshape((image_size, image_size * n_clusters, 3)))

Image.fromarray(palette)

### Extra stuff
We can bundle this process up into a pair of functions with nice, tweakable parameters

In [None]:
def get_palette(image, palette_size=5, image_size=75):
    '''
    Return n dominant colours for a given image
    
    Parameters
    ----------
    image : PIL.Image
        The image for which we want to create a palette of dominant colours
    palette_size : 
        The number of dominant colours to extract
    image_size : 
        Images are resized and squared by default to reduce processing time.
        This value sets the side-length of the square. Higher values will 
        indrease fidelity,  
    
    Returns
    -------
    palette : np.array
        palette coordinates in LAB space
    '''
    image = image.resize((image_size, image_size))
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return clusters.cluster_centers_


def display_palette(palette_colours, image_size=100):
    '''
    Return n dominant colours for a given image
    
    Parameters
    ----------
    palette_colours : np.array
        palette coordinates in LAB space
    image_size : 
        The size of each palette colour swatch to be returned
    
    Returns
    -------
    palette : PIL.Image
        image of square colour swatches
    '''
    palette_size=len(palette_colours)
    
    stretched_colours = [(lab2rgb(np.array(colour * image_size * image_size *5)
                                  .reshape(image_size*5, image_size, 3)) * 255)
                         .astype(np.uint8) 
                         for colour in palette_colours]
    
    palette_array = (np.hstack(stretched_colours)
                     .reshape((image_size*5, 
                               image_size * palette_size, 
                               3)))

    return Image.fromarray(palette_array)

As noted previously<sup>[1](#f1)</sup>, we're actually working in non-RGB colour space here. However, we can quite easily return the colours as RGB coordinates, as we do when we build the palette image to be displayed:

In [None]:
rgb_colours = [((lab2rgb(np.array(colour)
                         .reshape(1, 1, 3)) * 255)
                .astype(np.uint8)
                .squeeze())
               for colour in cluster.cluster_centers_]


for colour in rgb_colours:
    print(colour)

We can also return them as HEX values, which is useful in a few contexts

In [None]:
def rgb_to_hex(rgb_list):
    r, g, b = [int(round(channel)) for channel in rgb_list]
    return "#{:02x}{:02x}{:02x}".format(r, g, b)

for colour in rgb_colours:
    print(rgb_to_hex(colour))

#### Footnotes
<sup id="f1">1</sup> The most commonly used colour space is RGB, as images are naturally encoded in RGB for digital presentation. We'll actually be using CIELAB space in this project for various reasons. You don't really worry about this, but notebook 00 has some more detailed reasoning if you're interested. [↩](#a1)</div>