In [3]:
import hashlib, io, logging, os
import pickle
import PIL.Image
from PIL import Image
from IPython.display import display
from IPython.display import clear_output
import random
import numpy as np
from sklearn.model_selection import train_test_split
import csv

import tensorflow as tf
from tensorflow import keras
from object_detection.utils import dataset_util
from object_detection.utils import visualization_utils as vis_util

%run ../../global_variables.ipynb
%run ../detect_variables.ipynb
%run ./utils.ipynb
%run ../../utils/data_utils.ipynb
%run ../../utils/object_detection_utils.ipynb

# Function definition

In [4]:
def create_tf_example(annotation, id_map, binary=True, segmentation=False, verbose=False):
    ## LOADING IMAGE
    if IMAGE_FOLDER is None:
        img_path = os.path.join(DATASET_ROOT, annotation["folder"], annotation["filename"])
    else:
        img_path = os.path.join(DATASET_ROOT, IMAGE_FOLDER, annotation["filename"])
    encoded_image_data, width, height = load_png(img_path)
    key = hashlib.sha256(encoded_image_data).hexdigest()
    filename = img_path.split("/")[-1]
    
    if verbose:
        full_image = np.array(Image.open(img_path))
        full_image = np.expand_dims(full_image, -1)
        full_image = np.repeat(full_image, 3, 2)
        full_image = np.array(full_image)
    
    ## GENERAL FEATURES
    image_format = IMG_TYPE.encode('utf8') # b'jpeg' or b'png'
    ## DEFINING BOUNDING BOXES
    xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = [] # List of normalized right x coordinates in bounding box
    ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = [] # List of normalized bottom y coordinates in bounding box
    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = [] # List of integer class id of bounding box (1 per box)
    labels = annotation["objects"]
    for label in labels:
        ymins.append(label["ymin"]/width)
        ymaxs.append(label["ymax"]/width)
        xmins.append(label["xmin"]/height)
        xmaxs.append(label["xmax"]/height)
        if verbose:
            vis_util.draw_bounding_box_on_image_array(full_image, 
                                             label["ymin"]/width, 
                                             label["xmin"]/height,
                                             label["ymax"]/width,
                                             label["xmax"]/height,
                                             use_normalized_coordinates=True)
        if binary:
            classes_text.append(next(iter(id_map)).encode('utf8'))
            classes.append(id_map[next(iter(id_map))])
        else:
            classes_text.append(label["name"].encode('utf8'))
            classes.append(id_map[label["name"]])
    
    feature_dict = {
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs)
    }
    
    if segmentation:
        masks_list = []
        for label in labels:
            path = label["mask_path"]
            im = Image.open(os.path.join(DATASET_ROOT, path))
            imgByteArr = io.BytesIO()
            im.save(imgByteArr, format='PNG')
            masks_list.append(imgByteArr.getvalue())
            if verbose:
                img_pil = Image.open(imgByteArr)
                img_pil = np.array(img_pil)
                vis_util.draw_mask_on_image_array(full_image,img_pil)
        feature_dict['image/object/mask']=dataset_util.bytes_list_feature(masks_list)
    
    if verbose:
        display(Image.fromarray(full_image))
        
    tf_example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return tf_example

In [9]:
def create_tf_record(annotations_path, id_map, output_path, binary=True, segmentation=False, verbose=False):
    writer = tf.io.TFRecordWriter(output_path)
    print("Creating tfrecord based on "+str(len(annotations_path))+" example(s). Save location: "+output_path)
    a = display(str(0)+"/"+str(len(annotations_path)),display_id=True)
    for i, annotation_file in enumerate(annotations_path):
        a.update(str(i+1)+"/"+str(len(annotations_path)))
        annotation_path = os.path.join(DATASET_ROOT, ANNOTATION_FOLDER, annotation_file)
        if os.path.isfile(annotation_path):
            annotation = parse_annotation(annotation_path)
            tf_example = create_tf_example(annotation, id_map, binary=binary, segmentation=segmentation, verbose=verbose)
            writer.write(tf_example.SerializeToString())
        else:
            print("WARNING: Error while loading annotation "+annotation_path+". Ignoring image.")
    writer.close()
    print("Finished!")

In [10]:
def get_map(binary=True, save=True):
    if binary: pb_name = "binary_label_map"
    else: pb_name = "multiclass_label_map"
    id_map = pickle.load(open(os.path.join(DATASET_ROOT, MAP_FOLDER, pb_name+".pickle"), "rb" ))
    if save: create_pbtxt(id_map, OUPUT_PATH+'/'+pb_name+".pbtxt")
    return id_map

# Creating tf record dataset

In [11]:
# Listing annotations
img_path = os.path.join(DATASET_ROOT, ANNOTATION_FOLDER)
annotations = []
for (dirpath, dirnames, filenames) in os.walk(img_path):
    for filename in filenames:
        splitted = filename.split('.')
        if len(splitted) and splitted[1].lower()=="xml":
            annotations.append(filename)
random.shuffle(annotations)

# Split
if PERCENTAGE_TRAIN==1:
    train, test = annotations, []
else:
    train, test = train_test_split(annotations, train_size=PERCENTAGE_TRAIN, random_state=42)

In [12]:
delete_all_files_in_folder(OUPUT_PATH)
print("Scheme:", SCHEME)
# Parameters
verbose = False
# Get id_map
id_map = get_map(binary=BINARY, save=True)
# Train dataset
if len(train)!=0:
    # Save csv
    csv_train_path = os.path.join(OUPUT_PATH, "train_list.csv")
    check_dirs(csv_train_path)
    array_to_csv(train, csv_train_path)
    # Create train tf_record
    train_record_path = os.path.join(OUPUT_PATH, TRAIN_OUTPUT_FILE)
    check_dirs(train_record_path)
    create_tf_record(train, 
                     id_map,
                     train_record_path, 
                     binary=BINARY, 
                     segmentation=SEGMENTATION, 
                     verbose=verbose)
# Test dataset
if len(test)!=0:
    # Save csv
    csv_test_path = os.path.join(OUPUT_PATH, "test_list.csv")
    check_dirs(csv_test_path)
    array_to_csv(test, csv_test_path)
    # Create test tf_record
    test_record_path = os.path.join(OUPUT_PATH, TEST_OUTPUT_FILE)
    check_dirs(test_record_path)
    create_tf_record(test, 
                     id_map,
                     test_record_path, 
                     binary=BINARY, 
                     segmentation=SEGMENTATION,
                     verbose=verbose)

Deleting all files in /mnt/nvme-storage/pfauregi/training/obj_detection/ws_bd/dataset
Scheme: dataset01
Creating: /mnt/nvme-storage/pfauregi/training/obj_detection/ws_bd/dataset/train_list.csv
Creating tfrecord based on 900 example(s). Save location: /mnt/nvme-storage/pfauregi/training/obj_detection/ws_bd/dataset/train_dataset.record


'900/900'

Finished!
Creating: /mnt/nvme-storage/pfauregi/training/obj_detection/ws_bd/dataset/test_list.csv
Creating tfrecord based on 101 example(s). Save location: /mnt/nvme-storage/pfauregi/training/obj_detection/ws_bd/dataset/test_dataset.record


'101/101'

Finished!
