In [None]:
%load_ext tensorboard

import tensorflow as tf
import Common
import datetime
import AttentionModels
from lime import lime_image
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries

## Data Prep

### Imagenette

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    #rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(
        'D:/datasets/imagenette2/train',
        target_size=(224, 224),
        batch_size=16)

In [None]:
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input) #rescale=1./255)

test_generator = train_datagen.flow_from_directory(
        'D:/datasets/imagenette2/val',
        target_size=(224, 224),
        batch_size=16)

### Altermative CIFAR-100

In [None]:
(x_train, y_train), (x_test, y_test)= tf.keras.datasets.cifar100.load_data(label_mode="fine")

In [None]:
def cifar100_helper_generator(x, y):
    while True:
        for i in range(0, len(x), 25):
            yield (
                Common.resize_video(x[i:i+25], (224, 224))/255. ,
                tf.keras.utils.to_categorical(y[i:i+25], 100)
            )

## Vanilla ResNet50v2

In [None]:
resnet50 = tf.keras.models.load_model("models/imagenette2/base_resnet50v2")
#resnet50 = tf.keras.applications.ResNet50V2(classes=10, weights=None)

resnet50.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [None]:
log_dir = "logs/fit/resnet50v2_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

resnet50.fit(train_generator, 
             epochs=100,
             validation_data=test_generator,
             callbacks=[tensorboard_callback])
#rsnet50.fit(cifar100_helper_generator(x_train, y_train), steps_per_epoch=2000, epochs=50)

In [None]:
resnet50.save("models/imagenette2/base_resnet50v2")

### lime

In [None]:
X, Y = next(test_generator)

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    resnet50.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## L2PA

### Training

In [None]:
L2PA_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=(224, 224, 3), num_classes=10)
#fit_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=(224, 224, 3), num_classes=100)

In [None]:
log_dir = "logs/fit/l2pa_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

L2PA_model.fit(train_generator, 
    epochs=100,
    validation_data=test_generator,
    callbacks=[tensorboard_callback])
#fit_model.fit(cifar100_helper_generator(x_train, y_train), steps_per_epoch=2000, epochs=1)

In [None]:
L2PA_model.save("models/imagenette2/L2PA_resnet50v2")

### Attention

In [None]:
L2PA_extractor = AttentionModels.get_attention_extractor(L2PA_model)

In [None]:
X, Y = next(test_generator)
#X, Y = next(cifar100_helper_generator(x_test, y_test))

In [None]:
prediction, a1, a2, a3 = L2PA_extractor.predict(X)

i = 1

overlay = overlay = Common.combine_attention([a1[i], a2[i], a3[i]])
combined_image = Common.overlay_attention(X[i], overlay)
Common.display_attention_maps(X[i], [combined_image, a1[i], a2[i], a3[i]])

## Attention Gated

### Training

In [None]:
gated_model = tf.keras.models.load_model("models/imagenette2/AttentionGated_ResNet50v2")
gated_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

#gated_model = AttentionModels.create_AttentionGated_ResNet50v2(input_shape=(224, 224, 3), num_classes=10)

In [None]:
log_dir = "logs/fit/gated_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

gated_model.fit(
    train_generator, 
    epochs=100,
    validation_data=test_generator,
    callbacks=[tensorboard_callback])

In [None]:
gated_model.save("models/imagenette2/AttentionGated_ResNet50v2")

### Attention

In [None]:
gated_extractor = AttentionModels.get_attention_extractor(gated_model)

In [None]:
X, Y = next(test_generator)
#X, Y = next(cifar100_helper_generator(x_test, y_test))

In [None]:
prediction, a1, a2, a3 = gated_extractor.predict(X)

i = 3

overlay = overlay = Common.combine_attention([a1[i], a2[i], a3[i]])
combined_image = Common.overlay_attention(X[i], overlay)
Common.display_attention_maps(X[i], [combined_image, a1[i], a2[i], a3[i]])

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    gated_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## Attention Gated with Grid Attention

### Training

In [None]:
gatedgrid_model = tf.keras.models.load_model("models/imagenette2/AttentionGatedGrid_ResNet50v2")
gatedgrid_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

#gatedgrid_model = AttentionModels.create_AttentionGatedGrid_ResNet50v2(input_shape=(224, 224, 3), num_classes=10)

In [None]:
log_dir = "logs/fit/gatedgrid_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

gatedgrid_model.fit(
    train_generator, 
    epochs=100,
    validation_data=test_generator,
    callbacks=[tensorboard_callback])

In [None]:
gatedgrid_model.save("models/imagenette2/AttentionGatedGrid_ResNet50v2")

### Attention

In [None]:
gatedgrid_extractor = AttentionModels.get_attention_extractor(gatedgrid_model)

In [None]:
X, Y = next(test_generator)
#X, Y = next(cifar100_helper_generator(x_test, y_test))

In [None]:
prediction, a1, a2, a3 = gatedgrid_extractor.predict(X)

i = 8

overlay = overlay = Common.combine_attention([a1[i], a2[i], a3[i]])
combined_image = Common.overlay_attention(X[i], overlay)
Common.display_attention_maps(X[i], [combined_image, a1[i], a2[i], a3[i]])

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    gatedgrid_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## Residual Attention Network

### Training

In [None]:
# residual_attention_model = tf.keras.models.load_model("models/imagenette2/ResidualAttentionNet50v2")

residual_attention_model = AttentionModels.create_ResidualAttention_ResNet50v2(input_shape=(224, 224, 3), num_classes=10)

residual_attention_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [None]:
log_dir = "logs/fit/residualattention_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

residual_attention_model.fit(train_generator, 
    epochs=1,
    validation_data=test_generator),
    callbacks=[tensorboard_callback])

In [None]:
residual_attention_model.save("models/imagenette2/ResidualAttentionNet50v2")

### Attention

In [None]:
residual_attention_extractor = AttentionModels.get_attention_extractor(residual_attention_model)

In [None]:
X, Y = next(test_generator)

In [None]:
pred, a1, a2, a3 = residual_attention_extractor(X)

i = 1

overlay = overlay = Common.combine_attention([a1[i], a2[i], a3[i]])
combined_image = Common.overlay_attention(X[i], overlay)
Common.display_attention_maps(X[i], [combined_image, a1[i], a2[i], a3[i]])

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    residual_attention_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))