# Flower Classification with Data Augmentations & Densenet201
We added and tested random contrast and bright change as data augmentation methods in this notebook.

This notebook is based on this work, thire augmentation methods is kept.
https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96

In [None]:
''' Code structure
1 imports
2 def # (original)
        decode_image
        read_labeled_tfrecord
        read_unlabeled_tfrecord
        load_dataset
        data_augment
        get_training_dataset
        get_validation_dataset
        get_test_dataset
        count_data_items
   Data Augmentation
        get_mat
        transform_ori
        lrfn
   ANN functions
        get_model
        train_cross_validate
        train_and_predict
3 def # Modeifications, our work
    
4 main function
    train_and_predict
5 Confusion Matrix and Validation Score   '''

In [None]:
import time
print('Start: '+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))

import random, re, math
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

import tensorflow as tf, tensorflow.keras.backend as K
print('    Tensorflow version ' + tf.__version__)
from tensorflow.keras.applications import DenseNet201
from kaggle_datasets import KaggleDatasets


def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled = True, ordered = False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # Diregarding data order. Order does not matter since we will be shuffling the data anyway
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls = AUTO) # returns a dataset of (image, label) pairs if labeled = True or (image, id) pair if labeld = False
    return dataset

def data_augment(image, label):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    return image, label   

def get_training_dataset(dataset,do_aug=True):
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    if do_aug: dataset = dataset.map(transform, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(dataset):
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames): # V
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

###############     ##############
# # Data Augmentation # 数据增强用函数
    # 以下代码使用GPU/TPU进行随机旋转，剪切，缩放和移位。 当图像从显示空白的边缘移开时，通过拉伸原始边缘上的颜色来填充空白。 
    # 在下面的函数“ transform（）”中更改变量，以控制所需的扩充量。 
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies  返回用于转换索引的3x3转换矩阵
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )   
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

def transform_ori(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted 随机旋转，剪切，缩放和移动
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 
    # GET TRANSFORMATION MATRIX 
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift)
    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3));#d: tf.Tensor(shape=(262144, 3), dtype=float32)  (for 512px image)        
    return tf.reshape(d,[DIM,DIM,3]),label

# 计算学习率用函数
def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr

# ANN functions #神经网络用函数
def get_model():
    print('    -------------get_model is running')
    with strategy.scope():
        rnet = DenseNet201(
            input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),weights='imagenet',include_top=False)
        # trainable rnet
        rnet.trainable = True
        model = tf.keras.Sequential([rnet,tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(len(CLASSES), activation='softmax',dtype='float32') ])
    model.compile(optimizer='adam',loss = 'sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
    return model

def train_cross_validate(folds = 5):
    print('    -------------train_cross_validate is running')
    histories = [];   models = [];
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 3)
    kfold = KFold(folds, shuffle = True, random_state = SEED)
    for f, (trn_ind, val_ind) in enumerate(kfold.split(TRAINING_FILENAMES)):
        print('    '+'#'*25);print('    ### FOLD',f+1);
        train_dataset = load_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[trn_ind]['TRAINING_FILENAMES']), labeled = True)
        val_dataset = load_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_ind]['TRAINING_FILENAMES']), labeled = True, ordered = True)
        model = get_model()
        if no_training_mode:
            print('    No training mode is on, training is skipped: no_training_mode == True')
            models.append(model)
            return histories, models
        else:
            history = model.fit(
                get_training_dataset(train_dataset), 
                steps_per_epoch = STEPS_PER_EPOCH,epochs = EPOCHS,
                callbacks = [lr_callback],#, early_stopping],
                validation_data = get_validation_dataset(val_dataset),
                verbose=1)
            models.append(model)
            histories.append(history)
        if no_folds:
            print('    No folds mode is on, the other folds is skipped: no_training_mode == no_folds')
            return histories, models
    return histories, models

def train_and_predict(folds = 5):
    print('    -------------train_and_predict is running')
    test_ds = get_test_dataset(ordered=True) 
        # since we are splitting the dataset and iterating separately on images and ids, order matters.
    test_images_ds = test_ds.map(lambda image, idnum: image)
    
    # 正式训练
    print('Start training %i folds'%folds)
    histories, models = train_cross_validate(folds = folds)
    
    print('Computing predictions...')
    # get the mean probability of the folds models 交叉验证得到平均概率
    if no_folds: # 快速测试时仅仅只跑一次fold，而非交叉验证（[models[0]）
        probabilities = np.average([models[0].predict(test_images_ds)], axis = 0)
    else:
        probabilities = np.average([models[i].predict(test_images_ds) for i in range(folds)], axis = 0)
        
    predictions = np.argmax(probabilities, axis=-1)
    print('Generating submission.csv file...')
    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
    np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
    return histories, models
    
######################### Modified Functions ######################### Our work
def plot_history(history, label, epcohs):
    data = {}; 
    data[label] = [0]; data[label].extend(history.history[label]) 
    data['val_' + label] = [0];  data['val_' + label].extend(history.history['val_' + label])
    pd.DataFrame(data).plot(figsize=(8, 5))
    plt.grid(True);plt.axis([1, epochs, 0, data[label][1]*1.5]);plt.show()
    
