# Introduction

In this notebook, I demonstrate and provide the necessary functions required to integrate Albumentations augmentations into Tensorflow Datasets (tf.data.Dataset) pipelines that originated from .tfrec shards.

In comparison to tf.image's preprocessing module and keras's ImageDataGenerator class, Albumentations' augment library is equipped with a wider variety of augmentation functions, catering to more complex augmentations.

On top of being GPU-capable, this method also enables on-the-fly batch augmentation which makes computations exponentially faster and less memory-intensive.

To make the script future-proof and customizable, I have included locations where additional augmentation or preprocessing steps can be inserted.

For educational purposes various snippets have been included throughout the script elaborating on the steps performed. 

Lastly, to understand the data flow, it is advised to read in a reversed manner; i.e from get_train_dataset to decode_func.

In [None]:
# Import dependencies
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
import albumentations as A
import os

from functools import partial
from collections import Counter

# Selecting 1st GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Paths/Data
PATH = "../input/hpa-single-cell-image-classification"

# For GPU use this
FILENAMES = tf.io.gfile.glob(PATH + "/train_tfrecords/train*.tfrec")

In [None]:
# Initialize variables
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32
IMAGE_SIZE = [256, 256]
SEED = 42

print('Number of shards:', len(FILENAMES))

## Displaying .tfrec feature dictionary

In [None]:
def list_record_features(tfrecords_path):
    # Dict of extracted feature information
    features = {}
    # Iterate records
    for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
        # Get record bytes
        example_bytes = rec.numpy()
        # Parse example protobuf message
        example = tf.train.Example()
        example.ParseFromString(example_bytes)
        # Iterate example features
        for key, value in example.features.feature.items():
            # Kind of data in the feature
            kind = value.WhichOneof('kind')
            # Size of data in the feature
            size = len(getattr(value, kind).value)
            # Check if feature was seen before
            if key in features:
                # Check if values match, use None otherwise
                kind2, size2 = features[key]
                if kind != kind2:
                    kind = None
                if size != size2:
                    size = None
            # Save feature data
            features[key] = (kind, size)
    return features

# Print extracted feature information
features = list_record_features(FILENAMES[0])
print(*features.items(), sep='\n')

