# Transfer Learning on CIFAR-10: Complete Application

This notebook demonstrates modern transfer learning practices with PyTorch, showing:
- Training a simple CNN from scratch
- Using pretrained models (ResNet, EfficientNet)
- Data augmentation strategies
- Fine-tuning techniques

**Dataset**: CIFAR-10 (10 classes, 60k images)

In [None]:
# Installation (if needed)
# !pip install torch torchvision timm matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import timm  # PyTorch Image Models - modern architectures
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Data Preparation with Modern Augmentation

We use different augmentation strategies:
- **Basic**: Minimal augmentation for baseline
- **Strong**: Modern augmentations (RandomErasing, ColorJitter)

In [None]:
# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Basic augmentation
transform_basic_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Strong augmentation for pretrained models
transform_strong_train = transforms.Compose([
    transforms.Resize(224),  # Pretrained models expect 224x224
    transforms.RandomCrop(224, padding=28),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet stats
    transforms.RandomErasing(p=0.3)
])

# Test transform (no augmentation)
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test_basic = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_strong_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Create data loaders
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

print(f'Training samples: {len(trainset)}')
print(f'Test samples: {len(testset)}')

In [None]:
# Visualize some samples
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

# Show a batch
dataiter = iter(train_loader)
images, labels = next(dataiter)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for idx, ax in enumerate(axes.flat):
    if idx < len(images):
        img = denormalize(images[idx].cpu())
        img = torch.clamp(img, 0, 1)
        ax.imshow(img.permute(1, 2, 0))
        ax.set_title(classes[labels[idx]])
    ax.axis('off')
plt.tight_layout()
plt.show()

## 2. Training Utilities

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': f'{running_loss/len(pbar):.3f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    return running_loss / len(loader), 100. * correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc='Evaluating'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total

def train_model(model, train_loader, test_loader, epochs, lr, name):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    best_acc = 0
    
    print(f'\nTraining {name}')
    print('=' * 70)
    
    for epoch in range(epochs):
        print(f'\nEpoch {epoch+1}/{epochs}')
        
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), f'{name}_best.pth')
    
    print(f'\nBest Test Accuracy: {best_acc:.2f}%')
    return history

## 3. Baseline: Simple CNN from Scratch

Train a compact CNN to establish baseline performance.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

# For baseline, use 32x32 images
trainset_basic = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_basic_train)
testset_basic = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test_basic)

train_loader_basic = DataLoader(trainset_basic, batch_size=128, shuffle=True, num_workers=2)
test_loader_basic = DataLoader(testset_basic, batch_size=128, shuffle=False, num_workers=2)

model_scratch = SimpleCNN().to(device)
print(f'Parameters: {sum(p.numel() for p in model_scratch.parameters()):,}')

history_scratch = train_model(model_scratch, train_loader_basic, test_loader_basic, epochs=20, lr=0.001, name='SimpleCNN')

## 4. Transfer Learning: ResNet18 (Frozen Backbone)

Use pretrained ResNet18 with frozen backbone, only train the classifier.

In [None]:
# Load pretrained ResNet18
model_resnet_frozen = models.resnet18(pretrained=True)

# Freeze all layers
for param in model_resnet_frozen.parameters():
    param.requires_grad = False

# Replace final layer
num_features = model_resnet_frozen.fc.in_features
model_resnet_frozen.fc = nn.Linear(num_features, 10)
model_resnet_frozen = model_resnet_frozen.to(device)

trainable_params = sum(p.numel() for p in model_resnet_frozen.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model_resnet_frozen.parameters())
print(f'Trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)')

history_resnet_frozen = train_model(model_resnet_frozen, train_loader, test_loader, epochs=10, lr=0.001, name='ResNet18_Frozen')

## 5. Transfer Learning: ResNet18 (Full Fine-tuning)

Fine-tune all layers with lower learning rate.

In [None]:
# Load pretrained ResNet18
model_resnet_full = models.resnet18(pretrained=True)

# Replace final layer
model_resnet_full.fc = nn.Linear(model_resnet_full.fc.in_features, 10)
model_resnet_full = model_resnet_full.to(device)

print(f'Total parameters: {sum(p.numel() for p in model_resnet_full.parameters()):,}')

history_resnet_full = train_model(model_resnet_full, train_loader, test_loader, epochs=15, lr=0.0001, name='ResNet18_FullFinetune')

## 6. Modern Architecture: EfficientNet-B0

Use a modern, efficient architecture from the `timm` library.

In [None]:
# Load pretrained EfficientNet-B0
model_efficientnet = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
model_efficientnet = model_efficientnet.to(device)

print(f'Total parameters: {sum(p.numel() for p in model_efficientnet.parameters()):,}')

history_efficientnet = train_model(model_efficientnet, train_loader, test_loader, epochs=15, lr=0.0001, name='EfficientNet_B0')

## 7. Comparison and Analysis

In [None]:
# Plot comparison
histories = [
    (history_scratch, 'SimpleCNN (scratch)'),
    (history_resnet_frozen, 'ResNet18 (frozen)'),
    (history_resnet_full, 'ResNet18 (full finetune)'),
    (history_efficientnet, 'EfficientNet-B0')
]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

for history, label in histories:
    ax1.plot(history['test_acc'], label=label, linewidth=2)
    ax2.plot(history['test_loss'], label=label, linewidth=2)

ax1.set_xlabel('Epoch')
ax1.set_ylabel('Test Accuracy (%)')
ax1.set_title('Test Accuracy Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Epoch')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Loss Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final results
print('\nFinal Test Accuracy:')
print('=' * 50)
for history, label in histories:
    print(f'{label:30s}: {history["test_acc"][-1]:.2f}%')

## 8. Inference Example

In [None]:
# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

# Make predictions with best model (EfficientNet)
model_efficientnet.eval()
with torch.no_grad():
    outputs = model_efficientnet(images)
    _, predicted = outputs.max(1)

# Visualize predictions
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for idx, ax in enumerate(axes.flat):
    if idx < 16:
        img = denormalize(images[idx].cpu())
        img = torch.clamp(img, 0, 1)
        ax.imshow(img.permute(1, 2, 0))
        
        true_label = classes[labels[idx]]
        pred_label = classes[predicted[idx]]
        color = 'green' if labels[idx] == predicted[idx] else 'red'
        
        ax.set_title(f'T: {true_label}\nP: {pred_label}', color=color, fontsize=9)
    ax.axis('off')

plt.tight_layout()
plt.show()

# Calculate accuracy on this batch
correct = (predicted == labels).sum().item()
print(f'Batch Accuracy: {100 * correct / len(labels):.2f}%')

## Key Takeaways

1. **Transfer Learning Benefits**:
   - Pretrained models significantly outperform training from scratch
   - Even frozen backbones provide strong features
   - Full fine-tuning achieves best results

2. **Modern Architectures**:
   - EfficientNet provides better accuracy with fewer parameters
   - Use `timm` library for access to latest models

3. **Best Practices**:
   - Use data augmentation for better generalization
   - Start with frozen backbone, then fine-tune if needed
   - Lower learning rates for fine-tuning (10x-100x smaller)
   - Use cosine annealing for learning rate schedule