def contrastImg_ts(img_file1,alpha,beta,gamma):  #Adjust Contrast & Brightness randomly
    #随机调整对比度及亮度\
    alpha = tf.random.uniform([1],dtype='float32')*0.3+0.7;
    gamma = (tf.random.uniform([1],dtype='float32')-0.5)*0.8;
    beta = 1.0 - alpha;#     print(IMAGE_SIZE)
    img2 = tf.zeros([IMAGE_SIZE[0],IMAGE_SIZE[1],3],dtype='float32');
    print("原图权重=",alpha);
    contrasted = img_file1*alpha + img2*beta + gamma;
    return contrasted
    
def transform(image,label): # Basicly same as the original
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0];XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32');
    shr = 5. * tf.random.normal([1],dtype='float32') ;
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.;
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.;
    h_shift = 16. * tf.random.normal([1],dtype='float32') ;
    w_shift = 16. * tf.random.normal([1],dtype='float32') ;
  
    # GET TRANSFORMATION MATRIX 
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift);

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM );
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] );
    z = tf.ones([DIM*DIM],dtype='int32');
    idx = tf.stack( [x,y,z] );
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'));
    idx2 = K.cast(idx2,dtype='int32');
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2);
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] );
    img_changed = contrastImg_ts(image,alpha=1,beta=0,gamma=0);    # NEW
    d = tf.gather_nd(img_changed,tf.transpose(idx3));
    
    return tf.reshape(d,[DIM,DIM,3]),label

