In [None]:
import os
import cv2
import math
import imutils
import contextlib2
import pandas as pd
import tensorflow as tf

from glob import glob
from object_detection.core.standard_fields import TfExampleFields
from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import dataset_util, label_map_util

pd.set_option("display.max_colwidth", 10000)

In [None]:
ROOT_OUT = ""
IMAGES_FOLDER = ""

SPLIT = "validation" # ["train", "validation"]
NUM_SHARDS = 10

LABELS_CSV = "filteredLabels.csv"
LABEL_MAP_PATH = os.path.join(ROOT_OUT, "labelMap.pbtxt") #"labelMap.pbtxt"

RECORDS_FILEPATH = os.path.join(ROOT_OUT, SPLIT, f"{SPLIT}.tfrecord")

FILTERED_ANNO_CSV = f"{SPLIT}/{SPLIT}-filtered-annotations-v2.csv"

In [None]:
def string_to_int_list(string_list):
    return [int(x) for x in string_list]

def get_encoded_image(image_path):
    with tf.io.gfile.GFile(image_path, 'rb') as fid:
        return bytes(fid.read()) # Encoded image bytes

def create_label_map(label_map_path, labels):
    with open(label_map_path, 'w+') as f:
        for index, row in labels.iterrows():
            line = f'item {{\n' \
                            f'id: {index + 1}\n' \
                            f'name: "{row.LabelName}"\n' \
                            f'display_name: "{row.ClassName}"}}\n'
            f.write(line)

In [None]:
def balance_classes(examples, bboxes):
    class_load = bboxes.groupby('LabelName').agg({"ClassName":'count', "ImageID":list}).reset_index()

    target = class_load.ClassName.max()
    examples_clone = examples.copy()
    to_append = []

    for class_id, num_examples, img_ids in class_load.itertuples(index=False):
        factor = math.ceil(target / num_examples) - 1
        if factor <= 0: continue

        to_replicate = examples_clone[examples_clone.ImageID.isin(img_ids)]
        to_append += [to_replicate] * factor

    balanced = examples.append(to_append, ignore_index=True)
    return balanced

In [None]:
def tf_example_from_example_data_frame(example, label_map):

    image_id = f"{example.ImageID}_{example.Index}"
    image_path = example.ImagePath
    class_texts = [x.encode() for x in example.LabelName]
    class_labels = [label_map[x] for x in example.LabelName]
    
    encoded_image = get_encoded_image(image_path)

    feature_map = {
      TfExampleFields.object_bbox_ymin: dataset_util.float_list_feature(example.YMin),
      TfExampleFields.object_bbox_xmin: dataset_util.float_list_feature(example.XMin),
      TfExampleFields.object_bbox_ymax: dataset_util.float_list_feature(example.YMax),
      TfExampleFields.object_bbox_xmax: dataset_util.float_list_feature(example.XMax),
      TfExampleFields.object_class_text: dataset_util.bytes_list_feature(class_texts),
      TfExampleFields.object_class_label: dataset_util.int64_list_feature(class_labels),
      TfExampleFields.filename: dataset_util.bytes_feature(image_path.encode("utf-8")),
      TfExampleFields.source_id: dataset_util.bytes_feature(image_id.encode("utf-8")),
      TfExampleFields.image_encoded: dataset_util.bytes_feature(encoded_image),
  }

    
    feature_map[TfExampleFields.object_group_of] = dataset_util.int64_list_feature(string_to_int_list(example.IsGroupOf))
    feature_map[TfExampleFields.object_occluded] = dataset_util.int64_list_feature(string_to_int_list(example.IsOccluded))
    feature_map[TfExampleFields.object_truncated] = dataset_util.int64_list_feature(string_to_int_list(example.IsTruncated))
    feature_map[TfExampleFields.object_depiction] = dataset_util.int64_list_feature(string_to_int_list(example.IsDepiction))

    feature_map[TfExampleFields.image_format] = dataset_util.bytes_feature(b'jpg')
    feature_map[TfExampleFields.height] = dataset_util.int64_feature(int(480))
    feature_map[TfExampleFields.width] = dataset_util.int64_feature(int(640))

    return tf.train.Example(features=tf.train.Features(feature=feature_map))

In [None]:
labels = pd.read_csv(LABELS_CSV)
bboxes = pd.read_csv(FILTERED_ANNO_CSV)
bboxes = bboxes[bboxes.ImageID.isin([path.split("\\")[1].split(".")[0] for path in glob(f"{SPLIT}/*jpg")])]
examples = bboxes.groupby("ImageID").agg({"LabelName": list, "ClassName":list,
                                          "XMin": list, "XMax": list, "YMin": list, "YMax": list,
                                          "IsGroupOf": list, "IsOccluded": list,
                                          "IsTruncated": list, "IsDepiction": list}).reset_index()
examples["ImagePath"] = examples.ImageID.apply(lambda img_id: os.path.join(IMAGES_FOLDER, f"{SPLIT}/{img_id}.jpg"))

In [None]:
if not os.path.isfile(LABEL_MAP_PATH):
    create_label_map(LABEL_MAP_PATH, labels)
label_map = label_map_util.get_label_map_dict(LABEL_MAP_PATH)

In [None]:
examples = balance_classes(examples, bboxes)
examples = examples.sample(frac=1, random_state=22).reset_index(drop=True)

In [None]:
with contextlib2.ExitStack() as tf_record_close_stack:
    output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(tf_record_close_stack,
                                                                             RECORDS_FILEPATH, NUM_SHARDS)

    for example in examples.itertuples():
        tf_example = tf_example_from_example_data_frame(example, label_map)
        output_shard_index = example.Index % NUM_SHARDS
        output_tfrecords[output_shard_index].write(tf_example.SerializeToString())