<a href="https://colab.research.google.com/github/xuziyue/tensorflow-models/blob/main/flowers_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import tensorflow as tf

In [8]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow_hub as hub
import tensorflow_datasets as tfds

from tensorflow.keras import layers

In [9]:
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

In [10]:
(training_set, validation_set), dataset_info = tfds.load(
    'tf_flowers',
    split=['train[:70%]', 'train[70%:]'],
    with_info=True,
    as_supervised=True,
)

In [11]:
num_classes = dataset_info.features['label'].num_classes

num_training_examples = 0
num_validation_examples = 0

for example in training_set:
  num_training_examples += 1

for example in validation_set:
  num_validation_examples += 1

print('Total Number of Classes: {}'.format(num_classes))
print('Total Number of Training Images: {}'.format(num_training_examples))
print('Total Number of Validation Images: {} \n'.format(num_validation_examples))

Total Number of Classes: 5
Total Number of Training Images: 2569
Total Number of Validation Images: 1101 



In [12]:
IMAGE_RES = 224

def format_image(image, label):
  image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
  return image, label

BATCH_SIZE = 32

train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)

validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)

In [13]:
feature_extractor = hub.KerasLayer('https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', input_shape=(IMAGE_RES, IMAGE_RES, 3))

In [14]:
# dataset = tf.data.Dataset.range(8)
# dataset = dataset.shuffle(4).batch(2)
# dataset = dataset.prefetch(2)
# list(dataset.as_numpy_iterator())

In [15]:
feature_extractor.trainable = False

In [16]:
model = tf.keras.Sequential([feature_extractor,
                             layers.Dense(num_classes)])

In [17]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer (KerasLayer)     (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
EPOCHS = 6
history = model.fit(train_batches,
                    batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    validation_data=validation_batches,
                    steps_per_epoch=int(np.ceil(num_training_examples / float(BATCH_SIZE))),
                    validation_steps=int(np.ceil(num_validation_examples / float(BATCH_SIZE))))