In [2]:
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
import random

from skimage.util import random_noise

In [3]:
train_dir = 'data/training'
# valid_dir = 'data/validation'

In [4]:
def image_preprocess(image):
    image = image.astype(np.float32) / 255.0
    image = image - random.uniform(0, 0.6) 
    image = random_noise(image, mode='gaussian', seed=None, clip=True)
    return image

def mask_preprocess(mask):
    mask = mask.astype(np.float32) / 255.0
    mask[mask > 0.5] = 1
    mask[mask <= 0.5] = 0
    return mask

def data_generator(batch_size,
                   params_dict,
                   train_path,
                   image_folder, mask_folder,
                   save_path = None,
                   image_save_prefix  = "image", mask_save_prefix  = "mask",
                   target_size = (256,256),
                   seed = 1):
    '''
    generates augmented image and mask on the fly at the same time
    use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
    '''
    
    image_datagen = ImageDataGenerator(**params_dict, preprocessing_function = image_preprocess)
    mask_datagen = ImageDataGenerator(**params_dict, preprocessing_function = mask_preprocess)
    
    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes = [image_folder],
        class_mode = None,
        color_mode = 'grayscale',
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_path,
        save_prefix  = image_save_prefix,
        shuffle = False,
        interpolation = 'bicubic',
        seed = seed)
    
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes = [mask_folder],
        class_mode = None,
        color_mode = 'grayscale',
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_path,
        save_prefix  = mask_save_prefix,
        shuffle = False,
        interpolation = 'bicubic',
        seed = seed)
    
    # returning zip(generator_one, generator_two) in turn also creates generator
    return zip(image_generator, mask_generator)



In [5]:
augmented_dir= 'data/augmented'

In [6]:
data_augmentation_args = dict(rotation_range=45,
                              width_shift_range=10,
                              height_shift_range=10,
                              shear_range=0.05,
                              zoom_range=0.05,
                              horizontal_flip=False,
                              fill_mode='nearest')

generator = data_generator(32, data_augmentation_args, train_dir, 'images', 'masks', augmented_dir)

num_batches = 50
for i, batch in enumerate(generator):
#     print(i)
    if i >= num_batches - 1:
        break

Found 7 images belonging to 1 classes.
Found 7 images belonging to 1 classes.
