In [16]:
import os
import tensorflow as tf
import numpy as np
import cv2
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt

## Global Variables

In [68]:
DIR_DATA = 'data_small'
DIR_INPUT = os.path.join(DIR_DATA, 'input')
DIR_OUTPUT = os.path.join(DIR_DATA, 'output')
OBJECT_LABELS = {
    'none': (0, 'Background'),
    'aeroplane': (1, 'Vehicle'),
    'bicycle': (2, 'Vehicle'),
    'bird': (3, 'Animal'),
    'boat': (4, 'Vehicle'),
    'bottle': (5, 'Indoor'),
    'bus': (6, 'Vehicle'),
    'car': (7, 'Vehicle'),
    'cat': (8, 'Animal'),
    'chair': (9, 'Indoor'),
    'cow': (10, 'Animal'),
    'diningtable': (11, 'Indoor'),
    'dog': (12, 'Animal'),
    'horse': (13, 'Animal'),
    'motorbike': (14, 'Vehicle'),
    'person': (15, 'Person'),
    'pottedplant': (16, 'Indoor'),
    'sheep': (17, 'Animal'),
    'sofa': (18, 'Indoor'),
    'train': (19, 'Vehicle'),
    'tvmonitor': (20, 'Indoor'),
}

DIR_TFRECORDS = 'data_small_tfrecords'
NUM_EXAMPLES_PER_TFRECORD = 1



## Process data

In [75]:
def read_data(filename):
    # Reference: This function has been modified from 
    # https://github.com/balancap/SSD-Tensorflow/blob/master/datasets/pascalvoc_to_tfrecords.py
    
    # read image
    img_name = os.path.join(DIR_INPUT, filename + '.jpg')
    img = cv2.imread(img_name)
    
    # read annotation
    annotation_name = os.path.join(DIR_OUTPUT, filename + '.xml')
    tree = ET.parse(annotation_name)
    root = tree.getroot()
    bboxes, labels, labels_text = [], [], []
    for obj in root.findall('object'):
        # read class label
        label = obj.find('name').text
        labels.append(int(OBJECT_LABELS[label][0]))
        
        # read bbox
        bbox = obj.find('bndbox')
        ymin = float(bbox.find('ymin').text)
        xmin = float(bbox.find('xmin').text)
        ymax = float(bbox.find('ymax').text)
        xmax = float(bbox.find('xmax').text)
        bboxes.append((ymin, xmin, ymax, xmax))

    return img, bboxes, labels

def get_processed_data(filename):
    # read input and output
    img, bboxes, labels = read_data(filename)
    
    # TODO
    return img, np.array(bboxes)
    

## Write data to TFRecord format

In [72]:
# conversion functions (data to feature data types)
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def write_example_to_TFRecord(filename, writer):
    # get processed data
    img, label = get_processed_data(filename)
    
    # create example from this data
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'img': _bytes_feature(img.tostring()),
                'label': _bytes_feature(label.tostring())
            }
        )
    )

    writer.write(example.SerializeToString())


def write_data_to_TFRecord():        
    # read filenames
    filenames = sorted(os.listdir(DIR_INPUT))
    filenames = [filename[:-4] for filename in filenames]  # trim extension    
    
    # write data into multiple TFRecord files
    idx_tfrecord, idx_data = 0, 0
    if not os.path.exists(DIR_TFRECORDS):
        os.makedirs(DIR_TFRECORDS)
    
    while idx_data < len(filenames):
        # new TFRecord file
        filename_tfrecord = os.path.join(DIR_TFRECORDS, str(idx_tfrecord) + '.tfrecords')
        with tf.python_io.TFRecordWriter(filename_tfrecord) as writer:
            # write examples into this file until limit is reached
            idx_example = 0
            while idx_data < len(filenames) and idx_example < NUM_EXAMPLES_PER_TFRECORD:
                filename = filenames[idx_data]
                write_example_to_TFRecord(filename, writer)
                idx_data += 1
                idx_example += 1
            idx_tfrecord += 1

In [77]:
write_data_to_TFRecord()

2007_000027
uint8
(500, 486, 3)
uint8
(281, 500, 3)
uint8
(366, 500, 3)
uint8
(375, 500, 3)
uint8
(335, 500, 3)
uint8
(333, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(500, 334, 3)
uint8
(375, 500, 3)
uint8
(332, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(333, 500, 3)
uint8
(375, 500, 3)
uint8
(343, 500, 3)
uint8
(500, 333, 3)
uint8
(375, 500, 3)
uint8
(333, 500, 3)
uint8
(333, 500, 3)
uint8
(375, 500, 3)
uint8
(333, 500, 3)
uint8
(375, 500, 3)
uint8
(332, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(500, 375, 3)
uint8
(375, 500, 3)
uint8
(334, 500, 3)
uint8
(412, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(500, 375, 3)
uint8
(370, 500, 3)
uint8
(331, 500, 3)
uint8
(375, 500, 3)
uint8
(230, 500, 3)
uint8
(335, 500, 3)
uint8
(375, 500, 3)
uint8
(333, 500, 3)
uint8
(375, 500, 3)
uint8
(500, 422, 3)
uint8
(375, 500, 3)
uint8
(375, 500, 3)
uint8
(336, 500, 3)
uint8
(333, 500, 3)
uint8
(5