# EfficientNetB7 迁移学习
这是可能会令你非常自豪的一个程序！这个程序过后，或许你可以确信做出类似花伴侣这样酷的应用！<br/>
<a href="http://www.aiplants.net/">花伴侣</a><br/>
花草树木，一拍呈名。只需要拍摄植物的花、果、叶等特征部位，即可快速识别植物。花伴侣能识别中国野生及栽培植物3000属，近5000种，几乎涵盖身边所有常见花草树木。

In [None]:
# 导入库
import math, re, os
# 设置log等级，只输出Error和Fatal级别的信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets   # 用于访问Google云服务器上的数据集
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
print("Tensorflow version " + tf.__version__)

## 1 使用TPU训练模型
为了使用TPU，需要首先检测和链接TPU，并根据可用的TPU加速单元数量，决定批处理大小。学习率动态调度策略也考虑了这一点。

In [None]:
# 检测并链接 TPU 
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
strategy = tf.distribute.TPUStrategy(tpu)
AUTO = tf.data.experimental.AUTOTUNE   # 并行化训练模式
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
IMAGE_SIZE = [512, 512]   # 输入图像的尺寸
EPOCHS = 13  # 训练代数
BATCH_SIZE = 32 * strategy.num_replicas_in_sync  # 根据TPU加速器数量设定批处理大小

## 2 数据集目录观察
花朵数据集包含四种尺寸规格的图片，可以根据需要任意选择。这里选定512x512的数据集。

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path() # 获取数据集在Google Cloud Storage上的地址 
GCS_DS_PATH

In [None]:
GCS_PATH_SELECT = { # 不同尺寸的数据集
    192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
    224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
    331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
    512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
}
GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]  # 返回512x512数据集的路径

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')  # 返回训练集文件列表
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')  # 返回验证集文件列表
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') # 返回测试集测试集文件列表

In [None]:
# 训练集目录列表
TRAINING_FILENAMES

In [None]:
# 验证集目录列表
VALIDATION_FILENAMES

In [None]:
# 测试集目录列表
TEST_FILENAMES

In [None]:
# 104种花朵名称（标签）
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     
           'wild geranium',     'tiger lily',           'moon orchid',          'bird of paradise', 
           'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 
           'yellow iris',       'globe-flower',         'purple coneflower',    'peruvian lily', 
           'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',   
           'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian',
           'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        
           'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',          'great masterwort', 
           'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    
           'bolero deep blue',  'wallflower',           'marigold',             'buttercup',    
           'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',   
           'lilac hibiscus',    'bishop of llandaff',   'gaura',                'geranium',   
           'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',
           'californian poppy', 'osteospermum',         'spring crocus',        'iris',      
           'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',      
           'thorn apple',       'morning glory',        'passion flower',       'lotus',    
           'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',   
           'desert-rose',       'tree mallow',          'magnolia',             'cyclamen ',  
           'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',   
           'bougainvillea',     'camellia',             'mallow',               'mexican petunia', 
           'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose'] # 100 - 103

##  3 学习率调度函数
学习率对模型训练异常重要。为此，指定学习率动态调度策略非常必要！

In [None]:
LR_START = 0.00001  # 学习率初值
LR_MAX = 0.00005 * strategy.num_replicas_in_sync  # 学习率最大值
LR_MIN = 0.00001   # 学习率最小值
LR_RAMPUP_EPOCHS = 4   # 学习率增长代数
LR_SUSTAIN_EPOCHS = 0  # 学习率保持不变的代数
LR_EXP_DECAY = .8  # 学习率衰减因子

# 学习率调度函数
def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
# 学习率回调函数   
lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)

# 绘制学习率变化曲线，观察模型训练期间学习率变化规律
rng = range(EPOCHS)
y = [lrfn(x) for x in rng]
plt.plot(rng, y)
print(f"学习率调度策略，从最小值 {y[0]} 到最大值： {max(y)} 再衰减到： {y[-1]}")

## 4 可视化函数
定义显示图像函数、显示模型训练曲线函数、显示混淆矩阵的函数，分别用于观察数据集、观察模型训练效果和预测结果。

In [None]:
# 控制输出的显示方式
np.set_printoptions(threshold=15, linewidth=80)

