# Image Preprocessing

Almost always when working with an image dataset, you will have to do some preprocessing and clean up before you can use them to train a model. You will usually also have to apply some of the same steps to any input image for inference as well. In this notebook we will look at two techniques of achieving this and see how well they work for a custom collected dataset.

The example below is written given that you are scraping image data yourself. You are encourgage to scrape real data, eg. from wikicommons, instead of using preexisting dataset as those are likely to have been cleaned up before they were published, making some of our steps below pointless.

In [None]:
import cv2
from matplotlib import pyplot as plt
import os
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np

# Manual approach
First let's look at a manual approach to preprocessing our image data using OpenCV and Numpy.

## Load in data (one-off)

In [None]:
original = cv2.imread("../data/Benthall_Hall_A.jpg", cv2.IMREAD_COLOR)

rgb_img = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
plt.imshow(rgb_img)

## Resize the image

In [None]:
rgb_img = cv2.resize(rgb_img, (256, 256))
plt.imshow(rgb_img)

## Flip the image

In [None]:
flipped_img = cv2.flip(rgb_img, 1)
plt.imshow(flipped_img)

## Rotate the image

In [None]:
# https://www.pyimagesearch.com/2017/01/02/rotate-images-correctly-with-opencv-and-python/
# rotated_img = 

## Traverse a directory for all images

At this point we know how to process a single image, so what we can do now is to apply the same processing to all the images.

In [None]:
for root, dirs, files in os.walk("../data/wikicommons"):
  path = root.split(os.sep)

  for file in files:
    _, extension = os.path.splitext(file)

    if extension == ".jpg":
      fullpath = "/".join(path) + "/" + file
#       print(fullpath)

      current = cv2.imread(fullpath, cv2.IMREAD_COLOR)
      converted = cv2.cvtColor(current, cv2.COLOR_BGR2RGB)
      converted = cv2.resize(converted, (256, 256))
      plt.imshow(converted)

# Keras based image preprocessing

As most of the above operations are very common when dealing with image based ML. Keras has conveniently wrapped many of these functionalities into simple APIs that we can call. We shall do that below.

## Filter out invalid images

It is often the case, especially when your dataset is scrapped automatically from the internet, that some of the data need to be cleaned up. It could be invalid Unicode characters, malformed data structure, or in this case, corrupted images. We need to filter out these images as they couldn't be processed by our model.

There are a few options on how to do this. If you are following the image processing steps above, you may not need to do this as the data is already in numerical form. If you are using a larger dataset and using tf.dataset class as we will be doing below, we can either fix those corrupted images or just remove them from the dataset.

We will remove these below.

In [None]:
num_skipped = 0
for root, dirs, files in os.walk("../data/wikicommons"):
  path = root.split(os.sep)

  for file in files:
    _, extension = os.path.splitext(file)

    if extension == ".jpg":
      filepath = root + "/" + file

      try:
        fobj = open(filepath, "rb")
        is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
      finally:
        fobj.close()

      if not is_jfif:
        num_skipped += 1
        # Delete corrupted image
        os.remove(filepath)

print("Deleted %d images" % num_skipped)

Next we will load our images directly using Keras.

In [None]:
# Change this to match your data
images_path = "../data/wikicommons"
training_dataset = keras.preprocessing.image_dataset_from_directory(
  images_path,
  labels='inferred',
  label_mode='categorical',
  color_mode="rgb",
  batch_size=32,
  image_size=(150, 150),
  subset="training",
  validation_split=0.1,
  seed=1
)

test_dataset = keras.preprocessing.image_dataset_from_directory(
  images_path,
  labels='inferred',
  label_mode='categorical',
  color_mode="rgb",
  batch_size=32,
  image_size=(150, 150),
  subset="validation",
  validation_split=0.1,
  seed=1
)

## Defining our model

We can define our own model using Keras' layers API or we can use one of the prebuilt one.

In [None]:
model = keras.Sequential([
    keras.Input(shape=(150, 150, 3)),
#     keras.layers.RandomRotation(0.1),
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
    keras.layers.MaxPooling2D(pool_size=(4, 4)),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(2, activation="sigmoid")
])

model.summary()

# model = keras.applications.Xception(weights=None, input_shape=(256, 256, 3), classes=15)

## Compile the model

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

## Train the model

Training will probably take awhile. Keep an eye on the "accuracy" and "val_accuracy" metrics here and see how they evolve over time. What insight does it give you?

In [None]:
model.fit(
  training_dataset,
  validation_data=test_dataset,
  epochs=20
)

In [None]:
model.evaluate(test_dataset)

## Saving the model

After training our model, we can save them out to file so that we can load them in later to continue training with new data or we can load them into a different program and use it to classify images as we've trained them.

In [None]:
model.save("./model/")

We can load in a saved model using the `load_model` function provided and it will load the whole model structure as well as the trained weights as if we were continuing from when we last run `model.save()`.

In [None]:
model = keras.models.load_model("./model/")

model.summary()

We can also do checkpointing while we are training, so we can stop and resume training from the last checkpoint.

In [None]:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="./checkpoint/",
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True
)

# model.load_weights("./checkpoint/")

model.fit(
    training_dataset,
    validation_data=test_dataset,
    epochs=20,
    callbacks=[model_checkpoint_callback]
)

In [None]:
image = keras.preprocessing.image.load_img("../data/wikicommons/Forest/forest.16.jpg", target_size=(150, 150))
input_arr = keras.preprocessing.image.img_to_array(image)
input_arr = np.array([input_arr])  # Convert single image to a batch.
predictions = model.predict(input_arr)
predictions