Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFLite, 2.2.0, accuracy drops significantly when tf.lite.Optimize.DEFAULT option is used #40000

Closed
wwwind opened this issue May 29, 2020 · 10 comments
Assignees
Labels
comp:lite TF Lite related issues stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 TFLiteConverter For issues related to TFLite converter type:support Support issues

Comments

@wwwind
Copy link
Contributor

wwwind commented May 29, 2020

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): ubuntu
  • TensorFlow installed from (source or binary):
  • TensorFlow version (or github SHA if from source):

Command used to run the converter or code if you’re using the Python API
If possible, please share a link to Colab/Jupyter/any notebook.

Colab:
https://colab.research.google.com/drive/1Z2Xvh2dufYR8y9U-9735KgBOGYYd9NtN#scrollTo=X-vMKEjgTIp0

import tensorflow as tf
print(tf.__version__)

from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import numpy as np
import random
import tensorflow_datasets as tfds

np.random.seed(42)
tf.random.set_seed(42)

train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:90%]", "train[90%:]"],
    as_supervised=True
)

size = (224, 224)
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))

def normalize_img(img, label):
    img = tf.cast(img, tf.float32) / 255.
    return (img, label)

train_ds = train_ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE).\
    shuffle(1024).\
    batch(32).\
    prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validation_ds = validation_ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE).\
    batch(32).\
    prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

base = MobileNet(weights="imagenet", include_top=False,
                    input_shape=(224, 224, 3))

def get_training_model():
    base.trainable = False
    class_head = base.output
    class_head = GlobalAveragePooling2D()(class_head)
    class_head = Dense(512, activation="relu")(class_head)
    class_head = Dropout(0.5)(class_head)
    class_head = Dense(5, activation="softmax")(class_head)

    classifier = Model(inputs=base.input, outputs=class_head)

    classifier.compile(loss="sparse_categorical_crossentropy", 
                          optimizer="adam",
                          metrics=["accuracy"])

    return classifier

test_model = get_training_model()
history = test_model.fit(train_ds,
              validation_data=validation_ds,
              epochs=5)

test_model_dir = "./test_model"
test_model.save(test_model_dir)

converter = tf.lite.TFLiteConverter.from_saved_model(test_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
f = open("test_model.tflite", "wb")
f.write(quantized_tflite_model)
f.close()

# Referred from: https://www.tensorflow.org/lite/performance/post_training_integer_quant
def evaluate_model(interpreter):
    accurate_count = 0

    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    predictions = []
    for (val_images, val_labels) in validation_ds:
        for val_image, val_label in zip(val_images, val_labels):
            val_image = tf.expand_dims(val_image, 0)
            interpreter.set_tensor(input_index, val_image)

            # Run inference.
            interpreter.invoke()

            # Post-processing: remove batch dimension and find the digit with highest
            # probability.
            probability = interpreter.get_tensor(output_index)
            flower_id = np.argmax(probability[0])
            predictions.append(flower_id)

            # Compare prediction results with ground truth labels to calculate accuracy.
            if flower_id == val_label:
                accurate_count += 1
    
    accuracy = accurate_count * 1.0 / len(predictions)

    return accuracy

interpreter_test = tf.lite.Interpreter(model_path="test_model.tflite")
interpreter_test.allocate_tensors()

accuracy = evaluate_model(interpreter_test)
print("accuracy is {}".format(accuracy))

The output from the converter invocation

accuracy is 0.4332425068119891

Failure details
If I remove the line:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
from the script, then
accuracy is 0.9264305177111717

The converted model with this settings is wrong.

@wwwind wwwind added the TFLiteConverter For issues related to TFLite converter label May 29, 2020
@lheim
Copy link

lheim commented May 30, 2020

Hi @wwwind,
enabling converter.optimizations = [tf.lite.Optimize.DEFAULT] results in the weights being quantized from float32 to int8 - see here.
If you disable this option your converted network stores the weights in float32, as your original keras model. You shouldn't see a difference in accuracy and network size. While your network with int8 weights should be ~1/4 of the size and the accuracy usually decreases.

Quantizing the weights can result in a significant decrease in the accuracy, as you limit the dynamic range of the weights. You might want to explore the dynamic range of your original network. Looks like your model accuracy suffers immensely under the quantization.

You can also try to quantize the weights into float16 and check the accuracy again, see here.

@amahendrakar
Copy link
Contributor

@wwwind,
Could you please check @lheim's comment and let us know if it works. Thanks!

@amahendrakar amahendrakar added comp:lite TF Lite related issues TF 2.2 Issues related to TF 2.2 type:support Support issues stat:awaiting response Status - Awaiting response from author labels Jun 1, 2020
@wwwind
Copy link
Contributor Author

wwwind commented Jun 1, 2020

@amahendrakar No,
The bug is that with weights quantized, the accuracy drops significantly -
from 0.9264305177111717 to 0.4332425068119891
This does not look right.

@amahendrakar
Copy link
Contributor

Was able to reproduce the issue with TF v2.2 and TF-nightly. Please find the attached gist. Thanks!

@amahendrakar amahendrakar removed the stat:awaiting response Status - Awaiting response from author label Jun 1, 2020
@jvishnuvardhan
Copy link
Contributor

@wwwind Are you looking for int8 tflite model or float tflite model? Let me check the converted model and respond to you. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting response Status - Awaiting response from author label Jun 5, 2020
@wwwind
Copy link
Contributor Author

wwwind commented Jun 5, 2020

Hi @jvishnuvardhan Problem is that accuracy is much worse with just weights in int8 than when the model is fully int8 quantized.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jun 7, 2020
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 8, 2020
@abattery abattery assigned liufengdb and unassigned abattery Jun 8, 2020
@abattery
Copy link
Contributor

abattery commented Jun 8, 2020

@liufengdb could you take a look at this?

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 24, 2020
@jvishnuvardhan
Copy link
Contributor

@wwwind I think this was resolved in recent tf-nightly. I cannot reproduce low accuracy issue with tf-nightly. Please check the gist here. Thanks!

Please verify once and close the issue if this was resolved for you. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting response Status - Awaiting response from author label Dec 13, 2020
@wwwind
Copy link
Contributor Author

wwwind commented Dec 14, 2020

thanks for fixing!

@wwwind wwwind closed this as completed Dec 14, 2020
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 TFLiteConverter For issues related to TFLite converter type:support Support issues
Projects
None yet
Development

No branches or pull requests

7 participants