In [None]:
# Need for speed
While the results from the palette-based search are great, the palette-unshuffling is a relatively expensive process, and the colour-distance computation seems like something that could be done much more efficiently at scale. Notebook 04 becomes cripplingly slow at 5000 images, and the total number of images currently in the collection is approx. 120,000. Clearly if this is going to scale, the approach needs to change.  
In this notebook I'll try to cut down on the expense while retaining the goodness of the results. It's less important as we've already got the theory down and the _actual_ implementation will probably differ significantly (lots and lots of precomputing), but it's a nice experiment and test of my linear algebra abilities.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import itertools
import numpy as np
import pandas as pd
from PIL import Image

from sklearn.cluster import KMeans
from skimage.color import rgb2lab, lab2rgb
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cosine
from scipy.spatial.distance import cdist

from tqdm import tqdm_notebook as tqdm

Let's get a few more images than usual,

In [None]:
n_images = 5000
path_to_images = '../data/small_images/'

random_ids = np.random.choice(os.listdir(path_to_images), 
                              n_images, 
                              replace=False)

random_ids = np.sort(random_ids)

Build palettes for our images as usual

In [None]:
def display_palette(palette_colours, image_size=100, big=False):
    palette_size=len(palette_colours)
    
    scale = 1
    if big: scale = 5
    
    stretched_colours = [(lab2rgb(np.array(colour.tolist() * image_size * image_size * scale)
                                  .reshape(image_size * scale, image_size, 3)) * 255)
                         .astype(np.uint8) 
                         for colour in palette_colours]
    
    palette_array = (np.hstack(stretched_colours)
                     .reshape((image_size * scale, 
                               image_size * palette_size, 
                               3)))

    return Image.fromarray(palette_array)

def get_palette(image, palette_size=5, image_size=75):
    image = image.resize((image_size, image_size),
                         resample=Image.BILINEAR)
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return clusters.cluster_centers_

and build the `image_dict` and `palette_dict` as usual

In [None]:
image_dict = {}
palette_dict = {}

for image_id in tqdm(random_ids):
    try: 
        image = Image.open(path_to_images + image_id)
        
        if len(np.array(image).shape) != 3:
            image = Image.fromarray(np.stack((image,)*3, -1))
        
        image_dict[image_id] = image
        palette_dict[image_id] = get_palette(image)
    except: 
        pass

image_ids = np.sort(list(image_dict.keys()))
len(image_ids)

# Linear assignment
I initially assumed that I would have to reassign palette order for every image palette pair individually, or _at least_ for each row of the matrix at once. However, we can get quite close by pre-computing the assignment for all palettes according to a single palette, and indexing off that new reordered set without any further reordering. Most rearrangements seem to depend largely on the `L` component of LAB space, so we can get close for the vast majority of palette pairs. As long as the chosen palette's differences occupy the full range of L-space. This works well for `(greyscale, greyscale)` pairs and `(greyscale, colourful)` pairs. The only real trouble is when we have a `(colourful, colourful)` pair. While this represents a very small share of the available matches, it's arguably the most important type. Nevertheless, we'll try it here.

In [None]:
def colour_distance(colour_1, colour_2):
    return sum([(a - b) ** 2 for a, b in zip(colour_1, colour_2)]) ** 0.5


def assignment_switch(query_palette, palette_dict):
    rearranged = []
    for other_palette in palette_dict.values():
        distances = [[colour_distance(c1, c2)
                      for c2 in other_palette]
                     for c1 in query_palette]
        
        _, rearrangement = linear_sum_assignment(distances)
        rearranged.append([other_palette[i] for i in rearrangement])

    return np.array(rearranged)

In [None]:
query_palette = palette_dict[np.random.choice(image_ids)]
display_palette(query_palette)

In [None]:
rearranged = assignment_switch(query_palette, palette_dict)

In [None]:
palette_dict = dict(zip(image_ids, rearranged))

Rerunning the cell below a few times will give you a sense of how well the distribution of palettes is matched by this process.

In [None]:
display_palette(rearranged[np.random.choice(len(rearranged))])

# Vectorised palette_distance
We can significantly speed up the computation by reformatting the `palette_distance()` problem to cover a full row of palettes at once (ie comparing one `query_palette` to every other palette in one function call). It involves some sneaky maths, and we need to make sure that the palettes have all been ordered to match the `query_palette` (or approximately so, as above).  
Note that because we've embraced `numpy` throughout the computation, we'll also gain some giant speedups by throwing the work onto a GPU.

In [None]:
def vectorised_palette_distance(rearranged, query_palette):
    query = query_palette.reshape(-1, 1, 3)
    palettes = [p.squeeze() for p in np.split(rearranged, 5, axis=1)]

    colour_distances = np.stack([cdist(q, p, metric='cosine') 
                                 for q, p in zip(query, palettes)])
    
    palette_distances = np.sum(colour_distances.squeeze(), axis=0)
    return palette_distances

Generating the similarity matrix for 5000 images the old way would have taken around 3 hours. The function above generates almost exactly the same results in <30 seconds. The scaling is also demonstrably better. A distance matrix for 15,000 images now takes ~3 minutes to compute.

In [None]:
palette_distances = pd.DataFrame()

for query_id in tqdm(image_ids):
    distances = vectorised_palette_distance(rearranged, palette_dict[query_id])
    palette_distances[query_id] = pd.Series(dict(zip(image_ids, distances)))

In [None]:
sns.heatmap(palette_distances);

while the approximate-ness of the computation is frustrating, I think the speedup is huge enough to justify giving this a go. Here's a larger sample of similar images for a randomly chosen query.

In [None]:
query_id = np.random.choice(image_ids)
image_dict[query_id]

In [None]:
display_palette(palette_dict[query_id])

In [None]:
res = 300
n_similar = 49
size = int(n_similar ** 0.5)

big_image = np.empty((int(res * size), int(res * size), 3)).astype(np.uint8)
grid = np.array(list(itertools.product(range(size), range(size))))

most_similar_ids = palette_distances[query_id].sort_values().index.values[1 : n_similar+1]
similar_images = [image_dict[image_id].resize((res, res), resample=Image.BILINEAR) 
                  for image_id in most_similar_ids]

for pos, image in zip(grid, similar_images):
    block_t, block_l = pos * res
    block_b, block_r = (pos + 1) * res
    
    big_image[block_t : block_b, block_l : block_r] = np.array(image)

Image.fromarray(big_image)

I think that these results are pretty great. While it might not be worth the drop in fidelity if everything's being pre-computed (arguable...), it's probably worth doing this if we ever have to do the computation live. The second part will be especially helpful if users want to search the collection according to a new, custom palette.

# Vectorised reassignment
I'm not going to do this now because it's a much harder problem and one that I don't really _need_ to solve, but I'm almost certain there's a way of vectorising the linear assignment problem to be applied row-wise too. Doing so would match the output of this notebook exactly to the previous one, and would also deliver a pretty giant overall speedup.

In [None]:
palette_distances.to_pickle('../src/api/palette_distances.pkl')

In [None]:
import pickle
with open('../src/api/palettes.pkl', 'wb') as f:
    pickle.dump(palette_dict, f)

In [None]:
with open('../src/api/palettes.pkl', 'rb') as f:
    print(np.array(list(pickle.load(f).values())).shape    )