# Image similarity search

In [1]:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.applications.xception import preprocess_input

**Important:** Make sure you unzip the images contained in `data/sports_train.zip` into `data/train`

In [3]:
batch_size=90
img_size = 299
train_path = './data/train/'

### Image Data Generator

In [4]:
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
bottleneck_generator = datagen.flow_from_directory(
    train_path,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    shuffle=False)

Found 90 images belonging to 3 classes.


### Get one batch of images

In [5]:
images, labels  = bottleneck_generator.next()

In [6]:
images.shape

(90, 299, 299, 3)

### Load Xception base model

In [7]:
base_model = Xception(include_top=False,
                      weights='imagenet',
                      input_shape=(img_size, img_size, 3),
                      pooling='avg')

### Transform Images to bottleneck features

In [None]:
bottlenecks = base_model.predict(images, verbose=1)

In [None]:
bottlenecks.shape

### Distance Metric with Scikit Learn

In [None]:
def imshow_scaled(img):
    plt.imshow((img + 1) / 2)

In [None]:
plt.subplot(1, 3, 1)
imshow_scaled(images[0])

plt.subplot(1, 3, 2)
imshow_scaled(images[1])

plt.subplot(1, 3, 3)
imshow_scaled(images[80])

plt.tight_layout()

In [None]:
plt.plot(bottlenecks[0])
plt.plot(bottlenecks[1])
plt.plot(bottlenecks[80])

### Pairwise distances

In [None]:
from sklearn.neighbors import DistanceMetric

In [None]:
dist = DistanceMetric.get_metric('euclidean')

In [None]:
bn_dist = dist.pairwise(bottlenecks)

In [None]:
bn_dist.shape

In [None]:
plt.imshow(bn_dist, cmap='gray')

In [None]:
test_img_id = 0

In [None]:
imshow_scaled(images[test_img_id])

In [None]:
dist_from_sel = pd.Series(bn_dist[test_img_id])

In [None]:
dist_from_sel.sort_values().head(9)

In [None]:
n_rows = 3
n_cols = 3
n_images = n_rows * n_cols

In [None]:
retrieved = dist_from_sel.sort_values().head(n_images)

In [None]:
plt.figure(figsize=(10, 10))
i = 1
for idx in retrieved.index:
    plt.subplot(n_rows, n_cols, i)
    imshow_scaled(images[idx])
    i += 1
plt.tight_layout()

## Exercise

Define a function `image_search` that encapsulates the code above and retrieves and plot the top 8 images closest to a given image index:

```python
def image_search(img_index, n_rows=3, n_columns=3):
    ....your code here...
```