####   main function # 主函数
if __name__ == "__main__":
    upload_mode = True
    # Configuration 设置 及 超参数 
    print('Configuring: ' + time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
    if upload_mode:
        IMAGE_SIZE =[512,512]# [331,331]#[512,512];# [192,192];# 
        EPOCHS = 15
        FOLDS = 2;        SEED = 777;
        # BATCH_SIZE = 16 * strategy.num_replicas_in_sync   # 在下方，此处strategy未生成
        quick_test = False;no_training_mode = False; show_augmented_exampe = True; no_folds = False;
    else:
        IMAGE_SIZE = [512,512]#[192,192];
        EPOCHS = 20
        FOLDS = 2
        SEED = 777
        no_folds = True; # 跳过folds步骤
        quick_test = False; # 加载少量数据进行训练
        no_training_mode = False;#True; 
        show_augmented_exampe = False;#True; 
        
    MIXED_PRECISION = False;     XLA_ACCELERATE = False

    # Configurations 硬件配置
    print('Detecting Hardware:'+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
    # Detect hardware, return appropriate distribution strategy 检测硬件，返回适当的分配策略
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()# Cluster Resolver for Google Cloud TPUs.适用于GoogleCloudTPU的集群解析器。
        # TPU detection. No parameters necessary if TPU_NAME TPU检测。 如果TPU_NAME，则不需要任何参数
        #    environment variable is set. On Kaggle this is always the case. 环境变量已设置。 在Kaggle上，情况总是如此。
        print('    Running on TPU ', tpu.master())
    except ValueError:
        tpu = None

    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)# Connects to the given cluster.
        tf.tpu.experimental.initialize_tpu_system(tpu)# Initialize the TPU devices.
        strategy = tf.distribute.experimental.TPUStrategy(tpu)# TPU distribution strategy implementation.
    else:
        strategy = tf.distribute.get_strategy() 
        # default distribution strategy in Tensorflow. Works on CPU and single GPU.

    BATCH_SIZE = 16 * strategy.num_replicas_in_sync # 
    AUTO = tf.data.experimental.AUTOTUNE
    print("    REPLICAS: ", strategy.num_replicas_in_sync) # 多线程？
    

    # Mixed Precision and/or XLA 
    # 以下布尔值可以在GPU / TPU上启用混合精度和/或XLA。 默认情况下，TPU已经使用了某种混合精度，但我们可以添加更多精度。 
    #这些使GPU / TPU内存可以处理更大的批处理大小，并可以加快训练过程。 
    #Nvidia V100 GPU具有特殊的Tensor核心，在启用混合精度后会被利用。 Kaggle的Nvidia P100 GPU没有Tensor Core来加速。
    if MIXED_PRECISION:
        print('    Mixed precision enabled:'+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
        from tensorflow.keras.mixed_precision import experimental as mixed_precision
        if tpu: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        else: policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
        print('    Mixed precision enabled')
        
    if XLA_ACCELERATE:
        print('   Accelerated Linear Algebra:'+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
        tf.config.optimizer.set_jit(True)
        print('    Accelerated Linear Algebra enabled')
        
# # Data Directories 数据路径与获取
    # Data access # 数据路径
    print('Accessing Data: '+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
    GCS_DS_PATH = KaggleDatasets().get_gcs_path('flower-classification-with-tpus')
    # available image sizes :
    GCS_PATH_SELECT = { 
        192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
        331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'}
    print('    IMAGE_SIZE:' + str(IMAGE_SIZE)) # 图像尺寸
    GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]
    TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec') + tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
    # predictions on this dataset should be submitted for the competition
    TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec')
    
    if quick_test: # 数据截取(减少数据加速测试)
        print('数据截取中(减少数据加速测试): '+ time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
        TRAINING_FILENAMES = TRAINING_FILENAMES[0:2];#print(TRAINING_FILENAMES);
        TEST_FILENAMES = TEST_FILENAMES[0:1];#print(TEST_FILENAMES)
        print('\n    QUICK TEST MODEL!!! very little amount of data \n')
    else:
        print('\n    FULL TEST MODEL!!! ALL data\n')
    
    # Classes 定义花朵类别
    CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
               'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
               'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
               'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
               'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
               'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
               'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
               'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
               'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
               'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
               'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102
    
    # 生成学习率策略
    print('Generating Learning Rate: '+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
    # # Custom LR scheduler 学习率策略
    # From starter [kernel][1]: https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
    # TPU，GPU和CPU的学习率计划。 #使用LR加速是因为微调了预先训练的模型。# 从高LR开始会破坏预训练的权重。
    
    LR_START = 0.00001
    LR_MAX = 0.00005 * strategy.num_replicas_in_sync
    LR_MIN = 0.00001
    LR_RAMPUP_EPOCHS = 5
    LR_SUSTAIN_EPOCHS = 0
    LR_EXP_DECAY = .8
    
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = True)
    
    rng = [i for i in range(25 if EPOCHS<25 else EPOCHS)]
    y = [lrfn(x) for x in rng]
    if show_augmented_exampe:
        plt.plot(rng, y) 
    print("    Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))
    
    # # Dataset Functions # 获取数据
    # From starter [kernel][1]: https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
    
    # 求训练、验证、测试 数量
    NUM_TRAINING_IMAGES = int( count_data_items(TRAINING_FILENAMES) * (FOLDS-1.)/FOLDS )
    NUM_VALIDATION_IMAGES = int( count_data_items(TRAINING_FILENAMES) * (1./FOLDS) )
    NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
    STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
    print('    Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))
    
    # # Display Example Augmentation 展示增强效果
    # 以下是3个训练图像的示例，其中每个图像随机12次增强。
    if show_augmented_exampe:
        print('Plotting augmented exampe: '+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
        row = 3; col = 4;
        for ex_No in range(2): # Examples
            all_elements = get_training_dataset(load_dataset(TRAINING_FILENAMES),do_aug=False).unbatch()
            one_element = tf.data.Dataset.from_tensors( next(iter(all_elements)) )
            augmented_element = one_element.repeat().map(transform).batch(row*col)
            for (img,label) in augmented_element:
                plt.figure(figsize=(15,int(15*row/col)))
                for j in range(row*col):
                    plt.subplot(row,col,j+1);plt.axis('off');
                    plt.imshow(img[j,]);plt.show()
                break
        
    # run train and predict 运行训练
    print('Start training: '+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
    histories, models = train_and_predict(folds = FOLDS) # Original code
    
    # Plot loss descrease history # 绘制loss下降曲线
    epochs = EPOCHS;plot_history(histories[0], 'loss', epochs);

print('Everything End at: '+ time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))

In [None]:
# Confusion Matrix and Validation Score   
print('Evaluating:'+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))

# # Confusion Matrix and Validation Score 混淆矩阵 与 验证分数生成
def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
all_labels = []; all_prob = []; all_pred = []
kfold = KFold(FOLDS, shuffle = True, random_state = SEED)
for j, (trn_ind, val_ind) in enumerate( kfold.split(TRAINING_FILENAMES) ):
    print('    Inferring fold',j+1,'validation images...'+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
    VAL_FILES = list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_ind]['TRAINING_FILENAMES'])
    NUM_VALIDATION_IMAGES = count_data_items(VAL_FILES)
    cmdataset = get_validation_dataset(load_dataset(VAL_FILES, labeled = True, ordered = True))
    images_ds = cmdataset.map(lambda image, label: image)
    labels_ds = cmdataset.map(lambda image, label: label).unbatch()
    all_labels.append( next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() ) # get everything as one batch

    if quick_test or no_folds:
        prob = models[0].predict(images_ds)
        all_prob.append( prob )
        all_pred.append( np.argmax(prob, axis=-1) )
        break
    else:
        prob = models[j].predict(images_ds)
        all_prob.append( prob )
        all_pred.append( np.argmax(prob, axis=-1) )

print('Calculating Scores:'+time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
cm_correct_labels = np.concatenate(all_labels)
cm_probabilities = np.concatenate(all_prob)
cm_predictions = np.concatenate(all_pred)
# show predicted labels
if quick_test or no_folds:
    print("    Correct   labels: ", cm_correct_labels.shape, cm_correct_labels[0:10]);
    print("    Predicted labels: ", cm_predictions.shape, cm_predictions[0:10]);
else: 
    print("    Correct   labels: ", cm_correct_labels.shape, cm_correct_labels);
    print("    Predicted labels: ", cm_predictions.shape, cm_predictions);
# cal cmat, precision, recall
cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')\
# display confusion matrix
if not no_training_mode:
    display_confusion_matrix(cmat, score, precision, recall)

print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall));   

# Test Area
useless