[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](YOUR_COLAB_LINK_HERE)

# 04 Comparison: CNNs vs Transformers
## Objectives
- Compare inductive biases and trade-offs.
- Run a mini empirical comparison.
- Discuss when to use which model.


## Conceptual Primer: CNNs and Transformers

### What is a CNN?
A Convolutional Neural Network is a neural network designed for images. It uses **convolutions** to scan local regions with learnable filters. This builds in two key inductive biases: **locality** (nearby pixels matter most) and **translation equivariance** (a pattern can appear anywhere).

**How it works**
- **Convolution layers** learn filters that detect edges, textures, and shapes.
- **Pooling** reduces spatial size and increases robustness to small shifts.
- **Deeper layers** combine simpler patterns into higher-level features.
- A **classifier head** maps features to labels.

Mathematically, convolution computes weighted sums over local patches; stacking layers grows the receptive field. CNNs are data-efficient because their priors match natural images.

### What is a Transformer (ViT) for vision?
A Vision Transformer treats an image as a **sequence of patch tokens**. Each patch is embedded into a vector, positional information is added, and **self-attention** lets every patch interact with every other patch.

**Architecture overview**
- **Patchify** the image (e.g., 16x16 patches).
- **Linear embedding** of each patch.
- **Positional embeddings** preserve spatial order.
- **Transformer encoder blocks**: LayerNorm → Multi-Head Self-Attention → MLP, with residual connections.
- **Class token** (or token pooling) feeds a classifier head.

Self-attention gives global context from the start. ViTs typically need large-scale data or pretraining, but transfer learning makes them practical for smaller datasets.

### Key differences (intuition)
- **Inductive bias**: CNNs assume locality; Transformers are more flexible but data-hungry.
- **Context**: CNNs build global context gradually; Transformers use global attention immediately.
- **Compute**: CNNs are efficient; attention cost grows with image size.

References: ViT (https://arxiv.org/abs/2010.11929), Attention Is All You Need (https://arxiv.org/abs/1706.03762), ResNet (https://arxiv.org/abs/1512.03385).


In [None]:
!pip -q install torch torchvision transformers matplotlib

## Conceptual comparison

CNNs assume locality and translation equivariance.
Transformers use global attention and scale with data and compute.

In [None]:
import time
import torch
import torchvision
from torchvision import transforms
from transformers import AutoImageProcessor, ViTForImageClassification
import matplotlib.pyplot as plt


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False)
images, labels = next(iter(loader))
images = images.to(device)

cnn = torchvision.models.resnet18(pretrained=True).to(device)
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k').to(device)

vit_inputs = processor(images=images.cpu(), return_tensors='pt')
vit_inputs = {k: v.to(device) for k, v in vit_inputs.items()}

def throughput_cnn(model, images):
    start = time.time()
    with torch.no_grad():
        _ = model(images)
    end = time.time()
    return images.size(0) / (end - start + 1e-8)

def throughput_vit(model, inputs):
    start = time.time()
    with torch.no_grad():
        _ = model(**inputs)
    end = time.time()
    return inputs['pixel_values'].size(0) / (end - start + 1e-8)

cnn_tp = throughput_cnn(cnn, images)
vit_tp = throughput_vit(vit, vit_inputs)
cnn_params = sum(p.numel() for p in cnn.parameters()) / 1e6
vit_params = sum(p.numel() for p in vit.parameters()) / 1e6


In [None]:
plt.figure(figsize=(5, 3))
plt.bar(['ResNet-18', 'ViT'], [cnn_params, vit_params])
plt.title('Parameter Count (M)')
plt.ylabel('Millions')
plt.show()

plt.figure(figsize=(5, 3))
plt.bar(['ResNet-18', 'ViT'], [cnn_tp, vit_tp])
plt.title('Throughput (images/sec)')
plt.ylabel('Images/sec')
plt.show()


In [None]:
# Approx accuracy on a small subset (ImageNet-pretrained models on CIFAR-10)
def quick_accuracy_cnn(model, loader, max_batches=5):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (imgs, lbls) in enumerate(loader):
            if i >= max_batches:
                break
            imgs = imgs.to(device)
            logits = model(imgs)
            preds = logits.argmax(dim=1).cpu()
            correct += (preds == lbls).sum().item()
            total += lbls.size(0)
    return correct / max(total, 1)

def quick_accuracy_vit(model, loader, max_batches=5):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (imgs, lbls) in enumerate(loader):
            if i >= max_batches:
                break
            inputs = processor(images=imgs, return_tensors='pt')
            inputs = {k: v.to(device) for k, v in inputs.items()}
            logits = model(**inputs).logits
            preds = logits.argmax(dim=1).cpu()
            correct += (preds == lbls).sum().item()
            total += lbls.size(0)
    return correct / max(total, 1)

cnn_acc = quick_accuracy_cnn(cnn, loader)
vit_acc = quick_accuracy_vit(vit, loader)

plt.figure(figsize=(5, 3))
plt.bar(['ResNet-18', 'ViT'], [cnn_acc, vit_acc])
plt.title('Accuracy (subset)')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.show()


In [None]:
# Find disagreements between models
images, labels = next(iter(loader))
with torch.no_grad():
    cnn_logits = cnn(images.to(device))
    vit_inputs = processor(images=images, return_tensors='pt')
    vit_inputs = {k: v.to(device) for k, v in vit_inputs.items()}
    vit_logits = vit(**vit_inputs).logits
cnn_preds = cnn_logits.argmax(dim=1).cpu()
vit_preds = vit_logits.argmax(dim=1).cpu()
disagree_idx = [i for i in range(len(images)) if cnn_preds[i] != vit_preds[i]]
fig, axes = plt.subplots(1, 4, figsize=(8, 2))
for ax_i, ax in enumerate(axes):
    if ax_i < len(disagree_idx):
        idx = disagree_idx[ax_i]
        ax.imshow(images[idx].permute(1, 2, 0))
        ax.set_title(f'C:{cnn_preds[idx]} V:{vit_preds[idx]}')
    ax.axis('off')
plt.suptitle('CNN vs ViT disagreements')
plt.tight_layout()
plt.show()


### Scale Up
- Evaluate on larger subsets and log memory with torch.cuda.max_memory_allocated.
- Add a small CNN trained on CIFAR-10 for a fairer comparison.


## Pros/Cons

| Model | Pros | Cons |
|---|---|---|
| CNN | Strong inductive bias, data-efficient | Limited global context |
| ViT | Global attention, scales with data | Needs large data, compute-heavy |


## When to use what

- Use CNNs for small data, low compute, real-time constraints.
- Use ViTs when data and compute are ample or for transfer learning at scale.


## Discussion prompts

1. Why does inductive bias matter for data efficiency?
2. When is global attention worth the cost?
3. How do hybrid CNN+Transformer models bridge trade-offs?


### Summary
- CNNs and ViTs make different assumptions.
- Empirical throughput and parameter count help choose.

### Exercises
1. Measure accuracy on the same validation subset.
2. Try a smaller ViT model and compare throughput.
3. Find examples where CNN and ViT disagree.

### Further Reading
- https://arxiv.org/abs/2010.11929 (ViT)
- https://arxiv.org/abs/2012.12877 (Swin)
