[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rsanchezgarc/AI-ML-analytics-IE/blob/main/notebooks/4_DL_for_computer_vision/image_classification_CNN.ipynb)

# CIFAR-10 Image Classification with  PyTorch

This notebook trains a simple CNN on CIFAR-10 using PyTorch.


In [None]:
# Imports
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA available:", torch.cuda.is_available())

In [None]:
# CIFAR-10 classes and helper to show images
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

def show_images(images, labels, preds=None, max_images=5):
    """Display a few images with true and (optionally) predicted labels."""
    plt.figure(figsize=(15, 4))
    for i in range(min(max_images, images.shape[0])):
        plt.subplot(1, max_images, i + 1)
        img = images[i].cpu().numpy().transpose(1, 2, 0)
        plt.imshow(img)
        title = f"True: {CLASSES[int(labels[i])]}"
        if preds is not None:
            title += f"\nPred: {CLASSES[int(preds[i])]}"
        plt.title(title)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
# Data preparation

def get_data_loaders(batch_size=64, data_dir="/tmp/cifar10"):
    """Create train and validation loaders for CIFAR-10."""
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),  # scales from [0,255] to [0.,1.]
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_ds = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_transform
    )
    val_ds = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=test_transform
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
    )
    return train_loader, val_loader

In [None]:
# CNN model

class CIFAR10CNN(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
# Training and validation loops

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    running_correct = 0
    total = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        running_correct += (preds == labels).sum().item()
        total += images.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_correct / total
    return epoch_loss, epoch_acc


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    running_loss = 0.0
    running_correct = 0
    total = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        running_correct += (preds == labels).sum().item()
        total += images.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_correct / total
    return epoch_loss, epoch_acc

In [None]:
def train_model(
    epochs=10,
    batch_size=128,
    lr=1e-3,
    dropout=0.5,
):
    """Train the CIFAR-10 CNN and return model + history.

    Returns
    -------
    model : nn.Module
    history : dict with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    train_loader, val_loader = get_data_loaders(batch_size=batch_size)

    model = CIFAR10CNN(dropout=dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
    }

    best_val_acc = 0.0
    best_state_dict = None

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        print(
            f"Epoch {epoch:02d}/{epochs:02d} | "
            f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
            f"train_acc={train_acc:.4f} val_acc={val_acc:.4f}"
        )

    # Load best weights (based on val_acc) back into model
    if best_state_dict is not None:
        model.load_state_dict(best_state_dict)
        model.to(device)

    print(f"Best validation accuracy: {best_val_acc:.4f}")
    return model, history

In [None]:
def plot_history(history):
    """Plot train/val loss and accuracy curves."""
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history["train_loss"], label="train_loss")
    plt.plot(epochs, history["val_loss"], label="val_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, history["train_acc"], label="train_acc")
    plt.plot(epochs, history["val_acc"], label="val_acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
@torch.no_grad()
def inference_demo(model, batch_size=8):
    """Show a few CIFAR-10 test images with predictions from the model."""
    device = next(model.parameters()).device
    _, val_loader = get_data_loaders(batch_size=batch_size)

    images, labels = next(iter(val_loader))
    images = images.to(device)
    labels = labels.to(device)
    logits = model(images)
    preds = logits.argmax(dim=1)

    show_images(images.cpu(), labels.cpu(), preds.cpu())

In [None]:
# Run experiment

EPOCHS = 10       # adjust to your patience / GPU
BATCH_SIZE = 128  # typical good value for CIFAR-10
LR = 1e-3
DROPOUT = 0.5

model, history = train_model(
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    dropout=DROPOUT,
)

plot_history(history)
inference_demo(model, batch_size=8)