In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import image_utils
from sklearn.metrics.pairwise import pairwise_distances

%matplotlib inline

In [None]:
data_path = '../../data/fashion/dresses/'
feature_path = '../../data/features50/dresses/'

In [None]:
image_utils.download_feature_vectors(data_path, feature_path, feature_size=50)

In [None]:
filenames = [file.split('.')[0] for file in os.listdir(data_path)]
filenames_test = filenames[:10]
filenames = filenames[10:]

In [None]:
feature64_path = '../../data/features64/dresses/'
feature50_path = '../../data/features50/dresses/'
feature114_path = '../../data/features114/dresses/'

features64 = np.array([image_utils.load_feature_vector(os.path.join(feature64_path, file + '.npy')) 
                       for file in filenames])
features50 = np.array([image_utils.load_feature_vector(os.path.join(feature50_path, file + '.npy')) 
                       for file in filenames])
features114 = np.array([image_utils.load_feature_vector(os.path.join(feature114_path, file + '.npy')) 
                       for file in filenames])

features64_dist = pd.DataFrame(data=pairwise_distances(features64), index=filenames, columns=filenames)
features50_dist = pd.DataFrame(data=pairwise_distances(features50), index=filenames, columns=filenames)
features114_dist = pd.DataFrame(data=pairwise_distances(features114), index=filenames, columns=filenames)

In [None]:
def get_closest_images(filenames, df_distances):
    df_sample = df_distances.loc[filenames].T
    
    for col in df_sample.columns:
        fig, axarr = plt.subplots(ncols=6, nrows=1, figsize=(12, 12))

        orig_img = Image.open(os.path.join(data_path, col + '.jpg'))
        axarr[0].imshow(orig_img)
        axarr[0].axis('off')

        best_imgs = df_distances[col].nsmallest(6).index[1:]
        for idx, img_path in enumerate(best_imgs):
            ax = axarr[idx+1]
            img = Image.open(os.path.join(data_path, img_path + '.jpg'))
            ax.imshow(img)
            ax.axis('off')
        plt.tight_layout()

In [None]:
test_files = filenames[7:10]
get_closest_images(test_files, features64_dist)

In [None]:
get_closest_images(test_files, features114_dist)

In [None]:
feature114_path = '../../data/features114/dresses/'
image_utils.concat_feature_vectors(feature64_path, feature50_path, feature114_path)