In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import itertools
import numpy as np
import pandas as pd
from PIL import Image

from sklearn.cluster import KMeans
from skimage.color import rgb2lab, lab2rgb
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cosine
from scipy.spatial.distance import cdist

from tqdm import tqdm_notebook as tqdm

In [None]:
n_images = 5000
path_to_images = '../data/small_images/'

random_ids = np.random.choice(os.listdir(path_to_images), 
                              n_images, 
                              replace=False)

random_ids = np.sort(random_ids)

In [None]:
def display_palette(palette_colours, image_size=100, big=False):
    palette_size=len(palette_colours)
    
    scale = 1
    if big: scale = 5
    
    stretched_colours = [(lab2rgb(np.array(colour.tolist() * image_size * image_size * scale)
                                  .reshape(image_size * scale, image_size, 3)) * 255)
                         .astype(np.uint8) 
                         for colour in palette_colours]
    
    palette_array = (np.hstack(stretched_colours)
                     .reshape((image_size * scale, 
                               image_size * palette_size, 
                               3)))

    return Image.fromarray(palette_array)

def get_palette(image, palette_size=5, image_size=75):
    image = image.resize((image_size, image_size),
                         resample=Image.BILINEAR)
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return clusters.cluster_centers_

In [None]:
image_dict = {}
palette_dict = {}

for image_id in tqdm(random_ids):
    try: 
        image = Image.open(path_to_images + image_id)
        
        if len(np.array(image).shape) != 3:
            image = Image.fromarray(np.stack((image,)*3, -1))
        
        image_dict[image_id] = image
        palette_dict[image_id] = get_palette(image)
    except: 
        pass

image_ids = np.sort(list(image_dict.keys()))
len(image_ids)

# brute linalg

In [None]:
palettes = np.array(list(palette_dict.values()))
query_palette = palette_dict[np.random.choice(image_ids)]

In [None]:
np.linalg.norm((query_palette.reshape(5, 3)-palettes[0]), axis=1)

In [None]:
%%timeit
np.array(list(itertools.permutations(palettes[0])))

In [None]:
%%time
np.stack([list(itertools.permutations(palette)) for palette in palettes[:500]]).shape

In [None]:
%%time
big = np.stack([list(itertools.permutations(palette, 5)) for palette in palettes])
np.linalg.norm(big-query_palette, axis=3).sum(axis=2).min(axis=1)

In [None]:
np.linalg.norm(big-query_palette, axis=3).sum(axis=2).min(axis=1)

In [None]:
%%timeit
np.argsort(np.linalg.norm(big-query_palette, axis=3).sum(axis=2).min(axis=1))