In [None]:
import tensorflow as tf
import datetime
import matplotlib.pyplot as plt
from VCWA import Common, AttentionModels

## Data Prep

### Imagenette

In [None]:
input_shape = (224, 224, 3)
classes = 10
epochs = 100
dataset = "imagenette"
batch_size = 64
path = "D:/"
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=0.0001)


train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True,
)

train_generator = train_datagen.flow_from_directory(
    path + 'datasets/imagenette2/train',
    target_size=(224, 224),
    batch_size=batch_size)


test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input)

test_generator = test_datagen.flow_from_directory(
    path + 'datasets/imagenette2/val',
    target_size=(224, 224),
    batch_size=batch_size)

### CIFAR-10

In [None]:
input_shape = (224, 224, 3)
classes = 100
epochs = 100
dataset = "cifar-100"
batch_size = 64
path = "D:/"
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=0.0001)


train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    channel_shift_range=50.0,
    horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(
    path + 'datasets/cifar100/train',
    target_size=(224, 224),
    batch_size=batch_size)


test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input)

test_generator = test_datagen.flow_from_directory(
    path + 'datasets/cifar100/test',
    target_size=(224, 224),
    batch_size=batch_size)

### CUB-200-2011

In [None]:
input_shape = (224, 224, 3)
classes = 200
epochs = 100
dataset = "cub-200"
batch_size = 64
path = "D:/"
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=0.0001)


datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True, 
    validation_split=0.2)

train_generator = datagen.flow_from_directory(
    path + 'datasets/CUB_200_2011/images',
    target_size=(224, 224),
    batch_size=batch_size,
    subset="training")

test_generator = datagen.flow_from_directory(
    path + 'datasets/CUB_200_2011/images',
    target_size=(224, 224),
    batch_size=batch_size,
    subset="validation")

## Vanilla MobileNetV2

### Training

In [None]:
#mobilenet = tf.keras.models.load_model("models/" + dataset + "/mobilenetv2")
mobilenet = tf.keras.applications.MobileNetV2(input_shape=input_shape, classes=classes, weights=None)

mobilenet.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
mobilenet_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + resnet50.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

mobilenet.fit(
    train_generator,
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[resnet50_tensorboard_callback])

In [None]:
mobilenet.save("models/" + dataset + "/mobilenetv2")

### Attention

In [None]:
x, _ = test_generator.__getitem__(0)

In [None]:
Common.display_attention_batch(mobilenet, x, CAM_layer="block_12_add")

In [None]:
Common.display_lime_batch(mobilenet, x)

## L2PA

### Training

In [None]:
# L2PA_model = tf.keras.models.load_model("models/" + dataset + "/L2PA_MobileNetV2")
L2PA_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=input_shape, classes=classes)

L2PA_model.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
L2PA_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + L2PA_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

L2PA_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[L2PA_tensorboard_callback])

In [None]:
L2PA_model.save("models/" + dataset + "/" + L2PA_model.name)

### Attention

In [None]:
x, _ = test_generator.__getitem__(0)

In [None]:
Common.display_attention_batch(L2PA_model, x, use_attention=True)

In [None]:
Common.display_lime_batch(L2PA_model, x)

## Attention Gated

### Training

In [None]:
# gated_model = tf.keras.models.load_model("models/" + dataset + "/AttGated_MobileNetV2")
gated_model = AttentionModels.create_AttentionGated_ResNet50v2(input_shape=input_shape, classes=classes)

gated_model.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
gated_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + gated_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

gated_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[gated_tensorboard_callback])

In [None]:
gated_model.save("models/" + dataset + "/" + gated_model.name)

### Attention

In [None]:
x, _ = test_generator.__getitem__(0)

In [None]:
Common.display_attention_batch(gated_model, x, use_attention=True)

In [None]:
Common.display_lime_batch(gated_model, x)

## Attention Gated with Grid Attention

### Training

In [None]:
# gatedgrid_model = tf.keras.models.load_model("models/" + dataset + "/AttGatedGrid_MobileNetV2")
gatedgrid_model = AttentionModels.create_AttentionGatedGrid_ResNet50v2(input_shape=input_shape, classes=classes)

gatedgrid_model.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
gatedgrid_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + gatedgrid_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

gatedgrid_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[gatedgrid_tensorboard_callback])

In [None]:
gatedgrid_model.save("models/" + dataset + "/" + gatedgrid_model.name)

### Attention

In [None]:
x, _ = test_generator.__getitem__(0)

In [None]:
Common.display_attention_batch(gatedgrid_model, x, use_attention=True)

In [None]:
Common.display_lime_batch(gatedgrid_model, x)

## Residual Attention Network

### Training

In [None]:
# residual_attention_model = tf.keras.models.load_model("models/" + dataset + "/ResAttentionMobileNetV2")
residual_attention_model = AttentionModels.create_ResidualAttention_ResNet50v2(input_shape=input_shape, classes=classes)

residual_attention_model.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
residualattention_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + residual_attention_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

residual_attention_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[residualattention_callback])

In [None]:
residual_attention_model.save("models/" + dataset + "/" + residual_attention_model.name)

### Attention

In [None]:
x, _ = test_generator.__getitem__(0)

In [None]:
Common.display_attention_batch(residual_attention_model, x, use_attention=True)

In [None]:
Common.display_lime_batch(residual_attention_model, x)

## CBAM

### Training

In [None]:
# CBAM_model = tf.keras.models.load_model("models/" + dataset + "/CBAM_MobileNetV2")
CBAM_model = AttentionModels.create_CBAM_ResNet50v2(input_shape=input_shape, classes=classes)

CBAM_model.compile(
    loss="categorical_crossentropy", 
    optimizer=optimizer, 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
cbam_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + CBAM_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

CBAM_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[cbam_callback])

In [None]:
CBAM_model.save("models/" + dataset + "/" + CBAM_model.name)

### Attention

In [None]:
x, _ = test_generator.__getitem__(0)

In [None]:
Common.display_attention_batch(CBAM_model, x, CAM_layer="block_12_add")

In [None]:
Common.display_lime_batch(CBAM_model, x)