In [None]:
import tensorflow as tf

## Data prep

### Imagenette

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    #rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(
        'D:/datasets/imagenette2/train',
        target_size=(224, 224),
        batch_size=32)

In [None]:
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input) #rescale=1./255)

test_generator = train_datagen.flow_from_directory(
        'D:/datasets/imagenette2/val',
        target_size=(224, 224),
        batch_size=32)

### Altermative CIFAR-100

In [None]:
(x_train, y_train), (x_test, y_test)= tf.keras.datasets.cifar100.load_data(label_mode="fine")

In [None]:
import Common

def cifar100_helper_generator(x, y):
    while True:
        for i in range(0, len(x), 25):
            yield (
                Common.resize_video(x[i:i+25], (224, 224))/255. ,
                tf.keras.utils.to_categorical(y[i:i+25], 100)
            )

## Vanilla ResNet50v2

In [None]:
resnet50 = tf.keras.applications.ResNet50V2(classes=10, weights=None)
#resnet50 = tf.keras.applications.ResNet50V2(classes=100, weights=None)

In [None]:
resnet50.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [None]:
resnet50.fit(train_generator, epochs=5)
#rsnet50.fit(cifar100_helper_generator(x_train, y_train), steps_per_epoch=2000, epochs=50)

In [None]:
resnet50.evaluate(test_generator)
#resnet50.evaluate(cifar100_helper_generator(x_test, y_test), steps=400)

## Attention

### Training

In [None]:
import AttentionModels

In [None]:
fit_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=(224, 224, 3), num_classes=10)
#fit_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=(224, 224, 3), num_classes=100)

In [None]:
fit_model.fit(train_generator, epochs=50)
#fit_model.fit(cifar100_helper_generator(x_train, y_train), steps_per_epoch=2000, epochs=1)

In [None]:
fit_model.evaluate(test_generator)
#fit_model.evaluate(cifar100_helper_generator(x_test, y_test), steps=400)

In [None]:
fit_model.save("models/L2PA_resnet50v2")

### Attention

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm

extractor = AttentionModels.get_attention_extractor(fit_model)

In [None]:
X, Y = next(test_generator)
#X, Y = next(cifar100_helper_generator(x_test, y_test))
prediction, a1, a2, a3 = extractor.predict(X)

In [None]:
i = 19
fig, ax = plt.subplots(1, 4, figsize=(20, 20))
ax[0].imshow(X[i]/2 +0.5)
ax[1].imshow(a1[i], interpolation='nearest', cmap=cm.inferno)
ax[2].imshow(a2[i], interpolation='nearest', cmap=cm.inferno)
ax[3].imshow(a3[i], interpolation='nearest', cmap=cm.inferno)

## Attention Gated

### Training

In [None]:
import AttentionModels

In [None]:
fit_model = AttentionModels.create_AttentionGated_ResNet50v2(input_shape=(224, 224, 3), num_classes=10)

In [None]:
fit_model.fit(train_generator, epochs=50)

In [None]:
fit_model.evaluate(test_generator)

In [None]:
fit_model.save("models/AttentionGated_ResNet50v2")

### Attention

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm

extractor = AttentionModels.get_attention_extractor(fit_model)

In [None]:
X, Y = next(test_generator)
#X, Y = next(cifar100_helper_generator(x_test, y_test))
prediction, a1, a2, a3 = extractor.predict(X)

In [None]:
i = 19
fig, ax = plt.subplots(1, 4, figsize=(20, 20))
ax[0].imshow(X[i]/2 +0.5)
ax[1].imshow(a1[i], interpolation='nearest', cmap=cm.inferno)
ax[2].imshow(a2[i], interpolation='nearest', cmap=cm.inferno)
ax[3].imshow(a3[i], interpolation='nearest', cmap=cm.inferno)