def batch_to_numpy_images_and_labels(data): # 数据集转换为 Numpy类型
    images, labels = data  # 图像和标签
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # 对于测试集，numpy_labels返回值为 None
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label): # 预测标签与真实标签比较
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', "?" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):  # 显示一幅花朵图片
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), 
                  color='red' if red else 'black', 
                  fontdict={'verticalalignment':'center'}, 
                  pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):  # 批量显示
    """
    用法:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # 设定行数和列数，不显示不够整行的图片
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # 显示尺寸和间距
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # 显示
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    # 布局
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall): # 显示混淆矩阵
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 
                                               'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
# 显示模型训练曲线    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # 首次调用
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

## 5 定义数据集
兵马未动，粮草先行，数据集预处理始终是建模第一步！细心对待训练集、验证集和测试集！

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3) # image format uint8 [0,255]
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.with_options(ignore_order) 
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label   

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() 
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) 
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALIDATION_STEPS = -(-NUM_VALIDATION_IMAGES // BATCH_SIZE)
TEST_STEPS = -(-NUM_TEST_IMAGES // BATCH_SIZE) 
print(f'训练集图像数量: {NUM_TRAINING_IMAGES} ，\
      验证集图像数量： {NUM_VALIDATION_IMAGES}，\
      测试集图像数量：{NUM_TEST_IMAGES}')

### 数据集观察
抽样显示数据集图片，建立感性认识

In [None]:
# data dump
print("Training data shapes:")
for image, label in get_training_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())
print("Validation data shapes:")
for image, label in get_validation_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Validation data label examples:", label.numpy())
print("Test data shapes:")
for image, idnum in get_test_dataset().take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test data IDs:", idnum.numpy().astype('U')) # U=unicode string

In [None]:
# 观察训练集
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

In [None]:
# 随机抽样
display_batch_of_images(next(train_batch))

In [None]:
# 观察测试集
test_dataset = get_test_dataset()
test_dataset = test_dataset.unbatch().batch(20)
test_batch = iter(test_dataset)

In [None]:
# 随机抽样
display_batch_of_images(next(test_batch))

##  6 定义模型 EfficientNetB7
用TPU模式定义

In [None]:
with strategy.scope():
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    pretrained_model = hub.KerasLayer('https://tfhub.dev/tensorflow/efficientnet/b7/feature-vector/1', 
                                      trainable=True,
                                      input_shape=[*IMAGE_SIZE, 3], 
                                      load_options=load_locally)
    model = tf.keras.Sequential([
        # the expected image format for all TFHub image models is float32 in [0,1) range
        tf.keras.layers.Lambda(lambda data: tf.image.convert_image_dtype(data, tf.float32), input_shape=[*IMAGE_SIZE, 3]),
        pretrained_model,
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
        
model.compile(
    optimizer='adam',
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'],
    steps_per_execution=16
)
model.summary()

## 7 模型训练

In [None]:
history = model.fit(get_training_dataset(), 
                    steps_per_epoch=STEPS_PER_EPOCH, 
                    epochs=EPOCHS,
                    validation_data=get_validation_dataset(), 
                    validation_steps=VALIDATION_STEPS,
                    callbacks=[lr_callback])

## 8 显示准确率和损失函数曲线

In [None]:
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
display_training_curves(history.history['sparse_categorical_accuracy'], 
                        history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)

## 9 保存模型

In [None]:
save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save('./EfficientNetB7', options=save_locally) # saving in Tensorflow's "SavedModel" format

## 10 模型评估--混淆矩阵、F1-Score

In [None]:
load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
model = tf.keras.models.load_model('./EfficientNetB7', options=load_locally)

In [None]:
cmdataset = get_validation_dataset(ordered=True) # 验证集
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()
# 真实标签
cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy()
# 预测标签
cm_probabilities = model.predict(images_ds, steps=VALIDATION_STEPS)
cm_predictions = np.argmax(cm_probabilities, axis=-1)
print("验证集真实标签: ", cm_correct_labels.shape, cm_correct_labels)
print("验证集预测标签: ", cm_predictions.shape, cm_predictions)

In [None]:
cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
display_confusion_matrix(cmat, score, precision, recall)
print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))

##  11 模型预测
在测试集上预测，保存预测结果到 submission.csv 文件中...

In [None]:
test_ds = get_test_dataset(ordered=True)

print('在测试集上做预测...')
test_images_ds = test_ds.map(lambda image, idnum: image)
probabilities = model.predict(test_images_ds, steps=TEST_STEPS)
predictions = np.argmax(probabilities, axis=-1)
print(predictions)

print('生成预测结果文件 submission.csv...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), 
           fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
!head submission.csv

###  对验证集预测结果的可视化观察

In [None]:
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

In [None]:
# 在验证集上随机抽样观察
images, labels = next(batch)
probabilities = model.predict(images)
predictions = np.argmax(probabilities, axis=-1)
display_batch_of_images((images, labels), predictions)

### 对测试集预测结果的可视化观察

In [None]:
dataset = get_test_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

In [None]:
# 在测试集上随机抽样观察
images, labels = next(batch)
probabilities = model.predict(images)
predictions = np.argmax(probabilities, axis=-1)
display_batch_of_images((images, labels), predictions)

至此，颇有成就感！我们得到了一个性能非常好的花朵识别模型。虽然只有104种花朵。只要您有新的数据集，您完全可以在此基础上，继续运用迁移学习迭代下去，识别更多类型的花朵！！