# Data Loading & Preprocessing for Deep Learning

Training a Deep Learning model using a very large dataset involves loading data into the memory (RAM) and preprocessing the data that includes type casting, scaling, augmenting. The problem is that the big data used for training Deep Learning models do not usually fit into the RAM. Thus, we need to dynamically read batches of data and preprocess.


An efficient tool for dynamic loading of data and preprocessing is **TensorFlow's Data API (tf.data)**. This API works with tf.keras. The tf.data API enables to build complex input pipelines to aggregate data from multiple files that are stored on disks, perform per-element transformations (e.g., normalization), data augmentations (e.g., resize, rotation, zoom, etc.), create batches of data, etc.

The Data API is able to read from text files (such as CSV files), binary files with fixed-size records (e.g., .mat files), and binary files that use TensorFlow’s TFRecord format, which supports records of varying sizes. 

The tf.data introduces a **tf.data.Dataset** abstraction that represents a sequence of elements, in which each element consists of one or more components. For example, in an image pipeline, an element might be a single training example, with a pair of tensor components representing the image and its label. The Dataset is used for representing a very large set of elements.



## Dataset Object for Image Recognition

In practical Deep Learning experiments, data is typically stored on disks. For example, in object recognition problems the image data is stored locally on the disk. First, we need to **construct** a Dataset object from the local image repository. Then, the Dataset object needs to be **transformed** to load image-label pairs. Prior loading we need to convert each encoded image (e.g., PNG, JPEG-encoded images) as a Tensor object, type-casting it (e.g., float32), scaling, getting the image label from the stored images (typically from the nested structure of the image directories). Finally, images have to put into batches for training the model. These stpes are described in a later notebook.  

In this notebook, we present Dataset techniques for loading and preprocessing a simple data artifact (i.e., a Python list). Specifically, we describe two steps.
- Constructing a Dataset object
- Transforming a Dataset object


## Constructing a Dataset

To construct a Dataset object from the data artifact in memory, we may use the following methods. 

- tf.data.Dataset.from_tensors(): constructs a Dataset with a single element, comprising the given tensors.

- tf.data.Dataset.from_tensor_slices(): constructs a Dataset whose elements are slices of the given tensors.

- tf.data.Dataset.list_files(): constructs a Dataset from input files matching one or more glob patterns.

Alternatively, if the input data is stored in a file in the recommended TFRecord format, we may use the tf.data.TFRecordDataset() method.





## Transforming a Dataset 


Once we have a Dataset object, we can transform it into a new Dataset by chaining method calls on the Dataset object. There are generally two types of transformations that we can apply.

- Per-element transformation: using the map() method (e.g., for loading data as image tensor and label pairs, scaling, augmentation, etc.)
- Multi-element transformation: using the batch() method


In addition to these two methods, following methods are used for preprocessing or preparing the dataset for training: cache(), shuffle(), repeat(), prefetch(), interleave()

Below we describe these methods briefly.


- cache(filename)

It stores the elements in the Dataset, which is useful for future reuse. The first time the Dataset is iterated over, its elements will be cached either in the specified file or in memory (default behavior). Subsequent iterations will use the cached Dataset.

With a small enough dataset, the cache method makes the training extra fast because the data is saved in memory after the first epoch. For larger datasets, it may be possible to cache the data to a file.


- map(map_func, num_parallel_calls=None)

It applies the given transformation function "map_func" to the input Dataset. We can parallelize this process by setting the "num_parallel_calls" parameter. For example, we may set the "num_parallel_calls" to the number of threads/processes that can be used for transformation. Alternatively, we may use the value tf.data.AUTOTUNE, which dynamically sets the number of parallel calls based on available CPU.


- shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)

The shuffle method randomly shuffles the elements of the Dataset. It fills a buffer with buffer_size elements, then randomly samples elements from the buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer_size greater than or equal to the full size of the Dataset is required.

For instance, if the Dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer. Once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buffer.

By default, the "reshuffle_each_iteration" argument is set to None. As a consequence, when batching is performed, the Dataset will produce different batches during each epoch. For producing same batches, its value should beset to False.


- repeat(count) 

