# Stork net v2 - multiclass classification of stork nest images
*Module ANN & DL  
May 2020  
Ueli Mauch & Stefan Schmutz*

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sschmutz/stork-net/blob/master/scripts/stork_net_v2/stork_net_v2.ipynb)  

The aim of Stork net is to classify images of a stork nest. The model should be able to predict how many storks are present. Images were collected from a publicly available [webcam](https://www.berner-storch.ch/webcam/) and manually labeled.  
For simple access and version control of the labeled images, a GitHub repository [stork-net-dataset](https://github.com/sschmutz/stork-net-dataset) is provided.

This version of Stork net (v2) can classify up to 3 storks (4 classes). It achieves this by training a relatively simple convolutional neural network (CNN).  
The implementation of a more complex network by transfer learning from e.g. [MobileNet V2](https://arxiv.org/pdf/1801.04381.pdf) will be set aside for a later version of Stork net.

In [None]:
# import the following necessary libraries
import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt

## Loading images
Labeled images are already split and available on GitHub ([stork-net-dataset](https://github.com/sschmutz/stork-net-dataset)).
The full dataset will be downloaded with the following command, this enables us to use this notebook in google colab.

In [None]:
# the complete labeled dataset will be downloaded to ~/.keras/datasets/
data_dir = tf.keras.utils.get_file(origin="https://github.com/sschmutz/stork-net-dataset/archive/master.zip", fname="stork-net-dataset-master.zip", extract=True)
# remove the .zip file extension since the data is already extracted
data_dir = pathlib.Path(os.path.splitext(data_dir)[0])

# the training split has been furher separated into a "train" and "validation" part
train_dir = pathlib.Path(data_dir, "2019_train", "train")
validation_dir = pathlib.Path(data_dir, "2019_train", "validation")

test_dir = pathlib.Path(data_dir, "2019_test")

In [None]:
# get the statistics on how many images there are in each split
total_train = len(list(train_dir.glob("*/*.jpg")))
total_val = len(list(validation_dir.glob("*/*.jpg")))
total_test = len(list(test_dir.glob("*/*.jpg")))

# the class names are also saved, note that the order is important since the model output will have the same
class_names = np.array([item.name for item in train_dir.glob("*")])

The images will be loaded in batches of size defined as the variable `batch_size`.

Data augmentation could be defined inside ***ImageDataGenerator()***, see the respective section on the [keras website](https://keras.io/api/preprocessing/image/).  
We've tried data augmentation (if applied, only do it on the training data) as described in [this tutorial](https://www.tensorflow.org/tutorials/images/classification). It didn't improve the model so we did not include it in this version of the model.

In [None]:
# convert the images from uint8 to float32 in range [0,1]
train_image_generator = ImageDataGenerator(rescale=1./255)
validation_image_generator = ImageDataGenerator(rescale=1./255)
test_image_generator = ImageDataGenerator(rescale=1./255)

In [None]:
batch_size = 64 
epochs = 15
img_height = 480
img_width = 640
channels = 3

In [None]:
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(img_height, img_width),
                                                           class_mode="categorical",
                                                           classes = list(class_names))

val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              shuffle=True,
                                                              target_size=(img_height, img_width),
                                                              class_mode="categorical",
                                                              classes = list(class_names))

test_data_gen = test_image_generator.flow_from_directory(batch_size=batch_size,
                                                         directory=test_dir,
                                                         shuffle=True,
                                                         target_size=(img_height, img_width),
                                                         class_mode="categorical",
                                                         classes = list(class_names))

In [None]:
# get images and respective labels (one-hot encoded)
training_images, training_labels = next(train_data_gen)
validation_images, validation_labels = next(val_data_gen)
test_images, test_labels = next(test_data_gen)

# decode one-hot encoded labels
training_labels_decoded = tf.argmax(training_labels, axis=1)
validation_labels_decoded = tf.argmax(validation_labels, axis=1)
test_labels_decoded = tf.argmax(test_labels, axis=1)

# example of 25 images with their labels
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(training_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[training_labels_decoded[i]])
plt.show()

There are 2'940 labeled images in total which were split as follows:
- 2'057 training (70%)
- 442 validation (15%)
- 441 testing (15%)

If we look at the distribution of the classes, we see that images with 3 storks are underrepresented. It happened because the images were sampled randomly. This class imbalance could lead to a biased model, but let's see.

<img src="figures/class_imbalance.png" alt="class_imbalance" style="width: 500px;"/>

## Creation and training a CNN model
The relatively simple model consists of 3 convolutional layers including max pooling and is followed by 2 fully connected layers. There are 4 output nodes which represent, due to the softmax activation, probabilities of belonging to the respective classes.

<img src="figures/CNN_model.png" alt="CNN_model" style="height: 500px;"/>

In [None]:
model = Sequential([
    Conv2D(16, 3, padding="same", activation="relu", input_shape=(img_height, img_width, channels)),
    MaxPooling2D(),
    Dropout(0.1),
    Conv2D(32, 3, padding="same", activation="relu"),
    MaxPooling2D(),
    Conv2D(64, 3, padding="same", activation="relu"),
    MaxPooling2D(),
    Dropout(0.1),
    Flatten(),
    Dense(64, activation="relu"),
    Dense(4, activation="softmax")
])

The popular optimizer [ADAM](https://arxiv.org/pdf/1412.6980.pdf) has been proven to work well for this problem. Categorical Cross-Entropy loss will be computed since it's a task of multiclass classification where we'd like to have probabilities to all classes as output.

In [None]:
model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

In [None]:
model.summary()

In [None]:
# execute the training
history = model.fit(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=total_val // batch_size
)

In [None]:
# save the trained model
model.save("stork_net_v2.h5")

# this saved model can be loaded using the following command
# model = tf.keras.models.load_model("stork_net_v2.h5")

## Evaluation
To track accuracy and loss of both, training and validation set, those numbers are plotted. It's important to confirm that the model is neither under- nor overfitting.

In [None]:
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]

loss=history.history["loss"]
val_loss=history.history["val_loss"]

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc, label="Validation Accuracy")
plt.legend(loc="lower right")
plt.title("Training and Validation Accuracy")

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss, label="Validation Loss")
plt.legend(loc="upper right")
plt.title("Training and Validation Loss")
plt.show()

And finally to evaluate the model, accuracy and loss of predictions of the test set can be computed.

In [None]:
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=0)

print('\nTest loss:', test_loss)
print('\nTest accuracy:', test_acc)

If we look at the misclassified images over all labeled images we see that the difference between Reference (label) and Prediction is at maximum 1.  
The numbers in the tiles in the following figure represent the amount of images, while the color shows the percentage of misclassified images per labeled class.  
This clearly shows that our model has more trouble classifying 3 storks compared to the other classes. As suspected, the class imbalance of the labeled images could possibly be the cause.

<img src="figures/misclassification.png" alt="misclassification" style="width: 500px;"/>

## Prediction
This section demonstrates how to make predictions using the trained model. Here it's applied on the test data.

In [None]:
predictions = model.predict(test_images)

# decode one-hot encoded labels
predictions_decoded = tf.argmax(predictions, axis=1)

In [None]:
# plot all misclassified image of a test image batch
plt.figure(figsize=(20,20))
n_misclassified = 0

for i in range(64):
    prediction = class_names[predictions_decoded[i]]
    label = class_names[test_labels_decoded[i]]

    if prediction != label:
        n_misclassified +=1
        plt.subplot(5,5,n_misclassified)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(test_images[i], cmap=plt.cm.binary)
        plt.xlabel("prediction: %s \n label: %s" % (prediction, label))
    else:
        continue

plt.show()

## Conclusion

Despite using a simple CNN, the task of classifying the number of storks present can be done relatively well. The trend of accuracy and loss over the epochs do not show a strong sign of overfitting.  
The performance however with an accuracy of around 90% could (and probably should) be further improved.

## Outlook
In future steps, the issue of underrepresented class 3 (and no classes for larger amounts of storks) should be tackled. Labeling additional images of those classes would be a next task.  
To additionaly improve the performance, transfer learning from a more sophisticated network (e.g. [MobileNet V2](https://arxiv.org/pdf/1801.04381.pdf)) could be applied. 