In [1]:
!pip install numpy tensorflow_hub tensorflow_datasets
!pip install tensorflow



In [15]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_hub as hub
import tensorflow_datasets as tfds
from tensorflow.keras import layers

In [16]:
flower_dataset_url = "http://download.tensorflow.org/example_images/flower_photos.tgz"

In [17]:
(train_set, val_set), dataset_info = tfds.load('tf_flowers',
    split=['train[:70%]', 'train[70%:]'],
    with_info=True,
    as_supervised=True,
)

In [18]:
print(dataset_info)

tfds.core.DatasetInfo(
    name='tf_flowers',
    full_name='tf_flowers/3.0.1',
    description="""
    A large set of images of flowers
    """,
    homepage='https://www.tensorflow.org/tutorials/load_data/images',
    data_dir='/root/tensorflow_datasets/tf_flowers/3.0.1',
    file_format=tfrecord,
    download_size=218.21 MiB,
    dataset_size=221.83 MiB,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'train': <SplitInfo num_examples=3670, num_shards=2>,
    },
    citation="""@ONLINE {tfflowers,
    author = "The TensorFlow Team",
    title = "Flowers",
    month = "jan",
    year = "2019",
    url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }""",
)


In [19]:
class_num = dataset_info.features['label'].num_classes

train_num = len(train_set)
valid_num = len(val_set)

print(train_num, valid_num, class_num)

2569 1101 5


In [20]:
image_res = 224

def normalize(image, label):
    image = tf.image.resize(image, (image_res, image_res))/255.0
    return image, label

batch_size = 32
train_batch = train_set.shuffle(train_num//4).map(normalize).batch(batch_size).prefetch(1)
val_batch = val_set.map(normalize).batch(batch_size).prefetch(1)

In [21]:
IMAGE_SHAPE = (224, 224)
resnet_model = "https://www.kaggle.com/models/google/resnet-v2/TensorFlow2/101-classification/2"

In [22]:
feature_extractor_layer = hub.KerasLayer(
    resnet_model,
    input_shape=(224, 224, 3),
    trainable=False)

In [23]:
model = tf.keras.Sequential([
  feature_extractor_layer,
  layers.Dense(class_num)
])

In [24]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
history = model.fit(train_batch,
                    epochs=6,
                    validation_data=val_batch)

Epoch 1/6
 2/81 [..............................] - ETA: 10:35 - loss: 3.8604 - accuracy: 0.3906

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(6)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='train_acc')
plt.plot(epochs_range, val_acc, label='val_acc')
plt.legend(loc='lower right')
plt.title('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='train_loss')
plt.plot(epochs_range, val_loss, label='val_loss')
plt.legend(loc='upper right')
plt.title('Loss function')
plt.show()