Import Required Modules

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.datasets import cifar10
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import label_binarize
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc

Load and preprocess the CIFAR-10 dataset


In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize to [0, 1]
val_size = int(0.2 * len(x_train))
x_val, y_val = x_train[:val_size], y_train[:val_size]
x_train, y_train = x_train[val_size:], y_train[val_size:]

Parameters

In [3]:
PATCH_SIZE = 4  # Larger patches for fewer tokens
EMBED_DIM = 128  # Moderately sized embedding dimension
NUM_HEADS = 4  # Balanced number of heads for multi-head attention
NUM_LAYERS = 3  # Reduced number of layers to 3
MLP_DIM = 256  # MLP dimension for transformer layers
DROPOUT_RATE = 0.1
NUM_CLASSES = 10
EPOCHS = 10  # Reduced epochs to 10

Patch extraction

In [4]:
def extract_patches(images, patch_size):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    patch_dim = patch_size * patch_size * 3
    patches = tf.reshape(patches, [batch_size, -1, patch_dim])
    return patches

Positional Encoding

In [5]:
def positional_encoding(num_patches, dim):
    positions = np.arange(num_patches)[:, np.newaxis]
    dimensions = np.arange(dim)[np.newaxis, :]
    angle_rates = 1 / np.power(10000, (2 * (dimensions // 2)) / np.float32(dim))
    angle_rads = positions * angle_rates
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    return tf.cast(angle_rads, dtype=tf.float32)

Transformer block

In [6]:
def transformer_block(x, num_heads, mlp_dim, dropout_rate):
    # Multi-head self-attention
    attn_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=x.shape[-1])(x, x)
    x = layers.LayerNormalization()(x + attn_output)
    # Feedforward network
    mlp_output = layers.Dense(mlp_dim, activation='relu')(x)
    mlp_output = layers.Dropout(dropout_rate)(mlp_output)
    mlp_output = layers.Dense(x.shape[-1])(mlp_output)
    return layers.LayerNormalization()(x + mlp_output)

Vision Transformer model

In [7]:
def create_vit(num_patches, patch_dim, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout_rate):
    inputs = layers.Input(shape=(num_patches, patch_dim))
    # Linear projection of patches
    x = layers.Dense(embed_dim)(inputs)
    # Add positional encoding
    pos_encoding = positional_encoding(num_patches, embed_dim)
    x += pos_encoding
    # Transformer layers
    for _ in range(num_layers):
        x = transformer_block(x, num_heads, mlp_dim, dropout_rate)
    # Classification head
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return tf.keras.Model(inputs, outputs)

Create model

In [8]:
num_patches = (32 // PATCH_SIZE) ** 2
patch_dim = PATCH_SIZE * PATCH_SIZE * 3
model = create_vit(
    num_patches=num_patches,
    patch_dim=patch_dim,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    mlp_dim=MLP_DIM,
    num_layers=NUM_LAYERS,
    num_classes=NUM_CLASSES,
    dropout_rate=DROPOUT_RATE
)

Compile model

In [9]:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.001,
    decay_steps=10000,
    decay_rate=0.9
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Train the model

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
history = model.fit(
    extract_patches(x_train, PATCH_SIZE), y_train,
    validation_data=(extract_patches(x_val, PATCH_SIZE), y_val),
    batch_size=64,
    epochs=EPOCHS,
    callbacks=[early_stopping]
)

# Evaluate the model
test_patches = extract_patches(x_test, PATCH_SIZE)
test_loss, test_accuracy = model.evaluate(test_patches, y_test, verbose=2)
print(f"Test Accuracy: {test_accuracy:.2f}")

1. Plot learning curves

In [None]:
def plot_learning_curves(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()
    
plot_learning_curves(history)

2. Confusion Matrix

In [None]:
y_pred = np.argmax(model.predict(test_patches), axis=1)
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

3. Classification Report (Precision, Recall, F1-score)

In [None]:
report = classification_report(y_test, y_pred, target_names=[str(i) for i in range(NUM_CLASSES)])
print("Classification Report:\n", report)

4. Optional ROC curve and AUC for each class

In [None]:
y_test_binarized = label_binarize(y_test, classes=[i for i in range(NUM_CLASSES)])
y_pred_probs = model.predict(test_patches)

# Plotting ROC Curves
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(NUM_CLASSES):
    fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure(figsize=(12, 6))
for i in range(NUM_CLASSES):
    plt.plot(fpr[i], tpr[i], label=f"Class {i} (AUC = {roc_auc[i]:.2f})")
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.title('ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.show()