It is used to repeat the Dataset "count" number of times. It is useful in scenarios when data ends up but the training should be continued. If repeat is set, then it starts reading the Dataset from the very beginning. The "count" argument should be set equal to the value of the number of epochs.


- batch(batch_size, drop_remainder=False, num_parallel_calls=None) 

It splits the Dataset into subset of the given batch_size. By setting the "drop_remainder" to True, emitting batches of exactly same size can be guaranteed. It does so by removing enough training examples such that the size of the training set is divisible by the batch_size. Also, this process can be parallelized by using "num_parallel_calls". Typically, we set it to tf.data.AUTOTUNE, which will prompt the tf.data runtime to tune the value dynamically at runtime.


- prefetch(buffer_size) 

It creates a Dataset that prefetches elements from the given Dataset. It is used to prefetch a batch to decouple the time when data is produced from the time when data is consumed. The transformation uses a background thread and an internal buffer to prefetch elements from the input Dataset ahead of the time they are requested. 
This often improves latency and throughput, at the cost of using additional memory to store prefetched elements.

The number of elements to prefetch should be equal to (or possibly greater than) the number of batches consumed by a single training epoch. Instead of manually tuning this value, we set it to tf.data.AUTOTUNE, 
which will prompt the tf.data runtime to tune the value dynamically at runtime.


- interleave(map_func, cycle_length=None, block_length=None, num_parallel_calls=None) 

It reads data from different files and parallelize this process. It applies the "map_func" function across the Dataset, and interleaves the results.

The cycle_length and block_length arguments control the order in which elements are produced. cycle_length controls the number of input elements that are processed concurrently. In general, this transformation will apply map_func to cycle_length input elements, open iterators on the returned Dataset objects, and cycle through them producing block_length consecutive elements from each iterator, and consuming the next input element each time it reaches the end of an iterator.


In [1]:
import tensorflow as tf

print("TensorFlow Version: ", tf.__version__) 

TensorFlow Version:  2.5.0


## Demo: Dataset from Toy Data


Say that our data source is a tensor object (e.g., a Python list) called X, which is small enough to fit into memory. 

We construct a Dataset object from X by using the **from_tensor_slices()** method. Its elements are all the slices of X (along the first dimension). This dataset is called the TensorSliceDataset.

The Dataset object is a Python iterable. This makes it possible to consume its elements using a for loop.

## Construct a Dataset and Display Its Elements

In [2]:
# Create a Tensor X (using one of the two techniques below)
#X = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
X = tf.range(10)

# Create a "TensorSliceDataset" Dataset object
dataset = tf.data.Dataset.from_tensor_slices(X)
print("Dataset Object and its type: ", dataset)

# Print the type specification of an element of this dataset
print("\nElement Specification:\n", dataset.element_spec)


'''
There are two ways to inspect the dataset.
- Technique 1: Print dataset elements directly along with element shapes and types
- Technique 2: Print only the dataset elements
'''

# Technique 1: Print dataset elements directly along with element shapes and types
# This is possible because the Dataset object is a Python iterable
print("\nPrint all elements with their shape and data type information: ")
for i in dataset:
    print(i)

    
# Technique 2: Print only the dataset elements   
'''
Get the content of the dataset by the as_numpy_iterator() method. 
The as_numpy_iterator() method returns an iterator.
The iterator converts all elements of the dataset to numpy.
This method requires that we are running in eager mode and 
the dataset's element_spec contains only TensorSpec components.
We have two options.
- Option 1: Print the list
- Option 2: Print each element independently
'''

# Option 1: Print the list
print("\nOption 1: Print the list of all elements: ")
print(list(dataset.as_numpy_iterator()))

# Option 2: Print each element independently
print("\nOption 2: Print each element independently: ")
for element in dataset.as_numpy_iterator():
    print(element)

Dataset Object and its type:  <TensorSliceDataset shapes: (), types: tf.int32>

Element Specification:
 TensorSpec(shape=(), dtype=tf.int32, name=None)

Print all elements with their shape and data type information: 
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(7, shape=(), dtype=int32)
tf.Tensor(8, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)

