In [None]:
import keras
import numpy as np
from keras.applications import mobilenet, mobilenet_v2
from keras.models import Model
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing import image
import os
import json
from PIL import Image
import math
import matplotlib.pyplot as plt
%matplotlib inline

## 1. Load Pre-trained MobileNet

In [None]:
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

In [None]:
mobilenet_model.summary()

### 1.1 Remove last 5 layers

In [None]:
new_model = Model(mobilenet_model.inputs, mobilenet_model.layers[-6].output)
new_model.summary()

## 2. Create Helper Functions

In [None]:
def prepare_image(file):
    img = image.load_img(file, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array_expanded_dims = np.expand_dims(img_array, axis=0)
    return keras.applications.mobilenet.preprocess_input(img_array_expanded_dims)

In [None]:
def encode_image(file):
    preprocessed_image = prepare_image(file)
    predictions = new_model.predict(preprocessed_image)
    return predictions.reshape(1024)

## 3. Visualise all the shirts

In [None]:
shirt_folder = os.path.join(os.curdir, 'shirts')
shirt_imgs = [os.path.join(shirt_folder, fname) for fname in os.listdir(shirt_folder) if ('shirt' in fname) and ('jpg' in fname)]

In [None]:
fig=plt.figure(figsize=(20, 10))
columns = 5
rows = math.ceil(len(shirt_imgs)/columns)
for i in range(1, len(shirt_imgs)+1):
    sample_img = Image.open(shirt_imgs[i-1])
    fig.add_subplot(rows, columns, i)
    # Plot file name
    plt.title(shirt_imgs[i-1].split('/')[-1])
    plt.imshow(sample_img)
plt.show()

## 4. Create Image Embedding for all Shirts

In [None]:
shirt_encoding_map = {img: encode_image(img) for img in shirt_imgs}

In [None]:
shirt_encoding_map

## 5. Save Image Embeddings

In [None]:
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

In [None]:
json_dump = json.dumps(shirt_encoding_map, cls=NumpyEncoder)
f = open("shirt_encoding.txt","w+")
f.write(json_dump)

## 6. Create and Test Similarity Function

In [None]:
def load_encoding(path):
    with open(path, 'r') as f:
        data_encoding = f.read()
    return json.loads(data_encoding)

In [None]:
shirt_encoding_map = load_encoding('shirt_encoding.txt')

In [None]:
# Cosine Similarity function
def getCS(encoding1, encoding2):
    return np.dot(encoding1, np.array(encoding2))/(np.linalg.norm(encoding1)*np.linalg.norm(np.array(encoding2)))

In [None]:
def find_similar(test_image_file, base_encoding_dict, threshold=0.5):
    test_img_encoding = encode_image(test_image_file)
    keys = list(base_encoding_dict.keys())
    vals = list(base_encoding_dict.values())
    sim_arr = list(map(lambda x: getCS(test_img_encoding, x), vals))
    idx = np.argmax(sim_arr)
    print(dict(zip(keys, sim_arr)))
    if sim_arr[idx] < threshold:
        return None
    tag = keys[idx]
    return tag

In [None]:
# Image to be used for testing our function
test_img_path = 'test/test4.jpg'
test_img = Image.open(test_img_path)
plt.imshow(test_img)

In [None]:
sim_img_fname = find_similar(test_img_path, shirt_encoding_map)
print('\nSimilar Image: {}'.format(sim_img_fname))

### 6.1 Visualise Similar Images

In [None]:
def get_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

In [None]:
sim_img = Image.open(sim_img_fname)
combined_images = get_concat_h(test_img, sim_img)
plt.title('Test Shirt - Similar Shirt')
plt.imshow(combined_images)