In [None]:
def decode_func(image):
    # Preprocessing functions  
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  
    image = tf.image.resize(image, size=[*IMAGE_SIZE], 
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

In [None]:
def transform(image, image_size):
    """Customizable augmentation using Albumentations library."""
    """Set always_apply to instantiate transform as a
    preprocess transformation."""
    transforms = A.Compose([
            A.Rotate(limit=40),
            A.RandomBrightness(),
            A.HueSaturationValue(hue_shift_limit=100,
                                 val_shift_limit=0),
            A.Cutout(),
            A.RandomContrast(),
            A.RandomCrop(height=220, width=220),
            A.HorizontalFlip(always_apply=True)])
    
    # Converting image to tf.float32 type
    """Albumentations augmentation functions only supports
    .uint8 and .float32 data types therefore conversion is
    required."""
    image = tf.cast(x=image, dtype=tf.float32).numpy()

    # Apply augmentation on each image instance
    """transforms function is an Albumentations Compose()
    function that outputs a dictionary in a form of 
    {'image': image}, and image is in uint8 form."""
    image = transforms(image=image)['image']

    return image

In [None]:
def augment_func(image, label):
    """Function is formulated to convert our image
    from a numpy array to a tf numpy array."""
    image = tf.numpy_function(func=transform,
                              inp=[image, IMAGE_SIZE],
                              Tout=tf.float32)
    return image, label

In [None]:
def read_tfrecord(dataset, labelled):
    """Function is fed with a tf.data.Dataset data type and returns
    a singular image that has been augmented."""
    # Feature mapping for either labelled or unlabelled datasets
    labelled_map = {"image": tf.io.FixedLenFeature([], tf.string),
                    "image_name": tf.io.FixedLenFeature([], tf.string),
                    "target": tf.io.FixedLenFeature([], tf.string)}
    unlabelled_map = {"image": tf.io.FixedLenFeature([], tf.string),
                     "image_name": tf.io.FixedLenFeature([], tf.string)}
    tfrecord_format = (labelled_map if labelled else unlabelled_map)
    
    # Read singular images in a form of dictionary
    """parse_single_example function inputs a serialized dataset
    and outputs singular examples as a form of a dict 
    {'image': image, 'target': target}"""
    example = tf.io.parse_single_example(serialized=dataset,
                                  features=tfrecord_format)

    # Extracting the image from example dict
    """image is in a form of a Tensor."""
    image = example['image']

    # Preprocess images
    image = decode_func(image=image)

    # Converting labels into int32 data type
    """If labelled is True, function returns label together
    with the image."""
    if labelled:
        label = example['target']
        label = tf.cast(x=label, dtype=tf.string)
        return image, label      
        
    return image

In [None]:
def load_dataset(filenames, labelled=True, ordered=False):
    """Function is fed with filenames that are in a form of
     lists of (.tfrec) shards (eg. PATH\\train00-1234.tfrec)
     and outputs as a tf.data.Dataset data type."""
    # Initialize ignore_order variable
    ignore_order = tf.data.Options()

    if ordered==True:
        # Reinstate order if ordered is set to True when calling dataset
        ignore_order.experimental_deterministic = False

    # Automatically interleaves reads from multiple files
    dataset = tf.data.TFRecordDataset(filenames=filenames, 
                                    num_parallel_reads=AUTOTUNE)

    # if ordered, ignore_order will be True
    dataset = dataset.with_options(options=ignore_order)

    # Apply read_tfrecord function to each element of the dataset
    dataset = dataset.map(partial(read_tfrecord, labelled=labelled), 
                          num_parallel_calls=AUTOTUNE)

    return dataset

In [None]:
def get_train_dataset(filenames, labelled=True, ordered=False, 
                      augment=False, label_smoothing=True):
    """Function takes in the list of filenames that is
    input into the load_dataset custom function, followed by
    preprocessing the data at dataset level and outputs
    as a tf.data.Dataset data type."""
    # Loads dataset from the training set (.tfrec) shard list
    dataset = load_dataset(filenames=filenames, 
                           labelled=labelled, 
                           ordered=ordered)
    
    # Perform one hot encode for label smoothing
    """Set label_smoothing to True only when using 
    Categorical Crossentropy loss to simulate Sparse
    Categorical Crossentropy since there is no label
    smoothing parameter available."""
    if label_smoothing==True:
        """To obtain number of classes within the same function,
        we take 50 batches for security to obtain all unique
        labels"""
        # Obtaining a subset of the dataset
        subset = dataset.take(50).as_numpy_iterator()
        # Extracting the labels
        label_list = [label for (image, label) in subset]
        # Obtaining the number of unique labels
        classes = len(Counter(label_list).keys())
        # Applying One Hot Encoding to the dataset labels
        """tf.data.Dataset comes in a form of paired list;
        (image, label)"""
        def one_hot(image, label):
            return image, tf.one_hot(indices=label, 
                                     depth=classes,
                                     dtype=tf.float32)    
        dataset = dataset.map(map_func=one_hot,
                              num_parallel_calls=AUTOTUNE)
    
    # Performing image preprocessing using tf.image functions
    if augment==True:
        dataset = dataset.map(map_func=augment_func, 
                              num_parallel_calls=AUTOTUNE)
        
      ##################################################
      ## Insert other forms of advanced augmentations ##
      ## eg. Attentive CutMix, CutMix, Cutout, MixUp, ##
      ##          ShakeDrop, DropBlock, etc           ##
      ##################################################
#         dataset = dataset.map(map_func=cut_mix,
#                               num_parallel_calls=AUTOTUNE)

    # Repeats the dataset indefinitely during model training
    dataset = dataset.repeat(count=None)

    # Shuffles the dataset sampled by buffer_size
    dataset = dataset.shuffle(buffer_size=2048, seed=SEED)

    # Batching of the dataset as per BATCH_SIZE specified
    dataset = dataset.batch(batch_size=BATCH_SIZE)

    # Prefetch dataset in batches to reduce buffering time
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)

    return dataset

