diff --git a/docs/modules/prepro.rst b/docs/modules/prepro.rst index d00560e89..0cdea118e 100644 --- a/docs/modules/prepro.rst +++ b/docs/modules/prepro.rst @@ -357,6 +357,10 @@ In practice, you may want to use threading method to process a batch of images a b_ann[i][0], b_ann[i][1], [], classes, True, save_name='_bbox_vis_%d.png' % i) +Image Aug with TF Dataset API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Example code for VOC `here `__. Coordinate pixel unit to percentage ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/example/tutorial_tf_dataset_voc.py b/example/tutorial_tf_dataset_voc.py new file mode 100644 index 000000000..068c8e8d3 --- /dev/null +++ b/example/tutorial_tf_dataset_voc.py @@ -0,0 +1,95 @@ +#! /usr/bin/python +# -*- coding: utf8 -*- + +# tf import data dataset.map https://www.tensorflow.org/programmers_guide/datasets#applying_arbitrary_python_logic_with_tfpy_func +# tf.py_func https://www.tensorflow.org/api_docs/python/tf/py_func +# tl ref: https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_imagenet_inceptionV3_distributed.py +# cn ref: https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/79342369 +# cn ref: https://zhuanlan.zhihu.com/p/31466173 + +import numpy as np +import multiprocessing, random, json, time +import tensorflow as tf +import tensorlayer as tl + +imgs_file_list, _, _, _, classes, _, _,\ + _, objs_info_list, _ = tl.files.load_voc_dataset(dataset="2007") + +ann_list = [] +for info in objs_info_list: + ann = tl.prepro.parse_darknet_ann_str_to_list(info) + c, b = tl.prepro.parse_darknet_ann_list_to_cls_box(ann) + ann_list.append([c, b]) + +n_epoch = 10 +batch_size = 64 +im_size = [416, 416] +jitter = 0.2 +shuffle_buffer_size = 100 + + +def generator(): + inputs = imgs_file_list + targets = objs_info_list + assert len(inputs) == len(targets) + for _input, _target in zip(inputs, targets): + yield _input.encode('utf-8'), _target.encode('utf-8') + + +def _data_aug_fn(im, ann): + ## parse annotation + ann = ann.decode() + ann = tl.prepro.parse_darknet_ann_str_to_list(ann) + clas, coords = tl.prepro.parse_darknet_ann_list_to_cls_box(ann) + ## random brightness, contrast and saturation + im = tl.prepro.brightness(im, gamma=0.5, gain=1, is_random=True) + # im = tl.prepro.illumination(im, gamma=(0.5, 1.5), + # contrast=(0.5, 1.5), saturation=(0.5, 1.5), is_random=True) # TypeError: Cannot handle this data type + ## random horizontal flip + im, coords = tl.prepro.obj_box_left_right_flip(im, coords, is_rescale=True, is_center=True, is_random=True) + ## random resize and crop + tmp0 = random.randint(1, int(im_size[0] * jitter)) + tmp1 = random.randint(1, int(im_size[1] * jitter)) + im, coords = tl.prepro.obj_box_imresize(im, coords, [im_size[0] + tmp0, im_size[1] + tmp1], is_rescale=True, interp='bicubic') + im, clas, coords = tl.prepro.obj_box_crop(im, clas, coords, wrg=im_size[1], hrg=im_size[0], is_rescale=True, is_center=True, is_random=True) + ## value [0, 255] to [-1, 1] (optional) + # im = im / 127.5 - 1 + ## value [0, 255] to [0, 1] (optional) + im = im / 255 + im = np.array(im, dtype=np.float32) # important + return im, str([clas, coords]).encode('utf-8') + + +def _map_fn(filename, annotation): + ## read image + image = tf.read_file(filename) + image = tf.image.decode_jpeg(image, channels=3) + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + ## data augmentation + image, annotation = tf.py_func(_data_aug_fn, [image, annotation], [tf.float32, tf.string]) + return image, annotation + + +ds = tf.data.Dataset().from_generator(generator, output_types=(tf.string, tf.string)) +ds = ds.map(_map_fn, num_parallel_calls=multiprocessing.cpu_count()) +ds = ds.repeat(n_epoch) +ds = ds.shuffle(shuffle_buffer_size) +ds = ds.batch(batch_size) +value = ds.make_one_shot_iterator().get_next() + +sess = tf.InteractiveSession() + +## get a batch of images (after data augmentation) +_, _ = sess.run(value) # 1st time takes time to compile +st = time.time() +im, annbyte = sess.run(value) +print('took {}s'.format(time.time() - st)) + +ann = [] +for a in annbyte: + a = a.decode() + ann.append(json.loads(a)) + +## save all images +for i in range(len(im)): + tl.vis.draw_boxes_and_labels_to_image(im[i] * 255, ann[i][0], ann[i][1], [], classes, True, save_name='_bbox_vis_%d.png' % i)