In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize pixel values to [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0


In [None]:
# Data augmentation layers
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
])


In [None]:
def create_model():
    model = models.Sequential([
        layers.Input(shape=(32, 32, 3)),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    return model


In [None]:
model_no_aug = create_model()
model_no_aug.compile(optimizer='adam',
                     loss='sparse_categorical_crossentropy',
                     metrics=['accuracy'])

history_no_aug = model_no_aug.fit(x_train, y_train,
                                  epochs=10,
                                  validation_data=(x_test, y_test))
    

In [None]:
# Apply data augmentation to training data
augmented_train = data_augmentation(x_train)

model_with_aug = create_model()
model_with_aug.compile(optimizer='adam',
                       loss='sparse_categorical_crossentropy',
                       metrics=['accuracy'])

history_with_aug = model_with_aug.fit(augmented_train, y_train,
                                      epochs=10,
                                      validation_data=(x_test, y_test))


In [None]:
plt.figure(figsize=(12, 6))

# Plot training accuracy
plt.plot(history_no_aug.history['accuracy'], label='Train Accuracy (No Augmentation)')
plt.plot(history_with_aug.history['accuracy'], label='Train Accuracy (With Augmentation)')

# Plot validation accuracy
plt.plot(history_no_aug.history['val_accuracy'], label='Validation Accuracy (No Augmentation)')
plt.plot(history_with_aug.history['val_accuracy'], label='Validation Accuracy (With Augmentation)')

plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()
