This is the jupyter notebook containing the code related to the blog post "Optimizing a TensorFlow Input Pipeline: Best Practices in 2022" by Stefano Martire.

https://medium.com/@virtualmartire/optimizing-a-tensorflow-input-pipeline-best-practices-in-2022-4ade92ef8736

# Tools

In [None]:
import time
import tensorflow as tf
import numpy as np

#
##
###
##
#

class ArtificialDataset(tf.data.Dataset):
    def __new__(self, num_samples):
        return tf.data.Dataset.from_generator(self.generator,
                                              output_signature = tf.TensorSpec(shape=(), dtype = tf.int64),
                                              args=[num_samples])
    @staticmethod
    def generator(num_samples):
        # Opening the file
        time.sleep(0.03)
        for sample in range(num_samples):
            # Reading data from the file
            time.sleep(0.015)
            yield sample

def TrainOneTime(dataset, num_epochs=10):
    # Given a dataset, it simulates a training on it for num_epochs epochs.
    start_time = time.perf_counter()
    for _ in range(num_epochs):
        for sample in dataset:
            # Performing a (CONSTANT!) training step
            time.sleep(0.01)
    return time.perf_counter() - start_time

def benchmark(datasource, num_experiments=10):
    # Repeat num_experiments times the same ML experiment and compute some statistics.
    # datasource = a function that returns an instance of the dataset that we want to benchmark
    times = []
    for _ in range(num_experiments):
        experiment_time = TrainOneTime(datasource())
        times.append(experiment_time)
    print("Execution time (mean, std):", np.mean(times), np.std(times))

#
##
###
##
#

def fast_TrainOneTime(dataset, num_epochs=10):
    start_time = time.perf_counter()
    for _ in range(num_epochs):
        for sample in dataset:
            pass
    return time.perf_counter() - start_time

def fast_benchmark(datasource, num_experiments=10):
    times = []
    for _ in range(num_experiments):
        experiment_time = fast_TrainOneTime(datasource())
        times.append(experiment_time)
    print("Execution time (mean, std):", np.mean(times), np.std(times))

# Naive approach

In [None]:
benchmark(
    lambda: ArtificialDataset(60)
)

Execution time (mean, std): 20.352092805800005 1.3831828740299288


# Prefetch

In [None]:
benchmark(
    lambda: ArtificialDataset(60).prefetch(30)
)

Execution time (mean, std): 11.453088595199997 0.33055232552504404


In [None]:
benchmark(
    lambda: ArtificialDataset(60).prefetch(tf.data.AUTOTUNE)
)

Execution time (mean, std): 11.381429644100006 0.24948351336565694


# Shuffle

In [None]:
benchmark(
    lambda: ArtificialDataset(60).shuffle(buffer_size=20)
)

Execution time (mean, std): 21.013408302 0.9630890709031646


## Prefetch and shuffle

In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .shuffle(buffer_size=20)
    .prefetch(tf.data.AUTOTUNE)
)

Execution time (mean, std): 12.326305236699955 0.12620411556605377


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .prefetch(tf.data.AUTOTUNE)
    .shuffle(buffer_size=20)
)

Execution time (mean, std): 12.392917123399911 0.1788876373334663


# Sampling from more than one datasets

## Interleave

In [None]:
benchmark(
    lambda: tf.data.Dataset.range(2)
            .interleave(lambda _: ArtificialDataset(30), cycle_length=2)
)

Execution time (mean, std): 20.842603973900076 1.8259570877550806


In [None]:
benchmark(
    lambda: tf.data.Dataset.range(2)
            .interleave(lambda _: ArtificialDataset(30), cycle_length=2,
                                  num_parallel_calls=tf.data.AUTOTUNE)
)

Execution time (mean, std): 10.262414164099937 0.9066938742174154


## Sample from datasets

In [None]:
datasource = lambda: tf.data.Dataset.sample_from_datasets(
                        [ArtificialDataset(30), ArtificialDataset(30)],
                        weights=[0.4, 0.6]
                      )

benchmark(datasource)

Execution time (mean, std): 20.453822995300108 1.1829536582771094


In [None]:
datasource = lambda: tf.data.Dataset.sample_from_datasets(
                        [ArtificialDataset(30), ArtificialDataset(30)],
                        weights=[0.4, 0.6]
                      ).prefetch(tf.data.AUTOTUNE)

benchmark(datasource)

Execution time (mean, std): 11.693418666100115 0.2995515966723602


# Map

In [None]:
def mapped_function(sample):
    # Simulate some hard preprocessing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return sample

In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .map(mapped_function)
)

Execution time (mean, std): 42.89547191830006 2.56696075936222


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .map(mapped_function, num_parallel_calls=tf.data.AUTOTUNE)
)

Execution time (mean, std): 21.646979339800055 2.196655362368194


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .map(mapped_function, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
)

Execution time (mean, std): 19.910414557800323 2.06004657360212


## Vectorizing mapped function

In [None]:
def increment(x):
  return x+1

In [None]:
fast_benchmark(
    lambda:
    tf.data.Dataset.range(10000)
    .map(increment)
    .batch(32)
)

Execution time (mean, std): 3.1346440906999873 0.6096452887122465


In [None]:
fast_benchmark(
    lambda:
    tf.data.Dataset.range(10000)
    .batch(32)
    # tf.Tensor.__add__ already handle batches
    .map(increment)
)

Execution time (mean, std): 0.7280943599999887 0.057596312125470554


# Cache

In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .map(mapped_function)
    .cache()
)

Execution time (mean, std): 13.092854013700002 0.9728409355802924


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .cache()
)

Execution time (mean, std): 11.065014999600317 1.2006043972694804


# Parallel batch

In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .batch(6)
)

Execution time (mean, std): 11.719768421999834 0.33159636527296316


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .batch(6, num_parallel_calls=tf.data.AUTOTUNE)
)

Execution time (mean, std): 10.106076772100005 0.1897054209484061


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .batch(6)
    .prefetch(tf.data.AUTOTUNE)
)

Execution time (mean, std): 10.090009428100052 0.1256646476816818


In [None]:
benchmark(
    lambda:
    ArtificialDataset(60)
    .batch(6, num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE)
)

Execution time (mean, std): 10.543556925000019 0.3141512506595382
