In [None]:
from __future__ import print_function
import os,time,cv2, sys, math
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import time, datetime
import argparse
import random
import subprocess
import configuration
from utils import utils, helpers
from builders import model_builder
import matplotlib.pyplot as plt
import matplotlib

# specify the GPU 
os.environ['CUDA_VISIBLE_DEVICES'] = '/device:GPU:0'
# import the train configuration
train_config = configuration.train_config

# use 'Agg' on matplotlib so that plots could be generated even without Xserver
matplotlib.use('Agg')

# define loolean type for data augmentation
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

# specify the training hyper-parameters
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=62, help='Number of epochs to train for')
parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
parser.add_argument('--checkpoint_step', type=int, default=4, help='How often to save checkpoints (epochs)')
parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)')
parser.add_argument('--image', type=str, default=None, help='The image you want to predict on. Only valid in "predict" mode.')
parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.')
parser.add_argument('--crop_height', type=int, default=640, help='Height of cropped input image to network')
parser.add_argument('--crop_width', type=int, default=800, help='Width of cropped input image to network')
parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
parser.add_argument('--num_val_images', type=int, default=30, help='The number of images to used for validations')
parser.add_argument('--h_flip', type=str2bool, default=True, help='Whether to randomly flip the image horizontally for data augmentation')
parser.add_argument('--v_flip', type=str2bool, default=True, help='Whether to randomly flip the image vertically for data augmentation')
parser.add_argument('--brightness', type=float, default=0.1, help='Whether to randomly change the image brightness for data augmentation. Specifies the max bightness change as a factor between 0.0 and 1.0. For example, 0.1 represents a max brightness change of 10%% (+-).')
parser.add_argument('--rotation', type=float, default=30, help='Whether to randomly rotate the image for data augmentation. Specifies the max rotation angle in degrees.')
parser.add_argument('--model', type=str, default="BiSeNet", help='The model you are using. See model_builder.py for supported models')
parser.add_argument('--frontend', type=str, default="xception", help='The frontend you are using. See frontend_builder.py for supported models')
args = parser.parse_args(args=[])

# define the data augmentation method
def data_augmentation(input_image, output_image):
    # crop the input image to the specified size
    input_image, output_image = utils.random_crop(input_image, output_image, args.crop_height, args.crop_width)
    # 0: filp vertical  1: flip horizontal  
    if args.h_flip and random.randint(0,1):
        input_image = cv2.flip(input_image, 1)
        output_image = cv2.flip(output_image, 1)
    if args.v_flip and random.randint(0,1):
        input_image = cv2.flip(input_image, 0)
        output_image = cv2.flip(output_image, 0)
    # random change the brightness
    if args.brightness:
        factor = 1.0 + random.uniform(-1.0*args.brightness, args.brightness)
        table = np.array([((i / 255.0) * factor) * 255 for i in np.arange(0, 256)]).astype(np.uint8)
        input_image = cv2.LUT(input_image, table)
    # random rotation for the specified degree, etc. 30.
    if args.rotation:
        angle = random.uniform(-1*args.rotation, args.rotation)
    if args.rotation:
        M = cv2.getRotationMatrix2D((input_image.shape[1]//2, input_image.shape[0]//2), angle, 1.0)
        input_image = cv2.warpAffine(input_image, M, (input_image.shape[1], input_image.shape[0]), flags=cv2.INTER_NEAREST)
        output_image = cv2.warpAffine(output_image, M, (output_image.shape[1], output_image.shape[0]), flags=cv2.INTER_NEAREST)
    # return the pre-processed image.
    return input_image, output_image

# Global_step is used as a count in training, adding 1 for each batch of training
global_step = tf.Variable(initial_value=0,
                           name='global_step',
                           trainable=False,
                           collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

# define the configuration of learning rate
def _configure_learning_rate(train_config, global_step):
    lr_config = train_config['lr_config']
    # specify the iterations for an epoch
    num_batches_per_epoch = int(421 / args.batch_size)
    # apply the 'polynomial' learning rate policy
    T_total = (int(num_batches_per_epoch)+1) * args.num_epochs
    return lr_config['initial_lr'] * (1 - tf.to_float(global_step)/T_total)**lr_config['power']


# Get the names of the classes so we can record the evaluation results
class_names_list, label_values = helpers.get_label_info(os.path.join(args.dataset, "class_dict.csv"))
class_names_string = ""
for class_name in class_names_list:
    if not class_name == class_names_list[-1]:
        class_names_string = class_names_string + class_name + ", "
    else:
        class_names_string = class_names_string + class_name
# there are 32 classes in this project
num_classes = len(label_values)
# allow use gpu for training
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)

# net_input is the RGB image with 3 channels, net_output is the semantic image with 32 channels
net_input = tf.placeholder(tf.float32,shape=[None,None,None,3])
net_output = tf.placeholder(tf.float32,shape=[None,None,None,num_classes])
# load the network
network, init_fn = model_builder.build_model(model_name=args.model, frontend=args.frontend, net_input=net_input, num_classes=num_classes, crop_width=args.crop_width, 
crop_height=args.crop_height, is_training=True)
# Compute your softmax cross entropy loss
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=network, labels=net_output))
# define learning rate configuration
learning_rate = _configure_learning_rate(train_config, global_step)
# define the configuration of optimizer
optimizer_config = train_config['optimizer_config']
optimizer = tf.train.MomentumOptimizer(learning_rate,
                                        momentum=optimizer_config['momentum'],
                                        use_nesterov=optimizer_config['use_nesterov'],
                                        name='Momentum')
# minimize the loss
opt = optimizer.minimize(loss, var_list=[var for var in tf.trainable_variables()])
# define a saver instance
saver=tf.train.Saver(max_to_keep=500)
# Initialize the parameters of the model
sess.run(tf.global_variables_initializer())
# Count total number of parameters in the model
utils.count_params()


# If a pre-trained ResNet is required, load the weights.
# This must be done AFTER the variables are initialized with sess.run(tf.global_variables_initializer())
if init_fn is not None:
    init_fn(sess)
    
# Load a previous checkpoint if desired
model_checkpoint_name = "checkpoints/latest_model_" + args.model + "_" + args.dataset + ".ckpt"

# Load the data
print("Loading the data ...")
train_input_names,train_output_names, val_input_names, val_output_names, test_input_names, test_output_names = utils.prepare_data(dataset_dir=args.dataset)
# firstly, show the training hyper-parameters
print("\n***** Begin training *****")
print("Dataset -->", args.dataset)
print("Model -->", args.model)
print("Crop Height -->", args.crop_height)
print("Crop Width -->", args.crop_width)
print("Num Epochs -->", args.num_epochs)
print("Batch Size -->", args.batch_size)
print("Num Classes -->", num_classes)
print("Data Augmentation:")
print("\tVertical Flip -->", args.v_flip)
print("\tHorizontal Flip -->", args.h_flip)
print("\tBrightness Alteration -->", args.brightness)
print("\tRotation -->", args.rotation)
print("")

# create some empty arrays for further storage
avg_loss_per_epoch = []
avg_scores_per_epoch = []
avg_iou_per_epoch = []
val_indices = []
# specify the numbers of validation image for an epoch
num_vals = min(args.num_val_images, len(val_input_names))

# Set random seed to make sure models are validated on the same validation images.
# so that the validation results for different networks can be compared
random.seed(16)
val_indices=random.sample(range(0,len(val_input_names)),num_vals)

# start the training process
for epoch in range(args.epoch_start_i, args.num_epochs):
    current_losses = []
    count=0

    # random sort the training images
    id_list = np.random.permutation(len(train_input_names))
    # compute the number of iterations for an epoch
    num_iters = int(np.floor(len(id_list) / args.batch_size))
    # time start point
    st = time.time()
    epoch_st=time.time()
    for i in range(num_iters):
        input_image_batch = []
        output_image_batch = []
        # Collect a batch of images
        for j in range(args.batch_size):
            index = i*args.batch_size + j
            id = id_list[index]
            input_image = utils.load_image(train_input_names[id])
            output_image = utils.load_image(train_output_names[id])

            with tf.device('/gpu:0'):
                input_image, output_image = data_augmentation(input_image, output_image)

                # Prep the data. Make sure the labels are in one-hot format
                input_image = np.float32(input_image) / 255.0
                output_image = np.float32(helpers.one_hot_it(label=output_image, label_values=label_values))
                input_image_batch.append(np.expand_dims(input_image, axis=0))
                output_image_batch.append(np.expand_dims(output_image, axis=0))

        if args.batch_size == 1:
            input_image_batch = input_image_batch[0]
            output_image_batch = output_image_batch[0]
        else:
            input_image_batch = np.squeeze(np.stack(input_image_batch, axis=1))
            output_image_batch = np.squeeze(np.stack(output_image_batch, axis=1))
            
        # Do the training, opt is the optimizer, loss is cross-entropy, net_input is the input image, net_output is the labels
        _,current=sess.run([opt,loss],feed_dict={net_input:input_image_batch,net_output:output_image_batch})
        # storage the current loss
        current_losses.append(current)
        count = count + args.batch_size
        # show the training details for every 20 input images, epoch, count and current loss
        if count % 20 == 0:
            string_print = "Epoch = %d Count = %d Current_Loss = %.4f Time = %.2f"%(epoch,count,current,time.time()-st)
            utils.LOG(string_print)
            st = time.time()
    # show the average loss for a whole process of validation
    mean_loss = np.mean(current_losses)
    avg_loss_per_epoch.append(mean_loss)

    # Create folder if needed, there are four files in the 'checkpoints' folder, storages the trained model weights
    if not os.path.isdir("%s/%04d"%("checkpoints",epoch)):
        os.makedirs("%s/%04d"%("checkpoints",epoch))

    # Save latest checkpoint to same file name
    print("Saving latest checkpoint")
    saver.save(sess,model_checkpoint_name)
    # save the trained model for every (epoch/validation_step) epochs
    if epoch % args.validation_step == 0:
        print("Performing validation")
        target=open("%s/%04d/val_scores.csv"%("checkpoints",epoch),'w')
        target.write("val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n" % (class_names_string))
    # create some arrays to store the training results
        scores_list = []
        class_scores_list = []
        precision_list = []
        recall_list = []
        f1_list = []
        iou_list = []

        # Do the validation on a small set of validation images, etc.30
        for ind in val_indices:
            input_image = np.expand_dims(np.float32(utils.load_image(val_input_names[ind])[:args.crop_height, :args.crop_width]),axis=0)/255.0
            gt = utils.load_image(val_output_names[ind])[:args.crop_height, :args.crop_width]
            gt = helpers.reverse_one_hot(helpers.one_hot_it(gt, label_values))
            # do the validation
            output_image = sess.run(network,feed_dict={net_input:input_image})
            output_image = np.array(output_image[0,:,:,:])
            # convert the output image into the one-hot version
            output_image = helpers.reverse_one_hot(output_image)
            # revise the grey scale output image to the RCB image
            out_vis_image = helpers.colour_code_segmentation(output_image, label_values)
            # compute the pixel-wise accuracy, iou precise and so on
            accuracy, class_accuracies, prec, rec, f1, iou = utils.evaluate_segmentation(pred=output_image, label=gt, num_classes=num_classes)
            file_name = utils.filepath_to_name(val_input_names[ind])
            # write this reults to the file
            target.write("%s, %f, %f, %f, %f, %f"%(file_name, accuracy, prec, rec, f1, iou))
            
            for item in class_accuracies:
                target.write(", %f"%(item))
            target.write("\n")
            # scores_list stores the total accuracy for an output image
            scores_list.append(accuracy)
            # class_scores_list stores the pixel-wise classification accuracy for every objects
            class_scores_list.append(class_accuracies)
            precision_list.append(prec)
            recall_list.append(rec)
            f1_list.append(f1)
            iou_list.append(iou)
            # convert the grey scale ground truth image as the colored image to have a better view
            gt = helpers.colour_code_segmentation(gt, label_values)
            # find the file name for a specific validation image
            file_name = os.path.basename(val_input_names[ind])
            file_name = os.path.splitext(file_name)[0]
            # convert it into the colored image
            cv2.imwrite("%s/%04d/%s_pred.png"%("checkpoints",epoch, file_name),cv2.cvtColor(np.uint8(out_vis_image), cv2.COLOR_RGB2BGR))
            cv2.imwrite("%s/%04d/%s_gt.png"%("checkpoints",epoch, file_name),cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2BGR))
        target.close()
        avg_score = np.mean(scores_list)
        class_avg_scores = np.mean(class_scores_list, axis=0)
        avg_scores_per_epoch.append(avg_score)
        avg_precision = np.mean(precision_list)
        avg_recall = np.mean(recall_list)
        avg_f1 = np.mean(f1_list)
        avg_iou = np.mean(iou_list)
        avg_iou_per_epoch.append(avg_iou)
        
        #display the average global accuracy of 30 images
        print("\nAverage validation accuracy for epoch # %04d = %f"% (epoch, avg_score))
        #Verify the average class accuracy of 30 images
        print("Average per class validation accuracies for epoch # %04d:"% (epoch))
        for index, item in enumerate(class_avg_scores):
            print("%s = %f" % (class_names_list[index], item))
        # 30 validation images average Precise
        print("Validation precision = ", avg_precision)
        # 30 validation images average recall
        print("Validation recall = ", avg_recall)
        print("Validation F1 score = ", avg_f1)
        print("Validation IoU score = ", avg_iou)
    # compute the time consumption for an epoch
    epoch_time=time.time()-epoch_st
    # compute the remaining time needed for the whole training process
    remain_time=epoch_time*(args.num_epochs-1-epoch)
    m, s = divmod(remain_time, 60)
    h, m = divmod(m, 60)

    if s!=0:
        train_time="Remaining training time = %d hours %d minutes %d seconds\n"%(h,m,s)
    # training process completed
    else:
        train_time="Remaining training time : Training completed.\n"

    utils.LOG(train_time)
    scores_list = []
    
    # three figures will be created to show the validation accuracy, loss, mean IOU during the whole training process.
    fig1, ax1 = plt.subplots(figsize=(11, 8))
    ax1.plot(range(epoch+1), avg_scores_per_epoch)
    ax1.set_title("Average validation accuracy vs epochs")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Avg. val. accuracy")

    plt.savefig('accuracy_vs_epochs.png')
    plt.clf()
    fig2, ax2 = plt.subplots(figsize=(11, 8))

    ax2.plot(range(epoch+1), avg_loss_per_epoch)
    ax2.set_title("Average loss vs epochs")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Current loss")

    plt.savefig('loss_vs_epochs.png')
    plt.clf()
    fig3, ax3 = plt.subplots(figsize=(11, 8))

    ax3.plot(range(epoch+1), avg_iou_per_epoch)
    ax3.set_title("Average IoU vs epochs")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Current IoU")
    plt.savefig('iou_vs_epochs.png')


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Preparing the model ...
This model has 2674832 trainable parameters
Loading the data ...

***** Begin training *****
Dataset --> CamVid
Model --> BiSeNet
Crop Height --> 224
Crop Width --> 224
Num Epochs --> 2
Batch Size --> 1
Num Classes --> 32
Data Augmentation:
	Vertical Flip --> True
	Horizontal Flip --> True
	Brightness Alteration --> 0.1
	Rotation --> 30

[2020-08-26 23:50:47] Epoch = 0 Count = 20 Current_Loss = 2.1063 Time = 22.04
