# Transfer Learning

Most of the time you won't want to train a whole convolutional network yourself. Modern ConvNets training on huge datasets like ImageNet take weeks on multiple GPUs. Instead, most people use a pretrained network either as a fixed feature extractor, or as an initial network to fine tune. In this notebook, you'll be using [VGGNet](https://arxiv.org/pdf/1409.1556.pdf) trained on the [ImageNet dataset](http://www.image-net.org/) as a feature extractor. Below is a diagram of the VGGNet architecture.

<img src="assets/cnnarchitecture.jpg" width=700px>

VGGNet is great because it's simple and has great performance, coming in second in the ImageNet competition. The idea here is that we keep all the convolutional layers, but replace the final fully connected layers with our own classifier. This way we can use VGGNet as a feature extractor for our images then easily train a simple classifier on top of that. What we'll do is take the first fully connected layer with 4096 units, including thresholding with ReLUs. We can use those values as a code for each image, then build a classifier on top of those codes.

You can read more about transfer learning from [the CS231n course notes](http://cs231n.github.io/transfer-learning/#tf).

## Pretrained VGGNet

We'll be using a pretrained network from https://github.com/machrisaa/tensorflow-vgg. Make sure to clone this repository to the directory you're working from. You'll also want to rename it so it has an underscore instead of a dash.

```
git clone https://github.com/machrisaa/tensorflow-vgg.git tensorflow_vgg
```

This is a really nice implementation of VGGNet, quite easy to work with. The network has already been trained and the parameters are available from this link. **You'll need to clone the repo into the folder containing this notebook.** Then download the parameter file using the next cell.

In [1]:
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm

vgg_dir = 'tensorflow_vgg/'
# Make sure vgg exists
if not isdir(vgg_dir):
    raise Exception("VGG directory doesn't exist!")

class DLProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

if not isfile(vgg_dir + "vgg16.npy"):
    with DLProgress(unit='B', unit_scale=True, miniters=1, desc='VGG16 Parameters') as pbar:
        urlretrieve(
            'https://s3.amazonaws.com/content.udacity-data.com/nd101/vgg16.npy',
            vgg_dir + 'vgg16.npy',
            pbar.hook)
else:
    print("Parameter file already exists!")

Parameter file already exists!


## Flower power

Here we'll be using VGGNet to classify images of flowers. To get the flower dataset, run the cell below. This dataset comes from the [TensorFlow inception tutorial](https://www.tensorflow.org/tutorials/image_retraining).

In [2]:
import tarfile

dataset_folder_path = 'flower_photos'

class DLProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

if not isfile('flower_photos.tar.gz'):
    with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Flowers Dataset') as pbar:
        urlretrieve(
            'http://download.tensorflow.org/example_images/flower_photos.tgz',
            'flower_photos.tar.gz',
            pbar.hook)

if not isdir(dataset_folder_path):
    with tarfile.open('flower_photos.tar.gz') as tar:
        tar.extractall()
        tar.close()

## ConvNet Codes

Below, we'll run through all the images in our dataset and get codes for each of them. That is, we'll run the images through the VGGNet convolutional layers and record the values of the first fully connected layer. We can then write these to a file for later when we build our own classifier.

Here we're using the `vgg16` module from `tensorflow_vgg`. The network takes images of size $224 \times 224 \times 3$ as input. Then it has 5 sets of convolutional layers. The network implemented here has this structure (copied from [the source code](https://github.com/machrisaa/tensorflow-vgg/blob/master/vgg16.py)):

```
self.conv1_1 = self.conv_layer(bgr, "conv1_1")
self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
self.pool1 = self.max_pool(self.conv1_2, 'pool1')

self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
self.pool2 = self.max_pool(self.conv2_2, 'pool2')

self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
self.pool3 = self.max_pool(self.conv3_3, 'pool3')

self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
self.pool4 = self.max_pool(self.conv4_3, 'pool4')

self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
self.pool5 = self.max_pool(self.conv5_3, 'pool5')

self.fc6 = self.fc_layer(self.pool5, "fc6")
self.relu6 = tf.nn.relu(self.fc6)
```

So what we want are the values of the first fully connected layer, after being ReLUd (`self.relu6`). To build the network, we use

```
with tf.Session() as sess:
    vgg = vgg16.Vgg16()
    input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
    with tf.name_scope("content_vgg"):
        vgg.build(input_)
```

This creates the `vgg` object, then builds the graph with `vgg.build(input_)`. Then to get the values from the layer,

```
feed_dict = {input_: images}
codes = sess.run(vgg.relu6, feed_dict=feed_dict)
```

In [3]:
import os

import numpy as np
import tensorflow as tf

from tensorflow_vgg import vgg16
from tensorflow_vgg import utils

In [4]:
data_dir = 'flower_photos/'
contents = os.listdir(data_dir)
classes = [each for each in contents if os.path.isdir(data_dir + each)]

In [5]:
classes

['dandelion', 'roses', 'daisy', 'tulips', 'sunflowers']

In [6]:
classes
t = [[1,1], [2,2], [3,3], [4,4]]
np.concatenate(t)

array([1, 1, 2, 2, 3, 3, 4, 4])

In [7]:
next(enumerate(os.listdir(data_dir+'roses')), 1)

(0, '15172358234_28706749a5.jpg')

Below I'm running images through the VGG network in batches.

> **Exercise:** Below, build the VGG network. Also get the codes from the first fully connected layer (make sure you get the ReLUd values).

In [8]:
# Set the batch size higher if you can fit in in your GPU memory
batch_size = 16
codes_list = []
labels = []
batch = []

codes = None

with tf.Session() as sess:
    
    # TODO: Build the vgg network here
    
    vgg = vgg16.Vgg16()
    input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
    with tf.name_scope("content_vgg"):
        vgg.build(input_)

    for each in classes:
        print("Starting {} images".format(each))
        class_path = data_dir + each
        files = os.listdir(class_path)
        for ii, file in enumerate(files, 1):
            # Add images to the current batch
            # utils.load_image crops the input images for us, from the center
            img = utils.load_image(os.path.join(class_path, file))
            batch.append(img.reshape((1, 224, 224, 3)))
            labels.append(each)
            
            # Running the batch through the network to get the codes
            if ii % batch_size == 0 or ii == len(files):
                
                # Image batch to pass to VGG network
                images = np.concatenate(batch)
                
                # TODO: Get the values from the relu6 layer of the VGG network
                
                feed_dict = {input_: images}
                codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)
                
                # Here I'm building an array of the codes
                if codes is None:
                    codes = codes_batch
                else:
                    codes = np.concatenate((codes, codes_batch))
                
                # Reset to start building the next batch
                batch = []
                print('{} images processed'.format(ii))

/home/carnd/deep-learning/transfer-learning/tensorflow_vgg/vgg16.npy
npy file loaded
build model started
build model finished: 0s
Starting dandelion images
16 images processed
32 images processed
48 images processed
64 images processed
80 images processed
96 images processed
112 images processed
128 images processed
144 images processed
160 images processed
176 images processed
192 images processed
208 images processed
224 images processed
240 images processed
256 images processed
272 images processed
288 images processed
304 images processed
320 images processed
336 images processed
352 images processed
368 images processed
384 images processed
400 images processed
416 images processed
432 images processed
448 images processed
464 images processed
480 images processed
496 images processed
512 images processed
528 images processed
544 images processed
560 images processed
576 images processed
592 images processed
608 images processed
624 images processed
640 images processed
656 images

In [10]:
# write codes to file
with open('codes', 'w') as f:
    codes.tofile(f)
    
# write labels to file
import csv
with open('labels', 'w') as f:
    writer = csv.writer(f, delimiter='\n')
    writer.writerow(labels)

## Building the Classifier

Now that we have codes for all the images, we can build a simple classifier on top of them. The codes behave just like normal input into a simple neural network. Below I'm going to have you do most of the work.

In [11]:
# read codes and labels from file
import csv

with open('labels') as f:
    reader = csv.reader(f, delimiter='\n')
    labels = np.array([each for each in reader if len(each) > 0]).squeeze()
with open('codes') as f:
    codes = np.fromfile(f, dtype=np.float32)
    codes = codes.reshape((len(labels), -1))

### Data prep

As usual, now we need to one-hot encode our labels and create validation/test sets. First up, creating our labels!

> **Exercise:** From scikit-learn, use [LabelBinarizer](http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelBinarizer.html) to create one-hot encoded vectors from the labels. 

In [12]:
codes[:5]

array([[  0.        ,   0.        ,   0.        , ...,   0.        ,
         20.2116127 ,   0.        ],
       [  7.44284296,   0.        ,   0.        , ...,   0.        ,
          3.22595954,   9.75505924],
       [  0.        ,   0.        ,   2.33914661, ...,   0.        ,
          0.        ,   2.38483214],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.53543526,   0.19504809],
       [  0.        ,   1.70371854,   0.        , ...,   0.        ,
          0.        ,   9.58523464]], dtype=float32)

