In [2]:
import os
import glob
import json
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
from MobileNet_v2 import MobileNetV2


def main():
    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    # image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    # train_dir = os.path.join(image_path, "train")
    # validation_dir = os.path.join(image_path, "val")
    train_dir = r'../Datasets/landscape/train/'
    validation_dir = r'../Datasets/landscape/val/'
    assert os.path.exists(train_dir), "cannot find {}".format(train_dir)
    assert os.path.exists(validation_dir), "cannot find {}".format(validation_dir)

    im_height = 224
    im_width = 224
    batch_size = 16
    epochs = 20
    num_classes = 21

    def pre_function(img):
        # img = im.open('test.jpg')
        # img = np.array(img).astype(np.float32)
        img = img / 255.
        img = (img - 0.5) * 2.0
        return img

    # data generator with data augmentation
    train_image_generator = ImageDataGenerator(horizontal_flip=True,
                                               preprocessing_function=pre_function)

    validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)

    train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
                                                               batch_size=batch_size,
                                                               shuffle=True,
                                                               target_size=(im_height, im_width),
                                                               class_mode='categorical')
    total_train = train_data_gen.n

    # get class dict
    class_indices = train_data_gen.class_indices

    # transform value and key of dict
    inverse_dict = dict((val, key) for key, val in class_indices.items())
    # write dict into json file
    json_str = json.dumps(inverse_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
                                                                  batch_size=batch_size,
                                                                  shuffle=False,
                                                                  target_size=(im_height, im_width),
                                                                  class_mode='categorical')
    # img, _ = next(train_data_gen)
    total_val = val_data_gen.n
    print("using {} images for training, {} images for validation.".format(total_train,
                                                                           total_val))

    # create model except fc layer
    feature = MobileNetV2(include_top=False)
    # download weights 链接: https://pan.baidu.com/s/1YgFoIKHqooMrTQg_IqI2hA  密码: 2qht
    pre_weights_path = r'./tf_mobilenet_weights/pretrain_weights.ckpt'
    assert len(glob.glob(pre_weights_path+"*")), "cannot find {}".format(pre_weights_path)
    feature.load_weights(pre_weights_path)
    feature.trainable = False
    feature.summary()

    # add last fc layer
    model = tf.keras.Sequential([feature,
                                 tf.keras.layers.GlobalAvgPool2D(),
                                 tf.keras.layers.Dropout(rate=0.5),
                                 tf.keras.layers.Dense(num_classes),
                                 tf.keras.layers.Softmax()])
    model.summary()

    # using keras low level api for training
    loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')

    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            output = model(images, training=True)
            loss = loss_object(labels, output)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, output)

    @tf.function
    def val_step(images, labels):
        output = model(images, training=False)
        loss = loss_object(labels, output)

        val_loss(loss)
        val_accuracy(labels, output)

    best_val_acc = 0.
    for epoch in range(epochs):
        train_loss.reset_states()  # clear history info
        train_accuracy.reset_states()  # clear history info
        val_loss.reset_states()  # clear history info
        val_accuracy.reset_states()  # clear history info

        # train
        train_bar = tqdm(range(total_train // batch_size))
        for step in train_bar:
            images, labels = next(train_data_gen)
            train_step(images, labels)

            # print train process
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                                 epochs,
                                                                                 train_loss.result(),
                                                                                 train_accuracy.result())

        # validate
        val_bar = tqdm(range(total_val // batch_size))
        for step in val_bar:
            val_images, val_labels = next(val_data_gen)
            val_step(val_images, val_labels)

            # print val process
            val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                               epochs,
                                                                               val_loss.result(),
                                                                               val_accuracy.result())

        # only save best weights
        if val_accuracy.result() > best_val_acc:
            best_val_acc = val_accuracy.result()
            model.save_weights("./save_weights/resMobileNetV2.ckpt", save_format="tf")
    print('train finished')

if __name__ == '__main__':
    main()


Found 1470 images belonging to 21 classes.
Found 630 images belonging to 21 classes.
using 1470 images for training, 630 images for validation.
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
Conv (ConvBNReLU)            (None, 112, 112, 32)      992       
_________________________________________________________________
inverted_residual_17 (Invert (None, 112, 112, 16)      992       
_________________________________________________________________
inverted_residual_18 (Invert (None, 56, 56, 24)        5568      
_________________________________________________________________
inverted_residual_19 (Invert (None, 56, 56, 24)        9456      
_________________________________________________________________
inverted_residual_20 (Invert (None, 28, 28, 32)

  0%|                                                                                           | 0/91 [00:00<?, ?it/s]

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model_1 (Functional)         (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1280)              0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 21)                26901     
_________________________________________________________________
softmax_1 (Softmax)          (None, 21)                0         
Total params: 2,284,885
Trainable params: 26,901
Non-trainable params: 2,257,984
_________________________________________________________________


train epoch[1/20] loss:2.184, acc:0.402: 100%|█████████████████████████████████████████| 91/91 [00:38<00:00,  2.37it/s]
valid epoch[1/20] loss:0.714, acc:0.796: 100%|█████████████████████████████████████████| 39/39 [00:18<00:00,  2.07it/s]
train epoch[2/20] loss:0.705, acc:0.765: 100%|█████████████████████████████████████████| 91/91 [00:40<00:00,  2.27it/s]
valid epoch[2/20] loss:0.436, acc:0.870: 100%|█████████████████████████████████████████| 39/39 [00:15<00:00,  2.46it/s]
train epoch[3/20] loss:0.498, acc:0.835: 100%|█████████████████████████████████████████| 91/91 [00:35<00:00,  2.58it/s]
valid epoch[3/20] loss:0.339, acc:0.901: 100%|█████████████████████████████████████████| 39/39 [00:14<00:00,  2.74it/s]
train epoch[4/20] loss:0.360, acc:0.876: 100%|█████████████████████████████████████████| 91/91 [00:36<00:00,  2.50it/s]
valid epoch[4/20] loss:0.300, acc:0.906: 100%|█████████████████████████████████████████| 39/39 [00:15<00:00,  2.57it/s]
train epoch[5/20] loss:0.263, acc:0.919:

train finished
