Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.keras MobileNetV2 with weights=None fails to train #36065

Closed
dbacea opened this issue Jan 20, 2020 · 10 comments
Closed

tf.keras MobileNetV2 with weights=None fails to train #36065

dbacea opened this issue Jan 20, 2020 · 10 comments
Assignees
Labels
comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release type:bug Bug

Comments

@dbacea
Copy link

dbacea commented Jan 20, 2020

System information

  • Google Colab notebook
  • TensorFlow version: 2.1.0-rc1
  • Python version: 3.7
  • CUDA/cuDNN version: 10.0.130
  • GPU model and memory: Tesla T4 12 GB

Describe the current behavior
Based on the tutorial: https://www.tensorflow.org/tutorials/images/transfer_learning#format_the_data

Start running the cells inside the notebook of the tutorial.

#Create the base model from the pre-trained model MobileNet V2

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights=None)

base_model.trainable = True

Then trained the model for 10 epochs, with the parameters specified in the tutorial, but the validation loss does not go down, the accuracy remains stuck.

The results of training:
Epoch 1/10 582/582 [==============================] - 87s 149ms/step - loss: 0.6606 - accuracy: 0.5788 - val_loss: 0.6953 - val_accuracy: 0.5216 Epoch 2/10 582/582 [==============================] - 80s 138ms/step - loss: 0.6157 - accuracy: 0.6425 - val_loss: 0.7064 - val_accuracy: 0.5216 Epoch 3/10 582/582 [==============================] - 81s 139ms/step - loss: 0.5765 - accuracy: 0.6769 - val_loss: 0.7014 - val_accuracy: 0.5216 Epoch 4/10 582/582 [==============================] - 81s 139ms/step - loss: 0.5378 - accuracy: 0.7143 - val_loss: 0.7488 - val_accuracy: 0.4784 Epoch 5/10 582/582 [==============================] - 81s 139ms/step - loss: 0.5072 - accuracy: 0.7368 - val_loss: 0.8380 - val_accuracy: 0.4784 Epoch 6/10 582/582 [==============================] - 80s 138ms/step - loss: 0.4777 - accuracy: 0.7601 - val_loss: 0.9534 - val_accuracy: 0.4784 Epoch 7/10 582/582 [==============================] - 81s 138ms/step - loss: 0.4354 - accuracy: 0.7894 - val_loss: 1.0138 - val_accuracy: 0.4784 Epoch 8/10 582/582 [==============================] - 81s 138ms/step - loss: 0.3937 - accuracy: 0.8110 - val_loss: 1.2038 - val_accuracy: 0.4784 Epoch 9/10 582/582 [==============================] - 80s 138ms/step - loss: 0.3593 - accuracy: 0.8288 - val_loss: 1.7442 - val_accuracy: 0.4784 Epoch 10/10 582/582 [==============================] - 81s 139ms/step - loss: 0.3166 - accuracy: 0.8547 - val_loss: 1.6888 - val_accuracy: 0.4784

image

Describe the expected behavior
If MobileNet V1 is used instead, with the same weight initialization and same training parameters, the results are the following:

Epoch 1/10 582/582 [==============================] - 74s 126ms/step - loss: 0.6596 - accuracy: 0.5840 - val_loss: 0.7098 - val_accuracy: 0.5216 Epoch 2/10 582/582 [==============================] - 70s 120ms/step - loss: 0.6310 - accuracy: 0.6248 - val_loss: 0.6099 - val_accuracy: 0.6483 Epoch 3/10 582/582 [==============================] - 71s 122ms/step - loss: 0.6102 - accuracy: 0.6479 - val_loss: 0.6191 - val_accuracy: 0.6858 Epoch 4/10 582/582 [==============================] - 70s 121ms/step - loss: 0.5850 - accuracy: 0.6729 - val_loss: 0.5983 - val_accuracy: 0.6634 Epoch 5/10 582/582 [==============================] - 71s 122ms/step - loss: 0.5620 - accuracy: 0.6954 - val_loss: 0.6043 - val_accuracy: 0.6573 Epoch 6/10 582/582 [==============================] - 71s 122ms/step - loss: 0.5383 - accuracy: 0.7128 - val_loss: 0.5575 - val_accuracy: 0.6935 Epoch 7/10 582/582 [==============================] - 71s 122ms/step - loss: 0.5179 - accuracy: 0.7291 - val_loss: 0.6238 - val_accuracy: 0.7220 Epoch 8/10 582/582 [==============================] - 70s 121ms/step - loss: 0.4906 - accuracy: 0.7491 - val_loss: 0.5965 - val_accuracy: 0.6905 Epoch 9/10 582/582 [==============================] - 70s 121ms/step - loss: 0.4636 - accuracy: 0.7711 - val_loss: 0.5580 - val_accuracy: 0.7310 Epoch 10/10 582/582 [==============================] - 70s 120ms/step - loss: 0.4292 - accuracy: 0.7894 - val_loss: 0.5737 - val_accuracy: 0.7233

