In [13]:
import tensorflow as tf

In [14]:
tf.__version__

'2.10.0'

In [15]:
class FileDataset(tf.data.Dataset):
    def read_files_in_batches(num_samples):
        # open file
        time.sleep(0.03) # opening file time
        for sample_idx in range(num_samples):
            time.sleep(0.015) # reading file time
            yield (sample_idx,) # python generator that does not save the entire data in the the memory but saves only the useful part and generates one value at a time

    def __new__(cls,num_samples=3):
        return tf.data.Dataset.from_generator(
            cls.read_files_in_batches,
            output_signature=tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )   

In [16]:
def benchmark(dataset,num_epochs=2):
    for epoch in range(num_epochs):
        for sample in dataset:
            time.sleep(0.01) # training time

In [17]:
import time
import timeit

In [18]:
%%timeit
benchmark(FileDataset())

403 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<h3>Prefetch API</h3>

In [19]:
%%timeit
benchmark(FileDataset().prefetch(1))

377 ms ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
benchmark(FileDataset().prefetch(tf.data.AUTOTUNE))

372 ms ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<h3>Cache API</h3>

In [24]:
dataset = tf.data.Dataset.range(5)

for d in dataset:
    print(d.numpy())

0
1
2
3
4


In [25]:
dataset = dataset.map(lambda x: x**2)

list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [27]:
dataset = dataset.cache()
# if we don't put this in cache then everytime we call the datset.as_numpy_iterator() function, it has to map to find the square of each number
list(dataset.as_numpy_iterator())

[0, 1, 4, 9, 16]

In [30]:
# applying map function to our original FileDataset
def mapped_function(s): # this function is introducing some delay
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

In [31]:
%%timeit
benchmark(FileDataset().map(mapped_function),5)

1.72 s ± 38.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
%%timeit
benchmark(FileDataset().map(mapped_function).cache(),5)
# after introducing cache, it takes less than half the original time.
# for each epoch, the mapping is done only for one epoch and not for each subsequent epoch as it is cached so it saves time.

626 ms ± 28.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
