1. # Deep learning with python project: plant pathology recognition
- David Coba (student number: 12439665)
- Enrico Erler (student number: 13287214)

https://www.kaggle.com/dcobac/dl-project

# 1. Setup

In [None]:
import os
import time
import numpy as np
import pandas as pd
from IPython.display import Image
import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib
import pickle
print("All libs loaded")

# 2. Introduction

This project is based on the Kaggle competition [Plant Pathology 2021 - FGVC8](https://www.kaggle.com/c/plant-pathology-2021-fgvc8).
The goal of the competition is to classify leaves of apple trees. Some of them are healthy leaves and other show evidence of having diseases.
Developing a classifier is useful in this situation because it can reduce the number of trees that have to be checked by an inspector, potentially saving a considerable amount of hours. 

Originally images can be labeled as either:
  - Healthy
  - Having apple `scab`
  - Having apple `rust`
  - Having `frog_eye_leaf_spot`
  - Having a `complex` disease, which can be
    - Just the `complex`label
    - The `complex`label & other disease labels
    
At first we attempted to tackle the problem as it is presented: with a multi-label classifier in which classes can occur simultaneously with sigmoid activation functions mapping activations to probabilities.
However, after unsuccessful attempts we decided to simplify the problem to just the categories present in the [2020 edition of the competition](https://www.kaggle.com/c/plant-pathology-2020-fgvc7), and treat the `complex` label as a single category. It becomes a multi-label problem with 4 mutually exclusive categories. The authors of the competition report that they achieved 97% classification accuracy with ""an off the shelf" `resnet50` classifier [(Thapa, Zhang, Snavely, Belongie & Khan, 2020)](https://bsapubs.onlinelibrary.wiley.com/doi/10.1002/aps3.11390), and this is the benchmark we are trying to match.

- Example healthy image

In [None]:
Image(filename="../input/plant-pathology-2021-fgvc8/train_images/b7bcca8ce84f5046.jpg", width=602, height=476) 

- Example image with a visible disease (in this case `scab`)

In [None]:
Image(filename="../input/plant-pathology-2021-fgvc8/train_images/803e3bd17a16e65c.jpg", width=602, height=476) 

# 3. Data loading

In [None]:
data_df = pd.read_csv("../input/plant-pathology-2021-fgvc8/train.csv")
print("Data shape: ",data_df.shape)
data_df.head()

In [None]:
data_df["labels"] = data_df.labels.apply(lambda x: x.split())
data_df.head()

- Simplify data from 6 categories (including multi-label `complex` category) to 4 mutually exclusive categories to match 2020 competition outcomes:
  - `healthy`
  - `rust`
  - `scab`
  - `complex`

In [None]:
def simplify_outcome(y):
    """
    Function to apply to single dataframe cell. 

    Converts entries to ["complex"], if:
        1) it"s a multi-label outcome and,
        2) it contains one of the original 2020 plant pathology competition outcomes.

    Input: single cell entry

    Output: original outcome (if conditions not met) or ["complex"]
    """

    if (len(y) > 1) & any(x in ["healthy","rust","scab"] for x in y):
        return ["complex"]
    else:
        return y
    
# apply function per row in column "labels"
data_df.labels = data_df.labels.apply(simplify_outcome)
# only select original 2020 plant pathology competition outcomes
simple_labels = ["healthy","rust","scab","complex"]
only_simple_labels = data_df.labels.apply(lambda x:
    any(item for item in simple_labels if (item in x) and (len(x) == 1))
)

data_df = data_df[only_simple_labels]

# Unlist labels: [a] -> a
data_df.loc[:, "labels"] = data_df.loc[:, "labels"].apply(lambda x: x[0])

print("Outcomes:\n",data_df.labels.value_counts())
print("\nData shape: ", data_df.shape)
data_df.head()

- Split data into testing, validation and testing sets.
- Data was split so that 10% of the data was reserved to testing, ~20% to validation, and ~70% to training 
- Images were loaded with Tensorflow's `DataGenerator` and `flow_from_dataframe` pipeline.
- The images were resized to 256 by 256 pixels as bigger images couldn't be stored in the available VRAM capacity (11Gb) without reducing the batch size too much


In [None]:
train_df, test_df = train_test_split(
    data_df, #.iloc[1:100, ] # change to only load a subset of the data
    test_size = 0.1,
    random_state = 123 # set seed for reproducibility
)

In [None]:
img_size = 256

In [None]:
data_generator = ImageDataGenerator(
    rescale=1./255,
    validation_split = 0.22 # results in: ~70% train, ~20% valid & 10% test data
)

test_generator = ImageDataGenerator(rescale=1./255)

In [None]:
train_loader = data_generator.flow_from_dataframe(
    seed = 123,
    dataframe = train_df,
    directory = "../input/plant-pathology-2021-fgvc8/train_images/",
    subset="training",
    x_col = "image",
    y_col = "labels",
    target_size = (img_size, img_size), 
    class_mode = "categorical",
    batch_size = 32,
    shuffle = True
)

valid_loader = data_generator.flow_from_dataframe(
    seed = 123,
    dataframe = train_df,
    directory= "../input/plant-pathology-2021-fgvc8/train_images/",
    subset="validation",
    x_col = "image",
    y_col = "labels",
    target_size = (img_size, img_size), 
    class_mode = "categorical",
    batch_size = 32,
    shuffle = True
)

test_loader = test_generator.flow_from_dataframe(
    seed = 123,
    dataframe = test_df,
    directory= "../input/plant-pathology-2021-fgvc8/train_images/",
    x_col = "image",
    y_col = "labels",
    target_size = (img_size, img_size), 
    class_mode = "categorical",
    batch_size = 32,
    shuffle = False # Makes easier to assess test predictions
)

print("Finished loading data.")

# 4. Resnet & feature extraction

We have used `resnet50` to extract features from the pictures. The body of the network outputs 2048 8x8 filters, which made training not viable with the hardware we have access to. We considered doing a `MaxPool` pass reducing the filter size to 4x4 before flattening the extracted features, or using global average pooling. We settled on using an extra `MaxPool` pass with a 2x2 kernel, valid padding and a stride of 2, resulting in 2048 feature maps of size 4x4. 

- Select whether to load the `resnet50` body or not. It can be trainable or not.
  - Unless we are training a `resnet50` model or extracting its features this is not necessary.

In [None]:
LOAD_RESNET = False
TRAINABLE = False

In [None]:
if LOAD_RESNET:
    resnet_base = keras.applications.resnet.ResNet50(
        include_top = False,
        weights = "imagenet",
        input_shape = (img_size, img_size, 3)
    )
    
    resnet_base.trainable = TRAINABLE

    model_resnet_body = keras.Sequential([
        resnet_base,
        keras.layers.MaxPool2D(
            pool_size=(2, 2), 
            strides=2, 
            padding="valid"
        ), 
        layers.Flatten()
        # layers.GlobalAveragePooling2D()
    ])

    model_resnet_body.compile()
    print("resnet loaded")
    

Training models with this data set takes a long time, mainly because of bottlenecks in the IO streams loading pictures into memory.
If we are not interesting in fine-tuning the `resnet50` weights, we can pre-extract all features from the images before training the classifier to avoid having to evaluate the `resnet50` body at every iteration.
However, we ultimately abandoned this idea, since it made augmenting the dataset significantly more difficult and we wanted to fine-tune the model weights.

- It works, but without data augmentation.

In [None]:
EXPORT_RESNET_FEATURES = False
IMPORT_RESNET_FEATURES = False

In [None]:
if EXPORT_RESNET_FEATURES:
    def serialize_features(resnet, loader, filename):
        df = pd.DataFrame({"image": loader.filenames, "labels": loader.classes})
        with open(f"../input/resnet_features/{filename}_label.h5", "wb") as file:
            pickle.dump(df, file);
        
        time_0 = time.time()
        features = resnet.predict(loader)
        features = tf.convert_to_tensor(features)        
        with open(f"../input/resnet_features/{filename}", "wb") as file:
            pickle.dump(features, file);
        time_1 = time.time()

        ellapsed = round(time_1 - time_0, ndigits = 1)
        print(f"{filename} features extracted in {ellapsed}s.")
        return df, features

    labels_train, features_train = serialize_features(model_resnet_body, train_loader, "train.h5")
    labels_valid, features_valid = serialize_features(model_resnet_body, valid_loader, "valid.h5")
    labels_test, features_test = serialize_features(model_resnet_body, test_loader, "test.h5")

if IMPORT_RESNET_FEATURES: # load pre-saved resnet features
    print("Loading pre-extracted features")

    with open("../input/resnet_features/train.h5", "rb") as file:
        features_train = pickle.load(file)
    with open("../input/resnet_features/train_label.h5", "rb") as file:
        labels_train = pickle.load(file)
    print("Training features loaded as a tensor")

    with open("../input/resnet_features/valid.h5", "rb") as file:
        features_valid = pickle.load(file)
    with open("../input/resnet_features/valid_label.h5", "rb") as file:
        labels_valid = pickle.load(file)
    print("Validation features loaded as a tensor")

    with open("../input/resnet_features/test.h5", "rb") as file:
        features_test = pickle.load(file)
    with open("../input/resnet_features/test_label.h5", "rb") as file:
        labels_test = pickle.load(file)
    print("Testing features loaded as a tensor")

- Sanity check

In [None]:
if EXPORT_RESNET_FEATURES or IMPORT_RESNET_FEATURES:
    if train_loader.n != features_train.shape[0] or \
        valid_loader.n != features_valid.shape[0] or \
        test_loader.n != features_test.shape[0] :
        print(f"Training data: {len(train_loader.filenames)} in the data loader vs {features_train.shape[0]} in the data array.")
        print(f"Validating data: {len(valid_loader.filenames)} in the data loader vs {features_valid.shape[0]} in the data array.")
        print(f"Testing data: {len(test_loader.filenames)} in the data loader vs {features_test.shape[0]} in the data array.")
        raise Exception("The dimensions of the loaded features and the data set does not match.")
    else:
        print("All good")

- Encode labels as one-hot tensors.

In [None]:
if EXPORT_RESNET_FEATURES or IMPORT_RESNET_FEATURES:
    target_train = tf.one_hot(labels_train.labels, len(train_loader.class_indices))
    target_valid = tf.one_hot(labels_valid.labels, len(valid_loader.class_indices))
    target_test  = tf.one_hot(labels_test.labels, len(test_loader.class_indices))

# 5. Model training

- All models we have trained have similar architectures:
  - A set of data augmentation layers
  - A CNN body
    - resnet50 or 
    - resnet50 with trainable weights or
    - a custom CNN defined below
  - A regularized classifier with two hidden layers of 64 & 32 neurons

- Data augmentation for training

In [None]:
data_augmentation_layers = keras.Sequential([
        layers.RandomFlip("horizontal_and_vertical"),
        layers.RandomContrast(0.5),
        layers.RandomRotation(
                factor = 0.3,
                fill_mode = "nearest",
                interpolation = "nearest"
                ),
        layers.RandomTranslation(
                height_factor = 0.2, 
                width_factor = 0.2,
                fill_mode = "nearest",
                interpolation = "nearest"
                ),
])

- Custom CNN 

In [None]:
default_conv = layers.Conv2D(
    10, 
    kernel_size = 3, 
    strides = 1, 
    padding = "same",
    kernel_initializer = keras.initializers.GlorotNormal()
)

custom_cnn = keras.Sequential([
    # default_conv + input_shape
    layers.Conv2D(
        10, 
        kernel_size = 3, 
        strides = 1, 
        padding = "same",
        kernel_initializer = keras.initializers.GlorotNormal(),
        input_shape = (img_size, img_size, 3)
    ),
    default_conv,
    layers.MaxPool2D(),
    default_conv,
    default_conv,
    layers.MaxPool2D(),
    layers.Flatten()
])

In [None]:
def train_model(cnn, name):
    """
    Trains a model with a specific CNN body.
    """
    model = keras.Sequential([
        data_augmentation_layers,
        cnn,
        # Classifier:
        layers.BatchNormalization(),
        layers.Dropout(rate = 0.3),
        layers.Dense(64, activation = "relu"),
        layers.Dropout(rate = 0.3),
        layers.Dense(32, activation = "relu"),
        layers.Dropout(rate = 0.3),
        layers.Dense(4, activation = "softmax")
    ])

    model.compile(
        optimizer = "adam",
        loss = "categorical_crossentropy", 
        metrics = ["accuracy"]
    )
    
    all_callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor = "val_loss",
        min_delta = 0.005,
        patience=15, 
        restore_best_weights = True,
        ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath = f"../input/models/models/{name}/best_model.h5", # only saving best model
        verbose = 0,
        save_best_only = True,
        save_weights_only = False,
        mode = "auto",
        save_freq = "epoch"
        ),                                
    tf.keras.callbacks.TensorBoard(
        log_dir = f"../input/models/models/{name}/logs",
        histogram_freq = 1
        ),
    ]

    history = model.fit(train_loader,
        validation_data = valid_loader,
        epochs = 50,
        callbacks = all_callbacks
    )

    print(f"Saving {name} model ...")
    model.save(f"../input/models/models/{name}/model.h5"); 
    print("Model saved")
    print("Saving history...")
    history = history.history
    with open(f"../input/models/models/{name}/history.h5", "wb") as file:
        pickle.dump(history, file)
    print("History saved")
    return model, history

