# **MultiHead Attention for Image Classification**


The primary goal of this notebook is to demonstrate the use of Multi-Head Attention within a CNN for image classification. By combining convolutional layers with attention mechanisms, the model aims to achieve improved accuracy in classifying handwritten digits. Specifically, the notebook aims to:

  * **Implement Multi-Head Attention**: The notebook provides a detailed implementation of Multi-Head Attention using TensorFlow and Keras.
  * **Build a CNN with Attention**: Integrate the Multi-Head Attention model into a CNN architecture for image classification.
  * **Train and Evaluate the Model**: Train the model on the MNIST dataset and evaluate its performance using appropriate metrics.
  * **Visualize Predictions**: Display predictions on sample images to showcase the model's classification capabilities.

By achieving these goals, the notebook effectively illustrates the benefits and application of Multi-Head Attention in improving the accuracy of image classification tasks.

## **MultiHead Attention**

<img src= 'multihead_att.png' width=500>

Firstly, we shall be defining our own MultiHead Attention model. Its going to perform the following:

**1. Input Processing:**

  * The function takes three inputs: q1, k1, and v1, representing the query, key, and value tensors, respectively.
  * These inputs are passed through dense layers with ReLU activation to project them into a higher-dimensional space. The output tensors are named q2, k2, and v2.
  
**2. Reshaping for Multi-Head Attention:**

  * The projected tensors (q2, k2, v2) are reshaped using `tf.keras.layers.Reshape` to create separate heads.
  * The reshaping operation divides the tensors into nv (number of heads) along the last dimension.
  * This results in tensors q, k, and v with shape [-1, l, nv, dv], where l is the sequence length, nv is the number of heads, and dv is the dimension of each head.

**3. Scaled Dot-Product Attention:**

  * The core attention mechanism is implemented using a `tf.keras.layers.Lambda` layer.
  * Inside the lambda function, the query and key tensors (q and k) are multiplied using tf.matmul to calculate the attention weights.
  * The attention weights are then scaled down by the square root of the head dimension (dv) to prevent them from becoming too large.
  * A softmax function is applied to the scaled attention weights to normalize them into probabilities.

**4.Weighted Value Aggregation:**

  * The normalized attention weights (probabilities) are multiplied with the value tensor (v) using another tf.keras.layers.Lambda layer.
  * This weighted sum of values produces the output of the attention mechanism.

**5.Output Reshaping and Projection:**

  * The output tensor is reshaped back to the original shape using `tf.keras.layers.Reshape`.
  * The reshaped output is then passed through a final dense layer with ReLU activation to project it to the desired output dimension.

**6. Residual Connection and Normalization:**

  * The MultiHeadsAttModel function incorporates a residual connection by adding the initial input (q1) to the output of the multi-head attention. This residual connection has the benefits of stabilizing training and allowing information to flow more easily.

  * Finally, the output is normalized using the NormL function, which implements Layer Normalization. Layer Normalization scales and shifts the input data independently for each feature, which can make training faster and more stable.



In [None]:
import tensorflow as tf
import numpy as np

def MultiHeadsAttModel(l=8*8, d=512, dv=64, dout=512, nv=8):
    def model(inputs):
        q1, k1, v1 = inputs

        v2 = tf.keras.layers.Dense(dv*nv, activation="relu")(v1)
        q2 = tf.keras.layers.Dense(dv*nv, activation="relu")(q1)
        k2 = tf.keras.layers.Dense(dv*nv, activation="relu")(k1)

        v = tf.keras.layers.Reshape([-1, l, nv, dv])(v2)
        q = tf.keras.layers.Reshape([-1, l, nv, dv])(q2)
        k = tf.keras.layers.Reshape([-1, l, nv, dv])(k2)

        # Wrap TensorFlow operations within a Lambda layer
        att = tf.keras.layers.Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True) / tf.sqrt(tf.cast(dv, tf.float32)))([q, k])
        att = tf.keras.layers.Lambda(lambda x: tf.nn.softmax(x))(att)

        out = tf.keras.layers.Lambda(lambda x: tf.matmul(x[0], x[1]))([att, v])
        out = tf.keras.layers.Reshape([-1, l, d])(out)

        out = out + q1 # Ensure shapes are compatible for addition

        out = tf.keras.layers.Dense(dout, activation="relu")(out)

        return out

    return model


### **Layer Normalization**

The `NormL` class implements Layer Normalization, a technique used to normalize the activations within a layer. This normalization helps stabilize and speed up the training process of neural networks.


