In [27]:
import json
import numpy as np
import os
import tensorflow as tf

print('TensorFlow:', tf.__version__)
np.random.seed(25)

TensorFlow: 2.2.0-rc2


In [2]:
class TFrecordWriter:

    def __init__(self, n_samples, n_shards, output_dir='', prefix=''):
        self.n_samples = n_samples
        self.n_shards = n_shards
        self._step_size = self.n_samples // self.n_shards + 1
        self.prefix = prefix
        self.output_dir = output_dir
        self._buffer = []
        self._file_count = 1

    def _make_example(self, image, boxes, classes):
        feature = {
            'image':
                tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            'xmins':
                tf.train.Feature(float_list=tf.train.FloatList(value=boxes[:, 0])),
            'ymins':
                tf.train.Feature(float_list=tf.train.FloatList(value=boxes[:, 1])),
            'xmaxs':
                tf.train.Feature(float_list=tf.train.FloatList(value=boxes[:, 2])),
            'ymaxs':
                tf.train.Feature(float_list=tf.train.FloatList(value=boxes[:, 3])),
            'classes':
                tf.train.Feature(int64_list=tf.train.Int64List(value=classes))
        }
        return tf.train.Example(features=tf.train.Features(feature=feature))

    def _write_tfrecord(self, tfrecord_path):
        print('writing {} samples in {}'.format(len(self._buffer),
                                                tfrecord_path))
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for (image, boxes, classes) in self._buffer:
                example = self._make_example(image, boxes, classes)
                writer.write(example.SerializeToString())

    def push(self, image, boxes, classes):
        self._buffer.append([image, boxes, classes])
        if len(self._buffer) == self._step_size:
            fname = self.prefix + '_000' + str(self._file_count) + '.tfrecord'
            tfrecord_path = os.path.join(self.output_dir, fname)
            self._write_tfrecord(tfrecord_path)
            self._clear_buffer()
            self._file_count += 1

    def flush_last(self):
        if len(self._buffer):
            fname = self.prefix + '_000' + str(self._file_count) + '.tfrecord'
            tfrecord_path = os.path.join(self.output_dir, fname)
            self._write_tfrecord(tfrecord_path)

    def _clear_buffer(self):
        self._buffer = []

In [11]:
shapes_dataset_dir = '../tutorials/data/shapes_dataset/'
with open(shapes_dataset_dir + 'dataset.json', 'r') as fp:
    dataset_json = json.load(fp)

all_image_names = list(dataset_json.keys())
print('Found {} images'.format(len(all_image_names)))

class_map = {'circle': 0, 'rectangle': 1}

Found 12500 images


In [28]:
from tqdm.notebook import tqdm

In [35]:
sum_image = []
for image_path in tqdm(all_image_names):
    image = tf.io.read_file(shapes_dataset_dir + 'images/' + image_path)
    image = tf.image.decode_image(image)
    sum_image.append(image.numpy())

HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))




In [36]:
np_image = np.array(sum_image)
np_image.shape

(12500, 448, 448, 3)

In [38]:
np.mean(np_image, (0, 1, 2))

array([235.16132109, 235.00810276, 235.0938025 ])

In [39]:
np.std(np_image, (0, 1, 2))

array([68.30292656, 68.54383801, 68.40923567])

In [4]:
aindices = np.arange(len(all_image_names))
np.random.shuffle(indices)

train_image_names = all_image_names[:10000]
val_image_names = all_image_names[10000:]

print('Splitting dataset into {} training images and {} validation images'.format(len(train_image_names), len(val_image_names)))

Splitting dataset into 10000 training images and 2500 validation images


In [5]:
n_shards = 8
tf_record_dir = '../tutorials/data/shapes_dataset_tfrecords'

In [7]:
train_tf_record_writer = TFrecordWriter(n_samples=len(train_image_names),
                                        n_shards=n_shards,
                                        output_dir=tf_record_dir,
                                        prefix='train')

for image_name in train_image_names:
    boxes = []
    classes = []

    with tf.io.gfile.GFile(shapes_dataset_dir + 'images/' + image_name, 'rb') as fp:
        image = fp.read()
    
    for obj in dataset_json[image_name]:
        boxes.append(obj['box'])
        classes.append(class_map[obj['category']])
    train_tf_record_writer.push(image, np.array(boxes, dtype=np.float32), np.array(classes, dtype=np.int32))
train_tf_record_writer.flush_last()

writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0001.tfrecord
writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0002.tfrecord
writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0003.tfrecord
writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0004.tfrecord
writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0005.tfrecord
writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0006.tfrecord
writing 1251 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0007.tfrecord
writing 1243 samples in ../tutorials/data/shapes_dataset_tfrecords/train_0008.tfrecord


In [8]:
val_tf_record_writer = TFrecordWriter(n_samples=len(val_image_names),
                                        n_shards=n_shards,
                                        output_dir=tf_record_dir,
                                        prefix='val')

for image_name in val_image_names:
    boxes = []
    classes = []

    with tf.io.gfile.GFile(shapes_dataset_dir + 'images/' + image_name, 'rb') as fp:
        image = fp.read()
    
    for obj in dataset_json[image_name]:
        boxes.append(obj['box'])
        classes.append(class_map[obj['category']])
    val_tf_record_writer.push(image, np.array(boxes, dtype=np.float32), np.array(classes, dtype=np.int32))
val_tf_record_writer.flush_last()

writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0001.tfrecord
writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0002.tfrecord
writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0003.tfrecord
writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0004.tfrecord
writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0005.tfrecord
writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0006.tfrecord
writing 313 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0007.tfrecord
writing 309 samples in ../tutorials/data/shapes_dataset_tfrecords/val_0008.tfrecord
