<a href="https://colab.research.google.com/github/vrgeo/ml-tutorials/blob/main/session_02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Session 2: Salt classification with CNNs using TensorFlow and Keras

In this session, you will learn how to **train a simple CNN using TensorFlow and Keras**, for the purpose of salt classification.


As usual, we start by importing the necessary python packages. If you participated in session one, you should already be familiar with numpy and matplotlib. 

Additionaly, in this session we import **TensorFlow**, which is a free and **open-source machine learning library**. By default, TensorFlow includes a library called **Keras** which is used as an underlying API for the purpose of **designing and training Deep Neural Networks**. You can learn more about TensorFlow and also find additional tutorials and guides on their [website](https://www.tensorflow.org/overview).

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import classification_report

## 1. Preparing the Datasets
For this session, we are providing a **dataset of 28x28 pixel image patches**, extracted from the Z3 Netherlands seismic survey. Each image has been **labelled as 'salt' or 'background'**. We split the data into **training and test datasets**, the latter one will later allow us to validate our trained model.
Download the dataset from our github repository, by executing the code cell below. 

In [None]:
!wget https://github.com/vrgeo/ml-tutorials/blob/5c17fe8c779b0ab22243ef77753e94e923dd7deb/data/dataset_train.npz?raw=true -nv -O dataset_train.npz
!wget https://github.com/vrgeo/ml-tutorials/blob/5c17fe8c779b0ab22243ef77753e94e923dd7deb/data/dataset_test.npz?raw=true -nv -O dataset_test.npz

You might have notized these files have an unusual file extension,.npz. These files were created with numpy's **np.savez()** function, which allows you to write one or more numpy arrays from your running python application into a zipped file.

We can now load the files back into python, using the **numpy.load()** function. There are two arrays in each file, which can be accessed using the keys *'arr_0'* and *'arr_1'*. The first array in each file contains the patches, the second one the corresponding labels.

Run the cell below in order to load the datasets.

In [None]:
training_data = np.load("dataset_train.npz", allow_pickle=True)
training_patches = training_data["arr_0"]
training_labels = training_data["arr_1"]
print(f'Loaded {len(training_patches)} training patches and {len(training_labels)} corresponding labels')

test_data = np.load("dataset_test.npz", allow_pickle=True)
test_patches = test_data["arr_0"]
test_labels = test_data["arr_1"]
print(f'Loaded {len(test_patches)} training patches and {len(test_labels)} corresponding labels')

Let us begin by examining the training data. First, we shall see how many patches we have of each class. In the dataset, patches are labelled either as **'0' - 'background'** or **'1'- 'salt'**.
Using the numpy's **count_nonzero()** function, we count the number of non-zero values in the labels array, thus getting the number of patches labelled as 'salt'.

In [None]:
n_salt = np.count_nonzero(training_labels)
n_background = len(training_labels) - n_salt
print(f'Found {n_salt} patches containing salt and {n_background} patches not containing salt.')

We have a **2:1 split of background vs salt**. In image classification, is often advisable to have more examples for the background class, since it usually is more varied than the target class.

In the cell below, we define a small python function, that will help us **visualize patches and corresponding labels**, with the help of **matplotlib**. We are going to use it during the rest of the session, so make sure to run the code cell below once.

In [6]:
def visualize_patches(patches, labels, rows, cols, colormap = 'gray'):
  fig = plt.figure(figsize=(cols*2, rows*2.2))
  for i in range(len(patches)):
    title = labels[i]
    plt.subplot(rows, cols, i+1)
    plt.imshow(patches[i], cmap = colormap)
    plt.axis('off')
    plt.title(title)

By running the cell below, you can **visualize a random selection of 'salt' patches** from the trainig dataset. We use numpy to get the indices of all patches which have label that is non-zero. We then use numpy's **random.choice()** function to draw twelve random indices from this set. We pass this subset of our dataset to our **visualize_patches()** function.

In [None]:
salt_indices = np.nonzero(training_labels)[0]
random_salt_indices = np.random.choice(salt_indices, 12, replace=False)
visualize_patches(training_patches[random_salt_indices], training_labels[random_salt_indices], 3, 4)


Similarly, you can run the cell below, in order to **visualize twelve random 'background' patches**. 

Notice how we inverted the labels before passing them to numpy's nonzero() function. This way, we can use the same method as before to access background paches.

In [None]:
background_indices = np.nonzero(1-training_labels)[0]
random_background_indices = np.random.choice(background_indices, 12, replace=False)
visualize_patches(training_patches[random_background_indices], training_labels[random_background_indices], 3, 4)

We are now almost ready to design and train our network, but first our **dataset needs to be normalized and reshaped**. Normalizing your dataset to a range between 0 and 1 is always desirable when training machine learning models. The reshaping is necessary for keras to be able to read our data.

Run the cell below to normalize and reshape the training data.

In [9]:
def normalize_and_reshape_data(input_patches):
  patches = tf.keras.utils.normalize(input_patches, axis=1)
  patches = np.array(patches).reshape(-1, 28, 28, 1)
  return patches

training_patches_normalized = normalize_and_reshape_data(training_patches)
input_shape =  training_patches_normalized.shape[1:]

## 2. Defining and training the model

Now we can start right away with **designing our network**. The example model in this session is based on **LeNet-5**, one of the earliest CNNs, designed by [LeCun et. al., 1989](https://direct.mit.edu/neco/article/1/4/541/5515/Backpropagation-Applied-to-Handwritten-Zip-Code). It was originally desighned to recognize hand written digits, but has been adapted to perform binary classification for this session.

First, we initialize a **sequential model**. This type of Neural Network model is ordered in **layers**, and the output of each layer will be the input of the next layer. The first layer's input will be our 28x28 seismic patches.

Because we want to create a **Convolutional Neural Network** (CNN), we start with a **convolution layer**. Convolution layers are the defining feature of CNNs, they **allow the network to learn filters** and thus detect certain features, such as shapes or textures. In keras, we can define the **number of filters** and the **size of the filter kernels** for each convolution layer.

Each convolution layer is usually followed by a **pooling layer**. This layer reduces the size of the output of the previous layer, by a **downsampling** method. In this case, four input values are reduced to a single output value by averaging. By adding pooling layers, we force the model to **encode the information** and focus on the significant parts.

We add a second pair of convolution and pooling layer. Only by **chaining convolution layers** like this, we enable our model to **learn more complex features** by combining features from a previous conv layer. If a convolution layer has learned to detect edges, then following convolution layer might combine these detections in order to detect shapes or whole objects.

In [10]:
input_shape =  training_patches_normalized.shape[1:]

model = tf.keras.models.Sequential()  

model.add(tf.keras.layers.Conv2D(filters=6 ,kernel_size=5, strides=1, activation=tf.nn.relu, input_shape=input_shape))
model.add(tf.keras.layers.AveragePooling2D(pool_size=2,strides=2))

model.add(tf.keras.layers.Conv2D(filters=16 ,kernel_size=5, strides=1, activation=tf.nn.relu))
model.add(tf.keras.layers.AveragePooling2D(pool_size=2,strides=2))

After the last convolution and pooling layer follows the classification part of a CNN. This is a **Fully Connected Neural Network** that will take the higher level feature detections from the final convolution layer as an input vector and learns to **classify the feature vectors into 'salt' and 'background'**.

Notice how the **final layer has a single output neuron**. It takes the weighted sum of all the neuron activations of the previous layer into a **sigmoid function**, in order to compute a **'salt likelihood'** from 0 to 1 for a given input.

In [11]:
model.add(tf.keras.layers.Flatten())

model.add(tf.keras.layers.Dense(120, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(84, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(1, activation=tf.nn.sigmoid))

Now we compile and train our model, using **binary crossentropy loss** function, which computes a loss based on how close the final 'salt likelihood' output of the model was to the ground truth label.

We train our model for a total of **ten epochs**, during each epoch the entire training dataset is passed through the network. 

In [17]:
model.compile(optimizer='adam',
              loss='binary_crossentropy',  
              metrics=['accuracy'])

history=model.fit(training_patches_normalized, training_labels, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Since we **saved the training logs as an object** called history, we can now **plot the development of the loss during training**, using matplotlib. From the grapgh, we can see that the **loss has not reached a plateau yet**, so training for more than ten epochs might improve the model further.

In [None]:
plt.plot(history.history['loss'], label='Binary Cross Entropy Loss')
plt.ylabel('BCE Loss value')
plt.xlabel('Epoch')
plt.legend(loc="upper left")
plt.show()

## 3. Evaluation
Now that we have trained our model, we can **validate it's performance** on our test dataest. 

As before, we need to **reshape and normalize** the data first. After that, we can do inference by simply calling **model.predict()** on our normalized data. We get back salt likelyhood predictions, which we **binarize to 0 and 1**, in order to compare them to the ground truth labels.

Using a helpful **function from Scikit-Learn, called classification_report()**, we can get a detailed report on the performance of our model.

In [None]:
test_patches_normalized = normalize_and_reshape_data(test_patches)

predictions = model.predict(test_patches_normalized)
prediction_values = predictions.reshape(-1)
prediction_labels = (prediction_values > 0.5).astype(np.uint8)

print(classification_report(test_labels, prediction_labels, digits=5))

In the resultig table, we can see **several metrics**, that descibe the performance of our mode. 

**Accuracy** simply describes how many patches were classified correctly. 

For a given class, **precision** describes how many of the detections for that class actually belonged to that class. If this value is low, there are many **false positives**. 

**Recall** on the other hand, describes how many of the true instances of a class were correctly detected. If the recall is low, the model prodices a lot of **false negatives**. 

You have to consider both precision and recall when evaluating the performance of a model, which is why the **F1-score** is commonly used as a metric, since it is a **function of both precision and recall**.

Using the cell beow, you can **visualize a random subset of the training patches**, along with the model's **salt likelyhood prediction**.

In [None]:
random_test_indices = np.random.choice(np.arange(len(test_patches)), 12, replace=False)
visualize_patches(test_patches[random_test_indices], prediction_values[random_test_indices], 3, 4)

Using the cell below, you can then **view the ground truth labels** of the same patches.

In [None]:
visualize_patches(test_patches[random_test_indices], test_labels[random_test_indices], 3, 4)

In the cell below, we want to **visualize the falsely classified patches** specifically. We achieve this, by **computing the residuals**, the difference between predicted and true label. We access the instances where this value is not zero. 

Run the cell below in order to visualize the wrongly classified patches and their corresponding salt likelihood value returned by the model.

In [None]:
residuals = test_labels - prediction_labels
missclassified_indices = np.nonzero(residuals)[0]

visualize_patches(test_patches[missclassified_indices], prediction_values[missclassified_indices], 1+(len(missclassified_indices)/4), 4)

##Thank you for participating!
This is it for Session 2 of our ML tutorials. We hope you learned something new about training CNN classifiers using TensorFlow and Keras, as well as general machine learning best-practices, such as evaluating different metrics, inspecting the loss graph of your model and looking at missclassified instances.

## Where to go from here?
If you wanted to further improve the model from this session, there could be several ways. We have seen from the loss funcion, that training for more epochs could have a beneficial effect. 

The model used here is also a rather simple one, why don't you have a try at adding additioinal convolutional and pooling layers, or more fully connected layers at the end and see how that chages the performance? It should be easy to design and test your own CNN models within this notebook.

Finally, you might be able to import your own labelled dataset into this notebook. You do not have to use the same method of loading npz files, the images could also be .pngs from your harddrive (colab allows you to upload data from your pc) or even slices loaded from a segy file (see session 1).

