# CNN for CIFAR-10 Image Classification

This notebook implements Convolutional Neural Networks (CNNs) for image classification using the CIFAR-10 dataset.

## Learning Objectives:
- Understand CNN architecture (convolution, pooling, fully connected layers)
- Learn about feature maps and filters
- Implement data augmentation
- Visualize CNN features and training progress
- Compare different CNN architectures

In [1]:
# Cell 1: Import libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
# Cell 2: Load and explore CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Training data shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")
print(f"Number of classes: {len(class_names)}")
print(f"Pixel value range: {x_train.min()} to {x_train.max()}")

In [None]:
# Cell 3: Visualize sample images
plt.figure(figsize=(12, 8))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.imshow(x_train[i])
    plt.title(f'{class_names[y_train[i][0]]}')
    plt.axis('off')

plt.suptitle('Sample CIFAR-10 Images', fontsize=16)
plt.tight_layout()
plt.show()

# Show class distribution
unique, counts = np.unique(y_train, return_counts=True)
plt.figure(figsize=(10, 6))
plt.bar([class_names[i] for i in unique], counts)
plt.title('CIFAR-10 Training Data Distribution')
plt.xlabel('Classes')
plt.ylabel('Number of Images')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Cell 4: Data preprocessing
# Normalize pixel values to [0, 1]
x_train_norm = x_train.astype('float32') / 255.0
x_test_norm = x_test.astype('float32') / 255.0

# Convert labels to categorical
y_train_cat = keras.utils.to_categorical(y_train, 10)
y_test_cat = keras.utils.to_categorical(y_test, 10)

print(f"Normalized training data shape: {x_train_norm.shape}")
print(f"Categorical labels shape: {y_train_cat.shape}")
print(f"Pixel value range after normalization: {x_train_norm.min()} to {x_train_norm.max()}")
print(f"Sample label before: {y_train[0]}, after: {y_train_cat[0]}")

In [None]:
# Cell 5: Simple CNN model
def create_simple_cnn():
    model = models.Sequential([
        # First convolutional block
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.MaxPooling2D((2, 2)),
        
        # Second convolutional block
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        
        # Third convolutional block
        layers.Conv2D(64, (3, 3), activation='relu'),
        
        # Flatten and dense layers
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    return model

# Create and compile the model
simple_cnn = create_simple_cnn()
simple_cnn.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Display model architecture
simple_cnn.summary()

In [None]:
# Cell 6: Train the simple CNN
print("Training Simple CNN...")
history_simple = simple_cnn.fit(
    x_train_norm, y_train_cat,
    batch_size=32,
    epochs=10,
    validation_split=0.2,
    verbose=1
)

# Evaluate on test set
test_loss, test_accuracy = simple_cnn.evaluate(x_test_norm, y_test_cat, verbose=0)
print(f"\nSimple CNN Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")

In [None]:
# Cell 7: Improved CNN with data augmentation
def create_improved_cnn():
    model = models.Sequential([
        # Data augmentation layers
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
        
        # First convolutional block
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Second convolutional block
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Third convolutional block
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.25),
        
        # Flatten and dense layers
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax')
    ])
    return model

# Create and compile improved model
improved_cnn = create_improved_cnn()
improved_cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

improved_cnn.summary()

In [None]:
# Cell 8: Train improved CNN with callbacks
# Define callbacks
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    restore_best_weights=True
)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=3,
    min_lr=0.0001
)

print("Training Improved CNN with data augmentation...")
history_improved = improved_cnn.fit(
    x_train_norm, y_train_cat,
    batch_size=32,
    epochs=20,
    validation_split=0.2,
    callbacks=[early_stopping, reduce_lr],
    verbose=1
)

# Evaluate improved model
test_loss_improved, test_accuracy_improved = improved_cnn.evaluate(x_test_norm, y_test_cat, verbose=0)
print(f"\nImproved CNN Test Accuracy: {test_accuracy_improved:.4f} ({test_accuracy_improved*100:.2f}%)")

In [None]:
# Cell 9: Compare training histories
plt.figure(figsize=(15, 5))

