#Building an image retrieval system with deep features


#Fire up GraphLab Create

In [1]:
import graphlab
graphlab.product_key.set_product_key("7348-CE53-3B3E-DBED-152B-828E-A99E-F303")

#Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.  

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.

In [22]:
image_train = graphlab.SFrame('image_train_data/')
image_test = graphlab.SFrame('image_test_data/')

#Computing deep features for our images

The two lines below allow us to compute deep features.  This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded. 

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)

In [None]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [3]:
image_train.head()

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.242871761322, 1.09545373917, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


#Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.

In [5]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')

PROGRESS: Starting brute force nearest neighbors model training.


#Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.

In [4]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()

In [73]:
knn_model.query(cat)
knn_model.query(cat)['distance'].mean()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 15.509ms     |
PROGRESS: | Done         |         | 100         | 118.437ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 9.193ms      |
PROGRESS: | Done         |         | 100         | 115.905ms    |
PROGRESS: +--------------+---------+-------------+--------------+


37.77071136184157

We are going to create a simple function to view the nearest neighbors to save typing:

In [7]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [8]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 12.022ms     |
PROGRESS: | Done         |         | 100         | 366.666ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [9]:
cat_neighbors['image'].show()

Very cool results showing similar cats.

##Finding similar images to a car

In [10]:
car = image_train[8:9]
car['image'].show()

In [11]:
get_images_from_ids(knn_model.query(car))['image'].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 12.543ms     |
PROGRESS: | Done         |         | 100         | 338.858ms    |
PROGRESS: +--------------+---------+-------------+--------------+


#Just for fun, let's create a lambda to find and show nearest neighbor images

In [12]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [13]:
show_neighbors(8)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 12.226ms     |
PROGRESS: | Done         |         | 100         | 283.235ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [36]:
show_neighbors(26)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 14.524ms     |
PROGRESS: | Done         |         | 100         | 404.055ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [17]:
image_train['label'].sketch_summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+-----+-----+------+
| value | automobile | cat | dog | bird |
+-------+------------+-----+-----+------+
| count |    509     | 509 | 509 | 478  |
+-------+------------+-----+-----+------+


In [23]:
graphlab.canvas.set_target('ipynb')
cat = image_test[0:1]
cat['image'].show()

In [62]:
knn_model.query(cat)
knn_model.query(cat)['distance'].mean()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 13.219ms     |
PROGRESS: | Done         |         | 100         | 101.305ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 6.374ms      |
PROGRESS: | Done         |         | 100         | 68.28ms      |
PROGRESS: +--------------+---------+-------------+--------------+


37.77071136184157

In [25]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))
cat_neighbors['image'].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 12.458ms     |
PROGRESS: | Done         |         | 100         | 388.037ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [44]:
graphlab.canvas.set_target('ipynb')

In [75]:
knn1_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')
knn1_model.query(cat)['distance'].mean()

PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 6.163ms      |
PROGRESS: | Done         |         | 100         | 264.805ms    |
PROGRESS: +--------------+---------+-------------+--------------+


36.15573070978294

In [71]:
dog_neighbors = get_images_from_ids(knn_model.query(cat))
dog_neighbors['image'].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 10.907ms     |
PROGRESS: | Done         |         | 100         | 102.094ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [74]:
train_dog = image_train[image_train['label'=='dog']]
train_cat = image_train[image_train['label'=='cat']]
train_automobile = image_train[image_train['label'=='automobile']]
train_bird = image_train[image_train['label'=='bird']]
knn_dog_model = graphlab.nearest_neighbors.create(train_dog,features=['deep_features'],                                             label='id')

knn_cat_model = graphlab.nearest_neighbors.create(train_cat,features=['deep_features'],                                             label='id')

knn_automobile_model = graphlab.nearest_neighbors.create(train_automobile,features=['deep_features'],                                             label='id')

cat = image_test[0:1]
cat['image'].show()

knn_dog_model.query(cat)
knn_cat_model.query(cat)
knn_automobile_model.query(cat)


PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting brute force nearest neighbors model training.


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 12.868ms     |
PROGRESS: | Done         |         | 100         | 426.081ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 5.738ms      |
PROGRESS: | Done         |         | 100         | 254.606ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | #

query_label,reference_label,distance,rank
0,16289,34.623719208,1
0,45646,36.0068799284,2
0,32139,36.5200813436,3
0,25713,36.7548502521,4
0,331,36.8731228168,5


PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 16.552ms     |
PROGRESS: | Done         |         | 100         | 461.469ms    |
PROGRESS: +--------------+---------+-------------+--------------+
