# TP2: Fine-Tuning for Medical Image Classification

**Day 3 - AI for Sciences Winter School**

**Instructor:** Raphael Cousin

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/racousin/ai_for_sciences/blob/main/day3/tp2.ipynb)

---

## Objectives

In this exercise, you will apply transfer learning to a **real scientific task**: classifying blood cells from microscopy images.

By the end of this practical, you will:

1. **Apply fine-tuning** to a medical imaging dataset (BloodMNIST)
2. **Compare approaches**: frozen backbone vs full fine-tuning
3. **Analyze results**: understand when each approach works best
4. **Evaluate models**: using accuracy, confusion matrix, and per-class metrics

---

# Part 1: The Power of Pre-trained Models

## The Problem: Training From Scratch is Hard

Training a neural network from scratch requires:
- **Large datasets**: Millions of labeled examples
- **Significant compute**: Days or weeks of GPU time
- **Careful tuning**: Learning rate, architecture, regularization

**But what if you only have 1,000 samples?** This is the reality for most scientific applications.

## The Solution: Transfer Learning

```
┌─────────────────────────────────────────────────────────────────┐
│                     PRE-TRAINED MODEL                          │
│   (trained on ImageNet: 14M images, 1000 classes)              │
│                                                                │
│   ┌─────────┐    ┌─────────┐    ┌─────────┐    ┌─────────┐    │
│   │ Conv 1  │ →  │ Conv 2  │ →  │  ...    │ →  │ Classifier│   │
│   │ (edges) │    │(textures)│   │(objects)│    │(1000 cls) │   │
│   └─────────┘    └─────────┘    └─────────┘    └─────────┘    │
│      Generic features ────────────────→ Task-specific         │
└─────────────────────────────────────────────────────────────────┘
                              ↓
                    Replace classifier
                              ↓
┌─────────────────────────────────────────────────────────────────┐
│                      YOUR MODEL                                 │
│   (fine-tuned on your data: 1,000 images, 8 classes)           │
│                                                                │
│   ┌─────────┐    ┌─────────┐    ┌─────────┐    ┌─────────┐    │
│   │ Conv 1  │ →  │ Conv 2  │ →  │  ...    │ →  │ Classifier│   │
│   │ (edges) │    │(textures)│   │(objects)│    │(8 cls)   │   │
│   └─────────┘    └─────────┘    └─────────┘    └─────────┘    │
│   Reuse learned features              New classifier          │
└─────────────────────────────────────────────────────────────────┘
```

> **Key insight**: Early layers learn generic features (edges, textures) that transfer across tasks. Only later layers are task-specific.

## Setup

In [None]:
!pip install -q git+https://github.com/racousin/ai_for_sciences.git
!pip install -q medmnist

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import medmnist
from medmnist import INFO
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"MedMNIST version: {medmnist.__version__}")
print("Setup complete!")

## Load the BloodMNIST Dataset

In [None]:
# Dataset info
data_flag = 'bloodmnist'
info = INFO[data_flag]
n_classes = len(info['label'])
class_names = list(info['label'].values())

print(f"Dataset: {info['description']}")
print(f"Number of classes: {n_classes}")
print(f"\nClass names:")
for i, name in enumerate(class_names):
    print(f"  {i}: {name}")

In [None]:
# Data transforms - resize to 224x224 for pre-trained models
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Load datasets
DataClass = getattr(medmnist, info['python_class'])

train_dataset = DataClass(split='train', transform=transform, download=True)
val_dataset = DataClass(split='val', transform=transform, download=True)
test_dataset = DataClass(split='test', transform=transform, download=True)

print(f"\nDataset sizes:")
print(f"  Training:   {len(train_dataset):,} samples")
print(f"  Validation: {len(val_dataset):,} samples")
print(f"  Test:       {len(test_dataset):,} samples")

In [None]:
# Create data loaders
BATCH_SIZE = 64

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Batches per epoch: {len(train_loader)}")

## Visualize the Data

In [None]:
# Helper function to denormalize images for visualization
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return (tensor * std + mean).clamp(0, 1)

# Show samples from each class
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

# Get one sample per class
samples_per_class = {i: None for i in range(n_classes)}
for img, label in train_dataset:
    label_idx = label.item()
    if samples_per_class[label_idx] is None:
        samples_per_class[label_idx] = img
    if all(v is not None for v in samples_per_class.values()):
        break

for i, ax in enumerate(axes.flat):
    img = denormalize(samples_per_class[i])
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.set_title(class_names[i], fontsize=10)
    ax.axis('off')