In [13]:
from sklearn import preprocessing as pp
lb = pp.LabelBinarizer()
lb.fit(labels)
labels_vecs = lb.transform(labels) # Your one-hot encoded labels array here

Now you'll want to create your training, validation, and test sets. An important thing to note here is that our labels and data aren't randomized yet. We'll want to shuffle our data so the validation and test sets contain data from all classes. Otherwise, you could end up with testing sets that are all one class. Typically, you'll also want to make sure that each smaller set has the same the distribution of classes as it is for the whole data set. The easiest way to accomplish both these goals is to use [`StratifiedShuffleSplit`](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html) from scikit-learn.

You can create the splitter like so:
```
ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
```
Then split the data with 
```
splitter = ss.split(x, y)
```

`ss.split` returns a generator of indices. You can pass the indices into the arrays to get the split sets. The fact that it's a generator means you either need to iterate over it, or use `next(splitter)` to get the indices. Be sure to read the [documentation](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html) and the [user guide](http://scikit-learn.org/stable/modules/cross_validation.html#random-permutations-cross-validation-a-k-a-shuffle-split).

> **Exercise:** Use StratifiedShuffleSplit to split the codes and labels into training, validation, and test sets.

In [14]:
from sklearn.model_selection import StratifiedShuffleSplit

ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
splitter = ss.split(codes, labels_vecs)
train_index, test_index = next(splitter)
print(len(train_index), len(test_index))

