# <u>Using Deep Features to Train an Image Classifier</u>

In [48]:
import turicreate

# <u>Load data</u>

In [49]:
image_train = turicreate.SFrame('image_train_data/')
image_test = turicreate.SFrame('image_test_data/')

# <u>Explore Image Data</u>

In [50]:
image_train['image'].explore()



Unnamed: 0,SArray
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,


# <u>Train an Image Classifier on Raw Image Pixels</u>

In [51]:
raw_pixel_model = turicreate.logistic_classifier.create(image_train,
                                                       target = 'label',
                                                       features = ['image_array'])

PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.
          You can set ``validation_set=None`` to disable validation tracking.



# <u>Make predictions using simple raw pixel model</u>

In [52]:
image_test[0:3]['image'].explore()

Unnamed: 0,SArray
0,
1,
2,


In [53]:
image_test[0:3]['label']

dtype: str
Rows: 3
['cat', 'automobile', 'cat']

In [54]:
raw_pixel_model.predict(image_test[0:3])

dtype: str
Rows: 3
['bird', 'cat', 'bird']

# <u>Evaluate the raw pixel model on the test data</u>

In [55]:
raw_pixel_model.evaluate(image_test)

{'accuracy': 0.47875,
 'auc': 0.7275915833333333,
 'confusion_matrix': Columns:
 	target_label	str
 	predicted_label	str
 	count	int
 
 Rows: 16
 
 Data:
 +--------------+-----------------+-------+
 | target_label | predicted_label | count |
 +--------------+-----------------+-------+
 |     bird     |       dog       |  183  |
 |     dog      |       cat       |  197  |
 |     bird     |    automobile   |  140  |
 |  automobile  |    automobile   |  646  |
 |     cat      |       dog       |  325  |
 |     dog      |       dog       |  439  |
 |     dog      |    automobile   |  117  |
 |     bird     |       bird      |  533  |
 |  automobile  |       bird      |  113  |
 |     bird     |       cat       |  144  |
 +--------------+-----------------+-------+
 [16 rows x 3 columns]
 Note: Only the head of the SFrame is printed.
 You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.,
 'f1_score': 0.4732393989170201,
 'log_loss': 1.1994593105341385,
 'precisio

# <u>Train Image Classifier using Deep Features</u>

In [56]:
len(image_train)

2005

In [57]:
#deep_learning_model = turicreate.load_model('imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [58]:
image_train

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.24287176132202148, 1.0954537391662598, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.5250879526138306, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.5660159587860107, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.129795789718628, 0.0, 0.0, 0.7781944870948792, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.7178692817687988, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.5781855583190918, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.22067785263061523, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464221954346, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.57378625869751, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.6589357256889343, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


# <u>Train a Logistic Classifier</u>

In [59]:
deep_features_model = turicreate.logistic_classifier.create(image_train,
                                                           target='label',
                                                           features = ['deep_features'])

PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.
          You can set ``validation_set=None`` to disable validation tracking.



### Apply the deep features classifier on the first few images

In [60]:
image_test[0:3]['image'].explore()

Unnamed: 0,SArray
0,
1,
2,


In [61]:
deep_features_model.predict(image_test[0:3])

dtype: str
Rows: 3
['cat', 'automobile', 'cat']

# Quantitatively Evaluate Deep Features Classifier on Test Data

In [62]:
deep_features_model.evaluate(image_test)

{'accuracy': 0.79025,
 'auc': 0.9392359999999994,
 'confusion_matrix': Columns:
 	target_label	str
 	predicted_label	str
 	count	int
 
 Rows: 16
 
 Data:
 +--------------+-----------------+-------+
 | target_label | predicted_label | count |
 +--------------+-----------------+-------+
 |  automobile  |       dog       |   10  |
 |     cat      |    automobile   |   17  |
 |     dog      |       bird      |   60  |
 |  automobile  |       cat       |   20  |
 |  automobile  |    automobile   |  953  |
 |     bird     |    automobile   |   13  |
 |     dog      |       cat       |  194  |
 |     bird     |       dog       |   61  |
 |     cat      |       bird      |   88  |
 |  automobile  |       bird      |   17  |
 +--------------+-----------------+-------+
 [16 rows x 3 columns]
 Note: Only the head of the SFrame is printed.
 You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.,
 'f1_score': 0.7909071128411668,
 'log_loss': 0.5883743304184752,
 'precisio

In [72]:
knn_model = turicreate.nearest_neighbors.create(image_test,
                                               features = ['deep_features'],
                                               label = 'id')
animal = image_test[18:19]
animal['image'].explore()

Unnamed: 0,SArray
0,


In [73]:
knn_model.query(animal)

query_label,reference_label,distance,rank
0,65,0.0,1
0,9750,34.09602642323738,2
0,5419,34.54853689921873,3
0,7605,35.32610057100654,4
0,7797,36.27840973696042,5


In [109]:
def get_images_from_ids(query_result):
    return image_test.filter_by(query_result['reference_label'],'id')

animal_neighbors = get_images_from_ids(knn_model.query(animal))

animal_neighbors['image'].explore()

Unnamed: 0,SArray
0,
1,
2,
3,
4,


# <u>Find Summary Statistics of the Data</u>

In [63]:
sketch=turicreate.Sketch(image_train['label'])
sketch


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+------------+-------+
|   value    | count |
+------------+-------+
|    cat     |  509  |
|    dog     |  509  |
| automobile |  509  |
|    bird    |  478  |
+------------+-------+


# <u>Create Category-Specific Image Retrieval Models</u>

In [99]:
dog_images=image_train[image_train['label']=='dog']
cat_images=image_train[image_train['label']=='cat']
automobile_images=image_train[image_train['label']=='automobile']
bird_images=image_train[image_train['label']=='bird']

In [100]:
cat_images.explore()



Unnamed: 0,id,image,label,deep_features,image_array
0,33,,cat,"[0.5250879526138306, 0.0, 0.0, 0.0, 0.0, 0.0, 9.948286056518555, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0126363039016724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1646251678466797, 0.0, 0.0, 0.0, 0.0, 0.4600375294685364, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3580564260482788, 0.0, 0.0, 3.517725944519043, 2.9159154891967773, 0.0, 0.0, 0.0, 0.0, 0.4551548957824707, 0.0, 0.0, 0.0, 0.9146482944488525, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8120278716087341, 0.04379773139953613, 0.0, 1.6515014171600342, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3530466556549072, 0.0, 0.0, 1.4851171970367432, 0.0, 0.0, 0.0, 1.5763869285583496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.6830389499664307, 1.643679141998291, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, 4.0, 7.0, 11.0, 5.0, 9.0, 11.0, 5.0, 9.0, 17.0, 11.0, 15.0, 25.0, 19.0, 23.0, 18.0, 9.0, 17.0, 9.0, 1.0, 9.0, 13.0, 7.0, 15.0, 24.0, 19.0, 27.0, 21.0, 19.0, 25.0, 34.0, 32.0, 35.0, 52.0, 43.0, 42.0, 40.0, 27.0, 26.0, 19.0, 9.0, 10.0, 11.0, 6.0, 10.0, 20.0, 16.0, 18.0, 36.0, 29.0, 28.0, 25.0, 20.0, 22.0, 23.0, 19.0, 22.0, 32.0, 21.0, 22.0, 33.0, 26.0, 29.0, 58.0, 57.0, 61.0, 73.0, 75.0, 81.0, 83.0, 91.0, 94.0, 61.0, 68.0, 66.0, 69.0, 71.0, 66.0, 82.0, 84.0, 78.0, 80.0, 85.0, 81.0, 69.0, 73.0, 68.0, 6.0, 4.0, 8.0, 6.0, ...]"
1,36,,cat,"[0.5660159587860107, 0.0, 0.0, 0.0, 0.0, 0.0, 9.997204780578613, 0.0, 0.0, 0.0, 1.38345205783844, 0.0, 0.7764788269996643, 0.0, 0.0, 0.044802725315093994, 0.22744053602218628, 0.5177360773086548, 0.0, 1.5863749980926514, 0.0, 1.238269329071045, 0.0, 0.6377935409545898, 1.8000781536102295, 1.9250504970550537, 0.06807559728622437, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.452914237976074, 0.7830307483673096, 0.0, 0.0, 0.014504671096801758, 0.0, 0.0, 1.1913495063781738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8588502407073975, 0.23833894729614258, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4233408570289612, 0.0, 0.9965428113937378, 2.3011531829833984, 0.0, 0.0, 0.0, 0.0, 0.0, 2.076474189758301, 0.07118624448776245, 0.0, 1.7414445877075195, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3547537922859192, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.167968273162842, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, 193.0, 196.0, 192.0, 218.0, 221.0, 222.0, 211.0, 215.0, 216.0, 206.0, 214.0, 216.0, 200.0, 211.0, 212.0, 193.0, 206.0, 206.0, 186.0, 201.0, 201.0, 182.0, 197.0, 196.0, 178.0, 192.0, 192.0, 173.0, 187.0, 187.0, 167.0, 182.0, 182.0, 160.0, 175.0, 175.0, 155.0, 169.0, 169.0, 150.0, 164.0, 164.0, 144.0, 158.0, 158.0, 139.0, 153.0, 153.0, 134.0, 148.0, 148.0, 133.0, 144.0, 144.0, 130.0, 140.0, 141.0, 126.0, 136.0, 137.0, 125.0, 136.0, 137.0, 128.0, 138.0, 139.0, 131.0, 141.0, 142.0, 131.0, 144.0, 144.0, 149.0, 163.0, 163.0, 176.0, 194.0, 192.0, 173.0, 192.0, 190.0, 172.0, 187.0, 187.0, 169.0, 183.0, 183.0, 166.0, 181.0, 181.0, 165.0, 119.0, 62.0, 127.0, ...]"
2,159,,cat,"[0.0, 0.0, 0.0, 0.6432753205299377, 0.0, 0.0, 10.177190780639648, 0.0, 0.0, 0.0, 0.0, 0.0, 2.3863513469696045, 0.1998509168624878, 0.0, 0.2907107472419739, 0.7435265779495239, 1.0030467510223389, 1.4652762413024902, 0.0, 0.0, 0.14588844776153564, 0.0, 0.9578428268432617, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3240588903427124, 0.0, 0.0, 0.8844139575958252, 0.0, 0.0, 0.7257575392723083, 1.9903068542480469, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14905840158462524, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4211505055427551, 0.0, 0.0, 0.0, 0.0, 0.3523911237716675, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0920463800430298, 0.0, 0.4573858082294464, 2.408182144165039, 0.0, 0.0, 0.7307713031768799, 0.0, 0.0, 2.664341926574707, 0.1718875765800476, 0.0, 4.22159481048584, 0.0, 0.0, 0.0, 2.915095567703247, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.61897873878479, 0.8255078792572021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[154.0, 145.0, 135.0, 152.0, 144.0, 135.0, 157.0, 146.0, 136.0, 152.0, 138.0, 125.0, 131.0, 100.0, 76.0, 171.0, 125.0, 86.0, 164.0, 112.0, 67.0, 154.0, 92.0, 49.0, 141.0, 111.0, 85.0, 112.0, 140.0, 137.0, 101.0, 155.0, 160.0, 106.0, 158.0, 158.0, 99.0, 135.0, 134.0, 112.0, 144.0, 146.0, 136.0, 192.0, 195.0, 139.0, 209.0, 211.0, 132.0, 200.0, 201.0, 114.0, 183.0, 182.0, 115.0, 182.0, 180.0, 115.0, 176.0, 175.0, 117.0, 172.0, 173.0, 112.0, 164.0, 166.0, 85.0, 138.0, 145.0, 80.0, 130.0, 139.0, 87.0, 134.0, 134.0, 66.0, 107.0, 96.0, 42.0, 80.0, 59.0, 40.0, 73.0, 49.0, 34.0, 60.0, 36.0, 33.0, 57.0, 31.0, 31.0, 58.0, 29.0, 30.0, 58.0, 32.0, 157.0, 148.0, 138.0, 155.0, ...]"
3,331,,cat,"[0.0, 0.0, 0.5109639167785645, 0.0, 0.0, 0.0, 11.272439956665039, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1253304481506348, 0.0, 0.5219167470932007, 0.5971229076385498, 0.0, 0.6972668170928955, 0.0, 0.24001091718673706, 0.0, 1.0731264352798462, 0.0, 0.0, 0.10970926284790039, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6569366455078125, 0.0, 0.0, 0.0, 0.0, 1.0214426517486572, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2922544479370117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.8759312629699707, 0.8046186566352844, 0.0, 1.5726779699325562, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1985647678375244, 0.0, 0.0, 2.738689422607422, 0.0, 0.0, 0.0, 0.48699140548706055, 0.023147881031036377, 0.0, 0.16829437017440796, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8470346331596375, 1.0001214742660522, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[45.0, 65.0, 92.0, 72.0, 95.0, 110.0, 106.0, 132.0, 129.0, 106.0, 132.0, 129.0, 109.0, 134.0, 131.0, 111.0, 137.0, 134.0, 111.0, 139.0, 135.0, 115.0, 145.0, 142.0, 91.0, 112.0, 107.0, 85.0, 100.0, 96.0, 92.0, 113.0, 110.0, 84.0, 102.0, 99.0, 116.0, 141.0, 138.0, 122.0, 148.0, 145.0, 122.0, 148.0, 145.0, 123.0, 148.0, 145.0, 122.0, 148.0, 145.0, 122.0, 148.0, 145.0, 122.0, 148.0, 145.0, 122.0, 148.0, 145.0, 122.0, 148.0, 145.0, 123.0, 148.0, 144.0, 123.0, 147.0, 144.0, 123.0, 147.0, 143.0, 122.0, 150.0, 146.0, 108.0, 104.0, 113.0, 93.0, 45.0, 73.0, 72.0, 38.0, 64.0, 41.0, 23.0, 43.0, 23.0, 18.0, 24.0, 36.0, 31.0, 20.0, 52.0, 44.0, 29.0, 46.0, 67.0, 96.0, 76.0, ...]"
4,367,,cat,"[1.3865805864334106, 0.0, 0.0, 0.0, 0.0, 0.18289107084274292, 10.395684242248535, 0.0, 0.0, 0.0, 0.0, 0.1731877326965332, 0.0, 0.0, 1.3926588296890259, 1.0117316246032715, 0.6179732084274292, 0.7718187570571899, 0.0, 0.0, 0.0, 0.21304845809936523, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.1230735778808594, 1.1230860948562622, 0.0, 0.0, 0.0, 1.0546326637268066, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.44656991958618164, 0.0, 0.0, 0.6871272325515747, 1.3032863140106201, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16987138986587524, 0.0, 0.0, 3.236172914505005, 0.0, 0.0, 4.233730316162109, 0.0, 0.0, 0.0, 2.383932113647461, 1.214116096496582, 0.0, 0.0, 1.481264352798462, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13467669486999512, 2.3781299591064453, 1.7878865003585815, 0.60386061668396, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[168.0, 151.0, 143.0, 145.0, 130.0, 124.0, 134.0, 134.0, 136.0, 178.0, 173.0, 170.0, 189.0, 191.0, 193.0, 186.0, 192.0, 196.0, 215.0, 221.0, 212.0, 231.0, 244.0, 226.0, 214.0, 229.0, 236.0, 184.0, 202.0, 213.0, 230.0, 223.0, 224.0, 229.0, 219.0, 221.0, 226.0, 239.0, 240.0, 234.0, 242.0, 242.0, 200.0, 188.0, 198.0, 155.0, 136.0, 136.0, 243.0, 251.0, 197.0, 230.0, 237.0, 199.0, 214.0, 221.0, 182.0, 211.0, 220.0, 143.0, 199.0, 199.0, 106.0, 195.0, 192.0, 82.0, 197.0, 186.0, 87.0, 206.0, 193.0, 56.0, 176.0, 196.0, 86.0, 161.0, 181.0, 152.0, 206.0, 207.0, 180.0, 223.0, 228.0, 207.0, 173.0, 183.0, 177.0, 222.0, 222.0, 189.0, 239.0, 239.0, 169.0, 229.0, 241.0, 92.0, 179.0, 159.0, 149.0, 154.0, ...]"
5,384,,cat,"[1.0440353155136108, 0.0, 0.0, 0.0, 0.0, 0.0, 9.495406150817871, 0.0, 0.0, 0.0, 0.0, 0.0, 1.015682339668274, 0.654792070388794, 0.0, 0.3271738886833191, 0.0, 0.0, 0.0, 1.240085482597351, 0.0, 2.6598682403564453, 0.0, 0.0, 0.0, 0.16509848833084106, 0.0, 0.6701491475105286, 0.0, 0.0, 0.0, 0.0, 0.417128324508667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8625438213348389, 0.0, 0.47617536783218384, 1.9455668926239014, 0.0, 0.0, 0.6290905475616455, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6789456605911255, 0.0, 0.1041874885559082, 4.074601173400879, 0.0, 0.0, 0.0, 0.08778756856918335, 0.0, 0.35365527868270874, 0.0, 0.0, 0.23898005485534668, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9145865440368652, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[46.0, 45.0, 50.0, 47.0, 45.0, 51.0, 45.0, 44.0, 49.0, 41.0, 40.0, 45.0, 35.0, 34.0, 39.0, 28.0, 27.0, 32.0, 22.0, 21.0, 25.0, 16.0, 15.0, 18.0, 12.0, 12.0, 14.0, 12.0, 11.0, 14.0, 10.0, 10.0, 13.0, 8.0, 8.0, 10.0, 5.0, 5.0, 7.0, 3.0, 2.0, 4.0, 2.0, 2.0, 3.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 5.0, 5.0, 6.0, 4.0, 3.0, 7.0, 4.0, 3.0, 8.0, 4.0, 3.0, 8.0, 4.0, 3.0, 8.0, 4.0, 3.0, 8.0, 4.0, 3.0, 8.0, 4.0, 3.0, 8.0, 4.0, 3.0, 9.0, 4.0, 3.0, 11.0, 4.0, 3.0, 11.0, 40.0, 39.0, 44.0, 43.0, ...]"
6,494,,cat,"[0.0, 0.0539512038230896, 1.9574512243270874, 0.0, 0.0, 0.0, 8.590998649597168, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1748650074005127, 0.0, 0.314103364944458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5537757873535156, 1.1642049551010132, 0.0, 0.0, 0.0, 0.18456339836120605, 0.0, 0.0, 0.0, 0.0, 1.8631043434143066, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49724507331848145, 0.0, 0.0, 0.0, 1.067395567893982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.222784399986267, 0.0, 0.0, 0.0, 0.3758436441421509, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.4855694770812988, 0.0, 0.0, 0.47085750102996826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.4296932220458984, 0.4144173860549927, 1.1513252258300781, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[26.0, 34.0, 29.0, 24.0, 29.0, 25.0, 33.0, 43.0, 37.0, 31.0, 42.0, 36.0, 30.0, 39.0, 34.0, 34.0, 45.0, 40.0, 84.0, 111.0, 107.0, 100.0, 136.0, 130.0, 88.0, 114.0, 110.0, 162.0, 181.0, 185.0, 239.0, 246.0, 249.0, 251.0, 251.0, 251.0, 159.0, 159.0, 152.0, 119.0, 116.0, 109.0, 238.0, 238.0, 235.0, 255.0, 255.0, 255.0, 253.0, 254.0, 253.0, 247.0, 248.0, 247.0, 233.0, 234.0, 232.0, 196.0, 195.0, 191.0, 155.0, 156.0, 148.0, 155.0, 156.0, 151.0, 169.0, 172.0, 174.0, 174.0, 182.0, 186.0, 249.0, 250.0, 249.0, 217.0, 219.0, 222.0, 201.0, 206.0, 209.0, 209.0, 214.0, 215.0, 220.0, 224.0, 224.0, 231.0, 238.0, 237.0, 240.0, 248.0, 249.0, 243.0, 252.0, 253.0, 27.0, 35.0, 29.0, 24.0, ...]"
7,597,,cat,"[0.0, 0.0, 0.04706370830535889, 0.0, 0.0, 0.9712178111076355, 8.430437088012695, 0.2638046145439148, 0.0, 0.0, 2.298530101776123, 0.0, 0.0, 0.8137612342834473, 0.0, 0.9580428004264832, 0.0, 0.0, 0.15481221675872803, 1.0270837545394897, 0.0, 0.9986186027526855, 0.16812551021575928, 1.1448713541030884, 1.6568095684051514, 1.4422653913497925, 0.6090103387832642, 0.0, 0.0, 0.0, 0.2968716621398926, 0.9832152724266052, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4367617964744568, 2.3099989891052246, 0.0, 0.0, 0.2898378372192383, 0.0, 0.8834639191627502, 0.3079947829246521, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5830368995666504, 0.0, 0.0008624792098999023, 0.0, 0.0, 0.0, 0.1471460461616516, 0.0, 0.0, 0.731701672077179, 0.18507742881774902, 0.018941521644592285, 0.6119394302368164, 0.0, 0.0, 0.0, 0.1327456831932068, 0.0, 0.73091721534729, 0.5230916142463684, 0.0, 1.1849052906036377, 0.0, 0.0, 0.0, 0.1883857250213623, 0.0, 0.0, 0.0, 0.0, 0.7146134972572327, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.126732587814331, 0.12948018312454224, 0.8216015100479126, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[133.0, 153.0, 138.0, 126.0, 146.0, 136.0, 165.0, 184.0, 179.0, 170.0, 188.0, 187.0, 171.0, 187.0, 181.0, 153.0, 165.0, 155.0, 165.0, 177.0, 167.0, 180.0, 189.0, 180.0, 174.0, 180.0, 172.0, 175.0, 182.0, 173.0, 166.0, 180.0, 166.0, 167.0, 181.0, 166.0, 182.0, 194.0, 180.0, 183.0, 192.0, 179.0, 191.0, 198.0, 186.0, 185.0, 194.0, 183.0, 188.0, 193.0, 185.0, 192.0, 193.0, 183.0, 188.0, 188.0, 176.0, 133.0, 138.0, 122.0, 62.0, 73.0, 54.0, 95.0, 106.0, 89.0, 175.0, 186.0, 170.0, 185.0, 195.0, 183.0, 181.0, 191.0, 182.0, 180.0, 190.0, 183.0, 175.0, 187.0, 178.0, 170.0, 183.0, 172.0, 169.0, 183.0, 168.0, 170.0, 184.0, 168.0, 164.0, 179.0, 160.0, 158.0, 171.0, 153.0, 145.0, 164.0, 149.0, 130.0, ...]"
8,788,,cat,"[0.5058419108390808, 0.0, 0.0, 0.0, 0.4272115230560303, 0.0, 11.489871978759766, 0.0, 0.0, 0.0, 0.0, 0.0, 2.303081750869751, 0.0, 0.0, 0.0, 0.4100242853164673, 0.7948493361473083, 0.0, 0.0, 0.0, 1.8312119245529175, 0.0, 0.0, 0.04584550857543945, 0.0, 0.5371289849281311, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.463374137878418, 2.8007049560546875, 0.0, 0.0, 0.0, 0.0, 1.3103001117706299, 0.0, 0.0, 0.0, 0.8271439671516418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.004378199577331543, 0.34384065866470337, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1925148963928223, 1.983250617980957, 1.1891417503356934, 3.561600685119629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0976991653442383, 0.0, 0.0, 0.0, 0.2474733591079712, 0.0, 0.0, 0.6958598494529724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3900628685951233, 0.4555479884147644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[184.0, 200.0, 197.0, 189.0, 203.0, 200.0, 194.0, 205.0, 203.0, 198.0, 207.0, 206.0, 202.0, 207.0, 206.0, 203.0, 207.0, 206.0, 202.0, 206.0, 205.0, 200.0, 204.0, 203.0, 198.0, 201.0, 200.0, 197.0, 199.0, 198.0, 193.0, 195.0, 194.0, 187.0, 193.0, 192.0, 181.0, 190.0, 188.0, 175.0, 185.0, 184.0, 172.0, 181.0, 182.0, 168.0, 176.0, 178.0, 163.0, 171.0, 172.0, 158.0, 167.0, 166.0, 157.0, 164.0, 164.0, 158.0, 160.0, 162.0, 159.0, 162.0, 164.0, 168.0, 173.0, 177.0, 162.0, 168.0, 174.0, 150.0, 159.0, 164.0, 152.0, 164.0, 166.0, 151.0, 164.0, 166.0, 152.0, 165.0, 166.0, 152.0, 165.0, 167.0, 152.0, 165.0, 166.0, 149.0, 163.0, 166.0, 145.0, 161.0, 164.0, 141.0, 160.0, 161.0, 189.0, 204.0, 201.0, 193.0, ...]"
9,882,,cat,"[0.0, 0.0, 0.15620028972625732, 0.0, 0.0, 0.0, 8.879639625549316, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3530038595199585, 0.0, 0.0, 0.0, 0.41717779636383057, 0.0, 0.0, 0.0, 0.0, 0.23389798402786255, 0.0, 0.186390221118927, 0.0, 0.3493977189064026, 0.23430532217025757, 0.0, 0.0, 0.0, 1.1172078847885132, 0.0, 0.6166121959686279, 0.2698093056678772, 0.0, 0.0, 0.0, 0.6467511653900146, 0.0134202241897583, 0.05460095405578613, 0.0, 1.1291296482086182, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.081068515777588, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0593385696411133, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.5178027153015137, 0.0, 1.3639854192733765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]","[141.0, 133.0, 112.0, 143.0, 133.0, 113.0, 155.0, 141.0, 123.0, 153.0, 139.0, 120.0, 150.0, 132.0, 116.0, 153.0, 134.0, 121.0, 163.0, 147.0, 132.0, 183.0, 169.0, 151.0, 195.0, 176.0, 158.0, 190.0, 164.0, 147.0, 187.0, 158.0, 139.0, 189.0, 159.0, 138.0, 190.0, 162.0, 141.0, 182.0, 161.0, 142.0, 175.0, 162.0, 148.0, 170.0, 160.0, 149.0, 171.0, 157.0, 142.0, 174.0, 159.0, 137.0, 175.0, 145.0, 125.0, 175.0, 137.0, 121.0, 169.0, 146.0, 129.0, 162.0, 147.0, 131.0, 164.0, 147.0, 133.0, 164.0, 146.0, 132.0, 167.0, 148.0, 134.0, 167.0, 148.0, 133.0, 166.0, 146.0, 132.0, 160.0, 141.0, 129.0, 156.0, 138.0, 126.0, 157.0, 142.0, 129.0, 153.0, 141.0, 127.0, 148.0, 139.0, 124.0, 146.0, 138.0, 119.0, 148.0, ...]"


In [101]:
dog_model = turicreate.nearest_neighbors.create(dog_images,features=['deep_features'],label='id')
cat_model = turicreate.nearest_neighbors.create(cat_images,features=['deep_features'],label='id')
automobile_model = turicreate.nearest_neighbors.create(automobile_images,features=['deep_features'],label='id')
bird_model = turicreate.nearest_neighbors.create(bird_images,features=['deep_features'],label='id')

In [103]:
cat = image_test[0:1]
cat['image'].explore()

cat_model.query(cat)

Unnamed: 0,SArray
0,


query_label,reference_label,distance,rank
0,16289,34.62371920804245,1
0,45646,36.00687992842462,2
0,32139,36.52008134363789,3
0,25713,36.754850252057054,4
0,331,36.87312281675268,5


In [106]:
def get_images_from_ids(query_result):
    return cat_images.filter_by(query_result['reference_label'],'id')

cat_neighbors = get_images_from_ids(cat_model.query(cat))
cat_neighbors['image'].explore()

Unnamed: 0,SArray
0,
1,
2,
3,
4,


In [120]:
cat_image = image_train[image_train['id']==16289]
cat_image['image'].explore()

Unnamed: 0,SArray
0,


In [114]:
dog = image_test[0:1]
dog['image'].explore()

dog_model.query(dog)

Unnamed: 0,SArray
0,


query_label,reference_label,distance,rank
0,16976,37.464262878423774,1
0,13387,37.56668321685285,2
0,35867,37.60472670789396,3
0,44603,37.70655851529755,4
0,6094,38.51132549073972,5


In [115]:
def get_images_from_ids(query_result):
    return dog_images.filter_by(query_result['reference_label'],'id')

dog_neighbors = get_images_from_ids(dog_model.query(cat))
dog_neighbors['image'].explore()

Unnamed: 0,SArray
0,
1,
2,
3,
4,


In [119]:
dog_image = image_train[image_train['id']==16976]
dog_image['image'].explore()

Unnamed: 0,SArray
0,


# Nearest Neighbor Classification

In [121]:
image_test_dog=image_test[image_test['label']=='dog']
image_test_cat=image_test[image_test['label']=='cat']
image_test_automobile=image_test[image_test['label']=='automobile']
image_test_bird=image_test[image_test['label']=='bird']

In [122]:
cat_model.query(cat)

query_label,reference_label,distance,rank
0,16289,34.62371920804245,1
0,45646,36.00687992842462,2
0,32139,36.52008134363789,3
0,25713,36.754850252057054,4
0,331,36.87312281675268,5


In [123]:
dog_model.query(dog)

query_label,reference_label,distance,rank
0,16976,37.464262878423774,1
0,13387,37.56668321685285,2
0,35867,37.60472670789396,3
0,44603,37.70655851529755,4
0,6094,38.51132549073972,5


In [124]:
cat_model.query(cat)['distance'].mean()

36.15573070978294

In [125]:
dog_model.query(dog)['distance'].mean()

37.77071136184157

In [127]:
dog_cat_neighbors = cat_model.query(image_test_dog, k=1)

In [128]:
dog_automobile_neighbors = automobile_model.query(image_test_dog, k=1)

In [129]:
dog_dog_neighbors = dog_model.query(image_test_dog, k=1)

In [130]:
dog_bird_neighbors = bird_model.query(image_test_dog, k=1)

In [135]:
dog_distances = turicreate.SFrame({'dog_automobile': dog_automobile_neighbors['distance'],
                                'dog_cat': dog_cat_neighbors['distance'],
                                'dog_dog': dog_dog_neighbors['distance'],
                                'dog_bird': dog_bird_neighbors['distance']
                               })
dog_distances

dog_automobile,dog_bird,dog_cat,dog_dog
41.95797614571203,41.75386473035126,36.419607706754384,33.47735903726336
46.002133180677895,41.33829589248612,38.83532688735546,32.84584956840558
42.9462290692388,38.61575908528905,36.97634108541545,35.039707318905855
41.6866060048479,37.08922699538219,34.575007291446106,33.90103276968192
39.22696649347584,38.27228869398105,34.778824791016625,37.48492509092561
40.58451176980721,39.146208923590464,35.11715782924591,34.94516534398125
45.10673529610857,40.52304010596232,40.60958309132646,39.09572783446351
41.32211409739767,38.194791839269584,39.90368673062212,37.76961310322033
41.82446549950164,40.15671316613142,38.06747001682115,35.10891446032839
45.49769294011039,45.55979626027668,42.725873295060296,43.242283258453455


In [136]:
def is_dog_correct(row):
    if row['dog_dog']<=min(row.values()):
        return 1
    else:
        return 0

In [137]:
dog_distances.apply(is_dog_correct)

dtype: int
Rows: 1000
[1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, ... ]

In [138]:
dog_distances.apply(is_dog_correct).sum()

678

In [139]:
dog_distances.num_rows()

1000

In [141]:
(dog_distances.apply(is_dog_correct).sum())/(dog_distances.num_rows())*100

67.80000000000001