# 06. Mini CNN

This notebook implements a minimal Convolutional Neural Network for image classification.

## Experiment Overview
- **Goal**: Image classification using convolutional layers
- **Model**: Simple CNN (Conv2D → MaxPool → FC)
- **Features**: CIFAR-10 classification, feature map visualization
- **Learning**: Understanding convolutional neural networks

## What You'll Learn
- Convolutional layers and operations
- Pooling and feature extraction
- CNN architecture design
- Feature map visualization


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add scripts directory to path
sys.path.append('../scripts')
from utils import load_cifar10_data, plot_training_history, plot_confusion_matrix, get_device, set_seed
from train import train_model
from evaluate import evaluate_model

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

# Load CIFAR-10 dataset
print("Loading CIFAR-10 dataset...")
train_loader, val_loader, test_loader = load_cifar10_data(batch_size=64, test_split=0.2)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

# Visualize some training samples
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
for i in range(10):
    row, col = i // 5, i % 5
    # Get a batch and show first sample
    data, target = next(iter(train_loader))
    # CIFAR-10 images need to be denormalized for display
    img = data[0].permute(1, 2, 0)
    img = img * 0.5 + 0.5  # Denormalize
    img = torch.clamp(img, 0, 1)
    axes[row, col].imshow(img)
    axes[row, col].set_title(f'{class_names[target[0].item()]}')
    axes[row, col].axis('off')
plt.tight_layout()
plt.show()


In [None]:
# Define the Mini CNN model
class MiniCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(MiniCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 4 * 4, 256)  # 32x32 -> 16x16 -> 8x8 -> 4x4 after 3 pools
        self.fc2 = nn.Linear(256, num_classes)
        
        # Dropout
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # First conv block
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        
        # Second conv block
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        
        # Third conv block
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Create model instance
model = MiniCNN().to(device)

# Print model summary
print("Model Architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024:.2f} MB")

# Test forward pass
with torch.no_grad():
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    output = model(dummy_input)
    print(f"Output shape: {output.shape}")


In [None]:
# Train the model
print("Starting CNN training...")
trainer = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    task_type='classification',
    epochs=20,
    lr=0.001,
    device=device,
    save_dir='../results/logs/mini_cnn'
)

# Plot training history
trainer.plot_training_history(save_path='../results/plots/mini_cnn_training.png')


In [None]:
# Evaluate the model on test set
print("Evaluating on test set...")
results = evaluate_model(
    model=model,
    data_loader=test_loader,
    task_type='classification',
    device=device,
    save_dir='../results/plots/mini_cnn'
)

# Show some predictions
model.eval()
with torch.no_grad():
    # Get a batch of test data
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)
    output = model(data)
    pred = output.argmax(dim=1)
    
    # Visualize predictions
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    for i in range(10):
        row, col = i // 5, i % 5
        # Denormalize image for display
        img = data[i].cpu().permute(1, 2, 0)
        img = img * 0.5 + 0.5
        img = torch.clamp(img, 0, 1)
        axes[row, col].imshow(img)
        axes[row, col].set_title(f'True: {class_names[target[i].item()]}\\nPred: {class_names[pred[i].item()]}')
        axes[row, col].axis('off')
    plt.tight_layout()
    plt.show()

print(f"\\nFinal Test Accuracy: {results['accuracy']:.4f}")
print(f"Final Test F1-Score: {results['f1_score']:.4f}")

# Feature map visualization
def visualize_feature_maps(model, image, layer_name, device):
    """Visualize feature maps from a specific layer."""
    model.eval()
    
    # Hook to capture feature maps
    feature_maps = {}
    def hook_fn(module, input, output):
        feature_maps[layer_name] = output.detach()
    
    # Register hook
    if layer_name == 'conv1':
        hook = model.conv1.register_forward_hook(hook_fn)
    elif layer_name == 'conv2':
        hook = model.conv2.register_forward_hook(hook_fn)
    elif layer_name == 'conv3':
        hook = model.conv3.register_forward_hook(hook_fn)
    
    # Forward pass
    with torch.no_grad():
        _ = model(image.unsqueeze(0).to(device))
    
    # Remove hook
    hook.remove()
    
    # Get feature maps
    fmaps = feature_maps[layer_name][0]  # First sample in batch
    
    # Visualize first 16 feature maps
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i in range(16):
        row, col = i // 4, i % 4
        if i < fmaps.shape[0]:
            axes[row, col].imshow(fmaps[i].cpu().numpy(), cmap='viridis')
            axes[row, col].set_title(f'Channel {i}')
        axes[row, col].axis('off')
    
    plt.suptitle(f'Feature Maps from {layer_name}')
    plt.tight_layout()
    plt.show()

# Visualize feature maps for a sample image
sample_image, _ = next(iter(test_loader))
sample_image = sample_image[0]  # Take first image

print("Visualizing feature maps...")
for layer in ['conv1', 'conv2', 'conv3']:
    visualize_feature_maps(model, sample_image, layer, device)