2936 734


In [15]:
train_x, train_y = codes[train_index], labels_vecs[train_index]
other_x, other_y = codes[test_index], labels_vecs[test_index]
mid_other = len(other_x) // 2
val_x, val_y = other_x[:mid_other], other_y[:mid_other] 
test_x, test_y = other_x[mid_other:], other_y[mid_other:]

In [16]:
print("Train shapes (x, y):", train_x.shape, train_y.shape)
print("Validation shapes (x, y):", val_x.shape, val_y.shape)
print("Test shapes (x, y):", test_x.shape, test_y.shape)

Train shapes (x, y): (2936, 4096) (2936, 5)
Validation shapes (x, y): (367, 4096) (367, 5)
Test shapes (x, y): (367, 4096) (367, 5)


If you did it right, you should see these sizes for the training sets:

```
Train shapes (x, y): (2936, 4096) (2936, 5)
Validation shapes (x, y): (367, 4096) (367, 5)
Test shapes (x, y): (367, 4096) (367, 5)
```

### Classifier layers

Once you have the convolutional codes, you just need to build a classfier from some fully connected layers. You use the codes as the inputs and the image labels as targets. Otherwise the classifier is a typical neural network.

> **Exercise:** With the codes and labels loaded, build the classifier. Consider the codes as your inputs, each of them are 4096D vectors. You'll want to use a hidden layer and an output layer as your classifier. Remember that the output layer needs to have one unit for each class and a softmax activation function. Use the cross entropy to calculate the cost.

In [17]:
inputs_ = tf.placeholder(tf.float32, shape=[None, codes.shape[1]])
labels_ = tf.placeholder(tf.int64, shape=[None, labels_vecs.shape[1]])

# TODO: Classifier layers and operations

fc = tf.contrib.layers.fully_connected(inputs_, 256)
    
logits = tf.contrib.layers.fully_connected(fc, labels_vecs.shape[1], activation_fn=None)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels_))

optimizer = tf.train.AdamOptimizer().minimize(cost)

# Operations for validation/test accuracy
predicted = tf.nn.softmax(logits)
correct_pred = tf.equal(tf.argmax(predicted, 1), tf.argmax(labels_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

### Batches!

Here is just a simple way to do batches. I've written it so that it includes all the data. Sometimes you'll throw out some data at the end to make sure you have full batches. Here I just extend the last batch to include the remaining data.

In [18]:
def get_batches(x, y, n_batches=10):
    """ Return a generator that yields batches from arrays x and y. """
    batch_size = len(x)//n_batches
    
    for ii in range(0, n_batches*batch_size, batch_size):
        # If we're not on the last batch, grab data with size batch_size
        if ii != (n_batches-1)*batch_size:
            X, Y = x[ii: ii+batch_size], y[ii: ii+batch_size] 
        # On the last batch, grab the rest of the data
        else:
            X, Y = x[ii:], y[ii:]
        # I love generators
        yield X, Y

### Training

Here, we'll train the network.

> **Exercise:** So far we've been providing the training code for you. Here, I'm going to give you a bit more of a challenge and have you write the code to train the network. Of course, you'll be able to see my solution if you need help. Use the `get_batches` function I wrote before to get your batches like `for x, y in get_batches(train_x, train_y)`. Or write your own!

In [19]:
for x,y in get_batches(codes, labels_vecs, n_batches=2935):
    print(x[0])
    print(y[0])

[  0.          0.          0.        ...,   0.         20.2116127   0.       ]
[0 1 0 0 0]
[ 7.44284296  0.          0.         ...,  0.          3.22595954
  9.75505924]
[0 1 0 0 0]
[ 0.          0.          2.33914661 ...,  0.          0.          2.38483214]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          0.53543526
  0.19504809]
[0 1 0 0 0]
[ 0.          1.70371854  0.         ...,  0.          0.          9.58523464]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          2.75699449
  1.79344404]