In [None]:
def load_model(name):
    """
    Loads a pre-trained model & history.
    """
    print(f"Loading {name} model")
    model = keras.models.load_model(f"../input/models/models/{name}/model.h5");
    with open(f"../input/models/models/{name}/history.h5", "rb") as file:
        history = pickle.load(file)
    print("Last model loaded")
    return model, history


- Select whether to train a model in this session, and which model.

In [None]:
TRAIN_MODEL = False

In [None]:
def get_resnet_name():
    if TRAINABLE:
        return "resnet_trainable"
    else: return "resnet_not_trainable"

- Select which model to train.

In [None]:
# CNN, CNN_NAME = model_resnet_body, get_resnet_name() # or
# CNN, CNN_NAME = custom_cnn, "custom_cnn"

- Train chosen model and load the last version of the others, or load all previous versions

In [None]:
if TRAIN_MODEL:
    if CNN_NAME == "custom_cnn": 
        model_custom_cnn, history_custom_cnn  = train_model(CNN, CNN_NAME)
        model_resnet_trainable, history_resnet_trainable = load_model("resnet_trainable")
        model_resnet_not_trainable, history_resnet_not_trainable = load_model("resnet_not_trainable")
    elif CNN_NAME == "resnet_trainable":
        model_resnet_trainable, history_resnet_trainable = train_model(CNN, CNN_NAME)
        model_custom_cnn, history_custom_cnn = load_model("custom_cnn")
        model_resnet_not_trainable, history_resnet_not_trainable = load_model("resnet_not_trainable")
    elif CNN_NAME == "resnet_not_trainable":
        model_resnet_not_trainable, history_resnet_not_trainable = train_model(CNN, CNN_NAME)
        model_custom_cnn, history_custom_cnn = load_model("custom_cnn")
        model_resnet_trainable, history_resnet_trainable = load_model("resnet_trainable")
