In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import itertools
import numpy as np
import pandas as pd
from PIL import Image

from sklearn.cluster import KMeans
from skimage.color import rgb2lab, lab2rgb
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cosine
from scipy.spatial.distance import cdist

from tqdm import tqdm_notebook as tqdm

In [None]:
def display_palette(palette_colours, image_size=100, big=False):
    palette_size=len(palette_colours)
    
    scale = 1
    if big: scale = 5
    
    stretched_colours = [(lab2rgb(np.array(colour.tolist() * image_size * image_size * scale)
                                  .reshape(image_size * scale, image_size, 3)) * 255)
                         .astype(np.uint8) 
                         for colour in palette_colours]
    
    palette_array = (np.hstack(stretched_colours)
                     .reshape((image_size * scale, 
                               image_size * palette_size, 
                               3)))

    return Image.fromarray(palette_array)

def get_palette(image, palette_size=5, image_size=75):
    image = image.resize((image_size, image_size),
                         resample=Image.BILINEAR)
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return clusters.cluster_centers_

In [None]:
n_images = 15000
path_to_images = '../data/small_images/'

random_ids = np.random.choice(os.listdir(path_to_images), 
                              n_images, 
                              replace=False)

random_ids = np.sort(random_ids)

In [None]:
image_dict = {}
palette_dict = {}

for image_id in tqdm(random_ids):
    try: 
        image = Image.open(path_to_images + image_id)
        
        if len(np.array(image).shape) != 3:
            image = Image.fromarray(np.stack((image,)*3, -1))
        
        image_dict[image_id] = image
        palette_dict[image_id] = get_palette(image)
    except: 
        pass

In [None]:
image_ids = np.sort(list(image_dict.keys()))
len(image_ids)

# run the linear assignment problem for each palette, rearanging them to match the query

In [None]:
def colour_distance(colour_1, colour_2):
    return sum([(a - b) ** 2 for a, b in zip(colour_1, colour_2)]) ** 0.5


def assignment_switch(query_palette, palette_dict):
    rearranged = []
    for other_palette in palette_dict.values():
        distances = [[colour_distance(c1, c2)
                      for c2 in other_palette]
                     for c1 in query_palette]
        
        _, rearrangement = linear_sum_assignment(distances)
        rearranged.append([other_palette[i] for i in rearrangement])

    return np.array(rearranged)

In [None]:
query_palette = palette_dict[np.random.choice(image_ids)]
display_palette(query_palette)

In [None]:
rearranged = assignment_switch(query_palette, palette_dict)

# initally 
assumed that i would have to recalculate the assignment for each row, but this issue is approximately solved by pre-computing the assignment for all palettes according to any randomly selected palette and then indexing off that complete reordered set. it's approximate, but it keeps this thing fast

In [None]:
palette_dict = dict(zip(image_ids, rearranged))

In [None]:
display_palette(rearranged[np.random.choice(len(rearranged))])

# neat numpy implementation of colour_distance across the full reordered array
note that it's now also GPU-able

In [None]:
def vectorised_palette_distance(rearranged, query_palette):
    query = query_palette.reshape(-1, 1, 3)
    palettes = [p.squeeze() for p in np.split(rearranged, 5, axis=1)]

    colour_distances = np.stack([cdist(q, p, metric='cosine') 
                                 for q, p in zip(query, palettes)])
    
    palette_distances = np.sum(colour_distances.squeeze(), axis=0)
    return palette_distances

In [None]:
palette_distances = pd.DataFrame()

for query_id in tqdm(image_ids):
    distances = vectorised_palette_distance(rearranged, palette_dict[query_id])
    palette_distances[query_id] = pd.Series(dict(zip(image_ids, distances)))

sns.heatmap(palette_distances);

In [None]:
query_id = np.random.choice(image_ids)
image_dict[query_id]

In [None]:
display_palette(palette_dict[query_id])

In [None]:
res = 200
n_similar = 36
size = int(n_similar ** 0.5)

big_image = np.empty((int(res * size), int(res * size), 3)).astype(np.uint8)
grid = np.array(list(itertools.product(range(size), range(size))))

most_similar_ids = palette_distances[query_id].sort_values().index.values[1 : n_similar+1]
similar_images = [image_dict[image_id].resize((res, res), resample=Image.BILINEAR) 
                  for image_id in most_similar_ids]

for pos, image in zip(grid, similar_images):
    block_t, block_l = pos * res
    block_b, block_r = (pos + 1) * res
    
    big_image[block_t : block_b, block_l : block_r] = np.array(image)

Image.fromarray(big_image)

In [None]:
query_id

In [None]:
' '.join([id.replace('.jpg', '') for id in most_similar_ids])

# custom reassignment for each work is way too slow (~2s /it)
I could numpy-ify `assignment_switch()`...