# Obtaining image colour palettes
If we're going to search the collection by colour palette, we first need to be able to extract those palettes from images. How the search happens (and whether it involves the palettes at all) is almost irrelevant at this stage. We know that colour palettes are going to be an integral part of whatever this experience turns into, so we need them.

Our goal in this first notebook is just to extract a few of the dominant colours from a given image. There are several possible approaches to this, but rather than explore all of them through implementation, I've read up on them in detail and decided on what I think is the most appropriate method beforehand.  
If you're interested, here are a few links to things I've read in that research process:
- 
- 
- 

### The process
The basic process is as follows: 
- turn the image into a numpy array of shape $(h, w, 3)$, where $h$ and $w$ are the height and width of the image, and the third element representing the dimensions of our colour space <sup id="a1">[1](#f1)</sup> 
- reshape that array into 2D array of with three columns representing the three colour channels, ie. $(h\times w, 3)$
- treat those three channels as axes in a 3D coordinate space, and each row as a point within that space. The distance between points should then represent the difference between colours; close points are similar colours, distant points are very different colours.
- Use k-means clustering to obtain the 5 distinct groups of pixels in colour space. 
- Return the `cluster_centers` of the dominant groupings. This is the palette for the input image.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rcParams['figure.figsize'] = (20, 20)

import os
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
from sklearn.cluster import KMeans, AgglomerativeClustering, SpectralClustering
from sklearn.mixture import GaussianMixture

let's start by loading in an example image

In [None]:
path_to_images = '../data/small_images/'
image_id = np.random.choice(os.listdir(path_to_images))
image = Image.open(path_to_images + image_id)

if len(np.array(image).shape) != 3:
    image = Image.fromarray(np.stack((image,)*3, -1))

print(image_id.replace('.jpg', ''))

image

In [None]:
image_size = 100
image = image.resize((image_size, image_size))
lab_image = rgb2lab(np.array(image)).reshape(-1, 3)

In [None]:
n_clusters = 5
cluster = (KMeans(n_clusters=n_clusters)
           .fit(lab_image))

In [None]:
colours = [colour.tolist() for colour in cluster.cluster_centers_]

In [None]:
palette = (np.hstack([(lab2rgb(np.array(colour * image_size * image_size)
                               .reshape(image_size, image_size, 3)) * 255)
                      .astype(np.uint8)
                      for colour in colours])
           .reshape((image_size, image_size * n_clusters, 3)))

Image.fromarray(palette)

we can bundle this process up into a pair of functions with easily-tweakable parameters

In [None]:
def get_palette(image, palette_size=5, image_size=75):
    image = image.resize((image_size, image_size))
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return [colour.tolist() for colour in clusters.cluster_centers_]


def display_palette(palette_colours, palette_size=5, image_size=100):
    stretched_colours = [(lab2rgb(np.array(colour * image_size * image_size *5)
                                  .reshape(image_size*5, image_size, 3)) * 255)
                         .astype(np.uint8) 
                         for colour in palette_colours]
    
    palette_array = (np.hstack(stretched_colours)
                     .reshape((image_size*5, 
                               image_size * palette_size, 
                               3)))

    return Image.fromarray(palette_array)

As noted below<sup>[1](#f1)</sup>, we're actually working in non-RGB colour space here. However, we can quite easily return the palette colours as RGB coordinates:

In [None]:
rgb_colours = [((lab2rgb(np.array(colour).reshape(1, 1, 3)) * 255).astype(np.uint8).squeeze().tolist())
               for colour in colours]


for colour in rgb_colours:
    print(colour)

or as HEX

In [None]:
def rgb_to_hex(rgb_list):
    r, g, b = [int(round(channel)) for channel in rgb_list]
    return "#{:02x}{:02x}{:02x}".format(r, g, b)

for colour in rgb_colours:
    print(rgb_to_hex(colour))

tada!

#### Footnotes
<sup id="f1">1</sup> The colour space used is normally RGB as images are naturally encoded in RGB for digital presentation, though we'll actually be using CIELAB space in this project for various reasons. Don't worry about this for now - notebook 5 has some more detailed reasoning. [↩](#a1)</div>