In [1]:
from PIL import Image
import numpy as np
import glob
import matplotlib.pyplot as plt
import os
import random

from utils.search_utils import Search, CombinedSearch, AkiwiFeatureGenerator, ResnetFeatureGenerator

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

# Load data

In [None]:
product_imgs = '../../data/fashion/dresses/'
product_feats_root = './data/features/fashion/dresses/'
searches = {}

In [None]:
searches['akiwi_50'] = Search(product_imgs, os.path.join(product_feats_root, 'akiwi_50'), AkiwiFeatureGenerator(50))
searches['akiwi_64'] = Search(product_imgs, os.path.join(product_feats_root, 'akiwi_64'), AkiwiFeatureGenerator(64))
searches['akiwi_114'] = Search(product_imgs, os.path.join(product_feats_root,'akiwi_114'), AkiwiFeatureGenerator(114))
searches['resnet'] = Search(product_imgs, os.path.join(product_feats_root, 'resnet'), ResnetFeatureGenerator())
searches['resnet_retrained'] = Search(product_imgs, os.path.join(product_feats_root, 'resnet_retrained'), ResnetFeatureGenerator('./models/resnet152_retrained.pth'))

Loading features from: ./data/features/fashion/dresses/akiwi_50
Loading features from: ./data/features/fashion/dresses/akiwi_64


In [None]:
searches['resnet_50'] = CombinedSearch([searches['akiwi_50'], searches['resnet']], factors=[0.5, 0.5])
searches['resnet_50_0604'] = CombinedSearch([searches['akiwi_50'], searches['resnet']], factors=[0.5, 0.5])
searches['resnet_50_0703'] = CombinedSearch([searches['akiwi_50'], searches['resnet']], factors=[0.7, 0.3])
searches['resnet_50_0307'] = CombinedSearch([searches['akiwi_50'], searches['resnet']], factors=[0.3, 0.7])

# Plot Similar Images

In [None]:
def plot_similar_imgs(imgs, dist, title, input_img=None, save_path=None):
    if input_img:
        imgs = [input_img] + input_imgs
    
    num_imgs = len(imgs)
    fig, axarr = plt.subplots(1, num_imgs, figsize=(num_imgs*3, 4))

    for i, img_path in enumerate(imgs):
        img = Image.open(img_path)
        img = img.crop((40, 0, 216, 256))
        ax = axarr[i]
        ax.set_xlabel(round(dist[i], 2), fontsize=22)
        ax.set_xticks([])
        ax.set_yticks([])
        
        for s in ax.spines.keys():
            ax.spines[s].set_visible(False)
        ax.imshow(img)
    
    fig.suptitle(title, fontsize=28)
    if save_path is not None:
        plt.savefig(os.path.join('./test_features_results/', save_path))
        
    plt.show()

In [None]:
def plot_all(img_path, num_imgs=6):
    for search_name, search in searches.items():
        imgs, dist = search.get_similar_images(Image.open(img_path), num_imgs)
        plot_similar_imgs(imgs, dist, search_name)

In [None]:
plot_all('../../data/fashion/dresses/5713733606269.jpg', 10)

In [None]:
plot_all('../../data/fashion/dresses/L4221C071-G11.jpg', 10)