In [None]:
from tensorflow import keras
from tensorflow.keras.layers import *
from prepare_data import *
from datetime import datetime

In [None]:
dataset = Dataset.eighty
dataset_name = str(dataset).split(".")[1]
current_time = datetime.now().strftime("%d%m%Y-%H%M%S")

In [None]:
raw_train, raw_val = prep_dataset(dataset, 8)
train_gen, val_gen = prep_ensemble_aug_dataset(dataset, raw_train, raw_val, 8)

In [None]:
standard_model_path = f"../model-saves/extractors/{dataset_name}/{dataset_name.upper()}-EXTRACTOR/extractor/savefile.hdf5"
standard_extractor = keras.models.load_model(standard_model_path).layers[0].layers[-1]
standard_extractor.trainable = False
for layer in standard_extractor.layers:
    layer._name += "_1"

gray_model_path = f"../model-saves/extractors/{dataset_name}/{dataset_name.upper()}-GRAY-EXTRACTOR/extractor/savefile.hdf5"
gray_extractor = keras.models.load_model(gray_model_path).layers[0].layers[-1]
gray_extractor.trainable = False
for layer in gray_extractor.layers:
    layer._name += "_2"

blur_model_path = f"../model-saves/extractors/{dataset_name}/{dataset_name.upper()}-BLUR-EXTRACTOR/extractor/savefile.hdf5"
blur_extractor = keras.models.load_model(blur_model_path).layers[0].layers[-1]
blur_extractor.trainable = False
for layer in blur_extractor.layers:
    layer._name += "_3"

In [None]:
concat_layer = concatenate([standard_extractor.output, gray_extractor.output, blur_extractor.output])
reshape_layer = Reshape((3, 1000), input_shape=(3000,))(concat_layer)
rnn_layer = Bidirectional(GRU(1000))(reshape_layer)
dropout_layer = Dropout(0.5)(rnn_layer)
dense_layer = Dense(1000, activation='relu')(dropout_layer)
softmax = Dense(train_gen.num_classes(), activation='softmax')(dense_layer)

In [None]:
classifier_model = keras.Model(inputs=[standard_extractor.input, gray_extractor.input, blur_extractor.input], outputs=softmax)
classifier_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.00001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [None]:
logdir = "../logs/unfiltered/ensemble_aug_{0}_{1}/classifier".format(str(dataset), current_time)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

model_path = "../model-saves/unfiltered/ensemble_aug_{0}_{1}/classifier/savefile.hdf5".format(str(dataset), current_time)
model_save_callback = keras.callbacks.ModelCheckpoint(filepath=model_path, save_best_only=True, monitor='val_accuracy', mode='max', verbose=1)

In [None]:
class ShuffleCallback(keras.callbacks.Callback):
    def __init__(self, generator):
        self._generator = generator
        
    def on_epoch_end(self, epoch, logs=None):
        self._generator.shuffle()
    
train_shuffle_callback = ShuffleCallback(train_gen)
val_shuffle_callback = ShuffleCallback(val_gen)

In [None]:
classifier_model.fit(train_gen, validation_data=val_gen, callbacks=[tensorboard_callback, model_save_callback, train_shuffle_callback, val_shuffle_callback], epochs=10)