In [1]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers, initializers
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import tensorflow_datasets as tfds
import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)
assert float(tf.__version__[:3]) >= 2.3

#tfds.list_builders()

In [2]:
# 讀取model
model_dir = 'cifar10/models'
model = tf.keras.models.load_model(model_dir + "/Best-model-MobileNetV2-L2.h5")

In [3]:
# 資料處理
test_data, info = tfds.load("cifar10", split="test", with_info=True)

In [4]:
# 定義圖像縮放
def parse_aug_fn(dataset):
    def zoom(x, sh = 224, sw = 224):
        x = tf.image.resize(x, (sh, sw))
        return x
    
    # 影像標準化
    x = tf.cast(dataset['image'], tf.float32) / 255.
    # 影像放大到224*224
    x = zoom(x)
    y = tf.one_hot(dataset['label'], 10)
    y = y
    
    return x, y

In [5]:
# 資料處理
batch_size = 256
AUTOTUNE = tf.data.experimental.AUTOTUNE
test_data = test_data.map(map_func=parse_aug_fn, num_parallel_calls=AUTOTUNE)
test_data = test_data.prefetch(buffer_size=AUTOTUNE)
test_data_batch = test_data.batch(batch_size)

In [6]:
# 評估精準度(用batch加速計算)
model.evaluate(test_data_batch)



[0.3845701515674591, 0.9185000061988831]

In [7]:
# 轉換成tf.lite形式，並quantize。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_float_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# 以下放入的data只是為了提供Converter測量轉換的範圍
def representative_dataset_gen():
  for data in test_data.batch(1).take(100):
    yield [data[0]]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_quant_model = converter.convert()

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


INFO:tensorflow:Assets written to: C:\Users\user\AppData\Local\Temp\tmpgglvl4s7\assets


INFO:tensorflow:Assets written to: C:\Users\user\AppData\Local\Temp\tmpgglvl4s7\assets


INFO:tensorflow:Assets written to: C:\Users\user\AppData\Local\Temp\tmpv26wp5vm\assets


INFO:tensorflow:Assets written to: C:\Users\user\AppData\Local\Temp\tmpv26wp5vm\assets


INFO:tensorflow:Assets written to: C:\Users\user\AppData\Local\Temp\tmp0mzfplc8\assets


INFO:tensorflow:Assets written to: C:\Users\user\AppData\Local\Temp\tmp0mzfplc8\assets


In [9]:
# 儲存tflite model
import pathlib

tflite_models_dir = pathlib.Path("cifar10/models")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"cifar10_MobileNet.tflite"
tflite_model_file.write_bytes(tflite_float_model)


14156268

In [10]:
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"cifar10_MobileNet_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

4179520