# Training accuracy
plt.subplot(1, 3, 1)
plt.plot(history_simple.history['accuracy'], label='Simple CNN Train')
plt.plot(history_simple.history['val_accuracy'], label='Simple CNN Val')
plt.plot(history_improved.history['accuracy'], label='Improved CNN Train')
plt.plot(history_improved.history['val_accuracy'], label='Improved CNN Val')
plt.title('Model Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# Training loss
plt.subplot(1, 3, 2)
plt.plot(history_simple.history['loss'], label='Simple CNN Train')
plt.plot(history_simple.history['val_loss'], label='Simple CNN Val')
plt.plot(history_improved.history['loss'], label='Improved CNN Train')
plt.plot(history_improved.history['val_loss'], label='Improved CNN Val')
plt.title('Model Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Final accuracy comparison
plt.subplot(1, 3, 3)
models_comparison = ['Simple CNN', 'Improved CNN']
accuracies = [test_accuracy, test_accuracy_improved]
plt.bar(models_comparison, accuracies, color=['lightblue', 'lightgreen'])
plt.title('Test Accuracy Comparison')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
for i, v in enumerate(accuracies):
    plt.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
# Cell 10: Visualize CNN filters and feature maps
def visualize_filters(model, layer_name, num_filters=8):
    # Get the weights of the specified layer
    layer = None
    for l in model.layers:
        if l.name == layer_name:
            layer = l
            break
    
    if layer is None:
        print(f"Layer {layer_name} not found")
        return
    
    weights = layer.get_weights()[0]
    
    # Normalize weights for visualization
    weights = (weights - weights.min()) / (weights.max() - weights.min())
    
    plt.figure(figsize=(12, 8))
    for i in range(min(num_filters, weights.shape[3])):
        plt.subplot(2, 4, i + 1)
        plt.imshow(weights[:, :, 0, i], cmap='viridis')
        plt.title(f'Filter {i+1}')
        plt.axis('off')
    
    plt.suptitle(f'Filters from {layer_name}', fontsize=16)
    plt.tight_layout()
    plt.show()

# Visualize filters from the first conv layer
visualize_filters(simple_cnn, 'conv2d', 8)

In [None]:
# Cell 11: Visualize feature maps
def visualize_feature_maps(model, image, layer_names):
    # Create a model that outputs feature maps
    layer_outputs = []
    for layer_name in layer_names:
        for layer in model.layers:
            if layer.name == layer_name:
                layer_outputs.append(layer.output)
                break
    
    activation_model = models.Model(inputs=model.input, outputs=layer_outputs)
    
    # Get activations
    activations = activation_model.predict(image[np.newaxis, ...])
    
    # Plot feature maps
    for i, (layer_name, activation) in enumerate(zip(layer_names, activations)):
        plt.figure(figsize=(15, 10))
        
        # Show first 16 feature maps
        for j in range(min(16, activation.shape[-1])):
            plt.subplot(4, 4, j + 1)
            plt.imshow(activation[0, :, :, j], cmap='viridis')
            plt.title(f'Feature Map {j+1}')
            plt.axis('off')
        
        plt.suptitle(f'Feature Maps from {layer_name}', fontsize=16)
        plt.tight_layout()
        plt.show()

# Select a test image
test_image = x_test_norm[0]
plt.figure(figsize=(4, 4))
plt.imshow(test_image)
plt.title(f'Test Image: {class_names[y_test[0][0]]}')
plt.axis('off')
plt.show()

# Visualize feature maps from different layers
layer_names = ['conv2d', 'conv2d_1']  # First two conv layers
visualize_feature_maps(simple_cnn, test_image, layer_names)

In [None]:
# Cell 12: Detailed evaluation and confusion matrix
# Get predictions
y_pred_simple = simple_cnn.predict(x_test_norm)
y_pred_improved = improved_cnn.predict(x_test_norm)

# Convert to class labels
y_pred_simple_labels = np.argmax(y_pred_simple, axis=1)
y_pred_improved_labels = np.argmax(y_pred_improved, axis=1)
y_test_labels = np.argmax(y_test_cat, axis=1)

# Confusion matrices
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
cm_simple = confusion_matrix(y_test_labels, y_pred_simple_labels)
sns.heatmap(cm_simple, annot=True, fmt='d', cmap='Blues', 
           xticklabels=class_names, yticklabels=class_names)
plt.title('Simple CNN Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')

plt.subplot(1, 2, 2)
cm_improved = confusion_matrix(y_test_labels, y_pred_improved_labels)
sns.heatmap(cm_improved, annot=True, fmt='d', cmap='Greens',
           xticklabels=class_names, yticklabels=class_names)
plt.title('Improved CNN Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')

plt.tight_layout()
plt.show()

# Classification reports
print("Simple CNN Classification Report:")
print(classification_report(y_test_labels, y_pred_simple_labels, target_names=class_names))

print("\nImproved CNN Classification Report:")
print(classification_report(y_test_labels, y_pred_improved_labels, target_names=class_names))

In [None]:
# Cell 13: Analyze misclassified images
def show_misclassified_images(model, x_test, y_test, class_names, num_images=12):
    predictions = model.predict(x_test)
    predicted_labels = np.argmax(predictions, axis=1)
    true_labels = np.argmax(y_test, axis=1)
    
    # Find misclassified images
    misclassified_indices = np.where(predicted_labels != true_labels)[0]
    
    plt.figure(figsize=(15, 10))
    for i in range(min(num_images, len(misclassified_indices))):
        idx = misclassified_indices[i]
        
        plt.subplot(3, 4, i + 1)
        plt.imshow(x_test[idx])
        plt.title(f'True: {class_names[true_labels[idx]]}\n'
                 f'Pred: {class_names[predicted_labels[idx]]}\n'
                 f'Conf: {predictions[idx][predicted_labels[idx]]:.2f}')
        plt.axis('off')
    
    plt.suptitle('Misclassified Images', fontsize=16)
    plt.tight_layout()
    plt.show()

print("Misclassified images from Improved CNN:")
show_misclassified_images(improved_cnn, x_test_norm, y_test_cat, class_names)

In [None]:
# Cell 14: Model performance summary
print("=== CNN CIFAR-10 Classification Results ===")
print(f"Dataset: {x_train.shape[0]:,} training, {x_test.shape[0]:,} test images")
print(f"Image size: {x_train.shape[1]}x{x_train.shape[2]}x{x_train.shape[3]}")
print(f"Number of classes: {len(class_names)}")
print()

print("Model Comparison:")
print(f"Simple CNN:   {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"Improved CNN: {test_accuracy_improved:.4f} ({test_accuracy_improved*100:.2f}%)")
print(f"Improvement:  {test_accuracy_improved - test_accuracy:.4f} ({(test_accuracy_improved - test_accuracy)*100:.2f}%)")
print()

print("Key CNN Concepts Learned:")
print("✅ Convolutional layers for feature extraction")
print("✅ Pooling layers for spatial dimension reduction")
print("✅ Data augmentation for improved generalization")
print("✅ Batch normalization and dropout for regularization")
print("✅ Filter and feature map visualization")
print("✅ Model comparison and evaluation")
print()

print("Next Steps:")
print("- Try transfer learning with pre-trained models (VGG, ResNet)")
print("- Experiment with different architectures")
print("- Move to larger datasets (ImageNet)")
print("- Learn about RNNs for sequential data")