[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](YOUR_COLAB_LINK_HERE)

# 01 CNN Fundamentals
## Objectives
- Understand convolution, padding, stride, and pooling.
- Train a small CNN and interpret features.
- Visualize filters, feature maps, and saliency.


## Conceptual Primer: CNNs and Transformers

### What is a CNN?
A Convolutional Neural Network is a neural network designed for images. It uses **convolutions** to scan local regions with learnable filters. This builds in two key inductive biases: **locality** (nearby pixels matter most) and **translation equivariance** (a pattern can appear anywhere).

**How it works**
- **Convolution layers** learn filters that detect edges, textures, and shapes.
- **Pooling** reduces spatial size and increases robustness to small shifts.
- **Deeper layers** combine simpler patterns into higher-level features.
- A **classifier head** maps features to labels.

Mathematically, convolution computes weighted sums over local patches; stacking layers grows the receptive field. CNNs are data-efficient because their priors match natural images.

### What is a Transformer (ViT) for vision?
A Vision Transformer treats an image as a **sequence of patch tokens**. Each patch is embedded into a vector, positional information is added, and **self-attention** lets every patch interact with every other patch.

**Architecture overview**
- **Patchify** the image (e.g., 16x16 patches).
- **Linear embedding** of each patch.
- **Positional embeddings** preserve spatial order.
- **Transformer encoder blocks**: LayerNorm → Multi-Head Self-Attention → MLP, with residual connections.
- **Class token** (or token pooling) feeds a classifier head.

Self-attention gives global context from the start. ViTs typically need large-scale data or pretraining, but transfer learning makes them practical for smaller datasets.

### Key differences (intuition)
- **Inductive bias**: CNNs assume locality; Transformers are more flexible but data-hungry.
- **Context**: CNNs build global context gradually; Transformers use global attention immediately.
- **Compute**: CNNs are efficient; attention cost grows with image size.

References: ViT (https://arxiv.org/abs/2010.11929), Attention Is All You Need (https://arxiv.org/abs/1706.03762), ResNet (https://arxiv.org/abs/1512.03385).


In [None]:
!pip -q install torch torchvision matplotlib

In [None]:
import os, sys
repo_root = os.path.abspath('..')
sys.path.insert(0, repo_root)


## Conceptual overview

Convolutions assume locality and translation equivariance. Pooling grows the receptive field and adds robustness.

In [None]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from src.models.cnn import SimpleCNN
from src.trainers.training import train_one_epoch, evaluate
from src.utils.device import get_device
from src.utils.plotting import plot_curves
from src.vision_utils.gradcam import GradCAM


In [None]:
device = get_device()
transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
train_loader = DataLoader(Subset(train_ds, range(3000)), batch_size=64, shuffle=True)
test_loader = DataLoader(Subset(test_ds, range(500)), batch_size=64, shuffle=False)


## Train a small CNN

In [None]:
model = SimpleCNN(num_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
for _ in range(2):
    train_stats = train_one_epoch(model, train_loader, optimizer, device)
    val_stats = evaluate(model, test_loader, device)
    history['train_loss'].append(train_stats['loss'])
    history['train_acc'].append(train_stats['acc'])
    history['val_loss'].append(val_stats['loss'])
    history['val_acc'].append(val_stats['acc'])

plot_curves({'train_loss': history['train_loss'], 'val_loss': history['val_loss']}, 'CNN Loss')
plot_curves({'train_acc': history['train_acc'], 'val_acc': history['val_acc']}, 'CNN Accuracy')


## Visualize filters and feature maps

In [None]:
weights = model.features[0].weight.data.clone()
weights = (weights - weights.min()) / (weights.max() - weights.min() + 1e-8)
fig, axes = plt.subplots(2, 4, figsize=(6, 3))
for i, ax in enumerate(axes.flat):
    ax.imshow(weights[i].permute(1, 2, 0).cpu())
    ax.axis('off')
plt.suptitle('First-layer filters')
plt.tight_layout()
plt.show()


## Grad-CAM (saliency)

In [None]:
images, labels = next(iter(test_loader))
images = images.to(device)
cam = GradCAM(model, model.features[4])
heatmap = cam(images[:1]).cpu()[0, 0]
plt.figure(figsize=(4, 4))
plt.imshow(images[0].permute(1, 2, 0).cpu())
plt.imshow(heatmap, cmap='jet', alpha=0.5)
plt.title('Grad-CAM')
plt.axis('off')
plt.show()


In [None]:
model.eval()
images, labels = next(iter(test_loader))
with torch.no_grad():
    logits = model(images.to(device))
    preds = logits.argmax(dim=1).cpu()
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i].permute(1, 2, 0))
    ax.set_title(f'T:{train_ds.classes[labels[i]]} / P:{train_ds.classes[preds[i]]}')
    ax.axis('off')
plt.suptitle('Sample predictions')
plt.tight_layout()
plt.show()


In [None]:
from src.utils.metrics import confusion_matrix
y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in test_loader:
        logits = model(images.to(device))
        preds = logits.argmax(dim=1).cpu().tolist()
        y_pred.extend(preds)
        y_true.extend(labels.tolist())
cm = confusion_matrix(y_true, y_pred, num_classes=10)
plt.figure(figsize=(5, 4))
plt.imshow(cm, cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.colorbar()
plt.tight_layout()
plt.show()


### Scale Up
- Train 10-30 epochs and use data augmentation for better results.
- Try ResNet-18 for a stronger baseline.


### Summary
- Convolutions encode locality and translation equivariance.
- Pooling increases receptive fields and reduces spatial resolution.
- Saliency shows which regions influence predictions.

### Exercises
1. Change kernel size/stride and compare accuracy.
2. Add batch normalization and measure stability.
3. Add augmentation and compare results.

### Further Reading
- https://arxiv.org/abs/1409.1556 (VGG)
- https://arxiv.org/abs/1512.03385 (ResNet)
