In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import itertools
import numpy as np
import pandas as pd
from PIL import Image

from umap import UMAP
from scipy.spatial.distance import cosine

from tqdm import tqdm_notebook as tqdm

In [None]:
size = 20
n_images = size ** 2
path_to_images = '../data/small_images/'

image_ids = np.random.choice(os.listdir(path_to_images), n_images, replace=False)
images = [Image.open(path_to_images + image_id) for image_id in tqdm(image_ids)]
images = [Image.fromarray(np.stack((image,)*3, -1))
          if len(np.array(image).shape) != 3 else image
          for image in images]

each pixel in the image is treated as a point in 3d space. we evenly bin that 3d space and produce counts of the pixels appearing in each

In [None]:
pixel_lists = [np.array(image).reshape(-1, 3) for image in images]
image_dict = dict(zip(image_ids, pixel_lists))

In [None]:
step_size = 10
r = range(step_size)
bins = [str(list(bin)) for bin in list(itertools.product(r, r, r))]

bin_counts = pd.DataFrame(index=bins)

for image_id, image in tqdm(image_dict.items()):
    bin_counts[image_id] = pd.Series([str([int(i) for i in pixel / step_size])
                                      for pixel in image]).value_counts()

bin_counts = bin_counts.fillna(0)

In [None]:
embedding = UMAP().fit_transform(bin_counts.T.values)

In [None]:
plt.scatter(x=embedding[:, 0], y=embedding[:, 1]);

In [None]:
image_dict = dict(zip(image_ids, images))

similarity = pd.DataFrame(data=[[cosine(bin_counts[image_1], bin_counts[image_2])
                                 for image_1 in image_ids] for image_2 in tqdm(image_ids)],
                          index=image_ids,
                          columns=image_ids)

In [None]:
sns.heatmap(similarity);

In [None]:
id = np.random.choice(image_ids)
image_dict[id]

In [None]:
resolution = 200
n_similar = 10

most_similar_ids = similarity[id].sort_values().index.values[1 : n_similar + 1]
similar_images = [image_dict[id].resize((resolution, resolution)) for id in most_similar_ids]
Image.fromarray(np.hstack([np.array(image) 
                           for image in similar_images])
                .reshape(resolution, n_similar * resolution, 3))