**1. Initialization:**

  * The `__init__` method initializes the layer and creates two trainable weights: a and b. These weights are used for scaling and shifting the normalized data.

**2. Building the Layer:**

  * The build method is called when the layer is first used. It sets the shapes of the weights a and b based on the input shape.

**3. Applying Normalization:**

  * The call method is where the actual normalization happens.
  * It calculates the mean (`mu`) and standard deviation (`sigma`) of the input data (`x`) along the last dimension (axis=-1).
  * It then normalizes the data by subtracting the mean and dividing by the standard deviation. A small value (`eps`) is added to the standard deviation to avoid division by zero.
  * Finally, the normalized data is scaled by a and shifted by b before being returned.

In [None]:
class NormL(tf.keras.layers.Layer):
    def __init__(self):
        super(NormL, self).__init__()

    def build(self, input_shape):
        self.a = self.add_weight(name='a', shape=(1, input_shape[-1]),
                                 initializer='ones', trainable=True)
        self.b = self.add_weight(name='b', shape=(1, input_shape[-1]),
                                 initializer='zeros', trainable=True)

    def call(self, x):
        eps = 0.000001
        mu = tf.reduce_mean(x, axis=-1, keepdims=True)
        sigma = tf.math.reduce_std(x, axis=-1, keepdims=True)
        ln_out = (x - mu) / (sigma + eps)
        return ln_out * self.a + self.b


## **Data Preparation**

We are going to load the MNIST dataset, reshape the images, normalize the pixel values, and convert the labels into a suitable format for training the model. This preprocessing is essential for ensuring that the data is in the correct format and range for the model to learn effectively.

In [None]:
# Load and preprocess data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

X_train = X_train.reshape(60000, 28, 28, 1).astype('float32') / 255
X_test = X_test.reshape(10000, 28, 28, 1).astype('float32') / 255

Y_train = tf.keras.utils.to_categorical(y_train, 10)
Y_test = tf.keras.utils.to_categorical(y_test, 10)


## **Model Building**

We are going to define a Convolutional Neural Network (CNN) with Multi-Head Attention for classifying MNIST handwritten digits. It starts with an input layer shaped for MNIST images. Then, add convolutional layers to extract features, followed by pooling to reduce dimensionality. Next, it utilizes Multi-Head Attention to capture relationships within the features. After reshaping and normalization, flatten the output and feed it through dense layers, ultimately producing a probability distribution over the 10 digit classes (0-9). This structure enables the model to learn complex patterns and achieve accurate digit classification.

**In simpler terms:**

Imagine the model as a series of steps:

* **See:** It takes an image of a handwritten digit as input.
Extract features: It uses convolutional layers to identify important features like edges and curves.
* **Focus:** Multi-Head Attention helps the model focus on the most relevant features for classification.
* **Process:** It flattens the features and uses dense layers to analyze them.
* **Classify:** Finally, it outputs a probability distribution, predicting the digit in the image.

This model architecture leverages the strengths of both convolutional and attention mechanisms for effective digit classification.

In [None]:

# Define the model
inputs = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, (2, 2), activation='relu', padding='same')(inputs)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Conv2D(64, (2, 2), activation='relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(x)
x = tf.keras.layers.Conv2D(64*3, (2, 2), activation='relu')(x)

x = tf.keras.layers.Reshape([-1, 6*6, 64*3])(x) # Use TF Keras Reshape layer
att = MultiHeadsAttModel(l=6*6, d=64*3, dv=8*3, dout=32, nv=8)
x = att([x, x, x])
x = tf.keras.layers.Reshape([-1, 6, 6, 32])(x)  # Use TF Keras Reshape layer
x = NormL()(x)

x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(256, activation='relu')(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


## **Compiling & Training**

In [None]:
# Compile the model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train the model

model.fit(X_train, Y_train,
          batch_size=128,
          epochs=5,
          verbose=1,
          validation_data=(X_test, Y_test))


## **Inference**

In [None]:
import matplotlib.pyplot as plt

# Make predictions
y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = np.argmax(Y_test, axis=1)

# Plot some images with actual and predicted values
num_images_to_plot = 5
plt.figure(figsize=(10, 5))
for i in range(num_images_to_plot):
    plt.subplot(1, num_images_to_plot, i+1)
    plt.imshow(X_test[i].reshape(28, 28), cmap='gray')
    plt.title(f"True: {y_true[i]}\nPredicted: {y_pred_classes[i]}")
    plt.axis('off')
plt.show()