[0 1 0 0 0]
[ 1.2203989   0.          0.         ...,  0.          1.8489846
  4.86287451]
[0 1 0 0 0]
[ 2.00013685  0.          0.         ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          0.97949719  0.        ]
[0 1 0 0 0]
[ 1.97137499  0.          0.         ...,  0.          0.          1.63838124]
[0 1 0 0 0]
[ 2.65992165  7.84620571  0.76445663 ...,  0.          0.          3.77208781]
[0 1 0 0

[ 0.          0.          6.46772051 ...,  0.          2.55784535
  0.21511853]
[0 1 0 0 0]
[ 0.6106658  0.         0.        ...,  0.         0.         0.       ]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          1.85013068  0.        ]
[0 1 0 0 0]
[ 0.          0.          3.35568666 ...,  0.          2.84218264  0.        ]
[0 1 0 0 0]
[  0.           0.          12.04441738 ...,   0.           8.97706509   0.        ]
[0 1 0 0 0]
[  6.17900085   0.           0.         ...,  10.67333031   0.           8.43543625]
[0 1 0 0 0]
[ 0.          0.          3.92081833 ...,  0.          0.          8.54981422]
[0 1 0 0 0]
[ 0.49648011  0.          0.         ...,  0.          0.          0.        ]
[0 1 0 0 0]
[  0.           4.88295698   0.         ...,   0.          19.99900055
   4.7998476 ]
[0 1 0 0 0]
[ 0.          0.          0.24732971 ...,  0.          0.          9.66691399]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          0.          3.2417051

[ 0.          0.          0.         ...,  0.          2.98160696  0.        ]
[0 1 0 0 0]
[  0.          10.20598888   7.3471365  ...,   0.           0.           0.        ]
[0 1 0 0 0]
[ 1.57758307  0.          0.         ...,  0.          0.          3.34908891]
[0 1 0 0 0]
[  0.           0.           0.         ...,   0.           0.          16.57277107]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          4.53128147  0.        ]
[0 1 0 0 0]
[  0.           0.          15.46524239 ...,   0.           0.           0.        ]
[0 1 0 0 0]
[ 0.          1.92083204  0.89020026 ...,  0.          3.06565595  0.        ]
[0 1 0 0 0]
[ 8.32403755  0.          0.         ...,  0.          0.          3.57299328]
[0 1 0 0 0]
[ 2.50504899  0.          0.         ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          1.1245594
  1.61925471]
[0 1 0 0 0]
[ 2.74287009  0.          0.         ...,  0.          0.          0.   

[ 1.23723364  0.          0.         ...,  0.          0.          0.        ]
[0 1 0 0 0]
[  0.          15.3272686    4.62282801 ...,   0.          10.86539173
   5.43724108]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          3.51283741  0.        ]
[0 1 0 0 0]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 1 0 0 0]
[ 10.52559662   0.           0.         ...,   0.           0.          16.12920952]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          2.33003092  0.        ]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          0.          2.22460151]
[0 1 0 0 0]
[ 0.          0.          0.41734952 ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 9.50547504  0.          0.         ...,  0.          1.14267504
  1.98525691]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          0.          9.79000187]
[0 1 0 0 0]
[ 10.22496128   0.           0.         ...,   0.           1.32030845
   5.6464591 ]
[0 1 0 0 0]
[ 4.80871296  0.          

  0.80879176]