Option 1: Print the list of all elements: 
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Option 2: Print each element independently: 
0
1
2
3
4
5
6
7
8
9


## Transform a Dataset

We will apply various transformation on our toy Dataset. Multiple transformation methods can be applied via method chaining.

For each transformation method, we will briefly describe the method and the optimal combination of methods when they are chained.


Specifically, we describe the following methods and method combinations.

- cache
- shuffle (small buffer)
- shuffle (large buffer)
- batch
- repeat
- batch --> repeat
- repeat --> batch
- shuffle --> repeat --> batch
- batch --> shuffle --> repeat 

In [3]:
print("\n---------------Cache ---------------------------")

'''
The "cache()" method stores dataset elements in memory (by default) or in file for future reuse.
- The first time the dataset is iterated over (i.e., first epoch), 
its elements will be cached either in the specified file or in memory. 
- Subsequent iterations (i.e., epochs) will use the cached data.
This will save some operations (e.g., file opening, data reading, parsing, transforming, etc.) 
from being executed during each epoch.

Caching should be used judiciously.
- Smaller dataset (that fits into memory): use the cache method. 
- Large dataset:  typically is sharded (split in multiple files), and do not fit in memory.
Thus, it should not be cached in memory.
'''

print("\nOriginal Dataset:")
print(list(dataset.as_numpy_iterator()))

'''
Example:
After loading the dataset, we transform its elements by raising their power to 2.
Then, we cache the transformed dataset.
'''
dataset_1 = dataset.map(lambda x: x**2)
dataset_1 = dataset_1.cache()

print("\nTransformed Dataset is stored in cache:")
print(list(dataset_1.as_numpy_iterator()))

'''
Subsequent iterations read from the cache.
'''
print("\nSubsequent iterations read the transformed Dataset from cache:")
print(list(dataset_1.as_numpy_iterator()))


---------------Cache ---------------------------

Original Dataset:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Transformed Dataset is stored in cache:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Subsequent iterations read the transformed Dataset from cache:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]


In [4]:
print("\n---------------Shuffle ---------------------------")

print("\nOriginal Dataset:")
print(list(dataset.as_numpy_iterator()))

'''
The cache() method will produce exactly the same elements during each iteration (epoch) through the dataset. 
For randomizing the iteration order, we need to call the shuffle() method after calling cache().

The shuffle() method, first, fills a buffer with "buffer_size" elements.
Then, randomly samples elements from this buffer by replacing the selected elements with new elements.

The value of "buffer_size" influences the dataset randomization.
- Small buffer (smaller than the length of dataset)
- Large buffer (greater than or equal to the length of dataset)
'''


print("\n---------------Shuffle (small buffer)---------------------------")

'''
If the buffer size is smaller than the length of the dataset, its elements are not completely randomized.
In this example, the dataset contains 10 elements but buffer_size is set to 2.
Thus, shuffle will initially select a random element from the first 2 elements in the buffer. 
Once an element is selected, its space in the buffer is replaced by the next (i.e., 3rd) element, 
maintaining the 2 element buffer.

Observe from the output of 10 iterations, that the order of the dataset elements is not purey andom.
'''

dataset_2 = dataset.shuffle(buffer_size=2)

print("\nOutput of 10 iterations (epochs): Partial Randomness")
for i in range(10):
    print(list(dataset_2.as_numpy_iterator()))
    
    
print("\n---------------Shuffle (large buffer)---------------------------")

'''
For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.
'''
dataset_3 = dataset.shuffle(buffer_size=10)

print("\nOutput of 10 iterations (epochs): Full Randomness")
for i in range(10):
    print(list(dataset_3.as_numpy_iterator()))
    


---------------Shuffle ---------------------------

Original Dataset:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

---------------Shuffle (small buffer)---------------------------

