In [None]:
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
from keras import layers
from keras.optimizer_v2.adam import Adam

import keras_transfer_learning as ktl

In [None]:
(x_train, y_train), (x_test, y_test), key = ktl.load_dataset('fmnist')

In [None]:
top_layers = [
    layers.GlobalAveragePooling2D(),
    layers.Dense(10, activation='softmax')
]
model, base_model = ktl.build_model(
    model_name='resnet50', 
    input_shape=(32, 32), 
    top_layers=top_layers
)
model.compile(
    optimizer=Adam(learning_rate=1e-3), 
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)
model.summary()

In [None]:
initial_epochs = 10
fine_tune_epochs = 5

In [None]:
def augmentation_layer(x: tf.Tensor) -> tf.Tensor:
    x = tf.cast(x, dtype=tf.float32)
    x = tf.image.random_flip_left_right(x)

    return x

train_dataset, test_dataset = ktl.get_data_generators(x_train, y_train, x_test, y_test, augmentation_layer)

In [None]:
initial_history = model.fit(
    train_dataset,
    epochs=initial_epochs,
    validation_data=test_dataset,
)

In [None]:
ktl.plot_history(initial_history)


In [None]:
labels = [val for val in key.values()]
ktl.plot_confusion_matrix(model, x_train, y_train, labels)

In [None]:
ktl.unfreeze_model(model, base_model, 20)
model.compile(
    optimizer=Adam(learning_rate=1e-5), 
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)
model.summary()

In [None]:
total_epochs =  initial_epochs + fine_tune_epochs
ft_history = model.fit(
    train_dataset,
    epochs=total_epochs,
    initial_epoch=initial_history.epoch[-1]+1,
    validation_data=test_dataset
)

In [None]:
ktl.plot_history(ft_history)

In [None]:
labels = [val for val in key.values()]
ktl.plot_confusion_matrix(model, x_train, y_train, labels)

In [None]:
from pathlib import Path
dir = Path('./models/').resolve()
if not dir.exists():
    dir.mkdir()
model.save(dir / f'model_{total_epochs}epochs')