In [None]:
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
import keras
from keras.models import Model, Sequential
import tensorflow.keras.applications as m
from keras import layers
from keras.datasets.cifar10 import load_data
from keras.utils.layer_utils import count_params
from keras.optimizer_v2.adam import Adam
from tensorflow.keras.utils import to_categorical

import numpy as np
import matplotlib.pyplot as plt
from typing import Any, Callable, Tuple

import keras_transfer_learning as ktl

In [None]:
(x_train, y_train), (x_test, y_test), key = ktl.load_dataset('fmnist')

In [None]:
top_layers = [
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.2),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(10, activation='softmax')
]
model, base_model = ktl.build_model(
    model_name='resnet50', 
    input_shape=(32, 32), 
    top_layers=top_layers, 
    lr=1e-3
)
model.summary()

In [None]:
initial_epochs = 2
fine_tune_epochs = 1

In [None]:
def augmentation_layer(x: tf.Tensor) -> tf.Tensor:
    x = tf.cast(x, dtype=tf.float32)
    # x = tf.image.random_brightness(x, 0.1)
    # x = tf.image.random_contrast(x, 0, 0.1)
    x = tf.image.random_flip_left_right(x)

    return x

train_dataset, test_dataset = ktl.get_data_generators(x_train, y_train, x_test, y_test, augmentation_layer)


In [None]:
history = model.fit(
    train_dataset,
    epochs=initial_epochs,
    validation_data=test_dataset,
)

In [None]:
ktl.plot_history(history)

In [None]:
labels = [val for val in key.values()]
ktl.plot_confusion_matrix(model, x_train, y_train, labels)

In [None]:
ktl.unfreeze_model(model, base_model, 20, 1e-5)
model.compile(
    optimizer=Adam(learning_rate=1e-5), 
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)
model.summary()

In [None]:
total_epochs =  initial_epochs + fine_tune_epochs
history_fine = model.fit(
    train_dataset,
    epochs=total_epochs,
    initial_epoch=history.epoch[-1],
    validation_data=test_dataset
)

In [None]:
ktl.plot_history(history)

In [None]:
labels = [val for val in key.values()]
ktl.plot_confusion_matrix(model, x_train, y_train, labels)