In [None]:
def get_valid_dataset(filenames, labelled=True, 
                      ordered=False, label_smoothing=True):
    """Function takes in the list of filenames that is
    input into the load_dataset custom function, followed by
    preprocessing the data at dataset level and outputs
    as a tf.data.Dataset data type."""
    # Loads dataset from the validation set (.tfrec) shard list
    dataset = load_dataset(filenames=filenames, 
                           labelled=labelled,
                           ordered=ordered)

    # Perform one hot encode for label smoothing
    """Set label_smoothing to True only when using 
    Categorical Crossentropy loss to simulate Sparse
    Categorical Crossentropy since there is no label
    smoothing parameter available."""
    if label_smoothing==True:
        """To obtain number of classes within the same function,
        we take 50 batches for security to obtain all unique
        labels"""
        # Obtaining a subset of the dataset
        subset = dataset.take(50).as_numpy_iterator()
        # Extracting the labels
        label_list = [label for (image, label) in subset]
        # Obtaining the number of unique labels
        classes = len(Counter(label_list).keys())
        # Applying One Hot Encoding to the dataset labels
        """tf.data.Dataset comes in a form of paired list;
        (image, label)"""
        def one_hot(image, label):
            return image, tf.one_hot(indices=label, 
                                     depth=classes,
                                     dtype=tf.float32)          
        dataset = dataset.map(map_func=one_hot,
                              num_parallel_calls=AUTOTUNE)

    # Batching of the dataset as per BATCH_SIZE specified
    dataset = dataset.batch(batch_size=BATCH_SIZE)

    # Saves subsequent preloaded batches into memory
    dataset = dataset.cache()

    # Prefetch dataset in batches to reduce buffering time
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)

    return dataset

In [None]:
def get_test_dataset(filenames, labelled=False, ordered=True):
    """Function takes in the list of filenames that is
    input into the load_dataset custom function, followed by
    preprocessing the data at dataset level and outputs
    as a tf.data.Dataset data type."""
    # Loads dataset from the test set (.tfrec) shard list
    dataset = load_dataset(filenames=filenames, 
                           labelled=labelled,
                           ordered=ordered)

    # Batching of the dataset as per BATCH_SIZE specified
    dataset = dataset.batch(batch_size=BATCH_SIZE)

    # Prefetch dataset in batches to reduce buffering time
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)

In [None]:
def count_data_items(filenames):
    """Obtaining total number of images in dataset from
    the tfrecord shards."""
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
         for filename in filenames]
    return np.sum(n)

## Before Augmentation

In [None]:
%%time

row = 6; col = 5

all_elements = get_train_dataset(FILENAMES,
                                 augment=False,
                                 label_smoothing=False)

for (image, label) in all_elements:
    plt.figure(figsize=(15, int(15*row/col)))
    for j in range(row*col):
        plt.subplot(row, col, j+1)
        plt.axis('off')
        plt.imshow(image[j,])

    plt.tight_layout()
    plt.show()
    break

## After Augmentation

*Note that the sample augmentations were exaggerated to accentuate the changes done to the images.*

In [None]:
%%time

row = 6; col = 5

all_elements = get_train_dataset(FILENAMES,
                                 augment=True,
                                 label_smoothing=False)

for (image, label) in all_elements:
    plt.figure(figsize=(15, int(15*row/col)))
    for j in range(row*col):
        plt.subplot(row, col, j+1)
        plt.axis('off')
        plt.imshow(image[j,])

    plt.tight_layout()
    plt.show()
    break

Do give me an upvote if this notebook helped. Thanks!!!

### References

How to train a Keras model on TFRecord files: 
https://keras.io/examples/keras_recipes/tfrecord/

Using Albumentations with Tensorflow:
https://albumentations.ai/docs/examples/tensorflow-example/