plt.suptitle('Blood Cell Types', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### Question 1

Look at the cell images:
1. Can you visually distinguish between the different cell types?
2. Which classes look most similar to each other?
3. Why might this be a challenging classification task?

---

# Part 2: Training Utilities

We'll reuse the training functions from TP1.

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(loader, desc='Training', leave=False):
        inputs = inputs.to(device)
        labels = labels.squeeze().long().to(device)  # MedMNIST labels need squeeze
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total


def evaluate(model, loader, criterion, device):
    """Evaluate model on a dataset."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            labels = labels.squeeze().long().to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), 100. * correct / total, all_preds, all_labels


def count_parameters(model):
    """Count trainable parameters."""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total

### Question 1

With ~17,000 training images and 10 epochs, we get reasonable accuracy. But what if you only had 500 images? (Common in specialized scientific applications!)

In [None]:
---

# Part 4: Approach 2 - Feature Extraction (Frozen Backbone)

Now let's use a pre-trained ResNet18 and **freeze** all layers except the final classifier.

```
┌─────────────────────────────────────────────────┐
│              ResNet18 (pre-trained)             │
│                                                 │
│   ┌─────────────────────────┐   ┌───────────┐  │
│   │   Feature Extractor     │   │ Classifier │  │
│   │   (FROZEN - 11M params) │ → │ (TRAINABLE)│  │
│   │   Keep ImageNet weights │   │ 4K params  │  │
│   └─────────────────────────┘   └───────────┘  │
└─────────────────────────────────────────────────┘
```

> **When to use**: Very small dataset, or domain similar to ImageNet

In [None]:
# Train the frozen model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_frozen.fc.parameters(), lr=0.001)  # Only train classifier

EPOCHS = 5
history_frozen = {'train_acc': [], 'val_acc': []}

print("Training ResNet18 (frozen backbone)...")
print("="*50)

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model_frozen, train_loader, criterion, optimizer, device)
    val_loss, val_acc, _, _ = evaluate(model_frozen, val_loader, criterion, device)
    
    history_frozen['train_acc'].append(train_acc)
    history_frozen['val_acc'].append(val_acc)
    
    print(f'Epoch {epoch+1}/{EPOCHS} | Train: {train_acc:.1f}% | Val: {val_acc:.1f}%')

# Final test evaluation
_, test_acc_frozen, preds_frozen, labels_frozen = evaluate(model_frozen, test_loader, criterion, device)
print(f'\nTest Accuracy (frozen): {test_acc_frozen:.1f}%')

---

# Part 4: Exercise - Full Fine-Tuning

Now train a ResNet18 with **full fine-tuning** - all layers are trainable.

Remember the key tips:
- Use a **lower learning rate** (10x lower than from scratch)
- All layers will adapt to the new domain

In [None]:
### Question 2

Compare the frozen ResNet18 to SimpleCNN from scratch:
1. How do the accuracies compare?
2. How many trainable parameters does each have?
3. Why does frozen ResNet18 work well despite training only ~4K parameters?

In [None]:
# TODO: Train with full fine-tuning
criterion = nn.CrossEntropyLoss()

# TODO: What learning rate should you use?  # <-- Modify this!
# Hint: Lower than 0.001!
optimizer = optim.Adam(model_finetune.parameters(), lr=0.001)  # <-- Modify this!

history_finetune = {'train_acc': [], 'val_acc': []}

print("Training ResNet18 (full fine-tuning)...")
print("="*50)

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model_finetune, train_loader, criterion, optimizer, device)
    val_loss, val_acc, _, _ = evaluate(model_finetune, val_loader, criterion, device)
    
    history_finetune['train_acc'].append(train_acc)
    history_finetune['val_acc'].append(val_acc)
    
    print(f'Epoch {epoch+1}/{EPOCHS} | Train: {train_acc:.1f}% | Val: {val_acc:.1f}%')

# Final test evaluation
_, test_acc_finetune, preds_finetune, labels_finetune = evaluate(model_finetune, test_loader, criterion, device)
print(f'\nTest Accuracy (fine-tuned): {test_acc_finetune:.1f}%')

---

# Part 5: Compare Results

Let's visualize and compare both approaches.

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

epochs_range = range(1, EPOCHS + 1)

# Training accuracy
axes[0].plot(epochs_range, history_frozen['train_acc'], 'g-o', label='Frozen', linewidth=2)
axes[0].plot(epochs_range, history_finetune['train_acc'], 'r-^', label='Fine-tuned', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Accuracy (%)')
axes[0].set_title('Training Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Validation accuracy
axes[1].plot(epochs_range, history_frozen['val_acc'], 'g-o', label='Frozen', linewidth=2)
axes[1].plot(epochs_range, history_finetune['val_acc'], 'r-^', label='Fine-tuned', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Accuracy (%)')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary
print("\n" + "="*50)
print("TEST ACCURACY COMPARISON")
print("="*50)
print(f"Frozen backbone:  {test_acc_frozen:.1f}%")
print(f"Full fine-tuning: {test_acc_finetune:.1f}%")
print("="*50)

## Confusion Matrix Analysis

Let's see which cell types are most often confused.

In [None]:
# Plot confusion matrices for both models
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, preds, labels, title in [
    (axes[0], preds_frozen, labels_frozen, 'Frozen Backbone'),
    (axes[1], preds_finetune, labels_finetune, 'Full Fine-tuning')
]:
    cm = confusion_matrix(labels, preds)
    # Normalize by row (true labels)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=[c[:6] for c in class_names],
                yticklabels=[c[:6] for c in class_names],
                ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(f'{title}\n(normalized by row)')

plt.tight_layout()
plt.show()

In [None]:
# Summary table
print("\n" + "="*70)
print("SUMMARY")
print("="*70)
print(f"{'Approach':<25} {'Trainable Params':<18} {'Test Accuracy':>15}")
print("-"*70)

# Note: These values will be dynamically computed above
results = [
    ('SimpleCNN (scratch)', f"{count_parameters(model_scratch)[0]:,} (100%)", f"{history_scratch['test_acc'][-1]:.1f}%"),
    ('ResNet18 (frozen)', f"{count_parameters(model_frozen)[0]:,} (0.04%)", f"{history_frozen['test_acc'][-1]:.1f}%"),
    ('ResNet18 (full fine-tune)', f"{count_parameters(model_finetune)[0]:,} (100%)", f"{history_finetune['test_acc'][-1]:.1f}%"),
]

for name, params, acc in results:
    print(f"{name:<25} {params:<18} {acc:>15}")

print("="*70)

### Question 3

Looking at the results:
1. Which approach achieves the best accuracy?
2. The frozen backbone trains only ~0.04% of parameters but gets good accuracy. Why?
3. When would you choose frozen backbone over full fine-tuning?

---

# Part 6: Bonus Exercise - Data Augmentation

Data augmentation can improve generalization, especially with limited data.

In [None]:
# Create a small subset (500 samples)
subset_indices = np.random.choice(len(trainset_pretrained), 500, replace=False)
small_trainset = Subset(trainset_pretrained, subset_indices)

small_train_loader = DataLoader(small_trainset, batch_size=32, shuffle=True, num_workers=2)

print(f"Small dataset: {len(small_trainset)} training samples")
print(f"Original dataset: {len(trainset_pretrained)} training samples")
print(f"This simulates having limited labeled medical data!")

### Question 3

1. Which augmentations make sense for microscopy images? (Think about what variations occur naturally)
2. Which augmentations would NOT make sense? (e.g., would vertical flip be appropriate?)
3. How could you validate that augmentation helps?

# TODO: Create a frozen ResNet18 model
# Hint: Copy the code from Part 4

model_small_frozen = models.resnet18(weights='IMAGENET1K_V1')

# TODO: Freeze all layers
# for param in model_small_frozen.parameters():
#     ...  # <-- Modify this!

# TODO: Replace the classifier for 8 classes
# model_small_frozen.fc = ...  # <-- Modify this!

model_small_frozen = model_small_frozen.to(device)

In [None]:
# Available MedMNIST datasets
print("Available MedMNIST datasets:\n")
for flag, info in INFO.items():
    print(f"{flag:20s}: {info['description'][:60]}...")
    print(f"{'':20s}  Classes: {len(info['label'])}, Task: {info['task']}")
    print()

In [None]:
# TODO: Try a different dataset!
# Change data_flag to try another dataset:
# - 'dermamnist': Skin lesion images (7 classes)
# - 'pathmnist': Colon pathology (9 classes)
# - 'chestmnist': Chest X-ray (14 classes, multi-label)
# - 'retinamnist': Retinal OCT (5 classes)

# new_data_flag = 'dermamnist'  # <-- Modify this!
# new_info = INFO[new_data_flag]
# print(f"Dataset: {new_info['description']}")

---

# Summary

## What You Learned

1. **Applied transfer learning** to a real medical imaging task

2. **Compared approaches**:
   - Frozen backbone: Fast training, good for limited data
   - Full fine-tuning: Better accuracy, adapts features to new domain

3. **Analyzed results** using confusion matrices and per-class metrics

4. **Key findings**:
   - Pre-trained ImageNet features transfer to medical images!
   - Some cell types are inherently harder to distinguish
   - Fine-tuning typically improves performance, especially on difficult classes

## For Your Research

- **Medical imaging**: Consider fine-tuning from RadImageNet or other medical pre-trained models
- **Microscopy**: Transfer from ImageNet often works surprisingly well
- **Limited data**: Start with frozen backbone, add augmentation
- **Different modalities**: Domain-specific pre-training helps (X-ray, MRI, microscopy, etc.)

---

## Reflection Questions

Before moving to the next practical, think about:

1. **For your research**: Do you have a classification or detection task? How much labeled data do you have? (Like our blood cell classification example)

2. **Domain similarity**: Are there pre-trained models for your domain? (Medical imaging, microscopy, satellite imagery, etc.)

3. **Data augmentation**: What augmentations make sense for your data? (rotation, flipping, color jitter, etc.)

4. **Evaluation**: What metrics matter for your task? (accuracy, precision, recall, AUC?)

5. **Medical imaging context**: In the blood cell classification task, which cell types were hardest to distinguish? Why might this be?