else: # Just load all pre-trained models
    model_custom_cnn, history_custom_cnn = load_model("custom_cnn")
    model_resnet_trainable, history_resnet_trainable = load_model("resnet_trainable")
    model_resnet_not_trainable, history_resnet_not_trainable = load_model("resnet_not_trainable")
 

# 6. Model evaluation & discussion

## 6.1 Accuracy

In [None]:
test_accuracy = pd.Series(
    {"custom_cnn" : model_custom_cnn.evaluate(test_loader)[1],
    "resnet_trainable" : model_resnet_trainable.evaluate(test_loader)[1],
    "resnet_not_trainable" : model_resnet_not_trainable.evaluate(test_loader)[1]
    }
)

In [None]:
acc_plot = test_accuracy.plot(
    kind = "bar", 
    title = "Accuracy on the testing set",
    ylabel = 'Accuracy in %', 
    ylim = (0, 1)
)

In [None]:
# fig = acc_plot.get_figure()
# fig.savefig("../output/local/accuracies.svg", format = "svg", dpi = 1200)

- The custom CNN did not perform well at all under this training conditions.
- The `resnet50` with its original `imagenet` weights does not reach high accuracy either (around 50%).
- Fine tuning its weights over training increases accuracy to around 86% accuracy.

## 6.2 Learning slopes

