In [None]:
import pathlib
import pickle

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.utils import resample
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB0
from keras_preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler, ModelCheckpoint, TensorBoard
from keras.engine import base_layer

from augmentation import RandomColorDistortion
from callbacks import scheduler, TimeStopping

In [None]:
plt.rcParams["figure.dpi"] = 200

In [None]:
DATASET_PATH = "../../dataset/"

In [None]:
IMAGE_SIZE = (128, 128)
COLOR_MODE = "rgb"
BATCH_SIZE = 32

In [None]:
data = []
for file in pathlib.Path(DATASET_PATH).glob("*/*"):
    data.append({"filename": file.resolve().as_posix(), "label": file.resolve().parent.name})

In [None]:
df = pd.DataFrame(data)
df.head()

In [None]:
df[["dataset", "time"]] = df["filename"].str.split("/").str[-1].str.extract(r"^([^_]+)_[^_]+_([^_]+)")
df["time"].unique()

In [None]:
df.head()

In [None]:
df.groupby("label")["filename"].nunique()

In [None]:
df.groupby(["label", "time"]).agg(count = ("filename", "count"))

In [None]:
#df.query("label != 'horse'", inplace = True)
df.loc[df["label"].isin(["chicken", "horse"]), "label"] = "None_of_the_above"

In [None]:
df.groupby("label")["filename"].nunique()

In [None]:
df_classes = df.query("label != 'None_of_the_above'")

In [None]:
MIN_COUNT = df_classes.groupby("label")["filename"].nunique().min()  # 3000
MIN_COUNT

In [None]:
def balance_dataset(x):
    return resample(x, replace = False, n_samples = min(MIN_COUNT, len(x)), random_state = 42, stratify = x[["dataset", "time"]])

df_balanced = df_classes.groupby("label", group_keys = False).apply(balance_dataset)
df_balanced = df_balanced.sample(frac = 1, random_state = 42)
df_balanced.reset_index(drop = True, inplace = True)
df_balanced.head()

In [None]:
df_balanced = pd.concat([df_balanced, df.query("label == 'None_of_the_above'")], axis = 0, ignore_index = True)

In [None]:
df_balanced.groupby("label")["filename"].nunique()

In [None]:
df_balanced.groupby(["label", "time", "dataset"]).count()

In [None]:
df_train, df_valid = train_test_split(df_balanced, 
                                      test_size = 0.2, 
                                      shuffle = True, 
                                      random_state = 42, 
                                      stratify = df_balanced[["label", "dataset", "time"]])

In [None]:
df_valid.groupby(["label", "time", "dataset"]).count()

In [None]:
df_train.groupby("label")["filename"].nunique()

In [None]:
df_valid.groupby("label")["filename"].nunique()

In [None]:
datagen = ImageDataGenerator()
valid_datagen = ImageDataGenerator()

train_generator = datagen.flow_from_dataframe(dataframe = df_train,
                                              directory = None, 
                                              x_col = "filename",
                                              y_col = "label",
                                              batch_size = BATCH_SIZE,
                                              seed = 42,
                                              shuffle = True,
                                              class_mode = "categorical",
                                              target_size = IMAGE_SIZE)

valid_generator = valid_datagen.flow_from_dataframe(dataframe = df_valid,
                                                    directory = None, 
                                                    x_col = "filename",
                                                    y_col = "label",
                                                    batch_size = BATCH_SIZE,
                                                    seed = 42,
                                                    shuffle = True,
                                                    class_mode = "categorical",
                                                    target_size = IMAGE_SIZE)

In [None]:
# Define labels ordered according to dataset storage.
LABELS = list(train_generator.class_indices.keys())
NUM_CLASSES = len(LABELS)

In [None]:
plt.figure(figsize = (10, 10))
for i in range(6):
    batch_index = 0
    images, labels = next(train_generator)
    preprocessed_image = images
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(preprocessed_image[batch_index].astype("uint8"))
    plt.title(np.array(LABELS)[labels[batch_index] == 1][0])
    plt.tight_layout()
    plt.axis("off")

In [None]:
# Data augmentation layers.
data_augmentation = keras.Sequential([
                                      #layers.RandomRotation(factor = (-0.1, 0.1), fill_mode = "wrap"),
                                      #layers.RandomTranslation(height_factor = (-0.1, 0.1), width_factor = (-0.1, 0.1), fill_mode = "wrap"),
                                      layers.RandomFlip(),
                                      #RandomColorDistortion(brightness_max_delta = 0.2, 
                                      #                      saturation_delta = (0.5, 0.9),
                                      #                      hue_max_delta = 0.2, 
                                      #                      contrast_delta = (0.5, 0.9)),
                                     ], name = "data_augmentation")

In [None]:
plt.figure(figsize = (10, 10))
for i in range(6):
    batch_index = 0
    images, labels = next(train_generator)
    preprocessed_image = images
    augmented_images = data_augmentation(preprocessed_image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_images[batch_index].numpy().astype("uint8"))
    plt.title(np.array(LABELS)[labels[batch_index] == 1][0])
    plt.tight_layout()
    plt.axis("off")

In [None]:
def scheduler(epoch, lr):
    """Learning scheduler."""
    if epoch <= 200:
        return 0.01
    elif epoch > 200 and epoch <= 300:
        return 0.001
    else:
        return 0.0001

In [None]:
def build_model(num_classes):
    inputs = layers.Input(shape = (IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
    
    inputs_augmented = data_augmentation(inputs)
    
    model = EfficientNetB0(include_top = False, input_tensor = inputs_augmented, weights = "imagenet")
    
    # Freeze the pretrained weights.
    model.trainable = False

    # Unfreeze the top layers while leaving BatchNorm layers frozen.
    for layer in model.layers[-3:]:
        if not isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = True

    # Rebuild top.
    x = layers.GlobalAveragePooling2D(name = "avg_pool")(model.output)
    x = layers.BatchNormalization()(x)

    top_dropout_rate = 0.5
    x = layers.Dropout(top_dropout_rate, name = "top_dropout")(x)
    outputs = layers.Dense(NUM_CLASSES, activation = "softmax", name = "prediction")(x)

    # Compile.
    model = tf.keras.Model(inputs, outputs, name = "EfficientNet")
    optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-2)
    model.compile(optimizer = optimizer, loss = "categorical_crossentropy", metrics = ["accuracy"])
    
    return model

In [None]:
model = build_model(NUM_CLASSES)
model.summary()

In [None]:
epochs = 300

lr_scheduler = LearningRateScheduler(scheduler)
cp_callback = ModelCheckpoint(filepath = "./checkpoints/weights.h5",
                              save_weights_only = True,
                              save_best_only = True,
                              verbose = 1)
tb_callback = TensorBoard("./logs")
history = model.fit(train_generator, 
                    validation_data = valid_generator, 
                    epochs = epochs, 
                    validation_freq = 5,
                    callbacks = [lr_scheduler, tb_callback])

In [None]:
plt.plot(history.history["accuracy"], label = "accuracy")
plt.plot(history.history["val_accuracy"], label = "val_accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0, 1])
plt.legend(loc = "best")

In [None]:
model.save_weights("weights.h5")

In [None]:
with open("labels", "wb") as fp:   
    pickle.dump(LABELS, fp)

In [None]:
with open("shape", "wb") as fp:   
    pickle.dump(IMAGE_SIZE, fp)