prefetch function will allow CPU and GPU to work at the same time, this will lead to save time in training our model

tf.data.Dataset.prefetch(AUTOTUNE), tensorflow will decide how many batches it will fetch at same time period

CPU will read the dataset and GPU will train it

tf.data.Dataset.cache() : it will do open , read , map, train functions on first epoch and for the rest of epochs it will only train the model based on first epoch data and functions

In [1]:
import tensorflow as tf
import time

In [18]:
# measuring performance of prefetch
# it is dummy class to mimic the real thing

class filedataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        # open the file:
        time.sleep(0.03)
        for sample_idx in range(num_samples):
            time.sleep(0.015) # this is reading files part (CPU)
            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches, # here we have our generator
            output_signature= tf.TensorSpec(shape= (1,), dtype= tf.int64), # it will return integer
            args=(num_samples,)
        )
        

In [25]:
def benchmark(dataset, num_epochs=4):
    for epoch in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01) # this is training part (GPU)

here we read and train everything sequentially

In [28]:
%%timeit

benchmark(filedataset()) # it will go throgh everything and banchmarking it

717 ms ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


now we use prefetch technique to see how much it improves the performance

In [31]:
%%timeit

benchmark(filedataset().prefetch(3)) # we can see it takes less time

633 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [32]:
%%timeit

benchmark(filedataset().prefetch(tf.data.AUTOTUNE)) # autotune is common thing for prefetch

647 ms ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### cache API
it will open, read, map dataset only on first epoch

In [34]:
dataset = tf.data.Dataset.range(5)
for d in dataset:
    print(d.numpy())

0
1
2
3
4


In [35]:
dataset = dataset.map(lambda x: x**2)
for d in dataset:
    print(d.numpy())

0
1
4
9
16


In [37]:
dataset = dataset.cache()

list(dataset.as_numpy_iterator()) # it will read it from cache

[0, 1, 4, 9, 16]

In [38]:
def mapped_function(s):
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [39]:
%%timeit -n1 -r1

benchmark(filedataset().map(mapped_function))

1.25 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [41]:
%%timeit -n1 -r1

benchmark(filedataset().map(mapped_function).cache())  # it will improve performance using cache

488 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
