# Image classification with CNN+Attention Hybrid Model

**Author:** [RIYAJ ATAR]


In [1]:
!pip install -U tensorflow-addons

Collecting tensorflow-addons
  Downloading tensorflow_addons-0.13.0-cp37-cp37m-manylinux2010_x86_64.whl (679 kB)
[?25l[K     |▌                               | 10 kB 34.5 MB/s eta 0:00:01[K     |█                               | 20 kB 25.5 MB/s eta 0:00:01[K     |█▌                              | 30 kB 18.6 MB/s eta 0:00:01[K     |██                              | 40 kB 16.2 MB/s eta 0:00:01[K     |██▍                             | 51 kB 7.0 MB/s eta 0:00:01[K     |███                             | 61 kB 8.2 MB/s eta 0:00:01[K     |███▍                            | 71 kB 7.9 MB/s eta 0:00:01[K     |███▉                            | 81 kB 8.9 MB/s eta 0:00:01[K     |████▍                           | 92 kB 9.4 MB/s eta 0:00:01[K     |████▉                           | 102 kB 7.0 MB/s eta 0:00:01[K     |█████▎                          | 112 kB 7.0 MB/s eta 0:00:01[K     |█████▉                          | 122 kB 7.0 MB/s eta 0:00:01[K     |██████▎                 

## Setup

In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
import tensorflow as tf
from tensorflow.keras import layers

## Prepare the data

In [3]:
num_classes = 100
input_shape = (32, 32, 3)
image_size = 72
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


## Attention applied on Feature Maps of CNN output

In [4]:

def attention_weight(f,num_heads = 8):

  x1 = layers.Reshape([f.shape[1]*f.shape[2],f.shape[3]])(f)

  attention_output = layers.MultiHeadAttention(
      num_heads=num_heads, key_dim=f.shape[3], dropout=0.1
  )(x1, x1)

  x1 = layers.Reshape([f.shape[1],f.shape[2],f.shape[3]])(x1)

  f = layers.Multiply()([f,x1])

  return f

## Configure the hyperparameters

In [5]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100



## Use data augmentation

In [6]:
data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Normalization(),
        layers.experimental.preprocessing.Resizing(image_size, image_size),
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(factor=0.02),
        layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)


## Define keras model 

In [7]:


def model_cifar100():

  x = layers.Input(shape=(32,32,3))

  inputs = data_augmentation(x)


  # l = AveragePooling2D

  f = layers.Conv2D(16,3,1,'same',activation = 'relu')(inputs)
  f = layers.Conv2D(16,3,2,'same',activation = 'relu')(f) 
  f = layers.BatchNormalization()(f)
  # p1 = l(pool_size=2)(x)
  # f = Concatenate(axis=3)([f,p1])

  f = layers.Conv2D(32,3,1,'same',activation = 'relu')(f)
  f = layers.Conv2D(32,3,2,'same',activation = 'relu')(f) 
  f = layers.BatchNormalization()(f)

  # p2 =  l(pool_size=4)(x)
  # f = Concatenate(axis=3)([f,p2])


  f = layers.Conv2D(64,3,1,'same',activation = 'relu')(f)
  f = layers.Conv2D(64,3,2,'same',activation = 'relu')(f) 
  f = layers.BatchNormalization()(f)

  # p3 =  l(pool_size=8)(x)
  # f = Concatenate(axis=3)([f,p3])
  f = attention_weight(f,num_heads=8)

  f = layers.Conv2D(128,3,1,'same',activation = 'relu')(f)
  f = layers.Conv2D(128,3,2,'same',activation = 'relu')(f) 
  f = layers.BatchNormalization()(f)

  # p4 =  l(pool_size=16)(x)
  # f = Concatenate(axis=3)([f,p4])
  f = attention_weight(f,num_heads=16)

  representation = layers.Flatten()(f)
  representation = layers.Dropout(0.5)(representation)


  logits = layers.Dense(100)(representation)
  model = keras.Model(inputs = x,outputs = logits)

  return model

model = model_cifar100()
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
data_augmentation (Sequential)  (None, 72, 72, 3)    7           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 72, 72, 16)   448         data_augmentation[0][0]          
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 36, 36, 16)   2320        conv2d[0][0]                     
______________________________________________________________________________________________

## Compile, train, and evaluate the mode

In [8]:

def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


history = run_experiment(model)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78