In [1]:
import tensorflow as tf
import numpy as np
import os
import time

In [2]:
tf.__version__

'2.5.0'

### Prefetching

In [3]:
class Dataset(tf.data.Dataset):
    def read_file_in_batches(num_samples):
        time.sleep(0.03)
        
        for sample_idx in range(num_samples):
            time.sleep(0.015)
            yield (sample_idx, )
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_file_in_batches,
            output_signature=tf.TensorSpec(shape=(1,), dtype=tf.int64),
            args=(num_samples,)
        )

In [4]:
def benchmark(dataset, num_epochs=2):
    for epoch_num in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01)

In [5]:
%%timeit
benchmark(Dataset())

352 ms ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%%timeit
benchmark(Dataset().prefetch(1))

296 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%%timeit
benchmark(Dataset().prefetch(tf.data.AUTOTUNE))

299 ms ± 18.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Cache

In [8]:
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache("mycache.txt")

list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [13]:
def mapped_function(s):
    tf.py_function(lambda: time.sleep(0.3), [], ())
    return s

In [14]:
%%timeit -r1 -n1
benchmark(Dataset().map(mapped_function), 5)

6.23 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [15]:
%%timeit -r1 -n1
benchmark(Dataset().map(mapped_function).cache(), 5)

1.85 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