[0 1 0 0 0]
[ 0.          0.          2.62053967 ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 0.          0.          9.33931732 ...,  0.          2.26744986  0.        ]
[0 1 0 0 0]
[ 0.          0.          4.48809767 ...,  0.          0.67615378  0.        ]
[0 1 0 0 0]
[ 0.          0.          3.96756458 ...,  0.          0.          0.        ]
[0 1 0 0 0]
[  2.83619261  10.15142822   4.66743135 ...,   0.           5.93802261   0.        ]
[0 1 0 0 0]
[ 0.          0.          7.06779909 ...,  0.          2.73068523  0.        ]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.          2.37264228  0.        ]
[0 1 0 0 0]
[ 0.          8.2986412   0.         ...,  0.          0.          2.25646782]
[0 1 0 0 0]
[  0.           0.           0.         ...,   0.          10.25809574   0.        ]
[0 1 0 0 0]
[ 9.31373405  0.          0.         ...,  0.          0.          4.99564648]
[0 1 0 0 0]
[ 0.          0.          0.         ...,  0.       

[0 1 0 0 0]
[ 0.          4.05834246  0.         ...,  0.          0.          4.18358231]
[0 1 0 0 0]
[ 3.59720874  0.          0.         ...,  0.          0.          9.25279808]
[0 1 0 0 0]
[ 0.          0.          1.33644712 ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 1 0 0 0]
[ 8.51053524  0.          0.         ...,  0.          0.          7.07394648]
[0 1 0 0 0]
[ 0.          0.          9.10332489 ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 0.          0.          7.92989779 ...,  0.          0.          0.        ]
[0 1 0 0 0]
[ 4.98910475  6.3711834   0.         ...,  0.          0.81888819
  2.74018645]
[0 1 0 0 0]
[ 0.          0.          6.7274065  ...,  0.          9.83059883
  0.37412259]
[0 1 0 0 0]
[ 0.          0.          6.09853745 ...,  0.          3.71971369  0.        ]
[0 1 0 0 0]
[  0.           1.30237341   0.         ...,   0.           3.11639881
  11.91775894]
[0 1 0 0 0]
[  7.04508829   0.        

[0 0 1 0 0]
[  4.7498498    9.60914898  10.579566   ...,   0.           6.06745148
   9.95948124]
[0 0 1 0 0]
[ 10.11431503   0.39907351   0.         ...,   0.           3.59825754
   2.20384216]
[0 0 1 0 0]
[ 1.20243287  1.90104067  8.65294075 ...,  3.49050975  1.95909882  0.        ]
[0 0 1 0 0]
[ 0.          0.          4.95987892 ...,  0.          4.03322315  0.        ]
[0 0 1 0 0]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 1 0 0]
[ 2.93972039  5.6726203   1.73033869 ...,  0.          0.          0.        ]
[0 0 1 0 0]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 1 0 0]
[  7.0341506    0.          12.26090336 ...,   0.           3.28342295
   8.52481937]
[0 0 1 0 0]
[ 3.57782483  4.93591452  4.86825657 ...,  0.          9.07705021
  4.35310841]
[0 0 1 0 0]
[  0.          13.20474243   9.95046139 ...,   0.77102953   1.51497972
   3.28418636]
