Import Modules

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns

Load and preprocess the CIFAR-10 dataset

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize pixel values to range [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Split training set into training and validation sets
val_size = int(0.2 * x_train.shape[0])
x_val = x_train[:val_size]
y_val = y_train[:val_size]
x_train = x_train[val_size:]
y_train = y_train[val_size:]

# Define patch size
PATCH_SIZE = 4  # Patch size for the Swin Transformer

# Function to split an image into patches
def extract_patches(images, patch_size):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    patch_dim = patch_size * patch_size * 3
    patches = tf.reshape(patches, [batch_size, -1, patch_dim])
    return patches

# Apply patch extraction to training, validation, and test data
train_patches = extract_patches(x_train, PATCH_SIZE)
val_patches = extract_patches(x_val, PATCH_SIZE)
test_patches = extract_patches(x_test, PATCH_SIZE)

Swin Transformer Block

In [None]:
class SwinTransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, window_size=2, shift_size=1, dropout_rate=0.1):
        super(SwinTransformerBlock, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.window_size = window_size
        self.shift_size = shift_size
        self.layernorm = layers.LayerNormalization(epsilon=1e-6)
        self.msa = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.mlp = tf.keras.Sequential([
            layers.Dense(4 * embed_dim, activation='relu'),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim)
        ])

    def call(self, x):
        x_norm = self.layernorm(x)
        attn_output = self.msa(x_norm, x_norm)
        x = x + attn_output
        x_norm = self.layernorm(x)
        x = x + self.mlp(x_norm)
        return x

Swin Transformer Model

In [None]:
def build_swin_transformer(input_shape, num_classes, num_blocks, embed_dim, num_heads):
    inputs = layers.Input(shape=input_shape)
    
    # Patch partition
    x = layers.Dense(embed_dim)(inputs)
    
    # Swin transformer blocks
    for _ in range(num_blocks):
        x = SwinTransformerBlock(embed_dim, num_heads=num_heads)(x)
    
    # Global average pooling and classification
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return tf.keras.Model(inputs, outputs)

Model parameters

In [None]:
num_patches = (32 // PATCH_SIZE) ** 2  # For CIFAR-10 with 32x32 images
patch_dim = PATCH_SIZE * PATCH_SIZE * 3
embed_dim = 64
num_heads = 4
num_blocks = 2
num_classes = 10

Compile the model

In [None]:
# Build Swin Transformer model
model = build_swin_transformer(input_shape=(num_patches, patch_dim), num_classes=num_classes, num_blocks=num_blocks, embed_dim=embed_dim, num_heads=num_heads)

# Compile the model
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.001,
    decay_steps=10000,
    decay_rate=0.9
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Data Augmentation

In [None]:
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)
datagen.fit(x_train)

# Wrap data augmentation for patches
def augmented_patch_generator(datagen, x_data, y_data, patch_size):
    for x_batch, y_batch in datagen.flow(x_data, y_data, batch_size=64):
        yield extract_patches(x_batch, patch_size), y_batch

Train the model

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True)
history = model.fit(
    augmented_patch_generator(datagen, x_train, y_train, PATCH_SIZE),
    validation_data=(val_patches, y_val),
    steps_per_epoch=len(x_train) // 64,
    epochs=10,
    callbacks=[early_stopping]
)

# Save the model
model.save('swin_transformer_model.h5')

Evaluate the model

In [None]:
test_loss, test_accuracy = model.evaluate(test_patches, y_test)
print(f"Test Accuracy: {test_accuracy:.2f}")

Confusion Matrix

In [None]:
y_pred = np.argmax(model.predict(test_patches), axis=1)
conf_matrix = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=list(range(10)), yticklabels=list(range(10)))
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

Classification Report

In [None]:
print(classification_report(y_test, y_pred, target_names=[str(i) for i in range(10)]))

Plot learning curves

In [None]:
def plot_learning_curves(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs = range(len(acc))

    plt.figure()
    plt.plot(epochs, acc, 'b', label='Training accuracy')
    plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.figure()
    plt.plot(epochs, loss, 'b', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

plot_learning_curves(history)

ROC Curve and AUC

In [None]:
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes)
fpr = {}
tpr = {}
roc_auc = {}

for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_one_hot[:, i], model.predict(test_patches)[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure()
for i in range(num_classes):
    plt.plot(fpr[i], tpr[i], label=f"Class {i} (AUC = {roc_auc[i]:.2f})")

plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for Each Class')
plt.legend(loc="lower right")
plt.show()