Output of 10 iterations (epochs): Partial Randomness
[1, 2, 0, 3, 5, 4, 7, 8, 6, 9]
[1, 2, 3, 4, 5, 0, 7, 8, 6, 9]
[1, 0, 2, 4, 5, 3, 7, 6, 9, 8]
[1, 0, 3, 2, 4, 6, 7, 8, 5, 9]
[1, 2, 0, 3, 4, 5, 7, 6, 9, 8]
[0, 2, 1, 3, 4, 5, 7, 6, 8, 9]
[0, 1, 2, 4, 3, 5, 7, 6, 9, 8]
[1, 0, 2, 3, 5, 4, 7, 6, 8, 9]
[1, 2, 3, 0, 5, 6, 4, 7, 9, 8]
[0, 1, 2, 3, 4, 5, 6, 7, 9, 8]

---------------Shuffle (large buffer)---------------------------

Output of 10 iterations (epochs): Full Randomness
[1, 2, 3, 9, 8, 0, 4, 7, 6, 5]
[8, 1, 7, 4, 6, 9, 3, 0, 5, 2]
[7, 3, 9, 6, 0, 2, 8, 5, 4, 1]
[5, 4, 2, 6, 3, 0, 9, 8, 1, 7]
[4, 1, 7, 8, 0, 3, 5, 9, 2, 6]
[6, 9, 4, 5, 0, 7, 3, 2, 1, 8]
[3, 9, 2, 8, 5, 1, 4, 6, 0, 7]
[3, 8, 6, 2, 1, 9, 7, 5, 0, 4]
[0, 2, 8, 3, 7, 5, 4, 1, 9, 6]
[9, 5, 0, 1, 7, 2, 4, 3, 6, 8]


In [5]:
print("\n---------------Batch---------------------------")

'''
The batch() method combines consecutive elements of the dataset into batches.
The size of the batches is determined by the "batch_size" parameter.

The components of the resulting element will have an additional outer dimension, 
which will be batch_size (or N % batch_size for the last element if batch_size 
does not divide the number of input elements N evenly and drop_remainder parameter is False). 
'''
dataset_4 = dataset.batch(batch_size=3, drop_remainder=False)

print(dataset_4)
print("\nThe component of the batch dataset has an additional outer dimension.") 
for i in dataset_4:
    print(i)


print("\nDisplay each batch as a list:")
for element in dataset_4.as_numpy_iterator():
    print(element)
    
    
print("\n---------------Batch (same length)---------------------------")
'''
To create the batches with the same outer dimension or same length, 
set the "drop_remainder" parameter to True.
It prevents the smaller batch from being produced, by removing enough training examples. 
Consequently, the size of the training set will be divisible by the batch_size. 
'''
dataset_5 = dataset.batch(batch_size=3, drop_remainder=True)

print("\nDisplay each batch as a list (all batches have the same length):")
for element in dataset_5.as_numpy_iterator():
    print(element)


---------------Batch---------------------------
<BatchDataset shapes: (None,), types: tf.int32>

The component of the batch dataset has an additional outer dimension.
tf.Tensor([0 1 2], shape=(3,), dtype=int32)
tf.Tensor([3 4 5], shape=(3,), dtype=int32)
tf.Tensor([6 7 8], shape=(3,), dtype=int32)
tf.Tensor([9], shape=(1,), dtype=int32)

Display each batch as a list:
[0 1 2]
[3 4 5]
[6 7 8]
[9]

---------------Batch (same length)---------------------------

Display each batch as a list (all batches have the same length):
[0 1 2]
[3 4 5]
[6 7 8]


In [6]:
print("\n--------------------Repeat------------------------------")

'''
The repeat() method is used to repeat the dataset.
By default (if no argument is used), the dataset is repeated indefinitely.
However, if the "count" parameter is set,
then the dataset is repeated "count" number of times.
'''

print("\nOriginal Dataset:")
print(list(dataset.as_numpy_iterator()))

print("\nDataset repeated 2 times:")
dataset_6 = dataset.repeat(count=2)
print(list(dataset_6.as_numpy_iterator()))


print("\n---------------Repeat --> Batch---------------------------")

'''
The repeat() method should be used in scenarios when data ends up but the training should be continued.
For example, if we have 10 samples batched for training and we want to continue the training for 6 epochs,
then we need to repeat the dataset at least 6 times.

Because during each epoch, the model uses the whole dataset by breaking it into batches.
In this example, batch size is 3, so we get 3 batches to run in 1 epoch.
To train the model for 6 epochs, we must repeat the dataset at least 6 times.

For training deep learning models, dataset should be repeated based on the number of epochs.
If the dataset is repeated indefinitely, then we need to set the step size argument of a model's fit() method,
which is determined by (dataset length)/(batch size)
'''

