##### Copyright 2018 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hub with Keras

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/images/hub_with_keras"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/hub_with_keras.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/hub_with_keras.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

[TensorFlow Hub](http://tensorflow.org/hub) is a way to share pretrained model components. See [their site](https://tfhub.dev/) for a searchable listing of pre-trained models.

This tutorial demonstrates:

1. How to use TensorFlow Hub with `tf.keras`.
1. How to do image classification using TensorFlow Hub.
1. How to do simple transfer learning.

## Setup

### Imports

In [0]:
!pip install tensorflow_hub

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow_hub as hub

from tensorflow.keras import layers

tf.VERSION

### Dataset

 For this example we'll use the TensorFlow flowers dataset: 

In [0]:
import pathlib
data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)

data_root = pathlib.Path(data_root)
print(data_root)

In [0]:
IMAGE_SIZE=224

The simplest way to load this data into our model is using `tf.keras.preprocessing.image.ImageDataGenerator`:

All of TensorFlow Hub's image modules expect input in the `[0, 1]`. Use the `ImageDataGenerator`'s `rescale` parameter to achieve this. 

In [0]:
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root), target_size=(IMAGE_SIZE, IMAGE_SIZE))

The resulting object is an iterator that returns `image_batch, label_batch` pairs.

In [0]:
for image_batch,label_batch in image_data:
  print(image_batch.shape)
  print(label_batch.shape)
  break

## An ImageNet classifier

### Download the classifier

Use `hub.module` to load a mobilenet, and `tf.keras.layers.Lambda` to wrap it up as a keras layer.

In [0]:
mobilenet_layer = layers.Lambda(hub.Module("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2"))
mobilenet_classifier = tf.keras.Sequential([mobilenet_layer])

### Run it on a batch of images

TensorFlow hub requires that you manually ititialize it's vairables. 

In [0]:
import tensorflow.keras.backend as K
sess = K.get_session()
init = tf.global_variables_initializer()

sess.run(init)

Now run the classifier on the image batch.

In [0]:
result_batch = mobilenet_classifier.predict(image_batch)

### Decode the predictions

Fetch the `ImageNet` labels, and decode the predictions

In [0]:
import numpy as np
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())

labels_batch = imagenet_labels[np.argmax(result_batch, axis=-1)]
labels_batch

Now check how these predictions line up with the images:

In [0]:
import matplotlib.pylab as plt

plt.figure(figsize=(10,9))
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(labels_batch[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

See the `LICENSE.txt` file for attributions.

The results are far from perfect, but good considering that these are not the classes the model was trained for (except "daisy").

## Simple transfer learning

### Download the headless model

TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

In [0]:
mobilenet_features_module = hub.Module("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2")
mobilenet_features = layers.Lambda(mobilenet_features_module)
mobilenet_features.trainable = False

Again, this this must be initialized.

In [0]:
init = tf.global_variables_initializer()
sess.run(init)

### Attach a classification head

Now wrap the hub layer in a `tf.keras.Sequential` model, and add a new classification layer.

In [0]:
model = tf.keras.Sequential([
  mobilenet_features,
  layers.Dense(image_data.num_classes, activation='softmax')
])

Test run a single batch, to see that the result comes back with the expected shape.

In [0]:
result = model.predict(image_batch)
result.shape

### Train the model

Use compile to configure the training process:

In [0]:
model.compile(
  optimizer=tf.train.AdamOptimizer(), 
  loss='categorical_crossentropy',
  metrics=['accuracy'])

Now use the `.fit` method to train the model.

Normally you would set the number of steps per epoch, but to keep this example short we'll reain just long enough to see that the loss is reducing.

In [0]:
#steps_per_epoch=image_data.samples//image_data.batch_size)
model.fit((item for item in image_data), epochs=1, steps_per_epoch=50)

Now after, even just a few training iterations, we can already see that the model is making progress on the task.

### Decode the predictions

To redo the plot from before, first get the ordered list of class names:

In [0]:
label_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
label_names = np.array([key.title() for key, value in label_names])
label_names

Run the image batch through the model and comvert the indices to class names.

In [0]:
result_batch = model.predict(image_batch)

labels_batch = label_names[np.argmax(result_batch, axis=-1)]
labels_batch

Plot the result

In [0]:
plt.figure(figsize=(10,9))
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(labels_batch[n])
  plt.axis('off')
_ = plt.suptitle("Model predictions")