#Building an image retrieval system with deep features


#Fire up GraphLab Create

In [1]:
import graphlab

#Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.  

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.

In [2]:
image_train = graphlab.SFrame('image_train_data/')

[INFO] graphlab.cython.cy_server: GraphLab Create v2.1 started. Logging: /tmp/graphlab_server_1482905463.log
INFO:graphlab.cython.cy_server:GraphLab Create v2.1 started. Logging: /tmp/graphlab_server_1482905463.log


This non-commercial license of GraphLab Create for academic use is assigned to sangeet.saurabh@gmail.com and will expire on December 02, 2017.


#Computing deep features for our images

The two lines below allow us to compute deep features.  This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded. 

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)

In [3]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [4]:
image_train.head()

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.242871761322, 1.09545373917, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


#Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.

In [5]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')

#Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.

In [6]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()

In [7]:
knn_model.query(cat)

query_label,reference_label,distance,rank
0,384,0.0,1
0,6910,36.9403137951,2
0,39777,38.4634888975,3
0,36870,39.7559623119,4
0,41734,39.7866014148,5


We are going to create a simple function to view the nearest neighbors to save typing:

In [8]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [9]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))

In [10]:
cat_neighbors['image'].show()

Very cool results showing similar cats.

##Finding similar images to a car

In [11]:
car = image_train[8:9]
car['image'].show()

In [12]:
get_images_from_ids(knn_model.query(car))['image'].show()

#Just for fun, let's create a lambda to find and show nearest neighbor images

In [13]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [14]:
show_neighbors(8)

In [15]:
show_neighbors(26)

In [16]:
show_neighbors(623)

In [17]:
show_neighbors(708)

In [18]:
show_neighbors(814)

In [19]:
show_neighbors(409)

In [20]:
image_train['label'].sketch_summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+-----+-----+------+
| value | automobile | cat | dog | bird |
+-------+------------+-----+-----+------+
| count |    509     | 509 | 509 | 478  |
+-------+------------+-----+-----+------+


In [21]:
dog_frame = image_train[image_train['label'] == 'dog']
dog_frame

id,image,label,deep_features,image_array
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
177,Height: 32 Width: 32,dog,"[0.0, 1.45965671539, 0.0, 0.422992348671, 0.0, ...","[55.0, 75.0, 42.0, 51.0, 76.0, 37.0, 57.0, 83.0, ..."
424,Height: 32 Width: 32,dog,"[0.942399680614, 0.0, 0.220352768898, 0.0, ...","[60.0, 35.0, 18.0, 63.0, 49.0, 38.0, 66.0, 56.0, ..."
462,Height: 32 Width: 32,dog,"[1.43462562561, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[86.0, 69.0, 75.0, 57.0, 41.0, 48.0, 46.0, 35.0, ..."
542,Height: 32 Width: 32,dog,"[0.451547086239, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[196.0, 174.0, 113.0, 140.0, 117.0, 65.0, 8 ..."
573,Height: 32 Width: 32,dog,"[0.592360973358, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[153.0, 103.0, 52.0, 151.0, 102.0, 49.0, ..."
851,Height: 32 Width: 32,dog,"[0.690123438835, 0.0, 0.0, 0.0, 0.305860161 ...","[39.0, 6.0, 4.0, 53.0, 23.0, 24.0, 57.0, 37.0, ..."
919,Height: 32 Width: 32,dog,"[0.0, 0.177558660507, 0.139396846294, 0.0, ...","[29.0, 43.0, 4.0, 24.0, 35.0, 6.0, 24.0, 37.0, ..."
1172,Height: 32 Width: 32,dog,"[0.517601490021, 0.0, 1.96418333054, 0.0, 0.0, ...","[182.0, 180.0, 197.0, 196.0, 192.0, 209.0, ..."


In [22]:
cat_frame = image_train[image_train['label'] == 'cat']
bird_frame = image_train[image_train['label'] == 'bird']
automobile_frame = image_train[image_train['label'] == 'automobile']


In [23]:
len(bird_frame)

478

In [24]:
dog_model = graphlab.nearest_neighbors.create(dog_frame,features=['deep_features'],
                                             label='id')

In [25]:
cat_model = graphlab.nearest_neighbors.create(cat_frame,features=['deep_features'],
                                             label='id')

In [26]:
automobile_model = graphlab.nearest_neighbors.create(automobile_frame,features=['deep_features'],
                                             label='id')

In [27]:
bird_model = graphlab.nearest_neighbors.create(bird_frame,features=['deep_features'],
                                             label='id')

In [28]:
image_test = graphlab.SFrame('image_test_data/')

In [29]:
image_test[0:1]['image'].show()

In [30]:
nearest_cats = cat_model.query(image_test[0:1])

In [59]:
nearest_dogs = dog_model.query(image_test[0:1])

In [56]:
show_dog_neighbors = lambda i: get_images_from_ids(dog_model.query(image_test[i:i+1]))['image'].show()

In [58]:
show_dog_neighbors(0)

In [54]:
show_cat_neighbors = lambda i: get_images_from_ids(cat_model.query(image_test[i:i+1]))['image'].show()

In [55]:
show_cat_neighbors(0)

In [62]:
nearest_cats['distance'].mean()

36.15573070978294

In [63]:
nearest_dogs['distance'].mean()

37.77071136184157

