# Vision Transformers (ViT)

## Brief Recap of Vision Transformers

- Vision Transformers (ViT) represent a groundbreaking approach to computer vision tasks, applying transformer architectures traditionally used in NLP to image processing. 

- Introduced by Dosovitskiy et al. in 2020 in their paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", ViT has demonstrated remarkable performance on image classification tasks.

Vision Transformers work by:
1. Splitting an image into fixed-size patches (typically 16x16 pixels)
2. Linearly embedding these patches
3. Adding positional embeddings
4. Processing the resulting sequence with a standard Transformer encoder

Key advantages of Vision Transformers include:
- Global receptive field from the start, unlike CNNs which build it up gradually
- Strong performance on large datasets
- Efficient parallel processing of image patches
- Ability to capture long-range dependencies in images

## Architecture of Vision Transformers

The Vision Transformer architecture consists of several key components:

1. **Patch Embedding**: 
   - Divides the input image into non-overlapping patches
   - Flattens each patch and projects it to a lower-dimensional space
   - Adds a learnable classification token at the start of the sequence

2. **Position Embedding**:
   - Adds learnable position embeddings to provide spatial information
   - Helps the model understand the relative positions of patches

3. **Transformer Encoder**:
   - Multiple layers of multi-head self-attention
   - Feed-forward networks
   - Layer normalization and residual connections

4. **Classification Head**:
   - Uses the transformed classification token for final prediction
   - Typically consists of a simple MLP

## Setting Up the Environment

Let's start by importing the necessary libraries:

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
import matplotlib.pyplot as plt

In [None]:
# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

## Implementing Patch Creation and Embedding

First, we'll implement the patch creation and embedding layer:

In [None]:
class PatchEmbed(layers.Layer):
    def __init__(self, img_size=32, patch_size=4, embed_dim=64):
        super(PatchEmbed, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = layers.Conv2D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding='VALID'
        )
        self.flatten = layers.Reshape((-1, embed_dim))

    def call(self, x):
        x = self.proj(x)
        x = self.flatten(x)
        return x

### Patch Embedding Layer Explanation

The PatchEmbed class is a crucial component of the Vision Transformer that transforms input images into a sequence of embedded patches.

**Input Processing**
- Takes an image of size `img_size x img_size` (default 32x32)
- Divides it into non-overlapping patches of size `patch_size x patch_size` (default 4x4)
- The total number of patches is calculated as `(img_size // patch_size)²`[1]

**Layer Components**

1. **Convolutional Projection**
- Uses a Conv2D layer with:
  - `embed_dim` output filters (default 64)
  - Kernel size equal to patch_size (4x4)
  - Stride equal to patch_size for non-overlapping patches
  - 'VALID' padding to avoid padding artifacts[1]

2. **Flattening Operation**
- Reshapes the convolution output into a sequence of patch embeddings
- Transforms the 4D tensor into a 2D sequence of shape `(num_patches, embed_dim)`[1]

**Workflow**
1. Input image (batch_size, height, width, channels)
2. Conv2D projects patches to embedding dimension
3. Reshape flattens spatial dimensions into sequence
4. Output shape: (batch_size, num_patches, embed_dim)[1]

This layer effectively converts a 2D image into a sequence of embedded patches that can be processed by the transformer architecture.

## Implementing the Vision Transformer

Now we'll implement the complete Vision Transformer model:

In [None]:
class SimpleViT(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(SimpleViT, self).__init__()
        self.patch_embed = PatchEmbed()
        
        # Fix the add_weight call by using keyword arguments properly
        self.pos_embed = self.add_weight(
            shape=(1, 64, 64),  # (1, num_patches, embed_dim)
            initializer="zeros",
            trainable=True,
            name="pos_embed"
        )
        
        self.attention = layers.MultiHeadAttention(num_heads=4, key_dim=16)
        self.layernorm = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = tf.keras.Sequential([
            layers.Dense(128, activation="gelu"),
            layers.Dense(64)
        ])
        
        self.head = layers.Dense(num_classes)

    def call(self, x):
        # Create patches
        x = self.patch_embed(x)
        
        # Add position embeddings directly
        x = x + self.pos_embed
        
        # Transformer block
        attention_output = self.attention(x, x)
        x = self.layernorm(x + attention_output)
        x = x + self.mlp(x)
        
        # Classification head
        x = tf.reduce_mean(x, axis=1)
        return self.head(x)

### SimpleViT Architecture Explanation

The SimpleViT class implements a simplified version of the Vision Transformer architecture with the following key components:

**Initialization Components**
- PatchEmbed layer converts the input image into patches and embeds them
- Position embeddings matrix of shape (1, 64, 64) to encode spatial information[1]
- Multi-head attention layer with 4 heads and key dimension of 16
- Layer normalization for stabilizing the network
- MLP block with two dense layers (128 -> 64 dimensions)
- Classification head for final prediction

**Forward Pass Flow**

1. **Patch Embedding**
- Input image is divided into patches and embedded using PatchEmbed layer
- Output shape: (batch_size, num_patches, embed_dim)[1]

2. **Position Information**
- Adds learnable position embeddings to provide spatial context
- Position embeddings are added directly to patch embeddings[1]

3. **Transformer Processing**
- Self-attention mechanism processes patch relationships
- Layer normalization and residual connection combine attention output
- MLP further processes the features with residual connection

4. **Classification**
- Global average pooling across patches (reduce_mean)
- Final dense layer produces class logits

This implementation simplifies the original ViT architecture by:
- Using a single transformer block instead of multiple layers
- Omitting the class token
- Simplifying the MLP structure
- Using mean pooling instead of the class token for final classification

## Data Preprocessing and Loading

Let's set up the data pipeline using the CIFAR-10 dataset:

In [None]:
# Load and preprocess CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Create training and validation splits
train_size = int(0.8 * len(x_train[:1000]))  # Use 80% for training

# Split the data
train_x = x_train[:train_size]
train_y = y_train[:train_size]
val_x = x_train[train_size:1000]
val_y = y_train[train_size:1000]

# Create datasets
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))

# Configure datasets
train_ds = train_ds.shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.batch(32).prefetch(tf.data.AUTOTUNE)

## Model Configuration and Training

Here's how we can configure and train the Vision Transformer:

In [None]:
# Create and compile model
model = SimpleViT()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# Train model with validation data
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5
)

## Visualizations

In [None]:
plt.figure(figsize=(12, 4))

# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

## Conclusion

Vision Transformers represent a powerful alternative to traditional convolutional neural networks for image classification tasks. Their ability to process images as sequences of patches and capture global dependencies makes them particularly effective for many computer vision applications. While they typically require more data and computational resources than CNNs for training from scratch, they can achieve state-of-the-art performance when properly trained.

Key takeaways from this implementation:
- Vision Transformers can effectively process images by treating them as sequences of patches
- The architecture maintains the core components of the original Transformer while adapting them for image data
- The model can be effectively trained on standard image classification datasets
- Attention mechanisms provide interpretability through visualization of learned patterns