In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import itertools
import numpy as np
import pandas as pd
from PIL import Image

from sklearn.cluster import KMeans
from skimage.color import rgb2lab, lab2rgb
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cosine
from scipy.spatial.distance import pdist, squareform

from tqdm import tqdm_notebook as tqdm

In [None]:
def display_palette(palette_colours, image_size=100, big=False):
    palette_size=len(palette_colours)
    
    scale = 1
    if big: scale = 5
    
    stretched_colours = [(lab2rgb(np.array(colour.tolist() * image_size * image_size * scale)
                                  .reshape(image_size * scale, image_size, 3)) * 255)
                         .astype(np.uint8) 
                         for colour in palette_colours]
    
    palette_array = (np.hstack(stretched_colours)
                     .reshape((image_size * scale, 
                               image_size * palette_size, 
                               3)))

    return Image.fromarray(palette_array)

def get_palette(image, palette_size=5, image_size=75):
    image = image.resize((image_size, image_size),
                         resample=Image.LANCZOS)
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return clusters.cluster_centers_


def colour_distance(colour_1, colour_2):
    return sum([(a - b) ** 2 for a, b in zip(colour_1, colour_2)]) ** 0.5


def palette_distance(palette_1, palette_2):
    distances = [[colour_distance(c_1, c_2)
                  for c_2 in palette_2] 
                 for c_1 in palette_1]

    _, rearrangement = linear_sum_assignment(distances)
    palette_1 = [palette_1[i] for i in rearrangement]

    palette_distance = sum([colour_distance(c_1, c_2) 
                            for c_1, c_2 in zip(palette_1, palette_2)])
    
    return palette_distance

In [None]:
n_images = 5000
path_to_images = '../data/small_images/'

random_ids = np.random.choice(os.listdir(path_to_images), 
                              n_images, 
                              replace=False)

image_dict = {}
palette_dict = {}

for image_id in tqdm(random_ids):
    try: 
        image = Image.open(path_to_images + image_id)
        
        if len(np.array(image).shape) != 3:
            image = Image.fromarray(np.stack((image,)*3, -1))
        
        image_dict[image_id] = image
        palette_dict[image_id] = get_palette(image)
    except: 
        pass
    
image_ids = list(image_dict.keys())

# run the linear assignment problem for each palette, rearanging them to match the query

In [None]:
def assignment_switch(query_palette, palette_dict):
    rearranged = []
    for other_palette in palette_dict.values():
        distances = [[colour_distance(c1, c2)
                      for c2 in other_palette]
                     for c1 in query_palette]
        
        _, rearrangement = linear_sum_assignment(distances)
        rearranged.append([other_palette[i] for i in rearrangement])

    return np.array(rearranged)

In [None]:
query_palette = palette_dict[np.random.choice(image_ids)]
rearranged = assignment_switch(query_palette, palette_dict)

In [None]:
%%timeit
assignment_switch(query_palette, palette_dict)

In [None]:
display_palette(rearranged[np.random.choice(len(rearranged))])

# neat numpy implementation of colour_distance across the full reordered array

In [None]:
def vectorised_palette_distance(rearranged, query_palette):
    query = query_palette.reshape(-1, 1, 3)
    palettes = [p.squeeze() for p in np.split(rearranged, 5, axis=1)]

    colour_distances = np.stack([cdist(q, p, metric='cosine') 
                                 for q, p in zip(query, palettes)])
    palette_distances = np.sum(colour_distances.squeeze(), axis=0)
    return palette_distances

In [None]:
palette_distances = pd.DataFrame()

for query_id in tqdm(image_ids):
    distances = vectorised_palette_distance(rearranged, palette_dict[query_id])
    palette_distances[query_id] = pd.Series(dict(zip(image_ids, distances)))

In [None]:
query_id = np.random.choice(image_ids)
image_dict[query_id]

In [None]:
most_similar_ids = palette_distances[query_id].sort_values().index.values[1:6]
similar_images = [image_dict[id].resize((300, 300)) for id in most_similar_ids]
Image.fromarray(np.hstack([np.array(image) for image in similar_images]).reshape(300, 1500, 3))