In [None]:
import numpy as np
import cv2
import tensorflow as tf

In [None]:
def augment_retinal_image(image):
    # Randomly flipping the image horizontally
    if np.random.rand() < 0.5:
        image = cv2.flip(image, 1)
        
    # Randomly rotating the image by a small angle
    angle = np.random.randint(-5, 5)
    M = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), angle, 1)
    image = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))
    
    # Randomly applying brightness and contrast changes
    alpha = np.random.uniform(0.5, 1.5)
    image = cv2.convertScaleAbs(image, alpha=alpha, beta=np.random.randint(-50, 50))

    # Randomly zooming the image
    zoom = np.random.uniform(0.9, 1.1)
    M = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), 0, zoom)
    image = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))

    return image

In [None]:
def create_segmentation_dataset(data_dir, augment_fn=None):
    # Create dataset of all image paths
    train_image_paths = tf.io.gfile.glob(f"{data_dir}/train/*.jpg")
    train_images_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)

    test_image_paths = tf.io.gfile.glob(f"{data_dir}/test/*.jpg")
    test_images_ds = tf.data.Dataset.from_tensor_slices(test_image_paths)

    # Read and decode the images
    def read_and_decode(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        if augment_fn:
            image = augment_fn(image)
        return image
    train_images_ds = train_images_ds.map(read_and_decode, num_parallel_calls=tf.data.AUTOTUNE)
    test_images_ds = test_images_ds.map(read_and_decode, num_parallel_calls=tf.data.AUTOTUNE)

    # Read and decode the corresponding label images
    def read_and_decode_masks(path):
        mask = tf.io.read_file(path.replace("jpg", "png"))
        mask = tf.image.decode_png(mask, channels=1)
        if augment_fn:
            mask = augment_fn(mask)
        return mask
    train_masks_ds = train_images_ds.map(read_and_decode_masks)
    test_masks_ds = test_images_ds.map(read_and_decode_masks)

    # Zip the images and labels together to get a dataset of (image, label) pairs
    train_ds = tf.data.Dataset.zip((train_images_ds, train_masks_ds))
    test_ds = tf.data.Dataset.zip((test_images_ds, test_masks_ds))

    return train_ds, test_ds





You can use this function to create the train and test datasets as follows:



This function takes two arguments, the path of the data directory and an optional function for data augmentation. It creates a dataset of all image paths in the train folder and test folder, it then uses the tf.data.Dataset.map() method to read and decode the images, applying the optional data augmentation function, if provided. It also reads and decode the corresponding label images, which are in PNG format. Finally, it zips the images and labels


In [None]:
augment_fn = lambda x: augment_retinal_image(x)
train_ds, test_ds = create_segmentation_dataset("path/to/data/dir", augment_fn)