However, training these models takes around 20-40 min per epoch, depending on which model and on which hardware they are trained.
This means that we cannot compare training times across models.
Because of the long training times, we have not been able to run models for longer than ~30 epochs and the performance of the models could potentially improve with more training.
However, we can check how the loss function & accuracy evolve over training to assess if more training would be beneficial to increase out-of-sample prediction accuracy.

In [None]:
pd.DataFrame(history_custom_cnn).loc[:, ["loss", "val_loss"]].plot(title = "Custom CNN")
pd.DataFrame(history_custom_cnn).loc[:, ["Accuracy", "val_Accuracy"]].plot(title = "Custom CNN")

In this case it is hard to tell whether further training would improve the performance of the model. Although the loss function and the accuracy metric have slopes that indicate that they would probably keep improving, it is hard to see if that would be the case for their counterparts on the validation set, since the values make big jumps after each iteration of training. Nevertheless, overfitting is clearly visible after only a few epochs.

In [None]:
pd.DataFrame(history_resnet_not_trainable).loc[:, ["loss", "val_loss"]].plot(title = "Non-trainable resnet50")
pd.DataFrame(history_resnet_not_trainable).loc[:, ["accuracy", "val_accuracy"]].plot(title = "Non-trainable resnet50")

The validation-set metrics from the model with `resnet50` with non-trainable weights seem to have flattened out, so it looks like the model would not benefit from additional training. However, this is not entirely clear and it might be possible that these are "local" flats and that after further training the performance of the model would start improving. Overfitting can also be seen to occur after around epoch 20, however.