image

In this case, the loss and the accuracy are going into the right direction

Code to reproduce the issue
Opened the Google Colab notebook, run all the cells up to section named Create the base model from the pre-trained convnets. Modified the cell first cell under this heading to the following:
`IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

#Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
#weights='imagenet')
weights=None)

base_model.trainable = True
Then proceed by running the following cells inside the notebook, which create the classification head:global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

model = tf.keras.Sequential([
base_model,
global_average_layer,
prediction_layer
])
Compile the model as in the tutorial, with the same parameters:base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])`

Then train the model. The initial loss is 0.69 and initial accuracy is 0.51:
`num_train, num_val, num_test = (
metadata.splits['train'].num_examples*weight/10
for weight in SPLIT_WEIGHTS
)

initial_epochs = 10
steps_per_epoch = round(num_train)//BATCH_SIZE
validation_steps=20

loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)

history = model.fit(train_batches,
epochs=initial_epochs,
validation_data=validation_batches)
`

@ravikyram ravikyram self-assigned this Jan 21, 2020
@ravikyram ravikyram added comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release type:bug Bug labels Jan 21, 2020
@ravikyram
Copy link
Contributor

@dbacea

I have tried on colab with TF version 2.1 and was able to reproduce the issue. However i am not seeing MobileNet V1 in tf.keras.applications. I tried with MobileNet in keras.applications and i am seeing same accuracy and validation loss as MobileNet V2.Please, find the gist here. Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Jan 21, 2020
@dbacea
Copy link
Author

dbacea commented Jan 21, 2020

@ravikyram

I used also the MobileNet (which is the architecture of MobileNet V1) from keras-applications, same as you did in the gist you provided. In the gist you provided, for the MobileNet code you forgot to go through several steps. After the cell that contains:

global_average_layer = tf.keras.layers.GlobalAveragePooling2D() feature_batch_average = global_average_layer(feature_batch) print(feature_batch_average.shape)

You should then insert and run the following code before fitting the model:
`prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

model = tf.keras.Sequential([
base_model,
global_average_layer,
prediction_layer
])

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])`

By going through this steps, you will be able to reproduce the values recorded by me for the MobileNet V1.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 21, 2020
@ravikyram
Copy link
Contributor

I have tried on colab with TF version 2.1.0-rc2 and was able to reproduce the issue.Please, find the for MobileNetV2 gist and MobileNet gist. Thanks!

@jvishnuvardhan
Copy link
Contributor

@dbacea The tutorial was to demonstrate transfer learning approach where pretrained weights are used. In that case, num_train=18609 and num_val =2326 are sufficient as the number of trainable params are only 1,281. But, in your case, trainable params are 2,225,153 but the num_train and num_val are of same size. In order to get better val_accuracy, increase number of epochs from 10 to 30 and change validation steps also. After those changes, I am noticing val_accuracy of ~76.68%. Please check the gist here.

Epoch 22/30
582/582 [==============================] - 80s 137ms/step - loss: 0.0777 - accuracy: 0.9684 - val_loss: 1.1010 - val_accuracy: 0.7668

I am closing this issue as it was resolved. Please feel free to reopen if the issue persists again. Thanks!

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@AirBobby
Copy link

I have tried MobileNet V1, MobileNet V2 and ResNet backbone to train the same classification model with weights=None. Only MobileNet V2's validation loss and accuray does not go down.
Finally I figured out that the reason is the default momentum of batch normalization layer of MobileNet V2 is 0.999 which is too large. In pytorch it is default set to 0.9 and in tensorflow2 is default set to 0.99. After changing batch norm momentum to 0.99, the model works.

@aristako
Copy link

aristako commented Dec 2, 2020

@AirBobby do you mind showing how you changed this value from the existing MobileNetV2 model class? Nice find btw, quite tricky to solve.

@aristako
Copy link

aristako commented Dec 2, 2020

I solved my issue using @AirBobby's comment, as follows.

for layer in mobilenet_model.layers:
    if type(layer)==type(tf.keras.layers.BatchNormalization()):
        layer.momentum=0.9

@MasterSkepticista
Copy link

Issue persists on TF2.8, and setting momentum=0.9 seems to solve the issue. But this should be fixed within TensorFlow.

@lukaseller
Copy link

I can reproduce the same issue with TF2.10.0. Interestingly it seems that the problem is related to the inference mode only --- when I call the model with training=True on the test_dataset I get a similar accuracy as for the training_dataset.

I think this is indeed something that requires fixing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants