In [1]:
import os,time,cv2, sys, math
import tensorflow as tf
import argparse
import numpy as np
import random
from utils import utils, helpers
from builders import model_builder
# specify the gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '/device:GPU:0'
# hyper-parameters setting
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, default="./checkpoints/latest_model_BiSeNet_CamVid.ckpt", help='The path to the latest checkpoint weights for your model.')
parser.add_argument('--crop_height', type=int, default=640, help='Height of cropped input image to network')
parser.add_argument('--crop_width', type=int, default=800, help='Width of cropped input image to network')
parser.add_argument('--model', type=str, default="BiSeNet", help='The model you are using')
parser.add_argument('--dataset', type=str, default="CamVid", help='The dataset you are using')
args = parser.parse_args(args=[])

# Get the names of the classes so we can record the evaluation results
print("Retrieving dataset information ...")
class_names_list, label_values = helpers.get_label_info(os.path.join(args.dataset, "class_dict.csv"))
class_names_string = ""
for class_name in class_names_list:
    if not class_name == class_names_list[-1]:
        class_names_string = class_names_string + class_name + ", "
    else:
        class_names_string = class_names_string + class_name
num_classes = len(label_values)

# Initializing network and allow the use of gpu
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)
# Use placeholders as formal parameters to reduce the op of the training process
net_input = tf.placeholder(tf.float32,shape=[None,None,None,3])
# load the network (BiseNet, depth-wise BiseNet, depth-wise AAFF, depth-wise AAFF2)
network, _ = model_builder.build_model(args.model, net_input=net_input, num_classes=num_classes, crop_width=args.crop_width, crop_height=args.crop_height, is_training=False)

sess.run(tf.global_variables_initializer())

print('Loading model checkpoint weights ...')
# load the model weights
saver=tf.train.Saver(max_to_keep=1000)
saver.restore(sess, args.checkpoint_path)

# Load the data
print("Loading the data ...")
train_input_names,train_output_names, val_input_names, val_output_names, test_input_names, test_output_names = utils.prepare_data(dataset_dir=args.dataset)

# Which validation images do we want
val_indices = []
num_vals = min(30, len(val_input_names))

# Set random seed to make sure models are validated on the same validation images.
# So you can compare the results of different models more intuitively.
random.seed(16)
val_indices=random.sample(range(0,len(val_input_names)),num_vals)

# Create directories if needed
if not os.path.isdir("%s"%("Test")):
        os.makedirs("%s"%("Test"))

target=open("%s/test_scores.csv"%("Test"),'w')
target.write("test_name, test_accuracy, precision, recall, f1 score, mean iou, %s\n" % (class_names_string))
scores_list = []
class_scores_list = []
precision_list = []
recall_list = []
f1_list = []
iou_list = []
run_times_list = []

# Run testing on ALL test images
for ind in val_indices:
    # pre-process of the input images and labels
    input_image = np.expand_dims(np.float32(utils.load_image(test_input_names[ind])[:args.crop_height, :args.crop_width]),axis=0)/255.0
    gt = utils.load_image(test_output_names[ind])[:args.crop_height, :args.crop_width]
    gt = helpers.reverse_one_hot(helpers.one_hot_it(gt, label_values))
    # time start point
    st = time.time()
    # run testing
    output_image = sess.run(network,feed_dict={net_input:input_image})
    # time duration
    run_times_list.append(time.time()-st)
    # ouput the colored images
    output_image = np.array(output_image[0,:,:,:])
    output_image = helpers.reverse_one_hot(output_image)
    out_vis_image = helpers.colour_code_segmentation(output_image, label_values)
    # compute the correspondent metrics
    accuracy, class_accuracies, prec, rec, f1, iou = utils.evaluate_segmentation(pred=output_image, label=gt, num_classes=num_classes)
    # add these results into file
    file_name = utils.filepath_to_name(val_input_names[ind])
    target.write("%s, %f, %f, %f, %f, %f"%(file_name, accuracy, prec, rec, f1, iou))
    for item in class_accuracies:
        target.write(", %f"%(item))
    target.write("\n")

    scores_list.append(accuracy)
    class_scores_list.append(class_accuracies)
    precision_list.append(prec)
    recall_list.append(rec)
    f1_list.append(f1)
    iou_list.append(iou)
    
    gt = helpers.colour_code_segmentation(gt, label_values)

    cv2.imwrite("%s/%s_pred.png"%("Test", file_name),cv2.cvtColor(np.uint8(out_vis_image), cv2.COLOR_RGB2BGR))
    cv2.imwrite("%s/%s_gt.png"%("Test", file_name),cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2BGR))


target.close()

avg_score = np.mean(scores_list)
class_avg_scores = np.mean(class_scores_list, axis=0)
avg_precision = np.mean(precision_list)
avg_recall = np.mean(recall_list)
avg_f1 = np.mean(f1_list)
avg_iou = np.mean(iou_list)
avg_time = np.mean(run_times_list)
print("Average test accuracy = ", avg_score)
print("Average per class test accuracies = \n")
for index, item in enumerate(class_avg_scores):
    print("%s = %f" % (class_names_list[index], item))
print("Average precision = ", avg_precision)
print("Average recall = ", avg_recall)
print("Average F1 score = ", avg_f1)
print("Average mean IoU score = ", avg_iou)
print("Average run time = ", avg_time)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Retrieving dataset information ...
Preparing the model ...
Loading model checkpoint weights ...
INFO:tensorflow:Restoring parameters from ./checkpoints/latest_model_BiSeNet_CamVid.ckpt
Loading the data ...


  _warn_prf(average, modifier, msg_start, len(result))


KeyboardInterrupt: 

In [None]:
import tensorflow as tf
print(tf.__version__)