In [1]:
from tensorflow import keras
from tensorflow.keras.layers import *
from prepare_data import *
from datetime import datetime
import albumentations as A

In [2]:
dataset = Dataset.eighty
dataset_name = str(dataset).split(".")[1]
current_time = datetime.now().strftime("%d%m%Y-%H%M%S")

In [3]:
# prepare the train and val datasets to be used for training
raw_train, raw_val = prep_dataset(dataset, 8)
train_gen, val_gen = prep_multi_img_rnn_dataset(dataset, raw_train, raw_val, 8, 3)

Found 27831 files belonging to 80 classes.
Using 23657 files for training.
Found 27831 files belonging to 80 classes.
Using 4174 files for validation.


In [4]:
# load the saved inception extractor model and its weights
extractor_model_path = f"../model-saves/extractors/{dataset_name}/{dataset_name.upper()}-EXTRACTOR/extractor/savefile.hdf5"
feature_extractor = keras.models.load_model(extractor_model_path)
inception_model = feature_extractor.layers[0].layers[-1]
inception_model.trainable = False

# build and compile the classifier model
classifier_model = keras.Sequential([
    InputLayer(input_shape=(None, 299, 299, 3)),
    TimeDistributed(inception_model),
    Bidirectional(LSTM(1000)),
    Dropout(0.5),
    Dense(1000, activation='relu'),
    Dense(train_gen.num_classes(), activation='softmax')
])
classifier_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.00001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [5]:
# log callback: saves info during training
logdir = "../logs/unfiltered/multi_img_rnn_{0}_{1}/classifier".format(str(dataset), current_time)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

# model callback: saves improved models during training
model_path = "../model-saves/unfiltered/multi_img_rnn_{0}_{1}/classifier/savefile.hdf5".format(str(dataset), current_time)
model_save_callback = keras.callbacks.ModelCheckpoint(filepath=model_path, save_best_only=True, monitor='val_accuracy', mode='max', verbose=1)

In [6]:
# callback to shuffle the dataset after each epoch
class ShuffleCallback(keras.callbacks.Callback):
    def __init__(self, generator):
        self._generator = generator
        
    def on_epoch_end(self, epoch, logs=None):
        self._generator.shuffle()
    
train_shuffle_callback = ShuffleCallback(train_gen)
val_shuffle_callback = ShuffleCallback(val_gen)

In [7]:
classifier_model.fit(train_gen, validation_data=val_gen, callbacks=[tensorboard_callback, model_save_callback, train_shuffle_callback, val_shuffle_callback], epochs=20)

Epoch 1/20
Epoch 1: val_accuracy improved from -inf to 0.88742, saving model to ../model-saves/unfiltered/multi_img_rnn_Dataset.eighty_17052022-180758/classifier\savefile.hdf5
Epoch 2/20
Epoch 2: val_accuracy improved from 0.88742 to 0.93966, saving model to ../model-saves/unfiltered/multi_img_rnn_Dataset.eighty_17052022-180758/classifier\savefile.hdf5
Epoch 3/20
Epoch 3: val_accuracy improved from 0.93966 to 0.94996, saving model to ../model-saves/unfiltered/multi_img_rnn_Dataset.eighty_17052022-180758/classifier\savefile.hdf5
Epoch 4/20
Epoch 4: val_accuracy improved from 0.94996 to 0.95732, saving model to ../model-saves/unfiltered/multi_img_rnn_Dataset.eighty_17052022-180758/classifier\savefile.hdf5
Epoch 5/20
Epoch 5: val_accuracy improved from 0.95732 to 0.96026, saving model to ../model-saves/unfiltered/multi_img_rnn_Dataset.eighty_17052022-180758/classifier\savefile.hdf5
Epoch 6/20
Epoch 6: val_accuracy did not improve from 0.96026
Epoch 7/20
Epoch 7: val_accuracy did not impro

<keras.callbacks.History at 0x2d5b048d420>