In [36]:
image_test_cat = image_test[image_test['label'] == 'cat']
image_test_dog = image_test[image_test['label'] == 'dog']
image_test_bird = image_test[image_test['label'] == 'bird']
image_test_automobile = image_test[image_test['label'] == 'automobile']

In [37]:
image_test_bird

id,image,label,deep_features,image_array
25,Height: 32 Width: 32,bird,"[0.0, 0.317288756371, 0.0, 1.36552882195, ...","[100.0, 103.0, 74.0, 68.0, 91.0, 65.0, 116.0, ..."
35,Height: 32 Width: 32,bird,"[0.778077363968, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[66.0, 73.0, 84.0, 66.0, 71.0, 81.0, 64.0, 67.0, ..."
65,Height: 32 Width: 32,bird,"[0.888774394989, 0.0, 0.0, 1.24411165714, ...","[201.0, 206.0, 166.0, 187.0, 180.0, 132.0, ..."
67,Height: 32 Width: 32,bird,"[0.315794527531, 0.0, 0.0, 0.586381316185, ...","[76.0, 170.0, 228.0, 77.0, 171.0, 225.0, 8 ..."
70,Height: 32 Width: 32,bird,"[1.34134876728, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[193.0, 181.0, 145.0, 181.0, 172.0, 147.0, ..."
75,Height: 32 Width: 32,bird,"[1.92161560059, 0.0, 0.0, 0.0, 0.905619382858, ...","[63.0, 111.0, 53.0, 63.0, 110.0, 53.0, 65.0, 11 ..."
84,Height: 32 Width: 32,bird,"[0.472827553749, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[233.0, 231.0, 230.0, 226.0, 225.0, 223.0, ..."
86,Height: 32 Width: 32,bird,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0170202255249, ...","[116.0, 167.0, 136.0, 110.0, 167.0, 139.0, ..."
113,Height: 32 Width: 32,bird,"[1.47401452065, 0.0, 0.0, 0.219570279121, ...","[114.0, 117.0, 122.0, 117.0, 120.0, 125.0, ..."
118,Height: 32 Width: 32,bird,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.465785324574, ...","[4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 4.0, ..."


In [38]:
image_test['label'].sketch_summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  4000 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+------+------+------+
| value | automobile | cat  | bird | dog  |
+-------+------------+------+------+------+
| count |    1000    | 1000 | 1000 | 1000 |
+-------+------------+------+------+------+


In [39]:
dog_cat_neighbors = cat_model.query(image_test_dog, k=1)

In [40]:
dog_cat_neighbors

query_label,reference_label,distance,rank
0,33,36.4196077068,1
1,30606,38.8353268874,1
2,5545,36.9763410854,1
3,19631,34.5750072914,1
4,7493,34.778824791,1
5,47044,35.1171578292,1
6,13918,40.6095830913,1
7,10981,39.9036867306,1
8,45456,38.0674700168,1
9,44673,42.7258732951,1


In [44]:
dog_distances = graphlab.SFrame({'dog-automobile': automobile_model.query(image_test_dog, k=1)['distance'], \
                                 'dog-bird' : bird_model.query(image_test_dog, k=1)['distance'], 'dog-cat': \
                                 cat_model.query(image_test_dog, k=1)['distance'], \
                                 'dog-dog' : dog_model.query(image_test_dog, k=1)['distance']})

In [45]:
dog_distances

dog-automobile,dog-bird,dog-cat,dog-dog
41.9579761457,41.7538647304,36.4196077068,33.4773590373
46.0021331807,41.3382958925,38.8353268874,32.8458495684
42.9462290692,38.6157590853,36.9763410854,35.0397073189
41.6866060048,37.0892269954,34.5750072914,33.9010327697
39.2269664935,38.272288694,34.778824791,37.4849250909
40.5845117698,39.1462089236,35.1171578292,34.945165344
45.1067352961,40.523040106,40.6095830913,39.0957278345
41.3221140974,38.1947918393,39.9036867306,37.7696131032
41.8244654995,40.1567131661,38.0674700168,35.1089144603
45.4976929401,45.5597962603,42.7258732951,43.2422832585


In [72]:
def is_dog_correct(row):
    if (row['dog-dog'] < row['dog-cat']) and (row['dog-dog'] < row['dog-automobile']) and (row['dog-dog'] < row['dog-bird']) :
        return 1
    else:
        return 0

In [73]:
dog_distances['dog_correct'] = dog_distances.apply(is_dog_correct)

In [74]:
dog_distances['dog_correct'].sum()

678

In [51]:
len(dog_distances[dog_distances['dog_correct'] == 0])

11

In [64]:
cat_distances = graphlab.SFrame({'cat-automobile': automobile_model.query(image_test_cat, k=1)['distance'], \
                                 'cat-bird' : bird_model.query(image_test_cat, k=1)['distance'], 'cat-cat': \
                                 cat_model.query(image_test_cat, k=1)['distance'], \
                                 'cat-dog' : dog_model.query(image_test_cat, k=1)['distance']})

In [69]:
def is_cat_correct(row):
    if (row['cat-cat'] < row['cat-dog']) and (row['cat-cat'] < row['cat-automobile']) and (row['cat-cat'] < row['cat-bird']) :
        return 1
    else:
        return 0

In [70]:
cat_distances['cat_correct'] = cat_distances.apply(is_cat_correct)

In [71]:
cat_distances['cat_correct'].sum()

548