In [1]:
import tensorflow as tf
from tensorflow.keras import layers

In [2]:
batch_size = 8 # 每⼀批所处理的图⽚数量
img_height = 256 # 图⽚⾼度，单位为像素
img_width = 256 # 图⽚宽度，单位为像素

train_dir = './datasets/flower_photos'

In [3]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_dir,
validation_split=0.2, # 设定验证集⽐例
subset="training",
seed=123,
#image_size=(img_height, img_width),
batch_size=batch_size)

Found 3670 files belonging to 5 classes.
Using 2936 files for training.


In [4]:
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_dir,
validation_split=0.2, # 设定验证集⽐例
subset="validation",
seed=123,
#image_size=(img_height, img_width),
batch_size=batch_size)

Found 3670 files belonging to 5 classes.
Using 734 files for validation.


In [5]:
print(train_ds.class_names)
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
(8, 256, 256, 3)
(8,)


In [6]:
augmentation_dict = {
    'RandomFlip': tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    'RandomRotation': tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
    'RandomContrast': tf.keras.layers.experimental.preprocessing.RandomContrast(0.2),
    'RandomZoom': tf.keras.layers.experimental.preprocessing.RandomZoom(height_factor=0.1, width_factor=0.1),
    'RandomTranslation': tf.keras.layers.experimental.preprocessing.RandomTranslation(height_factor=0.1, width_factor=0.1),
    'RandomCrop': tf.keras.layers.experimental.preprocessing.RandomCrop(img_height, img_width),

}

In [7]:
num_classes = 5
model = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255),
    augmentation_dict['RandomTranslation'],
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(num_classes)
])

In [8]:
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

In [9]:
model.fit(
train_ds,
validation_data=val_ds,
epochs=5
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f99fc9af490>

In [10]:
tf.keras.models.save_model(
model,
'./models/image/1/', # ./models为tensorflow-serving的模型根⽬录
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None
)

INFO:tensorflow:Assets written to: ./models/image/1/assets