In [None]:
pd.DataFrame(history_resnet_trainable).loc[:, ["loss", "val_loss"]].plot()
trainable_history = pd.DataFrame(history_resnet_trainable).loc[:, ["Accuracy", "val_Accuracy"]].plot()

In [None]:
# fig = trainable_history.get_figure()
# fig.savefig("../output/local/trainable_history.svg", format = "svg", dpi = 1200)

## 6.3 Category predictions

Next, we examine the actual predictions of each model.

In [None]:
test_filenames = pd.DataFrame({"image": test_loader.filenames})
test_labels = pd.DataFrame({"labels": test_loader.classes})
# Rename labels to original strings
test_labels = test_labels.replace({"labels": {v: k for (k, v) in test_loader.class_indices.items()}})

def get_model_predictions(model):
    predicted_probs = pd.DataFrame(model.predict(x = test_loader))
    # Rename columns to class labels
    predicted_probs.rename(columns = {v : k for (k, v) in
                                      test_loader.class_indices.items() },
                           inplace = True)
    y_pred = pd.DataFrame(predicted_probs.iloc[:, :].idxmax(axis = 1)).rename(columns = {0 : "y_pred"})
    predictions = pd.concat([
        test_filenames,
        test_labels,
        y_pred,
        predicted_probs
    ],
                            axis = 1)
    return predictions

In [None]:
predictions_custom_cnn = get_model_predictions(model_custom_cnn)
predictions_resnet_not_trainable = get_model_predictions(model_resnet_not_trainable)
predictions_resnet_trainable = get_model_predictions(model_resnet_trainable)

In [None]:
predictions_custom_cnn

We already showed that the model with a custom CNN did not preform well, and we can see that it is just predicting the `rust` label for every picture.

In [None]:
predictions_resnet_not_trainable

The model with `resnet50` with non-trainable weights does not predict the same label for every picture, but it is not performing very well either since it had ~0.5 accuracy.
However, we cannot detect any unusual pattern in the activations of the last layer either.

In [None]:
predictions_resnet_trainable

The model with `resnet50` with trainable weights predicts almost all pictures correctly.

We can see with more details which categories the `resnet50` models classify properly and which ones does not by taking a look at the confusion matrix.

In [None]:
def generate_cm(predictions):
    cm = confusion_matrix(predictions.labels, predictions.y_pred, labels = ["complex", "healthy", "rust", "scab"], normalize = None)
    disp = ConfusionMatrixDisplay(confusion_matrix = cm, display_labels = np.array(["complex","healthy","rust","scab"]))
    disp.plot(cmap = "Blues") 
    print("Accuracy from confusion matrix: ", round(np.diag(cm).sum()/cm.sum(), 4))
    return disp

In [None]:
generate_cm(predictions_resnet_not_trainable)

We see that the classifier with non-trainable weights never predicts the label `rust`, and that it overestimates the probability of a picture belonging to the `scab` category.

In [None]:
cm_plot = generate_cm(predictions_resnet_trainable)

In [None]:
# cm_plot.figure_.savefig("../output/local/cm_plot.svg", format = "svg", dpi = 1200)

