# Quantization Aware Training
+ 以lenet为例介绍量化感知训练的基本流程
+ 部署量化感知训练后的网络模型

In [1]:
import tensorflow as tf
print(f"tf verion = {tf.__version__}")

import tensorflow_model_optimization as tfmot
from tensorflow.keras.layers import InputLayer,Reshape,Conv2D,MaxPool2D,Flatten,Dense,Dropout
from tensorflow.keras.models import load_model

tf verion = 2.2.0


## 解决GPU内存不足报错，对GPU进行按需分配

In [2]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

## 加载数据集

In [3]:
# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化输入图片，这样每个像素的值都在[0, 1]之间
x_train, x_test = x_train / 255.0, x_test / 255.0

# 扩张输入数据维度[height, width, channels(depth)]
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

## 构建LeNet模型

In [4]:
    '''
    LeNet5
        Conv2D(filters=6,kernel_size=5,strides=(1,1),padding='same',activation='relu',use_bias=False,input_shape=(28,28,1)),
        MaxPool2D(pool_size=(3,3),strides=2,padding="same"),
        Conv2D(filters=16,kernel_size=5,strides=(1,1),padding='same',activation='relu',use_bias=False),
        MaxPool2D(pool_size=(3,3),strides=2,padding="same"),
        Flatten(input_shape=(7, 7)),
        Dense(120, activation='relu'),
        Dense(84, activation='relu'),
        Dropout(0.2),
        Dense(10, activation='softmax')
    '''
model = tf.keras.models.Sequential([
        Conv2D(filters=12, kernel_size=(3, 3),activation='relu',input_shape=(28, 28, 1)),
        MaxPool2D(pool_size=(2,2)),
        Flatten(),
        Dense(10)
    ])

## 模型训练和评估（普通方式）

In [5]:
print("float32 model:")
model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
model.summary()
print("==> training")
model.fit(x_train, y_train, epochs=1, validation_split=0.1)
print("==> evaluate")
model.evaluate(x_test, y_test, verbose=2)

float32 model:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 12)        0         
_________________________________________________________________
flatten (Flatten)            (None, 2028)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                20290     
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
_________________________________________________________________
==> training
==> evaluate
313/313 - 2s - loss: 0.1600 - accuracy: 0.9544


[0.15999484062194824, 0.9544000029563904]

## 保存预训练模型

In [6]:
model.save("lenet_normal.hdf5")
model_json = model.to_json()
with open('lenet_normal.json', 'w') as file:
    file.write(model_json)

## 量化感知训练

In [8]:
# 加载预训练模型
pretrained_model = load_model("lenet_normal.hdf5") 

# 量化感知训练
print("quantized model:")
qat_model = tfmot.quantization.keras.quantize_model(pretrained_model)

qat_model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
qat_model.summary()
print("==> training")
x_train_subset = x_train[0:1000]
y_train_subset = y_train[0:1000]
qat_model.fit(x_train_subset, y_train_subset,
              batch_size=500,
              epochs=1,
              validation_split=0.1)
print("==> evaluate")
qat_model.evaluate(x_test, y_test, verbose=2)

quantized model:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
Total params: 20,444
Trainable params: 20,410
Non-trainable params: 34
_________________________________________________________________
==> training
==> evaluate
313/313 - 3s - loss: 2.8189 - accuracy: 0.3049


[2.8189375400543213, 0.30489999055862427]

In [14]:
qat_model.save("lenet_qat.hdf5")
qat_model.save_weights("./weights/lenet_qat.hd5")
qat_model_json = qat_model.to_json()
with open('lenet_qat.json', 'w') as file:
    file.write(qat_model_json)

In [15]:
model.load_weights("./weights/lenet_qat.hd5")
model.save("lenet_qat_nofake_quant.hdf5")


Two checkpoint references resolved to different objects (<tensorflow.python.keras.layers.core.Dense object at 0x7f35a4440310> and <tensorflow.python.keras.layers.pooling.MaxPooling2D object at 0x7f36088dea60>).


## 导出量化后的模型
量化之前通过量化感知训练得到的模型，格式为TFLite

In [9]:
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8

quantized_tflite_model = converter.convert()


AttributeError: module 'tensorflow' has no attribute 'contrib'

In [None]:
with open('lenet.tflite', 'wb') as file:
    file.write(quantized_tflite_model)