In [None]:
import tensorflow as tf
import datetime
import matplotlib.pyplot as plt
from VCWA import Common, AttentionModels

## Data Prep

### Imagenette

In [None]:
input_shape = (224, 224, 3)
classes = 10
epochs = 100
dataset = "imagenette"
batch_size = 16
path = "D:/"


train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(
    path + 'datasets/imagenette2/train',
    target_size=(224, 224),
    batch_size=batch_size)


test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input)

test_generator = test_datagen.flow_from_directory(
    path + 'datasets/imagenette2/val',
    target_size=(224, 224),
    batch_size=batch_size)

### CIFAR-100

In [None]:
input_shape = (32, 32, 3)
classes = 100
epochs = 100
dataset = "cifar-100"
batch_size = 256
path = "D:/"


(x_train, y_train), (x_test, y_test)= tf.keras.datasets.cifar100.load_data()

train_generator = (
    tf.data.Dataset.from_tensor_slices(
        (
            tf.cast(x_train/255, tf.float32),
            tf.cast(tf.keras.utils.to_categorical(y_train), tf.int32)
        )
    )
)

train_generator = train_generator.batch(batch_size)

test_generator = (
    tf.data.Dataset.from_tensor_slices(
        (
            tf.cast(x_test/255, tf.float32),
            tf.cast(tf.keras.utils.to_categorical(y_test), tf.int32)
        )
    )
)

test_generator = test_generator.batch(batch_size)

## Vanilla ResNet50v2

### Training

In [None]:
#resnet50 = tf.keras.models.load_model("models/imagenette2/base_resnet50v2")
resnet50 = tf.keras.applications.ResNet50V2(input_shape=input_shape, classes=classes, weights=None)

resnet50.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [None]:
resnet50_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + resnet50.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

resnet50.fit(
    train_generator,
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[resnet50_tensorboard_callback])

In [None]:
resnet50.save("models/" + dataset + "/" + resnet50.name)

### Attention

In [None]:
Common.display_attention_batch(resnet50, test_generator, CAM_layer="conv5_block3_3_conv")

### Lime

In [None]:
Common.display_lime_batch(resnet50, test_generator)

## L2PA

### Training

In [None]:
# TODO load

L2PA_model = AttentionModels.create_L2PA_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
L2PA_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + L2PA_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

L2PA_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[L2PA_tensorboard_callback])

In [None]:
L2PA_model.save("models/" + dataset + "/" + L2PA_model.name)

### Attention

In [None]:
Common.display_attention_batch(L2PA_model, test_generator, use_attention=True)

### Lime

In [None]:
Common.display_lime_batch(L2PA_model, test_generator)

## Attention Gated

### Training

In [None]:
#gated_model = tf.keras.models.load_model("models/imagenette2/AttentionGated_ResNet50v2")
#gated_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

gated_model = AttentionModels.create_AttentionGated_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
gated_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + gated_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

gated_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[gated_tensorboard_callback])

In [None]:
gated_model.save("models/" + dataset + "/" + gated_model.name)

### Attention

In [None]:
Common.display_attention_batch(gated_model, test_generator, use_attention=True)

### Lime

In [None]:
Common.display_lime_batch(gated_model, test_generator)

## Attention Gated with Grid Attention

### Training

In [None]:
#gatedgrid_model = tf.keras.models.load_model("models/imagenette2/AttentionGatedGrid_ResNet50v2")
#gatedgrid_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

gatedgrid_model = AttentionModels.create_AttentionGatedGrid_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
gatedgrid_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + gatedgrid_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

gatedgrid_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[gatedgrid_tensorboard_callback])

In [None]:
gatedgrid_model.save("models/" + dataset + "/" + gatedgrid_model.name)

### Attention

In [None]:
Common.display_attention_batch(gatedgrid_model, test_generator, use_attention=True)

### Lime

In [None]:
Common.display_lime_batch(gatedgrid_model, test_generator)

## Residual Attention Network

### Training

In [None]:
#residual_attention_model = tf.keras.models.load_model("models/imagenette2/ResidualAttentionNet50v2")
#residual_attention_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]

residual_attention_model = AttentionModels.create_ResidualAttention_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
residualattention_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + residual_attention_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

residual_attention_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[residualattention_callback])

In [None]:
residual_attention_model.save("models/" + dataset + "/" + residual_attention_model.name)

### Attention

In [None]:
Common.display_attention_batch(residual_attention_model, test_generator, use_attention=True)

### Lime

In [None]:
Common.display_lime_batch(residual_attention_model, test_generator)

## CBAM

### Training

In [None]:
#CBAM_model = tf.keras.models.load_model("models/imagenette2/CBAM_ResNet50v2")
#CBAM_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

CBAM_model = AttentionModels.create_CBAM_ResNet50v2(input_shape=input_shape, classes=classes)

In [None]:
cbam_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/" + dataset + "/" + CBAM_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

CBAM_model.fit(
    train_generator, 
    epochs=epochs,
    validation_data=test_generator,
    callbacks=[cbam_callback])

In [None]:
CBAM_model.save("models/" + dataset + "/" + CBAM_model.name)

### Attention

In [None]:
Common.display_attention_batch(CBAM_model, test_generator, CAM_layer="conv5_block3_3_conv")

### Lime

In [None]:
Common.display_lime_batch(CBAM_model, test_generator)