As is evident from our best-performing model (the `resnet50` with trainable weights), it predicts around 86% of the pictures correctly and most of the values lie on the diagonal of the confusion matrix. Nevertheless, you also see that it misspecifies quite a few of the  `healthy` leaves as `scab`. We think this is likely due to the fact that scab often shows as light dots on the leave in its early stages. Depending on the lighting conditions and sun reflections, healthy leaves can also have light spots (due to the sun shining on them or through them partially). We think this is one of the reasons all models have a higher misclassification rate on `healthy` vs. `scab`.

Here you can see one example of a `healthy` leaf that was misclassified by the untrained resnet as `scab`:

In [None]:
Image(filename="../input/plant-pathology-2021-fgvc8/train_images/8bd27e8d6124a5b3.jpg", width=301, height=238) 

Due to shadows and lighting, the leaf appears to have bright spots. Using our (human) experience, you can clearly make out that those are shadows and sunrays hitting the leaf's surface. However, we think that the models could not make this generalization as they only have the leaf pictures and do not know about how the sun and shadows interact. All they see are bright spots on the leaf and they think it's `scab`, because a lot of pictures with bright spots are labelled `scab`. We think this is a prime example of a generalization problem of machine learning/ AI.

## 6.4 Mean F1-scores


The Kaggle competition actually uses mean F1-scores to assess the model accuracies. To make our results comparable to other teams of the competition leaderboard, we transform our accuracies into mean F1-scores.

The F1-score per category is calculated as follows:

$$
  F1 = \frac{TP}{TP+1/2(TP+FN)}
$$ 


Whereby $TP$ is the number of true positives (i.e., correctly identified pictures) and $FN$ the number of false negatives (i.e., wrongly identified pictures) per category.

A mean F1-score calculates as the mean of individual F1-scores for each category:

$$
  F1_{mean}  = \frac{1}{n}  \sum_{k=1}^n F1_{k}
$$ 

In [None]:
def mean_f1_score(predictions_table):
    '''
    Function to return the mean F1-score.
    
    Input: 
    - prediction table with columns named 'labels' for true values and 'y_pred' for predicted values

    Output:
    - mean F1-score
    '''
    
    cm = confusion_matrix(predictions_table.labels, predictions_table.y_pred, labels = ["complex", "healthy", "rust", "scab"], normalize = None)

    TPS = []
    FPS = []
    F = []

    for i in range(len(cm)):
       TPS.append(np.diagonal(cm)[i])
       FPS.append(cm[:,i].sum() - TPS[i])
    
    for r,value in enumerate(TPS):
      F.append(value / (value + 0.5*(value + FPS[r])))
      #print(F[r]) # debugging
    return np.mean(F)
    

In [None]:
mean_f1_score(predictions_resnet_trainable)

The mean F1-score is about 0.63 for our best model. We thereby score in the mid-field of the 2021 competition leaderboard. We think this a good result given the short time period we had for this project, the relatively small number of hidden layers and low number of epochs. This can certainly be improved, however, and we offer some suggestions for improvement in the conclusion. 

## 6.5 Conclusions

We have trained 3 different models: A custom CNN, an unaltered `resnet50` with a custom dense network, and a re-trained `resnet50`. It was not unsurprising to us that the re-trained `resnet50` model performed best. However, it did not reach the 97% accuracy reported by 
[Thapa et al. (2020)](https://bsapubs.onlinelibrary.wiley.com/doi/10.1002/aps3.11390) and it is unclear whether the performance of our model would improve with further training.

The biggest issue we have encountered is dealing with the IO bottleneck, which made testing new approaches and models very time consuming.
Finally, we present a list of possible things we have discussed implementing that could increase the predictive accuracy of our models, but we have not had the time to check.

- Using a subset of (the subset of) the data set. 
    - This will increase how much iterations we can train our models for, but also reduce the data available at each iteration. There is a possibly a sweet spot in this trade-off.
- Trying different classifiers, since we've very much sticked to a basic one throughout the project.
- Doing manual data augmentation to pre-extract features with `resnet50`.
    - This would allow to train the classifier in a fraction of the current time re-loading all images per iteration.
- Going back to using original labels. 
    - Maybe our simplification has increased the noise to signal ratio in our data.
  