In [1]:
import numpy as np
import os
import tensorflow as tf
from scipy.misc import imread
from matplotlib import pyplot as plt
import random

In [2]:
# define the images and annotations path
base_dataset_dir = "/home/thalles_silva/DataPublic/Road_and_Buildings_detection_dataset/mass_merged"
train_dataset_base_dir = os.path.join(base_dataset_dir, "train")
images_folder_name = "sat/"
annotations_folder_name = "map/"
train_images_dir = os.path.join(train_dataset_base_dir, images_folder_name)
train_annotations_dir = os.path.join(train_dataset_base_dir, annotations_folder_name)

In [3]:
# read the train.txt file. This file contains the training images' names
file = open(os.path.join(train_dataset_base_dir, "train_all.txt"), 'r')
images_filename_list = [line for line in file]
number_of_train_examples = len(images_filename_list)
print("number_of_train_examples:", number_of_train_examples)

number_of_train_examples: 137


In [4]:
# define the images and annotations path
val_dataset_base_dir = os.path.join(base_dataset_dir, "valid")
val_images_dir = os.path.join(val_dataset_base_dir, images_folder_name)
val_annotations_dir = os.path.join(val_dataset_base_dir, annotations_folder_name)

# read the train.txt file. This file contains the training images' names
file = open(os.path.join(val_dataset_base_dir, "val.txt"), 'r')
val_images_filename_list = [line for line in file]

In [5]:
TRAIN_DATASET_DIR="../dataset/"
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
train_writer = tf.python_io.TFRecordWriter(os.path.join(TRAIN_DATASET_DIR,TRAIN_FILE))
val_writer = tf.python_io.TFRecordWriter(os.path.join(TRAIN_DATASET_DIR,VALIDATION_FILE))

In [6]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [7]:
def read_image_and_annotation(train_images_dir, train_annotations_dir, image_name):
    # read the input and annotation images
    image = imread(train_images_dir + image_name.strip() + ".tiff")
    annotation = imread(train_annotations_dir + image_name.strip() + ".tif")

    return image, annotation

In [8]:
def random_crop(image_np, annotation_np, crop_size=128):
    """
    image_np: rgb image shape (H,W,3)
    annotation_np: 1D image shape (H,W,1)
    crop_size: integer
    """
    image_h = image_np.shape[0]
    image_w = image_np.shape[1]

    random_x = np.random.randint(0, image_w-crop_size+1) # Return random integers from low (inclusive) to high (exclusive).
    random_y = np.random.randint(0, image_h-crop_size+1) # Return random integers from low (inclusive) to high (exclusive).

    offset_x = random_x + crop_size
    offset_y = random_y + crop_size

    return image_np[random_x:offset_x, random_y:offset_y,:], annotation_np[random_x:offset_x, random_y:offset_y]

In [9]:
def create_tfrecord_dataset(images_dir, annotations_dir, filename_list, writer, total_epochs=1, batch_size=1, crop_size=64, random_cropping=True):

    print("Total # of example:", batch_size*len(filename_list)*total_epochs)
    number_of_written_imgs = 0
    for epoch_counter in range(total_epochs):
        for image_name in filename_list:

            image_np, annotation_np = read_image_and_annotation(images_dir, annotations_dir, image_name)
  
            for batch_i in range(batch_size):

                if random_cropping:
                    while True:
                        image_np_cropped, annotation_np_cropped = random_crop(image_np, annotation_np, crop_size)
                        
                        total_n_of_pixels = crop_size*crop_size
                        
                        # count the number of background pixels in the annotation patch
                        background_pixels = total_n_of_pixels - np.count_nonzero(annotation_np_cropped)
                        
                        prob = random.random()
                        
                        # if # of background pixels > 90% of total pixels, discard with prob of 50%
                        if background_pixels >= 0.99 * total_n_of_pixels:
                            if prob <= 0.8:
                                #print("Discard image. 99% background")
                                continue
                        elif background_pixels >= 0.95 * total_n_of_pixels:
                            if prob <= 0.65:
                                #print("Discard image. 95% background")
                                continue
                        elif background_pixels >= 0.9 * total_n_of_pixels:
                            if prob <= 0.5:
                                #print("Discard image. 90% background")
                                continue

                        # count the # of zeros in the image patch, because the dataset has some images with zeros (invalid areas)
                        # we crop patches that have less than 10% of white pixels in it
                        n_of_zeros = np.sum(np.all(image_np_cropped == [255,255,255], axis=2))
                        
                        #print("# of zeros:", n_of_zeros, "from image:", image_name)
                        if n_of_zeros < 0.01 * (crop_size * crop_size):
                            break
                else:
                    batch_size = 1 # for negative random crop, never iterate over the same image
                    image_np_cropped = image_np
                    annotation_np_cropped = annotation_np
                    
                image_h = image_np_cropped.shape[0]
                image_w = image_np_cropped.shape[1]
          
                img_raw = image_np_cropped.tostring()
                annotation_raw = annotation_np_cropped.tostring()

                example = tf.train.Example(features=tf.train.Features(feature={
                        'height': _int64_feature(image_h),
                        'width': _int64_feature(image_w),
                        'image_raw': _bytes_feature(img_raw),
                        'annotation_raw': _bytes_feature(annotation_raw)}))
                
                #writer.write(example.SerializeToString())
                number_of_written_imgs += 1
                
        print("Image written:", number_of_written_imgs, "End of epoch:",epoch_counter)
    #writer.close()

In [11]:
create_tfrecord_dataset(train_images_dir, train_annotations_dir, images_filename_list, train_writer, crop_size=64, batch_size=64, total_epochs=64)
create_tfrecord_dataset(val_images_dir, val_annotations_dir, val_images_filename_list, val_writer, random_cropping=False)

Total # of example: 4
Image written: 4 End of epoch: 0
