# Retraining an Image Classifier

## 1. 环境配置

安装镜像中未安装的 tensorflow_hub 与 pillow，以及我这里出现的错误: *OSError: image file is truncated*，猜测是因为服务器允许的单次数据大小问题，详细讨论见 [Python PIL “IOError: image file truncated” with big images](https://stackoverflow.com/questions/12984426/python-pil-ioerror-image-file-truncated-with-big-images)。

In [None]:
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -U tensorflow_hub pillow

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")

## 2. 模型加载
这里是自己手动下载的方式。下载后需要解压。取消下面一行的注释并执行即可。

In [None]:
#!tar zxvf ./models/imagenet_mobilenet_v2_140_224_classification_4.tar.gz -C ./models

In [None]:
!ls models

In [None]:
MODULE_HANDLE = "./models"
IMAGE_SIZE = (224, 224)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

BATCH_SIZE = 16 #@param {type:"integer"}

## 3. 数据集处理

将图片缩放裁剪以符合模型输入，以及划分训练验证集、打乱顺序、数据集增强。

In [None]:
data_dir = './datasets'
#!ls ./datasets

In [None]:
datagen_kwargs = dict(rescale=1./255, validation_split=.20)
dataflow_kwargs = dict(target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
                   interpolation="bilinear")

valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    **datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
    data_dir, subset="validation", shuffle=False, **dataflow_kwargs)

do_data_augmentation = True #@param {type:"boolean"}
if do_data_augmentation:
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=40,
        horizontal_flip=True,
        width_shift_range=0.2, height_shift_range=0.2,
        shear_range=0.2, zoom_range=0.2,
        **datagen_kwargs
    )
else:
    train_datagen = valid_datagen

train_generator = train_datagen.flow_from_directory(
    data_dir, subset="training", shuffle=True, **dataflow_kwargs)

## 4. 模型定义

丢弃原模型的输出层，新建输出层以适应自己数据集的种类。

In [None]:
do_fine_tuning = True #@param {type:"boolean"}

In [None]:
print("Building model with", MODULE_HANDLE)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(MODULE_HANDLE, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(train_generator.num_classes,
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()

## 5. 模型训练

### 5.1 优化器

此处使用 SGB，随机梯度下降法。

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), 
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
    metrics=['accuracy']
)

### 5.2 训练

In [None]:
steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
    train_generator,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=validation_steps).history

### 5.3 loss、accuracy 曲线

In [None]:
plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"], label='loss')
plt.plot(hist["val_loss"], label='val_loss')
plt.legend();

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"], label='accuracy')
plt.plot(hist["val_accuracy"], label='val_accuracy')
plt.legend();

### 5.4 保存模型

In [None]:
saved_model_path = "./output/saved_model"
tf.saved_model.save(model, saved_model_path)