In [1]:
from tensorflow import keras


In [3]:
def Mobilenet_v2(input_size,weights,Dropout_rate,Trainable,alpha = 0.35):
    base_model = keras.applications.MobileNetV2(
        input_shape=(input_size, input_size, 3),
        alpha=alpha,
        weights=weights,
        include_top=False
    )
    inputs = keras.Input(shape=(input_size, input_size,3))

    scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
    x = scale_layer(inputs)
    x = base_model(x, training=False)
    if Trainable:
        base_model.trainable =True
    else:
        base_model.trainable = False
        print("特征层已冻结")
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(Dropout_rate, name='Dropout')(x)
    outputs = keras.layers.Dense(15, activation='softmax')(x)
    model = keras.Model(inputs, outputs)
    return model

In [4]:
epochs = 100
input_size=128
lr =0.0001
Dropout_rate=0.3
batch_size =128

In [5]:
model = Mobilenet_v2(input_size,weights='imagenet',Dropout_rate=Dropout_rate,Trainable=True)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 rescaling (Rescaling)       (None, 128, 128, 3)       0         
                                                                 
 mobilenetv2_0.35_128 (Funct  (None, 4, 4, 1280)       410208    
 ional)                                                          
                                                                 
 global_average_pooling2d (G  (None, 1280)             0         
 lobalAveragePooling2D)                                          
                                                                 
 Dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 15)                19215 

In [6]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_root = "./train/"

train_generator = ImageDataGenerator(rotation_range=360,
                                     zoom_range  =0.2,
                                     horizontal_flip = True)
train_dataset = train_generator.flow_from_directory(batch_size=batch_size,
                                                    directory=train_root,
                                                    shuffle=True,
                                                    target_size=(input_size,input_size))
# valid_generator = ImageDataGenerator()
# valid_dataset = valid_generator.flow_from_directory(batch_size=batch_size,
#                                                    directory=valid_root,
#                                                     shuffle=True,
#                                                     target_size=(input_size,input_size))
print(train_dataset.class_indices)

Found 2369 images belonging to 15 classes.
{'大客车': 0, '小汽车': 1, '榴莲': 2, '橙子': 3, '火车': 4, '牛': 5, '狗': 6, '猪': 7, '猫': 8, '苹果': 9, '葡萄': 10, '轮船': 11, '飞机': 12, '香蕉': 13, '马': 14}


In [7]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate=lr),
              loss=keras.losses.CategoricalCrossentropy(),
              metrics=["accuracy"])


In [9]:
# reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=10,verbose=1)
# early_stop =keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=30,verbose=1)
# save_weights = tf.keras.callbacks.ModelCheckpoint(save_path + "/model_{epoch:02d}_{val_accuracy:.4f}.h5",
#                                                   save_best_only=False, monitor='val_accuracy')
hist = model.fit(train_dataset, 
                 epochs=epochs,
                #  callbacks=[save_weights,reduce_lr,early_stop]
                 )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100

KeyboardInterrupt: 