# üñºÔ∏è Image Classification with CNN on CIFAR-10

A comprehensive deep learning project implementing Convolutional Neural Networks for image classification.

## üìã Contents
1. [Setup and Imports](#1.-Setup-and-Imports)
2. [Data Exploration](#2.-Data-Exploration)
3. [Data Augmentation](#3.-Data-Augmentation)
4. [Model Architecture](#4.-Model-Architecture)
5. [Training](#5.-Training)
6. [Evaluation](#6.-Evaluation)
7. [Model Comparison: CNN vs ResNet18](#7.-Model-Comparison)
8. [Overfitting Analysis](#8.-Overfitting-Analysis)
9. [Conclusion](#9.-Conclusion)

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import datasets, transforms

# Project imports
from config import *
from src.data_loader import get_data_loaders, get_class_names, show_sample_images
from src.augmentation import get_train_transforms, get_test_transforms, denormalize
from src.models.custom_cnn import CustomCNN, CustomCNNNoRegularization
from src.models.resnet import ResNet18
from src.train import train_model
from src.evaluate import evaluate_and_report, get_confusion_matrix
from src.utils import plot_training_history, plot_confusion_matrix, plot_model_comparison

# Set random seed for reproducibility
set_seed(42)

# Check device
device = get_device()
print(f"\nPyTorch version: {torch.__version__}")

## 2. Data Exploration

In [None]:
# Load CIFAR-10 dataset
train_loader, val_loader, test_loader = get_data_loaders(batch_size=BATCH_SIZE)

# Get class names
class_names = get_class_names()
print(f"\nClasses: {class_names}")

In [None]:
# Display sample images
images, labels = next(iter(train_loader))

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flatten()):
    img = denormalize(images[i])
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    ax.imshow(img)
    ax.set_title(class_names[labels[i]], fontsize=10)
    ax.axis('off')

plt.suptitle('Sample CIFAR-10 Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/sample_images.png', dpi=150)
plt.show()

In [None]:
# Dataset statistics
print("Dataset Statistics:")
print("-" * 40)
print(f"Training samples: {len(train_loader.sampler)}")
print(f"Validation samples: {len(val_loader.sampler)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"Image shape: {images[0].shape}")
print(f"Number of classes: {len(class_names)}")
print(f"Batch size: {BATCH_SIZE}")

## 3. Data Augmentation

In [None]:
# Demonstrate data augmentation
from src.augmentation import get_train_transforms, get_heavy_augmentation

# Load raw dataset
raw_dataset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True)
sample_img, label = raw_dataset[0]

# Apply different augmentations
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i, ax in enumerate(axes.flatten()):
    if i == 0:
        ax.imshow(sample_img)
        ax.set_title('Original', fontsize=10)
    else:
        transform = get_train_transforms()
        aug_img = transform(sample_img)
        aug_img = denormalize(aug_img).permute(1, 2, 0).numpy()
        aug_img = np.clip(aug_img, 0, 1)
        ax.imshow(aug_img)
        ax.set_title(f'Augmented {i}', fontsize=10)
    ax.axis('off')

plt.suptitle('Data Augmentation Examples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/augmentation_examples.png', dpi=150)
plt.show()

## 4. Model Architecture

In [None]:
# Create Custom CNN model
custom_cnn = CustomCNN(num_classes=10, dropout_rate=0.5)

print("Custom CNN Architecture:")
print("=" * 60)
print(custom_cnn)
print("=" * 60)
print(f"\nTotal Parameters: {custom_cnn.get_num_parameters():,}")

In [None]:
# Create ResNet18 model
resnet18 = ResNet18(num_classes=10)

print("\nResNet18 Architecture (adapted for CIFAR-10):")
print("=" * 60)
print(f"Total Parameters: {resnet18.get_num_parameters():,}")

In [None]:
# Test forward pass
test_input = torch.randn(1, 3, 32, 32)

custom_output = custom_cnn(test_input)
resnet_output = resnet18(test_input)

print(f"Custom CNN output shape: {custom_output.shape}")
print(f"ResNet18 output shape: {resnet_output.shape}")

## 5. Training

Training the Custom CNN model with:
- Adam optimizer
- Cross Entropy Loss
- Learning Rate Scheduler
- Early Stopping

In [None]:
# Train Custom CNN
print("Training Custom CNN...")
print("=" * 60)

custom_cnn = CustomCNN(num_classes=10, dropout_rate=0.5)

history_cnn = train_model(
    model=custom_cnn,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device,
    use_scheduler=True,
    use_early_stopping=True,
    model_name='custom_cnn'
)

In [None]:
# Plot training history
plot_training_history(history_cnn, save_path='../results/custom_cnn_training.png')

In [None]:
# Train ResNet18
print("Training ResNet18...")
print("=" * 60)

resnet18 = ResNet18(num_classes=10)

history_resnet = train_model(
    model=resnet18,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device,
    use_scheduler=True,
    use_early_stopping=True,
    model_name='resnet18'
)

In [None]:
# Plot ResNet18 training history
plot_training_history(history_resnet, save_path='../results/resnet18_training.png')

## 6. Evaluation

In [None]:
# Load best models
custom_cnn.load_state_dict(torch.load('../models/custom_cnn_best.pth'))
resnet18.load_state_dict(torch.load('../models/resnet18_best.pth'))

# Evaluate Custom CNN
print("Evaluating Custom CNN:")
print("=" * 60)
results_cnn = evaluate_and_report(custom_cnn, test_loader, device=device)

In [None]:
# Plot confusion matrix for Custom CNN
plot_confusion_matrix(results_cnn['confusion_matrix'], 
                      save_path='../results/custom_cnn_confusion_matrix.png')

In [None]:
# Evaluate ResNet18
print("\nEvaluating ResNet18:")
print("=" * 60)
results_resnet = evaluate_and_report(resnet18, test_loader, device=device)

In [None]:
# Plot confusion matrix for ResNet18
plot_confusion_matrix(results_resnet['confusion_matrix'], 
                      save_path='../results/resnet18_confusion_matrix.png')

## 7. Model Comparison

In [None]:
# Compare models
comparison_results = {
    'Custom CNN': results_cnn['accuracy'],
    'ResNet18': results_resnet['accuracy']
}

plot_model_comparison(comparison_results, 
                      save_path='../results/model_comparison.png')

In [None]:
# Per-class accuracy comparison
from src.utils import plot_class_accuracy

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Custom CNN
x = np.arange(len(class_names))
axes[0].bar(x, results_cnn['per_class_accuracy'], color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(class_names, rotation=45, ha='right')
axes[0].set_ylabel('Accuracy (%)')
axes[0].set_title('Custom CNN - Per-Class Accuracy')
axes[0].set_ylim(0, 100)

# ResNet18
axes[1].bar(x, results_resnet['per_class_accuracy'], color='coral')
axes[1].set_xticks(x)
axes[1].set_xticklabels(class_names, rotation=45, ha='right')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('ResNet18 - Per-Class Accuracy')
axes[1].set_ylim(0, 100)

plt.tight_layout()
plt.savefig('../results/per_class_comparison.png', dpi=150)
plt.show()

In [None]:
# Summary table
print("\n" + "=" * 60)
print("MODEL COMPARISON SUMMARY")
print("=" * 60)
print(f"{'Model':<20} {'Parameters':<15} {'Test Accuracy':<15}")
print("-" * 50)
print(f"{'Custom CNN':<20} {custom_cnn.get_num_parameters():,}{'':>5} {results_cnn['accuracy']:.2f}%")
print(f"{'ResNet18':<20} {resnet18.get_num_parameters():,}{'':>5} {results_resnet['accuracy']:.2f}%")
print("=" * 60)

## 8. Overfitting Analysis

In [None]:
from src.overfitting_analysis import (
    train_for_analysis, 
    plot_overfitting_comparison,
    plot_generalization_gap,
    demonstrate_regularization_techniques
)

# Demonstrate regularization techniques
demonstrate_regularization_techniques()

In [None]:
# Train model WITHOUT regularization (for comparison)
print("\nTraining model WITHOUT regularization...")
train_loader_no_aug, val_loader_no_aug, _ = get_data_loaders(
    batch_size=128, use_augmentation=False
)

model_no_reg = CustomCNNNoRegularization()
history_no_reg = train_for_analysis(
    model_no_reg, train_loader_no_aug, val_loader_no_aug, epochs=25
)

In [None]:
# Train model WITH regularization
print("\nTraining model WITH regularization...")
model_with_reg = CustomCNN(dropout_rate=0.5)
history_with_reg = train_for_analysis(
    model_with_reg, train_loader, val_loader, epochs=25
)

In [None]:
# Plot overfitting comparison
plot_overfitting_comparison(
    history_no_reg, history_with_reg,
    save_path='../results/overfitting_comparison.png'
)

In [None]:
# Plot generalization gap
plot_generalization_gap(
    history_no_reg, history_with_reg,
    save_path='../results/generalization_gap.png'
)

In [None]:
# Overfitting analysis summary
final_gap_no_reg = history_no_reg['train_acc'][-1] - history_no_reg['val_acc'][-1]
final_gap_with_reg = history_with_reg['train_acc'][-1] - history_with_reg['val_acc'][-1]

print("\n" + "=" * 60)
print("OVERFITTING ANALYSIS SUMMARY")
print("=" * 60)
print(f"\nWithout Regularization:")
print(f"  Final Train Acc: {history_no_reg['train_acc'][-1]:.2f}%")
print(f"  Final Val Acc: {history_no_reg['val_acc'][-1]:.2f}%")
print(f"  Generalization Gap: {final_gap_no_reg:.2f}%")
print(f"\nWith Regularization (Dropout + BatchNorm + Augmentation):")
print(f"  Final Train Acc: {history_with_reg['train_acc'][-1]:.2f}%")
print(f"  Final Val Acc: {history_with_reg['val_acc'][-1]:.2f}%")
print(f"  Generalization Gap: {final_gap_with_reg:.2f}%")
print(f"\nImprovement: {final_gap_no_reg - final_gap_with_reg:.2f}% reduction in gap")
print("=" * 60)

## 9. Conclusion

### Key Findings:

1. **Custom CNN Performance**: Our custom CNN achieved good accuracy on CIFAR-10, demonstrating that fundamental CNN architectures work well for image classification.

2. **ResNet18 Advantage**: ResNet18 outperformed the custom CNN due to:
   - Skip connections enabling deeper learning
   - More sophisticated feature extraction
   - Better gradient flow during training

3. **Regularization Impact**: Adding regularization techniques significantly reduced overfitting:
   - Dropout prevents co-adaptation of neurons
   - Batch Normalization stabilizes training
   - Data Augmentation increases effective dataset size

4. **Data Augmentation**: Crucial for improving generalization on small datasets like CIFAR-10.

### Future Improvements:
- Implement more advanced architectures (VGG, DenseNet)
- Try transfer learning with pretrained weights
- Experiment with advanced augmentation (CutOut, MixUp)
- Hyperparameter tuning with grid/random search

In [None]:
print("\n" + "=" * 60)
print("Project Completed Successfully! üéâ")
print("=" * 60)
print("\nCheck the 'results/' folder for all generated plots.")
print("Check the 'models/' folder for saved model checkpoints.")