In [None]:
import tensorflow as tf
import datetime
from lime import lime_image
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
from VCWA import Common, AttentionModels

In [None]:
tf.keras.applications.ResNet50V2(input_shape=(32, 32, 3), classes=10, weights=None).summary()

## Data Prep

In [None]:
input_shape = (224, 224, 3)
classes = 10
epochs = 5
dataset = "imagenette"
batch_size = 16
path = "D:/"

### Imagenette

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(
        path + 'datasets/imagenette2/train',
        target_size=(224, 224),
        batch_size=batch_size)


test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input)

test_generator = test_datagen.flow_from_directory(
        path + 'datasets/imagenette2/val',
        target_size=(224, 224),
        batch_size=16)

### Altermative CIFAR-100

In [None]:
(x_train, y_train), (x_test, y_test)= tf.keras.datasets.cifar100.load_data()

In [None]:
x_train

In [None]:
def cifar100_helper_generator(x, y):
    while True:
        for i in range(0, len(x), 25):
            yield (
                Common.resize_video(x[i:i+25], (224, 224))/255. ,
                tf.keras.utils.to_categorical(y[i:i+25], 100)
            )

## Vanilla ResNet50v2

### Training

In [None]:
#resnet50 = tf.keras.models.load_model("models/imagenette2/base_resnet50v2")
resnet50 = tf.keras.applications.ResNet50V2(input_shape=input_shape, classes=classes, weights=None)

resnet50.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [None]:
resnet50_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + resnet50.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

resnet50.fit(
    train_generator,
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[resnet50_tensorboard_callback])

In [None]:
resnet50.save("models/" + dataset + "/" + resnet50.name)

### Attention

In [None]:
Common.display_attention_batch(resnet50, test_generator, CAM_layer="conv5_block3_3_conv")

### lime

In [None]:
X, Y = next(test_generator)

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    resnet50.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## L2PA

### Training

In [None]:
# TODO load

L2PA_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
L2PA_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + L2PA_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

L2PA_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[L2PA_tensorboard_callback])

In [None]:
L2PA_model.save("models/" + dataset + "/" + L2PA_model.name)

### Attention

In [None]:
Common.display_attention_batch(L2PA_model, test_generator, use_attention=True)

## Attention Gated

### Training

In [None]:
#gated_model = tf.keras.models.load_model("models/imagenette2/AttentionGated_ResNet50v2")
#gated_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

gated_model = AttentionModels.create_AttentionGated_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
gated_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + gated_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

gated_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[gated_tensorboard_callback])

In [None]:
gated_model.save("models/" + dataset + "/" + gated_model.name)

### Attention

In [None]:
gated_extractor = AttentionModels.get_attention_extractor(gated_model)

In [None]:
X, Y = next(test_generator)
#X, Y = next(cifar100_helper_generator(x_test, y_test))

In [None]:
prediction, a1, a2, a3 = gated_extractor.predict(X)

i = 3

overlay = overlay = Common.combine_attention([a1[i], a2[i], a3[i]])
combined_image = Common.overlay_attention(X[i], overlay)
Common.display_attention_maps(X[i], [combined_image, a1[i], a2[i], a3[i]])

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    gated_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## Attention Gated with Grid Attention

### Training

In [None]:
#gatedgrid_model = tf.keras.models.load_model("models/imagenette2/AttentionGatedGrid_ResNet50v2")
#gatedgrid_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

gatedgrid_model = AttentionModels.create_AttentionGatedGrid_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
gatedgrid_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + gatedgrid_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

gatedgrid_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[gatedgrid_tensorboard_callback])

In [None]:
gatedgrid_model.save("models/" + dataset + "/" + gatedgrid_model.name)

### Attention

In [None]:
Common.display_attention_batch(gatedgrid_model, test_generator, use_attention=True)

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    gatedgrid_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## Residual Attention Network

### Training

In [None]:
#residual_attention_model = tf.keras.models.load_model("models/imagenette2/ResidualAttentionNet50v2")
#residual_attention_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]

residual_attention_model = AttentionModels.create_ResidualAttention_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
residualattention_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + residual_attention_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

residual_attention_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[residualattention_callback])

In [None]:
residual_attention_model.save("models/" + dataset + "/" + residual_attention_model.name)

### Attention

In [None]:
Common.display_attention_batch(residual_attention_model, test_generator, use_attention=True)

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[3].astype('double'), 
    residual_attention_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

## CBAM

### Training

In [None]:
#CBAM_model = tf.keras.models.load_model("models/imagenette2/CBAM_ResNet50v2")
#CBAM_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

CBAM_model = AttentionModels.create_CBAM_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
cbam_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + CBAM_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

CBAM_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[cbam_callback])

In [None]:
CBAM_model.save("models/" + dataset + "/" + CBAM_model.name)

### Attention

In [None]:
Common.display_attention_batch(resnet50, test_generator, CAM_layer="conv5_block3_3_conv")

### Lime

In [None]:
explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    X[1].astype('double'), 
    CBAM_model.predict, top_labels=5,
    hide_color=0, 
    num_samples=1000)

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))