print("\nOriginal Dataset:")
print(list(dataset.as_numpy_iterator()))

print("\nNo Repeat (batch size = 3): (runs for 1 epoch)")
dataset_7 = dataset.batch(batch_size=3, drop_remainder=True)
for element in dataset_7.as_numpy_iterator():
    print(element)

    
'''
The repeat() method concatenates its arguments without signaling the end of one epoch 
and the beginning of the next epoch. 
Because of this, a batch() method applied after repeat() will yield batches that straddle epoch boundaries.
'''
print("\nRepeat (batch size = 3): (runs up to 7 epochs)")
dataset_8 = dataset.repeat(6).batch(batch_size=3, drop_remainder=True)

iteration = 0
epoch_count = 0
for element in dataset_8.as_numpy_iterator():
    if (iteration%3 == 0):
        epoch_count += 1
        print("\nEpoch: %d" % epoch_count)
    print(element)
    iteration += 1
    
    
print("\n---------------Shuffle --> Batch --> Repeat---------------------------")

'''
For a clear separation of epoch, we need to put the batch() method before the repeat()
'''
    
print("\nRepeat (batch size = 3): (runs up to 6 epochs): Same batches/per epoch")
dataset_9 = dataset.batch(batch_size=3, drop_remainder=True).repeat(6)

iteration = 0
epoch_count = 0
for element in dataset_9.as_numpy_iterator():
    if (iteration%3 == 0):
        epoch_count += 1
        print("\nEpoch: %d" % epoch_count)
    print(element)
    iteration += 1


--------------------Repeat------------------------------

Original Dataset:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Dataset repeated 2 times:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

---------------Repeat --> Batch---------------------------

Original Dataset:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

No Repeat (batch size = 3): (runs for 1 epoch)
[0 1 2]
[3 4 5]
[6 7 8]

Repeat (batch size = 3): (runs up to 7 epochs)

Epoch: 1
[0 1 2]
[3 4 5]
[6 7 8]

Epoch: 2
[9 0 1]
[2 3 4]
[5 6 7]

Epoch: 3
[8 9 0]
[1 2 3]
[4 5 6]

Epoch: 4
[7 8 9]
[0 1 2]
[3 4 5]

Epoch: 5
[6 7 8]
[9 0 1]
[2 3 4]

Epoch: 6
[5 6 7]
[8 9 0]
[1 2 3]

Epoch: 7
[4 5 6]
[7 8 9]

---------------Shuffle --> Batch --> Repeat---------------------------

Repeat (batch size = 3): (runs up to 6 epochs): Same batches/per epoch

Epoch: 1
[0 1 2]
[3 4 5]
[6 7 8]

Epoch: 2
[0 1 2]
[3 4 5]
[6 7 8]

Epoch: 3
[0 1 2]
[3 4 5]
[6 7 8]

Epoch: 4
[0 1 2]
[3 4 5]
[6 7 8]

Epoch: 5
[0 1 2]
[3 4 5]
[6 7 8]