[0 0 1 0 0]
[ 0.          2.23583746  0.         ...,  0.          9.97071934  0.        ]
[0 0 1 0 0]
[ 0.          2.06197929  0.04100204 ...,  0.         

[0 0 1 0 0]
[ 6.54514313  2.99701309  0.         ...,  0.          0.          0.        ]
[0 0 1 0 0]
[ 0.          5.96415663  0.         ...,  6.37885857  0.          0.        ]
[0 0 1 0 0]
[  0.           8.08472729   0.         ...,   0.          14.71141338
  25.92187119]
[0 0 1 0 0]
[ 0.          9.15039539  0.         ...,  1.92372143  0.          0.        ]
[0 0 1 0 0]
[  0.           2.99058151   0.         ...,   0.           0.          12.19463539]
[0 0 1 0 0]
[ 0.          3.33105755  1.07140446 ...,  0.          4.8191309
  3.78991508]
[0 0 1 0 0]
[ 0.          8.13638115  0.65410888 ...,  0.          5.61269283
  9.56115341]
[0 0 1 0 0]
[ 8.50073528  4.72127581  8.23023701 ...,  0.          7.89658594
  2.2966826 ]
[0 0 1 0 0]
[ 0.          2.39337301  0.         ...,  0.          8.11163044  0.        ]
[0 0 1 0 0]
[  0.          17.13463211   0.49787909 ...,   0.           0.           5.68024921]
[0 0 1 0 0]
[  0.          11.3394537    0.         ...,   5.61235142

[0 0 1 0 0]
[ 2.94431305  0.          0.10092184 ...,  0.          0.          0.        ]
[0 0 1 0 0]
[ 0.          0.24466544  0.         ...,  0.          0.          0.        ]
[0 0 1 0 0]
[ 5.02224112  7.3687706   0.         ...,  0.          5.2905612
  5.23269129]
[0 0 1 0 0]
[ 11.71045494   0.           1.16413784 ...,   0.           0.           0.        ]
[0 0 1 0 0]
[ 0.          0.          7.59034061 ...,  0.          0.          1.98744583]
[0 0 1 0 0]
[  0.           0.           9.07216167 ...,   0.           0.          12.72047138]
[0 0 1 0 0]
[  0.          22.35809517  10.38700485 ...,   0.           0.49201697
   5.92715073]
[0 0 1 0 0]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 1 0 0]
[  0.31640428   4.28181553   5.49095917 ...,   0.          13.16330338
  18.91521263]
[0 0 1 0 0]
[ 2.12474251  0.          4.5638032  ...,  4.13512087  0.          0.        ]
[0 0 1 0 0]
[  0.           0.          13.00788403 ...,   0.           0.           0.        ]
[0 0 1 0 0]
[ 0

[0 0 1 0 0]
[  9.2988615   13.9355011    0.         ...,   1.17740119   0.           0.        ]
[0 0 1 0 0]
[ 0.          0.          1.9322902  ...,  0.          1.28459954  0.        ]
[0 0 1 0 0]
[ 3.92034745  2.04563999  0.02379715 ...,  0.          6.21019411  0.        ]
[0 0 1 0 0]
[ 0.          0.          0.         ...,  0.          6.88108826  0.        ]
[0 0 1 0 0]
[  0.          13.14091206   1.23684037 ...,   0.           8.3807888
   6.99050713]
[0 0 1 0 0]
[ 2.95158815  7.68791628  0.         ...,  0.01647307  0.01986493  0.        ]
[0 0 1 0 0]
[  6.18284893   0.           4.03766584 ...,   0.          17.10011673   0.        ]
[0 0 1 0 0]
[ 1.02816534  7.59683466  0.54916418 ...,  0.          0.          1.14386439]
[0 0 1 0 0]
[ 0.          3.66954994  0.         ...,  0.3927238   1.48154581
  7.90139008]
[0 0 1 0 0]
[ 11.74814034   9.67934132   5.42378712 ...,   0.           3.86706495
   7.82111216]
[0 0 1 0 0]
[ 0.          5.1782074   0.         ...,  0.       

[1 0 0 0 0]
[  0.           5.35025501  16.81808472 ...,   0.           0.           4.78630304]
[1 0 0 0 0]
[  0.          13.21731758  23.32812881 ...,   0.           0.           0.        ]
[1 0 0 0 0]
[ 0.          0.          2.26503134 ...,  0.          0.          5.8495369 ]
[1 0 0 0 0]
[  0.           1.4740752   13.49021244 ...,   0.           0.           1.69901752]
[1 0 0 0 0]
[ 0.          5.08584547  0.16366345 ...,  0.          0.          4.64344931]
[1 0 0 0 0]
[ 0.          3.92503738  4.05945015 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          8.46468258  6.7618413  ...,  0.          0.          7.08536482]
[1 0 0 0 0]
[  0.          15.90569973  16.07393456 ...,   0.           0.           0.        ]
[1 0 0 0 0]
[ 0.          6.4996109   8.67462158 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          3.81491184  2.51140451 ...,  0.          4.15279818
  2.1066339 ]
[1 0 0 0 0]
[ 0.          7.08804655  3.51993895 ...,  0.        

[1 0 0 0 0]
[ 0.         3.2117939  0.        ...,  0.         0.         0.       ]
[1 0 0 0 0]
[ 0.44662625  0.4655863   4.28827906 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          0.          0.07041895 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[  0.           5.60201979  15.94417858 ...,   0.           0.           0.        ]
[1 0 0 0 0]
[  0.          24.5764389   14.91331482 ...,   0.           0.           0.        ]
[1 0 0 0 0]
[ 0.          1.09364939  0.         ...,  0.          6.99873447  0.        ]
[1 0 0 0 0]
[ 0.          0.          4.69240618 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          6.23967743  9.32903671 ...,  0.          0.          4.47791529]
[1 0 0 0 0]
[ 0.          3.48534727  1.53349471 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          6.60459995  0.         ...,  0.          0.          2.24361086]
[1 0 0 0 0]
[  0.          16.85337067   4.53851461 ...,   0.           0.          

[  0.          11.22246265   9.58559513 ...,   0.           0.           6.86079454]
[1 0 0 0 0]
[ 0.          3.03893995  0.         ...,  0.          2.42440414  0.        ]
[1 0 0 0 0]
[  0.          10.50680161   9.38446903 ...,   0.           0.           0.        ]
[1 0 0 0 0]
[  1.99245453   2.91060972   0.         ...,   0.          12.09224701   0.        ]
[1 0 0 0 0]
[ 0.          3.47268438  0.         ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          0.          0.         ...,  0.          0.          4.15963554]
[1 0 0 0 0]
[ 0.          3.45541286  6.88997984 ...,  0.          0.          5.12334728]
[1 0 0 0 0]
[ 0.          2.92720461  0.         ...,  0.          0.          2.37738371]
[1 0 0 0 0]
[ 0.          4.31459904  0.         ...,  0.          0.          0.        ]
[1 0 0 0 0]
[  0.           9.44822693  11.74418831 ...,   0.           0.           1.02462029]
[1 0 0 0 0]
[  0.          10.77723694  13.01923656 ...,   0.           0.    

[1 0 0 0 0]
[ 0.          0.          9.09371948 ...,  0.          9.14278793  0.        ]
[1 0 0 0 0]
[  0.          11.3189621   13.36273861 ...,   0.           0.           2.16515899]
[1 0 0 0 0]
[ 10.10851383   1.00950229   3.21434116 ...,   3.22808695   5.41209364   0.        ]
[1 0 0 0 0]
[ 0.          8.38933849  0.         ...,  0.          0.          0.        ]
[1 0 0 0 0]
[ 0.          2.79098892  1.14011431 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[  0.           0.          32.52889252 ...,   0.           7.41951561   0.        ]
[1 0 0 0 0]
[ 0.          0.          2.67044115 ...,  0.          0.          0.        ]
[1 0 0 0 0]
[  0.          11.95471668   9.09284115 ...,   0.           0.           0.09698811]
[1 0 0 0 0]
[ 0.         0.         7.8970952 ...,  0.         0.         0.       ]
[1 0 0 0 0]
[  0.          13.4645195    0.         ...,   3.90579605   4.65830708   0.        ]
[1 0 0 0 0]
[ 0.          4.78400373  2.90879583 ...,  0.         

[0 0 0 0 1]
[ 11.25802135   0.           7.35139561 ...,   0.           0.           0.        ]
[0 0 0 0 1]
[ 9.25947094  5.69624424  0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 8.79943657  5.90875196  7.29101849 ...,  0.          0.          0.        ]
[0 0 0 0 1]
[  0.          11.96230984  12.73910999 ...,   6.86474133   1.43600762   0.        ]
[0 0 0 0 1]
[ 0.          0.          9.68499088 ...,  0.          4.92039585  0.        ]
[0 0 0 0 1]
[  0.           0.           0.         ...,   0.          17.18221092   0.        ]
[0 0 0 0 1]
[ 2.82367277  0.          0.         ...,  0.          4.06939411
  1.26088893]
[0 0 0 0 1]
[ 0.          1.96527302  0.         ...,  3.19577312  3.84432602
  0.62213254]
[0 0 0 0 1]
[ 7.69755697  0.          0.52405787 ...,  0.          0.          0.        ]
[0 0 0 0 1]
[  0.           2.86481619  10.04136467 ...,   2.82174134   0.           4.16100168]
[0 0 0 0 1]
[ 3.93276191  0.18901905  0.         ...,  0.       

[0 0 0 0 1]
[ 8.06787777  1.25382769  3.36115599 ...,  0.          2.82747436  0.        ]
[0 0 0 0 1]
[ 13.19825935   0.           0.40645754 ...,   0.           0.           0.        ]
[0 0 0 0 1]
[ 6.11307907  6.85130453  3.48129416 ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 6.01289606  3.27319527  0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[  0.           5.31784916   6.4235425  ...,   0.          13.73711967
   4.09965801]
[0 0 0 0 1]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 0 0 1]
[ 5.9072299   0.          5.83457756 ...,  0.          3.35203719
  6.90025377]
[0 0 0 0 1]
[  9.87968636   0.          13.80404282 ...,   0.           5.78908253   0.        ]
[0 0 0 0 1]
[  0.76157665  23.08307838  15.84811115 ...,   1.25707757   0.           0.        ]
[0 0 0 0 1]
[ 10.65630341   4.26769304  10.73526478 ...,   0.           0.           0.        ]
[0 0 0 0 1]
[ 3.12282658  0.          0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 0

[  0.          16.21974564   4.02801514 ...,   4.77809095   0.           0.        ]
[0 0 0 0 1]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 0 0 1]
[  8.49221706   7.37435341   7.65765476 ...,   0.          10.11779594
   2.86944079]
[0 0 0 0 1]
[  2.55760312   3.70978689  14.90209293 ...,   0.           0.56778979
   5.8147831 ]
[0 0 0 0 1]
[  3.01488805  16.32696915   0.         ...,   8.54630661   2.51058102   0.        ]
[0 0 0 0 1]
[  1.58729362   6.65966606   0.         ...,   0.           7.02461863
  10.33429909]
[0 0 0 0 1]
[ 0.          9.06612206  0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 2.50099969  1.09565306  0.95937657 ...,  0.          0.          0.        ]
[0 0 0 0 1]
[  7.31819153   0.          10.30026054 ...,   0.           1.18144
   8.95506573]
[0 0 0 0 1]
[ 0.          9.27177143  2.05014753 ...,  0.          2.98624754  0.        ]
[0 0 0 0 1]
[  0.          15.72114468  13.06003857 ...,   7.60227156   0.           0.        ]
[0 0 0 0 1]
[ 9.

[ 2.23475528  0.          0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 0.          2.57515049  7.59373856 ...,  0.          0.          7.84787273]
[0 0 0 0 1]
[ 1.74634743  1.21593797  1.49340963 ...,  0.          0.66769969
  8.34026527]
[0 0 0 0 1]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 0 0 1]
[  2.10390806  12.00138092   9.74383259 ...,   0.           0.           0.        ]
[0 0 0 0 1]
[ 10.75796509   0.           2.87534094 ...,   0.           0.           5.55239058]
[0 0 0 0 1]
[ 0.37892646  1.65224946  5.91232538 ...,  0.          2.61891699  0.        ]
[0 0 0 0 1]
[  0.58375502  10.574893     0.         ...,   0.           5.75768042   0.        ]
[0 0 0 0 1]
[  0.           6.1687727    6.94321728 ...,   0.          11.42752743   0.        ]
[0 0 0 0 1]
[  0.          20.72277451   0.         ...,   0.           4.44229174   0.        ]
[0 0 0 0 1]
[  0.           0.          19.41306686 ...,   0.           7.04669809   0.        ]
[0 0 0 0 1]
[ 13.14661

[ 1.77367091  0.          5.96120358 ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 16.49816322   0.           0.         ...,   0.           0.           0.        ]
[0 0 0 0 1]
[ 0.64627445  6.30214453  0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[  0.           9.49120712  15.33527184 ...,   0.           0.           4.81128168]
[0 0 0 0 1]
[ 0.          1.60546911  0.         ...,  0.          4.22051239  0.        ]
[0 0 0 0 1]
[ 0.13722029  0.          0.         ...,  0.          0.          3.5260644 ]
[0 0 0 0 1]
[ 5.22468996  2.90198898  0.         ...,  0.          0.          0.        ]
[0 0 0 0 1]
[ 5.53313684  0.          9.3206377  ...,  0.          5.32548666
  3.23818445]
[0 0 0 0 1]
[ 7.2706461   0.          0.         ...,  0.          1.56273222
  1.11439812]
[0 0 0 0 1]
[ 0.  0.  0. ...,  0.  0.  0.]
[0 0 0 0 1]
[  0.           0.          14.37492847 ...,   0.          15.58024025
   0.78309286]
[0 0 0 0 1]
[ 0.          0.          

In [20]:
# Hyperparameters

epochs = 20
iteration = 0
number_of_batches = 10

saver = tf.train.Saver()

with tf.Session() as sess:
    
    # TODO: Your training code here

    sess.run(tf.global_variables_initializer())

    for epoch in range(epochs):
        # Loop over all batches

        for batch_x, batch_y in get_batches(train_x, train_y, n_batches=number_of_batches):
            feed = {inputs_: batch_x, labels_: batch_y}
            _, cost = sess.run([optimizer, cost], feed_dict = feed)
            iteration += 1
            print("Epoch: {}/{}".format(epoch+1, epochs),
                "Iteration: {}".format(iteration),
                "Training loss: {:.5f}".format(cost))
                        
            if iteration % 5 == 0:
                feed = {inputs_: val_x,
                        labels_: val_y}
                val_acc = sess.run(accuracy, feed_dict=feed)
                print("Epoch: {}/{}".format(epoch+1, epochs),
                      "Iteration: {}".format(iteration),
                      "Validation Acc: {:.4f}".format(val_acc))
            
    # Save Model    
    saver.save(sess, "checkpoints/flowers.ckpt")

Epoch: 1/20 Iteration: 1 Training loss: 4.55052


TypeError: Fetch argument 4.5505242 has invalid type <class 'numpy.float32'>, must be a string or Tensor. (Can not convert a float32 into a Tensor or Operation.)

### Testing

Below you see the test accuracy. You can also see the predictions returned for images.

In [None]:
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    
    feed = {inputs_: test_x,
            labels_: test_y}
    test_acc = sess.run(accuracy, feed_dict=feed)
    print("Test accuracy: {:.4f}".format(test_acc))

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from scipy.ndimage import imread

Below, feel free to choose images and see how the trained classifier predicts the flowers in them.

In [None]:
test_img_path = 'flower_photos/roses/10894627425_ec76bbc757_n.jpg'
test_img = imread(test_img_path)
plt.imshow(test_img)

In [None]:
# Run this cell if you don't have a vgg graph built
if 'vgg' in globals():
    print('"vgg" object already exists.  Will not create again.')
else:
    #create vgg
    with tf.Session() as sess:
        input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
        vgg = vgg16.Vgg16()
        vgg.build(input_)

In [None]:
with tf.Session() as sess:
    img = utils.load_image(test_img_path)
    img = img.reshape((1, 224, 224, 3))

    feed_dict = {input_: img}
    code = sess.run(vgg.relu6, feed_dict=feed_dict)
        
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    
    feed = {inputs_: code}
    prediction = sess.run(predicted, feed_dict=feed).squeeze()

In [None]:
plt.imshow(test_img)

In [None]:
plt.barh(np.arange(5), prediction)
_ = plt.yticks(np.arange(5), lb.classes_)