# ============================================================
# CIFAR-10 Image Classification Using CNN
# ============================================================


# PROBLEM STATEMENT
CIFAR-10 is a dataset that consists of several images divided into the following 10 classes:

- Airplanes
- Cars
- Birds
- Cats
- Deer
- Dogs
- Frogs
- Horses
- Ships
- Trucks

The dataset stands for the Canadian Institute For Advanced Research (CIFAR)

CIFAR-10 is widely used for machine learning and computer vision applications.

The dataset consists of 60,000 32x32 color images and 6,000 images of each class.

Images have low resolution (32x32).

Data Source: https://www.cs.toronto.edu/~kriz/cifar.html

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow.keras import datasets, models, layers, optimizers
from tensorflow.keras.utils import to_categorical, array_to_img
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix, classification_report

print("TensorFlow version:", tf.__version__)

# ============================================================
# STEP 1: LOAD CIFAR-10 DATA
# ============================================================

In [None]:
(X_train, y_train_raw), (X_test, y_test_raw) = datasets.cifar10.load_data()

class_names = ["airplane", "automobile", "bird", "cat", "deer",
               "dog", "frog", "horse", "ship", "truck"]

print("Train images:", X_train.shape)
print("Test images:", X_test.shape)




# ============================================================
# STEP 2: VISUALIZE SAMPLE IMAGES
# ============================================================

In [None]:
plt.figure(figsize=(3, 3))
plt.imshow(X_train[100])
plt.title(class_names[y_train_raw[100][0]])
plt.axis("off")
plt.show()

# Grid of images
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
axes = axes.ravel()
for i in range(16):
    idx = np.random.randint(0, X_train.shape[0])
    axes[i].imshow(X_train[idx])
    axes[i].set_title(class_names[y_train_raw[idx][0]])
    axes[i].axis("off")
plt.tight_layout()
plt.show()

# ============================================================
# STEP 3: DATA PREPARATION (NORMALIZATION + ONE-HOT ENCODING)
# ============================================================

In [None]:
# Normalize pixel intensities
X_train = X_train.astype("float32") / 255.0
X_test  = X_test.astype("float32") / 255.0

# One-hot encoding of labels
num_classes = 10
y_train = to_categorical(y_train_raw, num_classes)
y_test  = to_categorical(y_test_raw, num_classes)

input_shape = X_train.shape[1:]
print("Input shape:", input_shape)

# ============================================================
# STEP 4: BUILD CNN MODEL
# ============================================================


In [None]:
cnn_model = models.Sequential([
    layers.Conv2D(64, (3, 3), activation="relu", input_shape=input_shape),
    layers.Conv2D(64, (3, 3), activation="relu"),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.4),

    layers.Conv2D(128, (3, 3), activation="relu"),
    layers.Conv2D(128, (3, 3), activation="relu"),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.4),

    layers.Flatten(),
    layers.Dense(1024, activation="relu"),
    layers.Dense(1024, activation="relu"),
    layers.Dense(num_classes, activation="softmax")
])

cnn_model.compile(
    loss="categorical_crossentropy",
    optimizer=optimizers.RMSprop(learning_rate=0.001),
    metrics=["accuracy"]
)

cnn_model.summary()

# ============================================================
# STEP 5: TRAIN MODEL
# ============================================================

In [None]:
history = cnn_model.fit(
    X_train, y_train,
    batch_size=32,
    epochs=5,
    validation_split=0.1,
    shuffle=True
)


# ============================================================
# STEP 6: EVALUATE MODEL + CONFUSION MATRIX
# ============================================================

In [None]:
test_loss, test_acc = cnn_model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy (Baseline): {test_acc:.4f}")

# Predictions
y_pred_prob = cnn_model.predict(X_test)
y_pred_class = np.argmax(y_pred_prob, axis=1)
y_true_class = y_test_raw.flatten()

# Confusion Matrix
cm = confusion_matrix(y_true_class, y_pred_class)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=False, cmap="Blues")
plt.title("Confusion Matrix - Baseline Model")
plt.xlabel("Predicted"); plt.ylabel("True")
plt.show()

print("Classification Report:")
print(classification_report(y_true_class, y_pred_class, target_names=class_names))

# ============================================================
# STEP 7: SAVE BASELINE MODEL
# ============================================================


In [None]:
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

baseline_path = os.path.join(save_dir, "cnn_cifar10_baseline.keras")
cnn_model.save(baseline_path)

print("Baseline model saved to:", baseline_path)

# ============================================================
# STEP 8: DATA AUGMENTATION & IMPROVED TRAINING
# ============================================================

In [None]:
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)

datagen.fit(X_train)

# Visualize augmented images
sample = X_train[:8]
fig = plt.figure(figsize=(12, 2))

for batch in datagen.flow(sample, batch_size=8):
    for i in range(8):
        ax = fig.add_subplot(1, 8, i + 1)
        ax.imshow(array_to_img(batch[i]))
        ax.axis("off")
    plt.suptitle("Example Augmented Images")
    plt.show()
    break

# Train with augmentation
history_aug = cnn_model.fit(
    datagen.flow(X_train, y_train, batch_size=32),
    steps_per_epoch=len(X_train) // 32,
    epochs=5,
    validation_data=(X_test, y_test)
)

# Evaluate augmented model
loss_aug, acc_aug = cnn_model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy (With Augmentation): {acc_aug:.4f}")

# ============================================================
# STEP 9: SAVE AUGMENTED MODEL
# ============================================================

In [None]:
aug_path = os.path.join(save_dir, "cnn_cifar10_augmented.keras")
cnn_model.save(aug_path)

print("Augmented model saved to:", aug_path)