Epoch: 6
[0 1 2]
[3 4 5]
[6 7

In [7]:
print("\n---------------Shuffle --> Repeat ---> Batch ---------------------------")
'''
The repeat() method should be applied after shuffle().
Because shuffle() doesn't signal the end of an epoch until the shuffle buffer is empty. 
So, to show every element of one epoch before moving to the next,
shuffle() should be placed placed before repeat().

This ordering of shuffle --> repeat puts batch after repeat, which is the optimal ordering.
It ensures that the batches are unique.

On the other hand, batch --> shuffle --> repeat ordering emits batches with exact same elements.
'''
    
print("\nOptimal Ordering: shuffle --> repeat --> batch")
dataset_10 = dataset.shuffle(buffer_size=10).repeat(6).batch(batch_size=3, drop_remainder=True)

iteration = 0
epoch_count = 0
for element in dataset_10.as_numpy_iterator():
    if (iteration%3 == 0):
        epoch_count += 1
        print("\nEpoch: %d" % epoch_count)
    print(element)
    iteration += 1
    
    
print("\nNon-optimal Ordering: batch --> shuffle --> repeat")
dataset_11 = dataset.batch(batch_size=3, drop_remainder=True).shuffle(buffer_size=10).repeat(6)

iteration = 0
epoch_count = 0
for element in dataset_11.as_numpy_iterator():
    if (iteration%3 == 0):
        epoch_count += 1
        print("\nEpoch: %d" % epoch_count)
    print(element)
    iteration += 1


---------------Shuffle --> Repeat ---> Batch ---------------------------

Optimal Ordering: shuffle --> repeat --> batch

Epoch: 1
[6 4 2]
[8 3 9]
[0 7 5]

Epoch: 2
[1 9 1]
[5 7 2]
[0 4 3]

Epoch: 3
[6 8 6]
[5 8 9]
[0 2 4]

Epoch: 4
[3 1 7]
[4 6 9]
[5 0 3]

Epoch: 5
[7 2 1]
[8 2 8]
[9 6 5]

Epoch: 6
[3 4 0]
[1 7 4]
[0 7 3]

Epoch: 7
[5 9 6]
[8 2 1]

Non-optimal Ordering: batch --> shuffle --> repeat

Epoch: 1
[0 1 2]
[6 7 8]
[3 4 5]

Epoch: 2
[3 4 5]
[0 1 2]
[6 7 8]

Epoch: 3
[3 4 5]
[0 1 2]
[6 7 8]

Epoch: 4
[3 4 5]
[0 1 2]
[6 7 8]

Epoch: 5
[6 7 8]
[0 1 2]
[3 4 5]

Epoch: 6
[6 7 8]
[3 4 5]
[0 1 2]


## Method Chaining: Template for Deep Learning Pre-processing Pipeline


In [8]:
print("\n------map (scale) --> cache --> shuffle --> repeat --> batch --> map (augment)---------")

'''
In a Deep Learning pre-processing pipeline, 
typically we need to apply some tranformations on the Dataset:
- Per-element
- Per-batch

For example, we want to scale each element of the Dataset (dividing by 10.0) 
and augment each batch by raising the batch elements by power of 2.

Following we show how to perform these two transformations along with the prevously discussed pre-processing
transformations such as cache, shuffle, repeat, batch

The optimal order of these transformation should be:
map (scale) --> cache --> shuffle --> repeat --> batch --> map (augment)
'''


# Function for per-element transformation
def scale(x):
    return x/10


# Function for per-batch transformation
def augment(x):
    return x**2


buffer_size = 10 # shuffle buffer
count = 6 # repeat count
batch_size = 3



dataset_12 = dataset.map(lambda x: (scale(x))).cache().shuffle(buffer_size)\
            .repeat(count).batch(batch_size, drop_remainder=True).map(lambda y: (augment(y)))


iteration = 0
epoch_count = 0
for element in dataset_12.as_numpy_iterator():
    if (iteration%3 == 0):
        epoch_count += 1
        print("\nEpoch: %d" % epoch_count)
    print(element)
    iteration += 1


------map (scale) --> cache --> shuffle --> repeat --> batch --> map (augment)---------

Epoch: 1
[0.   0.16 0.01]
[0.04 0.25 0.64]
[0.49 0.81 0.09]

Epoch: 2
[0.36 0.25 0.16]
[0.49 0.64 0.  ]
[0.36 0.09 0.81]

Epoch: 3
[0.04 0.01 0.36]
[0.81 0.09 0.04]
[0.25 0.49 0.64]

Epoch: 4
[0.16 0.01 0.  ]
[0.01 0.49 0.64]
[0.09 0.16 0.04]

Epoch: 5
[0.36 0.25 0.81]
[0.   0.01 0.04]
[0.25 0.09 0.49]

Epoch: 6
[0.16 0.   0.64]
[0.36 0.81 0.49]
[0.25 0.81 0.04]

Epoch: 7
[0.16 0.09 0.36]
[0.64 0.01 0.  ]
