In [1]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import glob
import random

from utils.gan_utils import StarGAN_generator, Pix2PixGenerator
from utils.search_utils import AkiwiFeatureGenerator, ResnetFeatureGenerator, Search, CombinedSearch

%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

In [2]:
StarGAN = StarGAN_generator('./models/stargan/')

Couldn't find model ./models/stargan/pattern.pth


In [3]:
Pix2Pix = Pix2PixGenerator('./models/pix2pix_models.pth')

# Load Search Models

In [4]:
folder_gens = {'akiwi_50': AkiwiFeatureGenerator(50), 
               'akiwi_64': AkiwiFeatureGenerator(64),
               'akiwi_114': AkiwiFeatureGenerator(114),
               'resnet': ResnetFeatureGenerator(),
               'resnet_retrained': ResnetFeatureGenerator('./models/resnet152_retrained.pth')
              }

In [5]:
dress_imgs = '../../data/fashion/dresses/'
model_imgs = '../../data/fashion_models/dresses_clustered2/'

dress_feats = './data/features/fashion/dresses/'
model_feats = './data/features/fashion_models/dresses/'

In [6]:
dress_search = {}
for dir_name, gen in folder_gens.items():
    dress_search[dir_name] = Search(dress_imgs, os.path.join(dress_feats, dir_name), gen)

Loading features from: ./data/features/fashion/dresses/akiwi_50
Loading features from: ./data/features/fashion/dresses/akiwi_64
Loading features from: ./data/features/fashion/dresses/akiwi_114
Loading features from: ./data/features/fashion/dresses/resnet
Loading features from: ./data/features/fashion/dresses/resnet_retrained


In [7]:
model_search = {}
for dir_name, gen in folder_gens.items():
    model_search[dir_name] = Search(model_imgs, os.path.join(model_feats, dir_name), gen)

Loading features from: ./data/features/fashion_models/dresses/akiwi_50
Loading features from: ./data/features/fashion_models/dresses/akiwi_64
Loading features from: ./data/features/fashion_models/dresses/akiwi_114
Loading features from: ./data/features/fashion_models/dresses/resnet
Loading features from: ./data/features/fashion_models/dresses/resnet_retrained


In [8]:
# combined search
dress_search['resnet_50'] = CombinedSearch([dress_search['akiwi_50'], dress_search['resnet']], factors=[2, 1])
model_search['resnet_50'] = CombinedSearch([model_search['akiwi_50'], model_search['resnet']], factors=[2, 1])

# StarGAN

In [None]:
def plot_img_row(images, img_labels=None):
    fig, axarr = plt.subplots(nrows=1, ncols=len(images), figsize=(len(images)*2, 2))
    
    for i, img in enumerate(images):
        axarr[i].imshow(img)
        axarr[i].set_xticks([])
        axarr[i].set_yticks([])
        
        if img_labels is not None:
            axarr[i].set_title(img_labels[i])
        
    plt.show()

In [None]:
test_img = './data/test_images/dresses_sample/11834P1881-34.jpg'
Image.open(test_img)

In [None]:
fake_img = Pix2Pix.generate_image(Image.open(test_img))
fake_img

In [None]:
sim_imgs, _ = model_rn_search.get_similar_images(fake_img, num_imgs=6)
plot_img_row([Image.open(i) for i in [test_img] + sim_imgs])

In [None]:
sim_imgs, _ = model_50_search.get_similar_images(fake_img, num_imgs=6)
plot_img_row([Image.open(i) for i in [test_img] + sim_imgs])

In [None]:
test_img_path = './test_images/dresses_sample/5641460_552693338.jpg'
test_img = Image.open(test_img_path)

In [None]:
fake_sleeves = get_stargan_imgs_for_attr(test_img, 'sleeve_length')
plot_img_row([test_img] + fake_sleeves, img_labels=['Original'] + StarGAN.LABELS['sleeve_length'])

In [None]:
fake_pattern = get_stargan_imgs_for_attr(test_img, 'pattern')
plot_img_row([test_img] + fake_pattern, img_labels=['Original'] + StarGAN.LABELS['pattern'])

In [None]:
def get_stargan_imgs_for_attr(img, attr):
    
    fake_imgs = []
    values = StarGAN.LABELS[attr]
    for idx, v in enumerate(values):
        
        fake_img = StarGAN.generate_image(img, attr, v)
        fake_imgs.append(fake_img)

    return fake_imgs

In [None]:
for fake_img in fake_sleeves:
    sim_imgs = dress_rn_search.get_similar_images(fake_img, num_imgs=6)
    sim_imgs = [fake_img] + sim_imgs
    plot_img_row(sim_imgs)

In [None]:
for fake_img in fake_pattern:
    sim_imgs = dress_rn_search.get_similar_images(fake_img, num_imgs=6)
    sim_imgs = [fake_img] + sim_imgs
    plot_img_row(sim_imgs)

In [None]:
for fake_img in fake_pattern:
    sim_imgs = dress_114_search.get_similar_images(fake_img, num_imgs=6)
    sim_imgs = [fake_img] + sim_imgs
    plot_img_row(sim_imgs)

In [None]:
# pix2pix
model_imgs = [Pix2Pix.generate_image(img) for img in [test_img] + fake_pattern]
plot_img_row(model_imgs, img_labels=['original'] + StarGAN.LABELS['pattern'])

In [None]:
# pix2pix
model_imgs = [Pix2Pix.generate_image(img) for img in [test_img] + fake_sleeves]
plot_img_row(model_imgs, img_labels=['original'] + StarGAN.LABELS['sleeve_length'])

In [None]:
for fake_img in model_imgs:
    sim_imgs = model_rn_search.get_similar_images(fake_img, num_imgs=6)
    sim_imgs = [fake_img] + sim_imgs
    plot_img_row(sim_imgs)

In [None]:
Pix2Pix.generate_image(test_imgs[0])